#-*- coding: utf-8 -*-
import os
import sys
import logging
import re
import subprocess
from lxml import etree
from optparse import OptionParser
import tagset

namespaces = {
			'tei':'http://www.tei-c.org/ns/1.0',
			'xml':'http://www.w3.org/XML/1998/namespace',
			'nkjp':'http://www.nkjp.pl/ns/1.0',
			}

def parseOptions():
	"""
	Parses commandline args
	"""
	parser = OptionParser()
	parser.add_option('--dirname',
						dest='dirname',
						metavar='FILE',
						help='directory to validate')
	parser.add_option('--correct',
						dest='correct',
						action='store_true',
						default=False,
						help='try to correct errors')
	parser.add_option('--text-file',
						dest='text_file',
						default='text.xml',
						help='text layer filename (default: text.xml)')
	parser.add_option('--ignore-tagset',
						dest='ignore_tagset',
						action='store_true',
						default=False,
						help='do not validate tags against tagset')
	opts, args = parser.parse_args()

	if None in [opts.dirname]:
		parser.print_help()
		exit(1)

	return opts

def getDirs(wypluwkaRoot):
	for dirpath, dirnames, filenames in os.walk(wypluwkaRoot):
		if 'ann_morphosyntax.xml' in filenames:
			yield dirpath

def XPath(qstr):
	return etree.XPath(qstr, namespaces=namespaces)

xpathGetId = XPath('@xml:id')
xpathNode4Id = XPath('//*[@xml:id = $id]')
xpathGetOrth = XPath('.//tei:f[@name="orth"]')
xpathGetOrthText = XPath('.//tei:f[@name="orth"]/tei:string/text()')
xpathGetEmptyBases = XPath(".//tei:fs[@type='lex'][tei:f[@name='ctag']/tei:symbol/@value != 'ign']/tei:f[@name='base'][string-length(tei:string/text()) = 0]")
xpathGetBasesTxt = XPath("tei:f[@name='base']/tei:string/text()")
xpathGetCtagsTxt = XPath("tei:f[@name='ctag']/tei:symbol/@value")
xpathGetMsds = XPath("tei:f[@name='msd']//tei:symbol")
xpathGetLexOrWordFss = XPath(".//tei:fs[@type='lex' or @type='words']")
xpathGetInterp = XPath(".//tei:f[@name='interpretation']/tei:string/text()")


def xpath(node, qstr):
	return node.xpath(qstr, namespaces=namespaces)

def getFirstText(strlist):
	if len(strlist) == 0:
		return ''
	else:
		return ''+strlist[0]

def getId(node):
	return xpathGetId(node)[0]

class RNGHelper(object):
	
	rngRoot = os.path.join(os.path.dirname(__file__), 'rng')
	
	filename2RNG = {
		'text.xml': 'NKJP_text.rng',
		'text_structure.xml': 'NKJP_structure.rng',
		'ann_segmentation.xml': 'NKJP_segmentation.rng',
		'ann_morphosyntax.xml': 'NKJP_morphosyntax.rng',
		'ann_senses.xml': 'NKJP_senses.rng',
		'ann_words.xml': 'NKJP_words.rng',
		'ann_groups.xml': 'NKJP_groups.rng',
		'ann_named.xml': 'NKJP_named.rng',
	}
	
	def __init__(self):
		self._fname2RNG = {}
		for fname, rngFile in RNGHelper.filename2RNG.iteritems():
			rngTree = etree.parse(os.path.join(RNGHelper.rngRoot, rngFile))
			self._fname2RNG[fname] = etree.RelaxNG(rngTree)
	
	def validate(self, fname, tree):
		#~ print 'Validating %s ...' % fname
		relaxNG = self._fname2RNG[fname]
		relaxNG.assertValid(tree)   

class DirValidator(object):
	
	def __init__(self, opts, rngHelper):
		self.correct = opts.correct
		self.dirname = opts.dirname
		xmlFiles = [name for name in os.listdir(opts.dirname) if self._shouldConsiderName(name)]
		self.treesMap = {}
		for f in xmlFiles:
#		   print opts.dirname, f
			try:
				tree = etree.parse(os.path.join(opts.dirname, f), base_url=opts.dirname+'/')
				#~ print 'DIRNAME', dirname
				tree.xinclude()
			except:
				logging.fatal('%s: FAILED TO PARSE' % f)
				raise
			try:
				rngHelper.validate(f, tree)
			except:
				logging.fatal('%s: FAILED TO VALIDATE' % f)
				raise
			self.treesMap[f] = tree
		self.textFilename = opts.text_file
		self.ignoreTagset = opts.ignore_tagset
		if not self.ignoreTagset:
			self.morphValidator = tagset.TagValidator('morphosyntax.cfg')
			self.wordsValidator = tagset.TagValidator('words.cfg')
		else:
			self.morphValidator = None
			self.wordsValidator = None
	
	def _shouldConsiderName(self, name):
		return name.endswith('.xml') and not name.endswith('header.xml')
	
	def _getFNames(self, what):
		for fname in self.treesMap.iterkeys():
			if fname.startswith('ann_'+what) and fname.endswith('.xml'):
				yield fname
	
	def _checkCorresp(self, tree2ids, node, fname, corresp):
		correspFile, hashChar, correspId = corresp.rpartition('#')
		if correspFile != '' and not correspFile in self.treesMap:
			logging.error('%s --- invalid corresp: %s - no such file: %s' % (fname, corresp, correspFile))
		elif not 'string-range' in correspId:
			if correspFile == '':
				correspFile = fname
			if correspId not in tree2ids[correspFile]:
				logging.error('%s --- invalid corresp: %s - no such @xml:id in file %s: %s' % (fname, corresp, correspFile, correspId))

	def _checkCorresps(self):
		tree2ids = {}
		for fname, tree in self.treesMap.iteritems():
			tree2ids[fname] = set(xpath(tree, '//@xml:id'))
		
		for fname, tree in self.treesMap.iteritems():
			for node in xpath(tree, '//*[@corresp]'):
				self._checkCorresp(tree2ids, node, fname, node.attrib['corresp'])
			for node in xpath(tree, '//tei:body//tei:ptr'):
				self._checkCorresp(tree2ids, node, fname, node.attrib['target'])
			if fname in ['text.xml', 'text_structure.xml']:
				for node in xpath(tree, '//tei:teiHeader//tei:ptr'):
					self._checkCorresp(tree2ids, node, fname, node.attrib['target'])
	
	def _checkForEmptyOrths(self, fname):
		tree = self.treesMap.get(fname, None)
		if tree is None:
			return
		nodes = xpath(tree, "//tei:seg[string-length(.//tei:f[@name='orth']/tei:string/text()) = 0]")
		for node in nodes:
			nodeid = getId(node)
			orthNode = xpathGetOrth(node)[0]
			logging.error('%s#%s--- empty orth' % (fname, nodeid))
	
	def _checkForEmptyLemmas(self, fname):
		'''
			Warn about lemmas with empty base when these lemmas are other than ign.
		'''
		tree = self.treesMap.get(fname, None)
		if tree is None:
			return
		segs = xpath(tree, "//tei:seg[.//tei:f[@name='base'][string-length(tei:string/text()) = 0]]")
		for seg in segs:
			segid = getId(seg)
			baseNodes = xpathGetEmptyBases(seg)
			for baseNode in baseNodes:
				logging.error('%s#%s --- empty base' % (fname, segid))
	
	def _checkInterpretations(self, validator, fname, correct):
		if self.ignoreTagset:
			return
		tree = self.treesMap.get(fname, None)
		if tree is None:
			return
		
		for segnode in xpath(tree, "//tei:seg"):
			for fsnode in xpathGetLexOrWordFss(segnode):
				base = getFirstText(xpathGetBasesTxt(fsnode))
				ctag = getFirstText(xpathGetCtagsTxt(fsnode))
				for msdnode in xpathGetMsds(fsnode):
					msd = msdnode.attrib['value']
					#~ path = os.path.join(dirname, fname)
					path = fname
					try:
						newBase, newCTag, newMsd = validator.validateTag(base, ctag, msd)
						if (newBase, newCTag, newMsd) != (base, ctag, msd):
							if correct:
								msdnode.attrib['value'] = newMsd
							if (newBase, newCTag) == (base, ctag):
								logging.error(
										'%s#%s --- was: "%s:%s:%s" but should be: "%s:%s:%s"' 
										% (path, getId(segnode), base, ctag, msd, newBase, newCTag, newMsd))
							else:
								logging.error(
										'%s#%s --- was: "%s:%s:%s" but should be: "%s:%s:%s"' 
										% (path, getId(segnode), base, ctag, msd, newBase, newCTag, newMsd))
					except tagset.TagException as err:
						logging.error('%s#%s --- %s' % (path, getId(segnode), err.value))
	
	def _checkMorphInterps(self, validator):
		if self.ignoreTagset:
			return
		path = 'ann_morphosyntax.xml'
		tree = self.treesMap[path]
		for segnode in xpath(tree, "//tei:seg"):
			orth = xpath(segnode, './/tei:f[@name="orth"]/tei:string/text()')[0]
			msdid = xpath(segnode, './/tei:f[@name="choice"]/@fVal')[0][1:]
			msdnode = xpath(segnode, './/tei:f[@name="msd"]//tei:symbol[@xml:id="%s"]' % msdid)[0]
			msd = msdnode.attrib['value']
			ctag = xpath(msdnode, 'ancestor::tei:fs[@type="lex"]/tei:f[@name="ctag"]/tei:symbol/@value')[0]
			base = xpath(msdnode, 'ancestor::tei:fs[@type="lex"]/tei:f[@name="base"]/tei:string/text()')[0]
#		   interp = xpathGetInterp(segnode)[0]
			shouldValidate = True
#		   if interp.startswith(':') and orth == ':':
#			   base = ':'
#			   ctag, _, msd = interp[2:].partition(':')
#		   elif interp.startswith(':'):
#			   logging.warn('%s#%s, line: %d --- empty or invalid base?' % (path, getId(segnode), line))
#			   shouldValidate = False
#		   else:
#			   base = interp.split(':')[0]
#			   ctag = interp.split(':')[1]
#			   msd = ':'.join(interp.split(':')[2:])
			if orth == '':
				logging.error('%s#%s --- empty base' % (path, getId(segnode)))
			if shouldValidate:
				try:
					if base == '':
						logging.error('%s#%s --- empty base' % (path, getId(segnode)))
					else:
						newBase, newCTag, newMsd = validator.validateTag(base, ctag, msd)
						if (newBase, newCTag, newMsd) != (base, ctag, msd):
							logging.error('%s#%s --- was: "%s:%s:%s" but should be: "%s:%s:%s"' 
										% (path, getId(segnode), base, ctag, msd, newBase, newCTag, newMsd))
				except tagset.TagException as err:
					logging.error('%s#%s --- %s' % (path, getId(segnode), err.value))
			
	
	def _checkForEmptyPars(self):
		tree = self.treesMap[self.textFilename]
		path = self.textFilename
		
		for divnode in xpath(tree, '//tei:div'):
			abnodes = xpath(divnode, 'tei:ab|tei:u')
			if not abnodes and not self._getText(divnode):
				logging.error('%s#%s --- %s' % (path, getId(divnode), 'empty div'))
			elif abnodes:
				for abnode in abnodes:
					if self._getParagraphText(abnode) == '':
						logging.error('%s#%s --- %s' % (path, getId(divnode), 'empty paragraph'))
	
	def _getParagraphText(self, parnode):
		return self._getText(parnode).strip('\n')
	
	def _getText(self, node):
		res = []
		if node.text:
			res.append(node.text.strip('\n'))
		for subnode in node:
			res.append(self._getText(subnode))
			if subnode.tail:
				res.append(subnode.tail.strip('\n'))
		return u''.join(res)
	
	def _checkCompleteness(self, what, against):
		tree1 = self.treesMap.get(what, None)
		tree2 = self.treesMap.get(against, None)
		if tree1 is None or tree2 is None:
			return
		
		pars1 = set([corresp.partition('#')[2] for corresp in xpath(tree1, '//tei:body//tei:p/@corresp')])
		pars2 = set(xpath(tree2, '//tei:p/@xml:id'))
		
		sents1 = set([corresp.partition('#')[2] for corresp in xpath(tree1, '//tei:body//tei:s/@corresp')])
		sents2 = set(xpath(tree2, '//tei:s/@xml:id'))
		
		for parid in pars2 - pars1:
			logging.error('%s#%s --- paragraph missing in %s' % (against, parid, what))
		
		for sentid in sents2 - sents1:
			logging.error('%s#%s --- sentence missing in %s' % (against, sentid, what))
	
	def _checkWordsCompleteness(self):
		morph_tree = self.treesMap.get('ann_morphosyntax.xml', None)
		words_tree = self.treesMap.get('ann_words.xml', None)
		
		if not words_tree is None:
			segids = set(xpath(morph_tree, '//tei:seg/@xml:id'))
			targets = xpath(words_tree, '//tei:body//tei:ptr/@target')
			wrapped_segids = set([target.split('#')[1] for target in targets if target.startswith('ann_morphosyntax.xml#')])
			for segid in segids.difference(wrapped_segids):
				logging.error('ann_morphosyntax.xml#%s --- morph not wrapped in any syntactic word' % segid)
	
	def _checkSegmentationStringRanges(self):
		morphTree = self.treesMap['ann_morphosyntax.xml']
		segmTree = self.treesMap['ann_segmentation.xml']
		textTree = self.treesMap[self.textFilename]
		for seg in xpath(segmTree, '//tei:seg[not(@nkjp:rejected = "true")]'):
			morph = xpath(morphTree, '//tei:seg[@corresp="ann_segmentation.xml#{0}"]'.format(getId(seg)))[0]
			orth = xpath(morph, './/tei:f[@name="orth"]/tei:string/text()')[0]
			match = re.match(
								r'%s#string-range\(([a-z0-9\-\._]+),([0-9]+),([0-9]+)\)' % self.textFilename,
								seg.attrib['corresp'])
			textParId, startIdx, length = match.group(1), int(match.group(2)), int(match.group(3))
#		   print textParId, startIdx, length
			textPar = xpath(textTree, '//*[@xml:id="%s"]' % textParId)[0]
			textOrth = self._getParagraphText(textPar)[startIdx:(startIdx+length)]
			if self._normalizeOrth(orth) != self._normalizeOrth(textOrth):
				logging.error('ann_segmentation.xml#%s --- invalid @corresp: "%s" != "%s" in "%s"' 
							% (getId(seg), orth, textOrth, self._getParagraphText(textPar)))
	
	def _checkOrthsConsistency(self, what):
		
		tree = self.treesMap.get(what, None)
		if tree is None:
			return
		
		for seg in xpath(tree, '//tei:seg'):
			orth = xpath(seg, './/tei:f[@name="orth"]/tei:string/text()')[0]
			morphs = self._getTargetMorphs(tree, seg)
			morphs = sorted(morphs, key=lambda morph: morph.getparent().index(morph))
			morphsOrth = self._getCombinedOrth(morphs)
			if orth != morphsOrth:
				logging.error('%s#%s --- invalid orth: "%s" != "%s"'
							% (what, getId(seg), orth, morphsOrth))
	
	def _getCombinedOrth(self, morphs):
		res = []
		for idx, morph in enumerate(morphs):
			if idx != 0 and not xpath(morph, './/tei:f[@name="nps"]'):
				res.append(' ')
			res.append(xpathGetOrthText(morph)[0])
		return u''.join(res)
	
	def _getTargetMorphs(self, tree, seg):
		
		morphTree = self.treesMap.get('ann_morphosyntax.xml', None)
		res = []
		
		for target in xpath(seg, './/tei:ptr/@target'):
			if target.startswith('ann_morphosyntax.xml#'):
				morphId = target.split('#')[1]
				res.append(xpathNode4Id(morphTree, id=morphId)[0])
			elif '#' in target:
				fname, segId = tuple(target.split('#'))
				targetTree = self.treesMap.get(fname, None)
				targetSeg = xpathNode4Id(targetTree, id=segId)[0]
				res += self._getTargetMorphs(targetTree, targetSeg)
			else:
				targetSeg = xpathNode4Id(tree, id=target)[0]
				res += self._getTargetMorphs(tree, targetSeg)
		return res	
	
	def _normalizeOrth(self, orth):
		return orth.replace(u'–', '-').replace(' ', '').replace(unichr(0x00A0), '')
	
	def validate(self):
		self._checkForEmptyPars()
		self._checkCorresps()
		
		self._checkSegmentationStringRanges()
		self._checkForEmptyOrths('ann_morphosyntax.xml')
		self._checkForEmptyLemmas('ann_morphosyntax.xml')
		self._checkMorphInterps(self.morphValidator)
		self._checkWordsCompleteness()
			
		for f in ['ann_named.xml']:
			self._checkForEmptyOrths(f)
			self._checkForEmptyLemmas(f)
		
		for f in ['ann_words.xml']:
			self._checkForEmptyOrths(f)
			self._checkForEmptyLemmas(f)
			self._checkInterpretations(self.wordsValidator, f, self.correct)
		
		for f in ['ann_groups.xml']:
			self._checkForEmptyOrths(f)
		
		self._checkCompleteness('ann_words.xml', 'ann_morphosyntax.xml')
		self._checkCompleteness('ann_named.xml', 'ann_morphosyntax.xml')
		self._checkCompleteness('ann_groups.xml', 'ann_words.xml')
		
		self._checkCompleteness('ann_morphosyntax.xml', 'ann_segmentation.xml')
		self._checkCompleteness('ann_senses.xml', 'ann_segmentation.xml')
		
		self._checkOrthsConsistency('ann_named.xml')
		self._checkOrthsConsistency('ann_words.xml')
		self._checkOrthsConsistency('ann_groups.xml')
		
		if self.correct:
			for fname, tree in self.treesMap.iteritems():
				f = open(fname, 'w')
				f.write(tree, encoding='UTF-8', xml_declaration=True)
				f.close()

opts = parseOptions()
logging.basicConfig(level=logging.INFO, format='{0}/%(message)s'.format(opts.dirname))
rngHelper = RNGHelper()
try:
	dirValidator = DirValidator(opts, rngHelper)
	dirValidator.validate()
except Exception as err:
	logging.exception('FAILED')
	sys.exit(1)
