import morfeusz2
import distance
from spacy.tokens.token import Token


class Flexer(object):
  name = "flexer"
  def __init__(self, nlp):
    self.nlp = nlp
    try:
      self.nlp.tokenizer.morf.generate("")
    except RuntimeError:
      # morfeusz does not have the generate dictionary loaded
      self.nlp.tokenizer.morf = morfeusz2.Morfeusz(expand_tags = True, whitespace = morfeusz2.KEEP_WHITESPACES, generate = True)
    self.morf = self.nlp.tokenizer.morf
    self.val2attr = {'pri': 'person', 'sec': 'person','ter': 'person',
                     'nom': 'case', 'acc': 'case', 'dat': 'case', 'gen': 'case', 'inst': 'case', 'loc': 'case', 'voc': 'case',
                     'f': 'gender', 'm1': 'gender', 'm2': 'gender', 'm3': 'gender', 'n': 'gender',
                     'aff': 'negation', 'neg': 'negation', 
                     'sg': 'number', 'pl': 'number',
                     'pos': 'degree', 'com': 'degree', 'sup': 'degree', 
                     'perf': 'aspect', 'imperf': 'aspect',
                     'npraep': 'prepositionality', 'praep': 'prepositionality', 
                     'congr': 'accomodability', 'rec': 'accomodability', 
                     'akc': 'accentibility', 'nakc': 'accentibility', 
                     'agl': 'agglutination', 'nagl': 'agglutination', 
                     'nwok': 'vocality', 'wok': 'vocality', 
                     'npun': 'fullstoppedness', 'pun': 'fullstoppedness', 
                     'col': 'collectivity', 'ncol': 'collectivity'
                    }

    self.accomodation_rules = {"amod": ["number", "case", "gender"],
                               "amod:flat": ['number', 'gender', 'case'],
                               "acl": ["number", "case", "gender"],
                               "appos": ["number", "case"],
                               "aux": ['number', 'gender', 'aspect'],
                               "aux:clitic": ['number'],
                               "aux:pass": ['number'],
                               "cop": ['number'],
                               "det:numgov": ["gender"], 
                               "det": ["number", "case", "gender"], 
                               "det:nummod": ["number", "case", "gender"], 
                               "det:poss": ["number", "case", "gender"],
                               "conj": ['case', "number", 'collectivity', 'degree', 'negation'],
                               "nummod": ["case"],
#                               "nummod:gov": ['number'], 
                               "nsubj:pass": ['number'],
                               "fixed": ["case"],
                               "flat": ['number', 'case'],

                              }
    self.governing_deprels = ["detmod:numgov", "det:numgov", "nummod:gov"]
    
    self.FLEX_ERROR_MSG = "This method requires passing either a string or a spacy.tokens.token.Token (where the token corresponds to the head of the phrase) argument!"
    self.LEM_ERROR_MSG = "This method requires passing either a string or a spacy.tokens.token.Token (where the token corresponds to the head of the phrase) argument!"
 
 
  



  def __call__(self, doc):
    # this component does nothing in __call__
    # its functionality is performed via the flex method
    return doc

  def flex(self, to_flex, pattern):
    # pattern is a ":" separated list of desired attributes for the new word to take on
    # the new word will be selected from the options provided by the generator
    # as the levenshtein nearest pattern counting from the pressent token's features
    if type(to_flex) != Token:
      if type(to_flex) == str:
        to_flex_doc = self.nlp(to_flex)
        if len(to_flex_doc) == 1:
          token = to_flex_doc[0]
        else:
          raise ValueError("This method requires passing a single token, or a string corresponding to a single token!")
      else:
        raise ValueError("This method requires passing either a string representing the word, or a spacy.tokens.token.Token as an argument!")
    else:
      token = to_flex

    if pattern in ["", None]:
      return token.orth_

    token_string = token.orth_
    if token_string.isupper():
      case_fun = lambda s: s.upper()
    elif token_string.islower():
      case_fun = lambda s: s.lower()
    elif token_string.istitle():
      case_fun = lambda s: s.capitalize()
    else:
      case_fun = lambda s: s 
    split_pattern = pattern.split(":")
    lemma = token.lemma_
    
    pos_tag = token.tag_
    feats = token._.feats
    tag = pos_tag
    if feats != "":
      tag += ":" + feats
      
    split_tag = tag.split(":")

    def gen_to_tag(gen):
      return gen[2].split(":")
    
    generation = self.morf.generate(lemma)
    right = [g for g in generation if all([f in gen_to_tag(g) for f in split_pattern])]
    # we select only those generated forms, which satisfy the required pattern
    
    if right == []:
      inflected = token.orth_
      return inflected

    else:  
      srt = sorted(right, key = lambda g: distance.levenshtein(split_tag, gen_to_tag(g)))
      # we choose the form most levenshtein similar to our initial tag
      inflected = srt[0][0]
      inflected = case_fun(inflected)
      return inflected


  def flex_subtree(self, token, pattern):
    id_to_inflected = {}
    children = list(token.children)
    children_to_inflect = [child for child in children if child.dep_ not in self.governing_deprels]
    governing_children = [child for child in children if child.dep_ in self.governing_deprels]

    if governing_children: 
      inflected_token = token.orth_ + token.whitespace_
      governor = governing_children[0]
      inflected_governor_subtree = self.flex_subtree(governor, pattern)
      id_to_inflected.update(inflected_governor_subtree)

    else:
      inflected_token = self.flex(token, pattern) + token.whitespace_

    id_to_inflected[token.i] = inflected_token
    for child in children_to_inflect:
      child_deprel = child.dep_
      if child_deprel in self.accomodation_rules:
        accomodable_attrs = self.accomodation_rules[child_deprel] 
      else:
        accomodable_attrs = []
      feats = [f for f in pattern.split(":") if f in self.val2attr] # limiting to supported features
      accomodable_feats = [f for f in feats if self.val2attr[f] in accomodable_attrs]
      child_pattern = ":".join(accomodable_feats)
      inflected_subtree = self.flex_subtree(child, child_pattern)
      id_to_inflected.update(inflected_subtree)
    return id_to_inflected


  def flex_mwe(self, to_flex, pattern):
    # The algorithm recurrently goes through each child, and inflects it into the desired
    # pattern, if the relation which connects it to its head, requires grammatical
    # agreement.
    # The algorithm is rule based.
    if type(to_flex) != Token:
      if type(to_flex) == str:
        to_flex_doc = self.nlp(to_flex)
        doc_head = [tok for tok in to_flex_doc if tok.dep_ == "ROOT"]
        if doc_head:
          head = doc_head[0]
        else:
          return ""
      else:
        raise ValueError(self.FLEX_ERROR_MSG)
    else:
      head = to_flex

    id_to_flexed = self.flex_subtree(head, pattern)
    seq = sorted([(i,t) for i, t in id_to_flexed.items()])
    phrase = "".join([t for i, t in seq]).strip()
    return phrase


  def lemmatize_subtree(self, token):
    # The algorithm recurrently goes through each child and inflects it into the pattern
    # corresponding to the base form of the head of the phrase.
    # The algorithm is rule based.
    id_to_lemmatized = {}
    children = list(token.children)
    children_to_lemmatize = [child for child in children if child.dep_ not in self.governing_deprels]
    governing_children = [child for child in children if child.dep_ in self.governing_deprels]
    if governing_children: 
      governor = governing_children[0]
      lemmatized_governor_subtree = self.lemmatize_subtree(governor)
      id_to_lemmatized.update(lemmatized_governor_subtree)
      id_to_lemmatized[token.i] = token.orth_
      pattern = ""

    else:
      lemmatized_token = token.lemma_ + token.whitespace_
      pattern = self.nlp(lemmatized_token)[0]._.feats
      id_to_lemmatized[token.i] = lemmatized_token

    for child in children_to_lemmatize:
      child_deprel = child.dep_
      if child_deprel in self.accomodation_rules:
        lemmatized_subtree = self.flex_subtree(child, pattern)
      else:
        lemmatized_subtree = {tok.i: tok.orth_ + tok.whitespace_ for tok in child.subtree}
      id_to_lemmatized.update(lemmatized_subtree)
    return id_to_lemmatized


  def lemmatize_mwe(self, to_lemmatize):
    if type(to_lemmatize) != Token:
      if type(to_lemmatize) == str:
        to_lemmatize_doc = self.nlp(to_lemmatize)
        doc_head = [tok for tok in to_lemmatize_doc if tok.dep_ == "ROOT"]
        if doc_head:
          head = doc_head[0]
        else:
          return ""
      else:
        raise ValueError(self.LEM_ERROR_MSG)
    else:
      head = to_lemmatize
    id_to_lemmatized = self.lemmatize_subtree(head)
    seq = sorted([(i,t) for i, t in id_to_lemmatized.items()])
    phrase = "".join([t for i, t in seq]).strip()
    return phrase


