#-*- coding: utf-8 -*-
import itertools
import sys
import re
from pyparsing import ParserElement, Word, Literal, ZeroOrMore, OneOrMore, Or, oneOf, Dict, Group, alphanums, printables, lineEnd, delimitedList, restOfLine

ParserElement.setDefaultWhitespaceChars(' ')
ASSIGN = Literal('=').suppress()
ARROW = Literal('-->').suppress()
IN = oneOf('IN =').suppress()
ENDLS = ZeroOrMore(lineEnd).suppress()
COMMENT = Literal('#') + restOfLine
Ignorables = ZeroOrMore(Or(ENDLS, COMMENT)).suppress()
Token = Word(alphanums+'-_')
OptToken = Literal('[').suppress() + Token + Literal(']').suppress()
Attrs = delimitedList(Token, delim=':')

def _require(cond, msg):
	if not cond:
		raise ValueError(msg)

def _readFile(fname):
	f = open(fname, 'r')
	res = f.read()
	f.close()
	return res

class Tagset:
	
	def __init__(self, attrStr, posStr):
		self.attr2Vals = {}
		self.val2Attr = {}
		
		self.pos2Attrs = {}
		self.attr2Pos = {}
		self.val2Pos = {}
		
		self.parseAttrs(attrStr)
		self.parsePos(posStr)
	
	def onAddAttr(self, name, vals):
		self.attr2Vals[name] = vals
		for val in vals:
			self.val2Attr[val] = name

	def onAddPos(self, pos, attrs, optAttrs):
		self.pos2Attrs[pos] = (attrs, optAttrs)
		for attr in itertools.chain(attrs, optAttrs):
			self.attr2Pos.setdefault(attr, set())
			self.attr2Pos[attr].add(pos)
			for val in self.attr2Vals[attr]:
				self.val2Pos.setdefault(val, set())
				self.val2Pos[val].add(pos)
	
	def parseAttrs(self, attrStr):
		attrRow = ENDLS + Token + ASSIGN + Group(ZeroOrMore(Token)) + ENDLS
		attrRow.setParseAction(lambda s: self.onAddAttr(s[0], list(s[1])))
		attrs = ZeroOrMore(attrRow)
		attrs.ignore(COMMENT)
		attrs.parseString(attrStr)
	
	def parsePos(self, posStr):
		posRow = ENDLS + Token + ASSIGN + Group(ZeroOrMore(Token)) + Group(ZeroOrMore(OptToken)) + ENDLS
		posRow.setParseAction(lambda s: self.onAddPos(s[0], list(s[1]), list(s[2])))
		pos = ZeroOrMore(posRow)
		pos.ignore(COMMENT)
		pos.parseString(posStr)
	
	def getAllAttrs4Pos(self, pos):
		attrs, optAttrs = self.pos2Attrs[pos]
		return list(itertools.chain(attrs, optAttrs))
	
	def getMandatoryAttrs4Pos(self, pos):
		return self.pos2Attrs[pos][0]
	
	def getOptionalAttrs4Pos(self, pos):
		return self.pos2Attrs[pos][1]
	
	def getAttr4Value(self, val):
		return self.val2Attr[val]
	
	def getPosSet4Value(self, val):
		return self.val2Pos[val]
	
	def getValues4Attr(self, attr):
		return self.attr2Vals[attr]
	
	def getPosSet(self):
		return set(self.pos2Attrs.keys())
	
	def getTag(self, base, pos, msd):
		if not pos in self.getPosSet():
			raise TagException('No such pos: %s' % pos)
		
		attrVals = msd.split(':')
		if '' in attrVals:
			attrVals.remove('')
		self._checkValsExist(attrVals)
		
		attrs = set(map(lambda val: self.val2Attr[val], attrVals))
		self._checkAttrs(pos, attrs)
		
		attrsMap = {}
		for val in attrVals:
			attrsMap[self.getAttr4Value(val)] = val
		
		tag = Tag(self, base, pos, attrsMap)
		tag.getMsdStr()
		return tag
	
	def _checkValsExist(self, vals):
		for val in vals:
			if not val in self.val2Attr:
				raise TagException('Invalid value: "%s"' % (val, ))
	
	def _checkAttrs(self, pos, attrs):
		mandatoryAttrs = set(self.getMandatoryAttrs4Pos(pos))
		optionalAttrs = set(self.getOptionalAttrs4Pos(pos))
		allAttrsList = self.getAllAttrs4Pos(pos)
		allAttrs = set(allAttrsList)
		
		if attrs.intersection(mandatoryAttrs) != mandatoryAttrs:
			missingAttrs = mandatoryAttrs.difference(attrs)
			raise TagException('Missing required attrs: %s' % list(missingAttrs))
		elif attrs.difference(allAttrs) != set():
			invalidAttrs = attrs.difference(allAttrs)
			raise TagException('Invalid attrs: %s' % list(invalidAttrs))

class TagException:
	def __init__(self, value):
		self.value = value
	def __str__(self):
		return repr(self.value)

class Tag:
	def __init__(self, tagset, base, pos, attrsMap):
		self.tagset = tagset
		self.base = base
		self.pos = pos
		self.attrsMap = attrsMap
	
	def getBase(self):
		return self.base
	
	def getPos(self):
		return self.pos
	
	def getAttr(self, attr):
		return self.attrsMap.get(attr, None)
	
	def getMsdStr(self):
		res = ''
		for attr in self.tagset.getMandatoryAttrs4Pos(self.pos):
			if res != '':
				res += ':'
			res += self.attrsMap[attr]
		for attr in self.tagset.getOptionalAttrs4Pos(self.pos):
			if attr in self.attrsMap:
				if res != '':
					res += ':'
				res += self.attrsMap[attr]
		return res
	
	def __str__(self):
		return 'base="%s", ctag="%s", msd="%s"' % (self.base, self.pos, self.getMsdStr())
	
	def debug(self):
		print 'pos =', self.pos
		print 'attrs =', self.attrsMap

class Rule:
	def __init__(self, tagset, pos, attrs, vals):
		self.pos = pos
		self.tagset = tagset
		self.attrs = attrs
		self.vals = vals
		
		self._validate()
	
	def _validate(self):
		_require(self.pos in self.tagset.getPosSet(), 'Invalid pos: %s' % self.pos)
		
		n = len(self.attrs)
		for val in self.vals:
			_require(len(val) == n, 'Invalid value: %s' % (val, ))
			for attr, v in zip(self.attrs, val):
				_require(attr == 'base' or attr == self.tagset.getAttr4Value(v), 'Invalid value for attr %s: %s' % (attr, v))
		
		_require(set(self.attrs).difference(set(['base'])).issubset(set(self.tagset.getAllAttrs4Pos(self.pos))), \
				'Invalid attrs for pos %s: %s' % (self.pos, self.attrs))
	
	def satisfies(self, tag):
		if tag.getPos() == self.pos:
			tagVal = tuple([ tag.getAttr(attr) if attr != 'base' else tag.getBase() for attr in self.attrs ])
			return tagVal in self.vals
		else:
			return True
	
	def __str__(self):
		attrsStr = self._singleValToStr(self.attrs)
		valsStr = self._valsToStr(self.vals)
		if len(self.vals) > 1:
			return '%s --> %s IN %s' % (self.pos, attrsStr, valsStr)
		else:
			return '%s --> %s = %s' % (self.pos, attrsStr, valsStr)
	
	def _valsToStr(self, vals):
		res = self._singleValToStr(vals[0])
		for val in vals[1:]:
			res += ' '
			res += self._singleValToStr(val)
		return res
	
	def _singleValToStr(self, attrs):
		res = attrs[0]
		for a in attrs[1:]:
			res += ':'
			res += a
		return res

class TagValidator:
	def __init__(self, cfgFilename):
		self.tagset, self.rules = self._parseTagsetAndRules(cfgFilename)
	
	def _parseTagsetAndRules(self, fname):
		cfgStr = _readFile(fname)
		p = re.compile( '(\[ATTR\]|\[POS\]|\[RULES\])')
		d1__, d2__, attrStr, d3__, posStr, d4__, rulesStr = p.split(cfgStr)
		tagset = Tagset(attrStr, posStr)
		rules = self._parseRules(tagset, rulesStr)
		return (tagset, rules)
	
	def _parseRules(self, tagset, string):
		res = []
		ruleRow = ENDLS + Token + ARROW + Group(Attrs) + IN + Group(OneOrMore(Group(Attrs))) + ENDLS
		rules = ZeroOrMore(Group(ruleRow))
		rules.ignore(COMMENT)
		for pos, attrs, vals in rules.parseString(string):
			res.append(Rule(tagset, pos, tuple(attrs), [tuple(val) for val in vals]))
		return res
	
	def validateTag(self, base, ctag, msd):
		tag = self.tagset.getTag(base, ctag, msd)
		for rule in self.rules:
			if not rule.satisfies(tag):
				raise TagException('Rule "%s" not satisfied for: %s' % (rule, tag))
		newMsd = tag.getMsdStr()
		return (base, ctag, newMsd)
#~ 
#~ validateTag('będzie', 'bedzie', 'pl:pri:imperf')
#~ validateTag('siebie', 'siebie', 'gen')

