/*
 * 
 *  Copyright (C) 2011 Mateusz Kopec
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see http://www.gnu.org/licenses/.
 *
 */
package annotation;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;

import resources.SynonymDictionary;
import utils.Counter;
import utils.Utils;
import KnoW.KnoW;
import basic.Context;
import basic.Gloss;

/**
 * Responsible for computing similarity of context and gloss
 * 
 * @author Mateusz Kopec
 * 
 */
public class MatchingStrategy {

	private enum IntersectionType {
		PRODUCT, COSINE, EUCLIDEAN, JACCARD, KULLBACK, PEARSON
	};

	private enum ExtensionType {
		NONE, SYNONIMY_UX_ALL, SYNONIMY_UX_MONO, PLWORDNET_ALL, PLWORDNET_MONO, PLWORDNET_RELATIONS, PLWORDNET_SEMSIM
	};

	private enum ContextNormalisation {
		NONE, SQUARE, LINEAR
	};

	// params
	private boolean lemmatize;
	private IntersectionType intersectionType;
	private boolean withIGF;
	private boolean withIDF;
	private boolean glossNormalisation;
	private ContextNormalisation contextNormalisation;
	private boolean binary;
	private ExtensionType extensionType;
	private double thresholdLow;
	private double thresholdHigh;

	// gloss processing
	private Map<String, Double> igfs;
	private Map<String, Map<String, Double>> glossVectors;

	public MatchingStrategy(Map<String, Gloss> preprocessedGlosses, String params) throws Exception {

		setParams(params);

		// gloss processing
		if (withIGF)
			igfs = calculateIgfs(preprocessedGlosses);

		// prepare final vector for each gloss
		glossVectors = new HashMap<String, Map<String, Double>>();
		for (Entry<String, Gloss> gloss : preprocessedGlosses.entrySet()) {
			String senseId = gloss.getKey();
			Gloss preprocessedGloss = gloss.getValue();

			Counter glossCounter = new Counter();
			List<String> words;
			if (lemmatize)
				words = preprocessedGloss.getBases();
			else
				words = preprocessedGloss.getOrths();

			for (String s : words) {
				if (binary)
					glossCounter.put(s, 1);
				else
					glossCounter.increase(s);
			}

			Map<String, Double> glossVector = new HashMap<String, Double>();

			for (Entry<String, Integer> entry : glossCounter.entrySet())
				if (withIGF)
					glossVector.put(entry.getKey(), igfs.get(entry.getKey()) * entry.getValue());
				else
					glossVector.put(entry.getKey(), 1.0 * entry.getValue());

			// gloss normalisation
			if (glossNormalisation)
				normalizeVector(glossVector);

			// synonyms
			if (extensionType.equals(ExtensionType.SYNONIMY_UX_ALL)) {
				extendVectorWithSynonyms(glossVector, KnoW.getSynonimyUx(), false);
			} else if (extensionType.equals(ExtensionType.SYNONIMY_UX_MONO)) {
				extendVectorWithSynonyms(glossVector, KnoW.getSynonimyUx(), true);
			} else if (extensionType.equals(ExtensionType.PLWORDNET_ALL)) {
				extendVectorWithSynonyms(glossVector, KnoW.getPlwordnetSynonyms(), false);
			} else if (extensionType.equals(ExtensionType.PLWORDNET_MONO)) {
				extendVectorWithSynonyms(glossVector, KnoW.getPlwordnetSynonyms(), true);
			} else if (extensionType.equals(ExtensionType.PLWORDNET_RELATIONS)) {
				extendVectorWithSynonyms(glossVector, KnoW.getPlwordnetRelations(), false);
			} else if (extensionType.equals(ExtensionType.PLWORDNET_SEMSIM)) {
				extendVectorWithSynonyms(glossVector, KnoW.getPlwordnetSemSim(), false);
			}

			glossVectors.put(senseId, glossVector);
		}
	}

	/**
	 * Creates vector of context
	 * 
	 * @param preprocessedContext
	 *            context to create vector from
	 * @return vector: word -> value
	 */
	public Map<String, Double> prepareContextVector(Context preprocessedContext) {
		Map<String, Double> contextVector = new HashMap<String, Double>();

		Map<String, Collection<Integer>> distances = new HashMap<String, Collection<Integer>>();
		int keywordIndex = preprocessedContext.getKeywordIndex();
		int i = 0;

		List<String> words;
		if (lemmatize)
			words = preprocessedContext.getBases();
		else
			words = preprocessedContext.getOrths();

		for (String s : words) {
			if (i != keywordIndex) { // keyword is not counted

				if (binary)
					contextVector.put(s, 1.0);
				else if (contextVector.containsKey(s))
					contextVector.put(s, contextVector.get(s) + 1.0);
				else
					contextVector.put(s, 1.0);

				int d = Math.abs(i - keywordIndex);
				Collection<Integer> di = distances.get(s);
				if (di == null) {
					di = new ArrayList<Integer>();
					distances.put(s, di);
				}
				di.add(d);
			}
			i++;
		}

		if (withIDF)
			for (Entry<String, Double> e : contextVector.entrySet())
				e.setValue(e.getValue() * KnoW.getFrequencyCounter().getIDF(e.getKey()));

		// threshold
		thresholdVector(contextVector, thresholdLow, thresholdHigh);

		// normalization
		if (!contextNormalisation.equals(ContextNormalisation.NONE))
			for (Entry<String, Double> e : contextVector.entrySet()) {
				double d = 0;
				for (Integer di : distances.get(e.getKey()))
					if (contextNormalisation.equals(ContextNormalisation.SQUARE))
						d += 1.0 / (di * di);
					else if (contextNormalisation.equals(ContextNormalisation.LINEAR))
						d += 1.0 / di;

				e.setValue(e.getValue() * d);
			}

		// synonyms
		if (extensionType.equals(ExtensionType.SYNONIMY_UX_ALL)) {
			extendVectorWithSynonyms(contextVector, KnoW.getSynonimyUx(), false);
		} else if (extensionType.equals(ExtensionType.SYNONIMY_UX_MONO)) {
			extendVectorWithSynonyms(contextVector, KnoW.getSynonimyUx(), true);
		} else if (extensionType.equals(ExtensionType.PLWORDNET_ALL)) {
			extendVectorWithSynonyms(contextVector, KnoW.getPlwordnetSynonyms(), false);
		} else if (extensionType.equals(ExtensionType.PLWORDNET_MONO)) {
			extendVectorWithSynonyms(contextVector, KnoW.getPlwordnetSynonyms(), true);
		} else if (extensionType.equals(ExtensionType.PLWORDNET_RELATIONS)) {
			extendVectorWithSynonyms(contextVector, KnoW.getPlwordnetRelations(), false);
		} else if (extensionType.equals(ExtensionType.PLWORDNET_SEMSIM)) {
			extendVectorWithSynonyms(contextVector, KnoW.getPlwordnetSemSim(), false);
		}

		return contextVector;
	}

	/**
	 * Computes the similarity between context vector and chosen sense
	 * 
	 * @param contextVector
	 *            vector
	 * @param senseId
	 *            id of chosen sense
	 * @return similarity value
	 */
	public double computeSimilarity(Map<String, Double> contextVector, String senseId) {

		Map<String, Double> glossVector = glossVectors.get(senseId);

		// compare to similarity of two vectors
		return getVectorSimilarity(glossVector, contextVector, intersectionType);
	}

	private static void thresholdVector(Map<String, Double> vector, double thresholdLow, double thresholdHigh) {
		double maxValue = -1;
		for (Double d : vector.values())
			if (d > maxValue)
				maxValue = d;

		Iterator<Entry<String, Double>> i = vector.entrySet().iterator();
		while (i.hasNext()) {
			Entry<String, Double> entry = i.next();
			if (entry.getValue() < thresholdLow * maxValue || entry.getValue() > thresholdHigh * maxValue)
				i.remove();
		}
	}

	private static void extendVectorWithSynonyms(Map<String, Double> vector, SynonymDictionary dict,
			boolean onlyMonosemous) {
		Map<String, Double> copy = new HashMap<String, Double>(vector);

		for (Entry<String, Double> entry : copy.entrySet()) {

			if (entry.getKey().matches(Utils.singleWordRegex)) {
				Collection<String> synonyms = dict.getSynonymsForLemma(entry.getKey(), onlyMonosemous);

				for (String synonym : synonyms)
					for (String sub : synonym.split("\\s+"))
						// split of multiword synonyms

						// check if synonym is a well-formed word
						if (!vector.containsKey(sub))
							if (sub.matches(Utils.singleWordRegex))
								vector.put(sub, entry.getValue());
			}
		}
	}

	private static void normalizeVector(Map<String, Double> vector) {
		double sumOfValues = 0;
		for (Double d : vector.values())
			sumOfValues += d;

		for (Entry<String, Double> e : vector.entrySet())
			e.setValue(100 * e.getValue() / sumOfValues);
	}

	/**
	 * Calculate Inverse Gloss Frequencies for all words in given glosses
	 * 
	 * @param preprocessedGlosses
	 * @return a map : word -> it's IGF
	 */
	private Map<String, Double> calculateIgfs(Map<String, Gloss> preprocessedGlosses) {
		Counter c = new Counter();
		for (Entry<String, Gloss> e : preprocessedGlosses.entrySet()) {
			Gloss g = e.getValue();
			Set<String> strings = new HashSet<String>();

			if (lemmatize)
				strings.addAll(g.getBases());
			else
				strings.addAll(g.getOrths());

			for (String w : strings)
				c.increase(w);
		}

		Map<String, Double> idfFromGlossess = new HashMap<String, Double>();

		for (Entry<String, Integer> s : c.entrySet())
			idfFromGlossess.put(s.getKey(), Math.log(1.0 * preprocessedGlosses.size() / s.getValue()));

		return idfFromGlossess;
	}

	private static double getVectorSimilarity(Map<String, Double> vector1, Map<String, Double> vector2,
			IntersectionType intersectionType) {
		double result = 0;

		if (intersectionType.equals(IntersectionType.EUCLIDEAN)) {
			Collection<String> words = getAllWords(vector1, vector2);

			for (String word : words) {
				Double weigth1 = vector1.get(word);
				Double weigth2 = vector2.get(word);
				if (weigth1 == null)
					weigth1 = 0.0;
				if (weigth2 == null)
					weigth2 = 0.0;

				result += Math.pow(weigth2 - weigth1, 2);
			}

			if (result == 0)
				result += 0.00000000000000001;

			result = 1.0 / Math.sqrt(result);

		} else if (intersectionType.equals(IntersectionType.JACCARD)) {
			double nom = getProduct(vector1, vector2);
			double denom = (Math.pow(getVectorLength(vector1), 2) + Math.pow(getVectorLength(vector2), 2) - result);
			if (denom == 0)
				result = 0;
			else
				result = nom / denom;

		} else if (intersectionType.equals(IntersectionType.KULLBACK)) {
			Collection<String> words = getAllWords(vector1, vector2);

			for (String word : words) {
				Double weigth1 = vector1.get(word);
				Double weigth2 = vector2.get(word);
				if (weigth1 == null)
					weigth1 = 0.0;
				if (weigth2 == null)
					weigth2 = 0.0;

				double pi1 = weigth1 / (weigth1 + weigth2);
				double pi2 = weigth2 / (weigth1 + weigth2);
				double w = pi1 * weigth1 + pi2 * weigth2;
				if (weigth1 > 0)
					result += pi1 * (weigth1 * Math.log(weigth1 / w));
				if (weigth2 > 0)
					result += pi2 * (weigth2 * Math.log(weigth2 / w));
			}
			result = 1.0 / result;

		} else if (intersectionType.equals(IntersectionType.PEARSON)) {
			double tf1 = getSumWeights(vector1, 1);
			double tf2 = getSumWeights(vector2, 1);
			double m = getAllWords(vector1, vector2).size();

			double nom = m * getProduct(vector1, vector2) - tf1 * tf2;
			double denom = (m * getSumWeights(vector1, 2) - Math.pow(tf1, 2))
					* (getSumWeights(vector2, 2) - Math.pow(tf2, 2));
			if (denom < 0)
				denom = 0;

			denom = Math.sqrt(denom);

			if (denom == 0)
				if (nom == 0)
					result = 2;
				else
					result = 0;
			else
				result = 1 + nom / denom;

		} else if (intersectionType.equals(IntersectionType.PRODUCT)) {
			result = getProduct(vector1, vector2);

		} else if (intersectionType.equals(IntersectionType.COSINE)) {
			double nom = getProduct(vector1, vector2);
			double denom = getVectorLength(vector1) * getVectorLength(vector2);

			if (denom == 0)
				result = 0;
			else
				result = nom / denom;
		}

		return result;
	}

	/**
	 * Calculate lenght of a vector as a root of sum of squares of all values.
	 * 
	 * @param vector
	 * @return length
	 */
	private static double getVectorLength(Map<String, Double> vector) {
		double result = 0;

		for (Double val : vector.values())
			result += val * val;

		result = Math.sqrt(result);

		return result;
	}

	private static Collection<String> getAllWords(Map<String, Double> vector1, Map<String, Double> vector2) {
		Collection<String> result = new HashSet<String>(vector1.keySet());
		result.addAll(new HashSet<String>(vector2.keySet()));
		return result;
	}

	private static double getSumWeights(Map<String, Double> vector, int pow) {
		double result = 0;
		for (Double val : vector.values())
			result += Math.pow(val, pow);
		return result;
	}

	private static double getProduct(Map<String, Double> vector1, Map<String, Double> vector2) {
		double result = 0;
		for (Entry<String, Double> entry : vector1.entrySet()) {
			String word = entry.getKey();
			Double f = vector2.get(word);
			if (f != null)
				result += f * entry.getValue();
		}
		return result;
	}

	private void setParams(String params) throws Exception {
		String[] splitted = params.split(":");

		if (splitted.length != 10)
			throw new Exception("Wrong parameters length!" + params);

		// input paramteters
		String lm = splitted[0];
		String bi = splitted[1];
		String wigf = splitted[2];
		String widf = splitted[3];
		String thl = splitted[4];
		String thh = splitted[5];
		String nog = splitted[6];
		String noc = splitted[7];
		String et = splitted[8];
		String it = splitted[9];

		this.lemmatize = (lm.equalsIgnoreCase("yes"));

		this.binary = bi.equalsIgnoreCase("yes");

		this.withIGF = wigf.equalsIgnoreCase("yes");
		this.withIDF = widf.equalsIgnoreCase("yes");

		this.thresholdLow = Double.valueOf(thl);
		this.thresholdHigh = Double.valueOf(thh);

		this.glossNormalisation = nog.equalsIgnoreCase("yes");

		if (noc.equalsIgnoreCase("NONE"))
			this.contextNormalisation = ContextNormalisation.NONE;
		else if (noc.equalsIgnoreCase("SQUARE"))
			this.contextNormalisation = ContextNormalisation.SQUARE;
		else if (noc.equalsIgnoreCase("LINEAR"))
			this.contextNormalisation = ContextNormalisation.LINEAR;

		if (et.equalsIgnoreCase("NONE"))
			this.extensionType = ExtensionType.NONE;
		else if (et.equalsIgnoreCase("SYNONIMY_UX_ALL"))
			this.extensionType = ExtensionType.SYNONIMY_UX_ALL;
		else if (et.equalsIgnoreCase("SYNONIMY_UX_MONO"))
			this.extensionType = ExtensionType.SYNONIMY_UX_MONO;
		else if (et.equalsIgnoreCase("PLWORDNET_ALL"))
			this.extensionType = ExtensionType.PLWORDNET_ALL;
		else if (et.equalsIgnoreCase("PLWORDNET_MONO"))
			this.extensionType = ExtensionType.PLWORDNET_MONO;
		else if (et.equalsIgnoreCase("PLWORDNET_RELATIONS"))
			this.extensionType = ExtensionType.PLWORDNET_RELATIONS;
		else if (et.equalsIgnoreCase("PLWORDNET_SEMSIM"))
			this.extensionType = ExtensionType.PLWORDNET_SEMSIM;
		else
			throw new Exception("Wrong parameters!" + params);

		if (it.equalsIgnoreCase("COSINE"))
			this.intersectionType = IntersectionType.COSINE;
		else if (it.equalsIgnoreCase("EUCLIDEAN"))
			this.intersectionType = IntersectionType.EUCLIDEAN;
		else if (it.equalsIgnoreCase("JACCARD"))
			this.intersectionType = IntersectionType.JACCARD;
		else if (it.equalsIgnoreCase("KULLBACK"))
			this.intersectionType = IntersectionType.KULLBACK;
		else if (it.equalsIgnoreCase("PEARSON"))
			this.intersectionType = IntersectionType.PEARSON;
		else if (it.equalsIgnoreCase("PRODUCT"))
			this.intersectionType = IntersectionType.PRODUCT;
		else
			throw new Exception("Wrong parameters!" + it + params);
	}

}
