'''
Created on 26-10-2011

@author: lennyn
'''
import abc
from utils import *

E = ElementMaker(namespace=teins, nsmap={None:teins, 'xi':xins})
xiE = ElementMaker(namespace=xins, nsmap={None:xins})

class Converter(object):
    
    __metaclass__ = abc.ABCMeta
    
    def __init__(self, pmlpaths, wypluwka_path):
        self.pmlpaths = pmlpaths
        self.wypluwka_path = wypluwka_path
        self.par_ids, self.sent_ids = self._get_parids_sentids()
        self.segid2index = self._get_segid2index()
    
    @abc.abstractproperty
    def out_path(self):
        return
    
    @abc.abstractproperty
    def what(self):
        return
    
    @abc.abstractmethod
    def convert_sent(self, pmlsent):
        return
    
    def _get_parids_sentids(self):
        morphosyntax_path = os.path.join(self.wypluwka_path, 'ann_morphosyntax.xml')
        f = open(morphosyntax_path, 'r')
        try:
            parids = set()
            sentids = set()
            for line in f.readlines():
                l = line.strip()
                if l.startswith('<p '):
                    parids.add(l.partition('xml:id="')[2].split('"')[0])
                elif l.startswith('<s '):
                    sentids.add(l.partition('xml:id="')[2].split('"')[0])
            return (parids, sentids)
        finally:
            f.close()
    
    def _get_segid2index(self):
        morphosyntax_path = os.path.join(self.wypluwka_path, 'ann_morphosyntax.xml')
        f = open(morphosyntax_path, 'r')
        res = {}
        i = 0
        try:
            for line in f.readlines():
                l = line.strip()
                if l.startswith('<seg '):
                    segid = l.partition('xml:id="')[2].split('"')[0]
                    res[segid] = i
                    i += 1
            return res
        finally:
            f.close()

    def convert(self):
        if self.pmlpaths:
            root = E.teiCorpus(
                xiE.include(href='NKJP_1M_header.xml'),
                    E.TEI(
                        xiE.include(href='header.xml'),
                        E.text(
                            E.body(),
                            **{lxml_name(xmlns, 'lang') : 'pl'}
                        ))
                )
            for path in self.pmlpaths:
                print path, '===>', self.out_path
                tree = etree.parse(path, etree.XMLParser(recover=False))
                for pid in self._unique_list([pid for pid in xpath(tree, '//pml:sent/@pid') if pid in self.par_ids]):
                    par_node = self._convert_par(tree, pid)
                    xpath(root, '//tei:body')[0].append(par_node)
            
            write_tree(root, self.out_path)
        else:
            print 'skipping', self.out_path

    def _get_corresp_pid(self, pid, corresp_what):
        if corresp_what == 'morphosyntax':
            return 'morph_' + pid.partition('_')[2]
        elif corresp_what == 'words':
            return 'words_' + pid.partition('_')[2]

    def _convert_par(self, tree, pid):
        corresp_what = {'named' : 'morphosyntax', 'words' : 'morphosyntax', 'groups' : 'words'}[self.what]
        corresp_pid = self._get_corresp_pid(pid, corresp_what)
        par_attrs = {
            lxml_name(xmlns, 'id') : morph2id(pid, self.what),
            'corresp' : 'ann_%s.xml#%s' % (corresp_what, corresp_pid)
        }
        par_node = E.p(**par_attrs)
        sents = xpath(tree, '//pml:sent[@pid="%(pid)s"]' % {'pid':pid})
    #    morph_ids = _get_morph_ids(morph_tree, pid)
        for pmlsent in sents:
            if pmlsent.attrib['id'] in self.sent_ids:
    #            print pmlsent.attrib['id']
                par_node.append(self.convert_sent(pmlsent))
        return par_node

    def _unique_list(self, l):
        elems = set()
        res = []
        for elem in l:
            if not elem in elems:
                res.append(elem)
                elems.add(elem)
        return res

    def get_ord(self, node):
        ord = node.attrib.get('ord', None)
        if ord is None:
            ords = xpath(node, 'preceding-sibling::*[@ord]/@ord')
            if ords == []:
                return 0
            else:
                return int(ords[0])
        else:
            return int(ord)

    def get_pml_tag_sort_key(self, child):
        if child.tag.endswith('seg'):
            idxs = (self.segid2index[child.attrib['id']], float('inf'))
            return (idxs, float('-inf'), get_ord(child))
        else:
            idxs = [self.segid2index[seg.attrib['id']] for seg in xpath(child, './/pml:seg')]
            idxs.append(float('inf'))
            descendants_num = len(xpath(child, './/*[self::pml:seg|self::pml:ne|self::pml:pw|self::pml:pg]'))
            return (tuple(sorted(idxs)), -descendants_num, self.get_ord(child))

    def get_foreign_edges_map(self, pmlsent):
        res = {}
        for pmlseg in xpath(pmlsent, './/pml:seg'):
            lms = xpath(pmlseg, './/*[local-name()="ref" or local-name()="rf"]/pml:LM/text()')
            if len(lms) > 0:
                for neid in [''+neid for neid in lms]: # XXX
                    res.setdefault(neid, [])
                    res[neid].append(pmlseg)
            else:
                ref = xpath(pmlseg, './/*[local-name()="ref" or local-name()="rf"]/text()')
                if ref:
                    neid = ''+ref[0] # XXX
                    res.setdefault(neid, [])
                    res[neid].append(pmlseg)
        return res

    def get_child_ptrs(self, pmlne, foreign_edges):
        res = []
        children = xpath(pmlne, 'pml:children/*')
        for child in sorted(children + foreign_edges, key=lambda child: self.get_pml_tag_sort_key(child)):
            child_id = child.attrib['id']
            if child.tag.endswith('seg'):
                target = 'ann_morphosyntax.xml#' + child_id
            else:
                target = morph2id(child_id, self.what)
            res.append(E.ptr(target=target))
        
        return res
