#-*- coding: utf-8 -*-
import itertools
import sys
import re
from pyparsing import ParserElement, Word, Literal, ZeroOrMore, OneOrMore, Or, oneOf, Dict, Group, alphanums, printables, lineEnd, delimitedList, restOfLine

ParserElement.setDefaultWhitespaceChars(' ')
ASSIGN = Literal('=').suppress()
ARROW = Literal('-->').suppress()
IN = oneOf('IN =').suppress()
ENDLS = ZeroOrMore(lineEnd).suppress()
COMMENT = Literal('#') + restOfLine
Ignorables = ZeroOrMore(Or(ENDLS, COMMENT)).suppress()
Token = Word(alphanums+'-_')
OptToken = Literal('[').suppress() + Token + Literal(']').suppress()
Attrs = delimitedList(Token, delim=':')

def _require(cond, msg):
    if not cond:
        raise ValueError(msg)

def _readFile(fname):
    f = open(fname, 'r')
    res = f.read()
    f.close()
    return res

class Tagset:
    
    def __init__(self, attrStr, posStr):
        self.attr2Vals = {}
        self.val2Attr = {}
        
        self.pos2Attrs = {}
        self.attr2Pos = {}
        self.val2Pos = {}
        
        self.parseAttrs(attrStr)
        self.parsePos(posStr)
    
    def onAddAttr(self, name, vals):
        self.attr2Vals[name] = vals
        for val in vals:
            self.val2Attr[val] = name

    def onAddPos(self, pos, attrs, optAttrs):
        self.pos2Attrs[pos] = (attrs, optAttrs)
        for attr in itertools.chain(attrs, optAttrs):
            self.attr2Pos.setdefault(attr, set())
            self.attr2Pos[attr].add(pos)
            for val in self.attr2Vals[attr]:
                self.val2Pos.setdefault(val, set())
                self.val2Pos[val].add(pos)
    
    def parseAttrs(self, attrStr):
        attrRow = ENDLS + Token + ASSIGN + Group(ZeroOrMore(Token)) + ENDLS
        attrRow.setParseAction(lambda s: self.onAddAttr(s[0], list(s[1])))
        attrs = ZeroOrMore(attrRow)
        attrs.ignore(COMMENT)
        attrs.parseString(attrStr)
    
    def parsePos(self, posStr):
        posRow = ENDLS + Token + ASSIGN + Group(ZeroOrMore(Token)) + Group(ZeroOrMore(OptToken)) + ENDLS
        posRow.setParseAction(lambda s: self.onAddPos(s[0], list(s[1]), list(s[2])))
        pos = ZeroOrMore(posRow)
        pos.ignore(COMMENT)
        pos.parseString(posStr)
    
    def getAllAttrs4Pos(self, pos):
        attrs, optAttrs = self.pos2Attrs[pos]
        return list(itertools.chain(attrs, optAttrs))
    
    def getMandatoryAttrs4Pos(self, pos):
        return self.pos2Attrs[pos][0]
    
    def getOptionalAttrs4Pos(self, pos):
        return self.pos2Attrs[pos][1]
    
    def getAttr4Value(self, val):
        return self.val2Attr[val]
    
    def getPosSet4Value(self, val):
        return self.val2Pos[val]
    
    def getValues4Attr(self, attr):
        return self.attr2Vals[attr]
    
    def getPosSet(self):
        return set(self.pos2Attrs.keys())
    
    def getTag(self, base, pos, msd):
        if not pos in self.getPosSet():
            raise TagException('No such pos: %s' % pos)
        
        attrVals = msd.split(':')
        if '' in attrVals:
            attrVals.remove('')
        self._checkValsExist(attrVals)
        
        attrs = set(map(lambda val: self.val2Attr[val], attrVals))
        self._checkAttrs(pos, attrs)
        
        attrsMap = {}
        for val in attrVals:
            attrsMap[self.getAttr4Value(val)] = val
        
        tag = Tag(self, base, pos, attrsMap)
        tag.getMsdStr()
        return tag
    
    def _checkValsExist(self, vals):
        for val in vals:
            if not val in self.val2Attr:
                raise TagException('Invalid value: "%s"' % (val, ))
    
    def _checkAttrs(self, pos, attrs):
        mandatoryAttrs = set(self.getMandatoryAttrs4Pos(pos))
        optionalAttrs = set(self.getOptionalAttrs4Pos(pos))
        allAttrsList = self.getAllAttrs4Pos(pos)
        allAttrs = set(allAttrsList)
        
        if attrs.intersection(mandatoryAttrs) != mandatoryAttrs:
            missingAttrs = mandatoryAttrs.difference(attrs)
            raise TagException('Missing required attrs: %s' % list(missingAttrs))
        elif attrs.difference(allAttrs) != set():
            invalidAttrs = attrs.difference(allAttrs)
            raise TagException('Invalid attrs: %s' % list(invalidAttrs))

class TagException:
    def __init__(self, value):
        self.value = value
    def __str__(self):
        return repr(self.value)

class Tag:
    def __init__(self, tagset, base, pos, attrsMap):
        self.tagset = tagset
        self.base = base
        self.pos = pos
        self.attrsMap = attrsMap
    
    def getChangedCopy(self, attr, val):
        if attr == 'base':
            newBase = val
            return Tag(self.tagset, newBase, self.pos, self.attrsMap)
        else:
            newAttrsMap = self.attrsMap.copy()
            newAttrsMap[attr] = val
            return Tag(self.tagset, self.base, self.pos, newAttrsMap)
    
    def getBase(self):
        return self.base
    
    def getPos(self):
        return self.pos
    
    def getAttr(self, attr):
        return self.attrsMap.get(attr, None)
    
    def getAttrs(self, attrs):
        return [self.getAttr(attr) for attr in attrs]
    
    def getAttrNames(self):
        return self.attrsMap.keys()
    
    def setAttr(self, attr, val):
        self.attrsMap[attr] = val
    
    def replace(self, what, replacement):
        for val in what:
            if self.getAttr(self.tagset.getAttr4Value(val)) != val:
                return
        for oldVal, newVal in itertools.izip_longest(what, replacement):
            attr = self.tagset.getAttr4Value(oldVal)
            self.setAttr(attr, newVal)
    
    def getMsdStr(self):
        res = ''
        for attr in self.tagset.getMandatoryAttrs4Pos(self.pos):
            if res != '':
                res += ':'
            res += self.attrsMap[attr]
        for attr in self.tagset.getOptionalAttrs4Pos(self.pos):
            if attr in self.attrsMap:
                if res != '':
                    res += ':'
                res += self.attrsMap[attr]
        return res
    
    def __str__(self):
        enc_base = self.base.encode('utf-8')
        return u'%s:%s:%s' % (self.base, self.pos, self.getMsdStr())
#       return u'base="%s", ctag="%s", msd="%s"' % (self.base, self.pos, self.getMsdStr())
    
    def debug(self):
        print 'pos =', self.pos
        print 'attrs =', self.attrsMap

class Rule:
    def __init__(self, tagset, pos, attrs, vals):
        self.pos = pos
        self.tagset = tagset
        self.attrs = attrs
        self.vals = vals
        
        self._validateSelf()
    
    def _validateSelf(self):
        _require(self.pos in self.tagset.getPosSet(), 'Invalid pos: %s' % self.pos)
        
        n = len(self.attrs)
        for val in self.vals:
            _require(len(val) == n, 'Invalid value: %s' % (val, ))
            for attr, v in zip(self.attrs, val):
                _require(attr == 'base' or attr == self.tagset.getAttr4Value(v), 'Invalid value for attr %s: %s' % (attr, v))
        
        _require(set(self.attrs).difference(set(['base'])).issubset(set(self.tagset.getAllAttrs4Pos(self.pos))), \
                'Invalid attrs for pos %s: %s' % (self.pos, self.attrs))
    
    def _tryToRepair(self, tag):
        if len(self.vals) == 1:
            newTag = tag
            for attr, val in zip(self.attrs, self.vals):
                newTag = newTag.getChangedCopy(attr, val[0])
            return newTag
        else:
            raise TagException('Rule "%s" not satisfied for: %s' % (self, tag))
    
    def validate(self, tag):
        if tag.getPos() == self.pos:
            tagVal = tuple([ tag.getAttr(attr) if attr != 'base' else tag.getBase() for attr in self.attrs ])
            if not tagVal in self.vals:
                raise TagException('Rule "%s" not satisfied for: %s' % (self, tag))
#               return self._tryToRepair(tag)
            else:
                return tag
        else:
            return tag
    
    def __str__(self):
        attrsStr = self._singleValToStr(self.attrs)
        valsStr = self._valsToStr(self.vals)
        if len(self.vals) > 1:
            return '%s --> %s IN %s' % (self.pos, attrsStr, valsStr)
        else:
            return '%s --> %s = %s' % (self.pos, attrsStr, valsStr)
    
    def _valsToStr(self, vals):
        res = self._singleValToStr(vals[0])
        for val in vals[1:]:
            res += ' '
            res += self._singleValToStr(val)
        return res
    
    def _singleValToStr(self, attrs):
        res = attrs[0]
        for a in attrs[1:]:
            res += ':'
            res += a
        return res

class ValidatorAfter(object):
    
    def __init__(self):
        pass
    
    def _requireNotNone(self, tag, attr):
        if tag.getAttr(attr) is None:
            raise TagException(u'missing attr: "%s" in "%s"' % (attr, tag))
    
    def _requireNone(self, tag, attr):
        if tag.getAttr(attr) is not None:
            raise TagException(u'attr "%s" should be empty, but was "%s" in "%s"' % (attr, tag.getAttr(attr), tag))
    
    def _require(self, tag, attr, val):
        if tag.getAttr(attr) != val:
            raise TagException(u'attr "%s" should be "%s" but was "%s" in "%s"' % (attr, val, tag.getAttr(attr), tag))
    
    def _getAttrs(self, tag, attrs):
        return [tag.getAttr(attr) for attr in attrs]
    
    def _validateNoun(self, tag):
        if tag.getPos() == 'Noun':
            if (self._getAttrs(tag, ['aspect', 'reflexivity', 'negation']) != [None, None, None]):
                self._require(tag, 'gender', 'n')
                for attr in ['aspect', 'reflexivity', 'negation']:
                    self._requireNotNone(tag, attr)
        return tag
    
    def _validatePpron3(self, tag):
        if tag.getPos() in ['ppron3', 'Ppron3']:
            if (tag.getAttrs(['number', 'case', 'gender']) in [
                                    ['sg', 'gen', 'm1'], 
                                    ['sg', 'dat', 'm1'], 
                                    ['sg', 'acc', 'm1'], 
                                    ['sg', 'gen', 'm2'],
                                    ['sg', 'dat', 'm2'], 
                                    ['sg', 'acc', 'm2'], 
                                    ['sg', 'gen', 'm3'], 
                                    ['sg', 'dat', 'm3'],
                                    ['sg', 'acc', 'm3'],
                                    ['sg', 'gen', 'n'], 
                                    ['sg', 'dat', 'n']]):
                self._requireNotNone(tag, 'accentability')
            else:
                self._require(tag, 'accentability', 'akc')
            
            if tag.getAttr('case') == 'loc':
                self._require(tag, 'post-prepositionality', 'praep')
#            elif tag.getAttr('case') == 'nom':
#                self._require(tag, 'post-prepositionality', 'npraep')
        return tag
    
    def _validateVerbfin(self, tag):
        if tag.getPos() == 'Verbfin':
            if (self._getAttrs(tag, ['tense', 'mood']) in [['past', 'ind'], ['fut', 'cond']]):
                self._requireNotNone(tag, 'gender')
            elif (self._getAttrs(tag, ['tense', 'mood']) in [['pres', 'cond'], ['fut', 'ind']]):
                pass
            else:
                self._requireNone(tag, 'gender')
        return tag
    
    def _validateNum(self, tag):
        if tag.getPos() == 'Num':
            if (self._getAttrs(tag, ['degree', 'accommodability']) == ['pos', 'rec']):
                tag.setAttr('degree', None)
        return tag
            
        
    def validate(self, tag):
        newTag = tag
        for validate in [self._validateNoun, self._validateVerbfin, self._validatePpron3, self._validateNum]:
            newTag = validate(newTag)
        if newTag is None:
            raise AssertionError()
        return newTag
    


class TagValidator:
    def __init__(self, cfgFilename, validateAfter=True):
        self.tagset, self.rules = self._parseTagsetAndRules(cfgFilename)
#       self.validatorBefore = validatorBefore.ValidatorBefore(self.tagset)
        if validateAfter:
            self.validatorAfter = ValidatorAfter()
        else:
            self.validatorAfter = None
    
    def _parseTagsetAndRules(self, fname):
        cfgStr = _readFile(fname)
        p = re.compile( '(\[ATTR\]|\[POS\]|\[RULES\])')
        d1__, d2__, attrStr, d3__, posStr, d4__, rulesStr = p.split(cfgStr)
        tagset = Tagset(attrStr, posStr)
        rules = self._parseRules(tagset, rulesStr)
        return (tagset, rules)
    
    def _parseRules(self, tagset, string):
        res = []
        ruleRow = ENDLS + Token + ARROW + Group(Attrs) + IN + Group(OneOrMore(Group(Attrs))) + ENDLS
        rules = ZeroOrMore(Group(ruleRow))
        rules.ignore(COMMENT)
        for pos, attrs, vals in rules.parseString(string):
            res.append(Rule(tagset, pos, tuple(attrs), [tuple(val) for val in vals]))
        return res
    
    def validateTag(self, base, ctag, msd):
#       (base, ctag, msd) = self.validatorBefore.validate(base, ctag, msd)
        tag = self.tagset.getTag(base, ctag, msd)
        for rule in self.rules:
            tag = rule.validate(tag)
#           if not rule.satisfies(tag):
#               raise TagException('Rule "%s" not satisfied for: %s' % (rule, tag))
        if self.validatorAfter is not None:
            tag = self.validatorAfter.validate(tag)
        newMsd = tag.getMsdStr()
        return (base, ctag, newMsd)
#~ 
#~ validateTag('będzie', 'bedzie', 'pl:pri:imperf')
#~ validateTag('siebie', 'siebie', 'gen')

