package de.ugoe.cs.eventbench.ppm;

import java.util.LinkedList;
import java.util.List;
import java.util.Random;

import de.ugoe.cs.eventbench.data.Event;
import de.ugoe.cs.eventbench.markov.IncompleteMemory;

public class PredictionByPartialMatch {
	
	private int maxOrder;
	
	private Trie<Event<?>> trie;
	
	private double probEscape;
	
	private final Random r;
	
	public PredictionByPartialMatch(int maxOrder, Random r) {
		this(maxOrder, r, 0.1);
	}
	
	public PredictionByPartialMatch(int maxOrder, Random r, double probEscape) {
		this.maxOrder = maxOrder;
		this.r = r; // TODO defensive copy instead?
		this.probEscape = probEscape;
	}
	
	public void setProbEscape(double probEscape) {
		this.probEscape = probEscape;
	}
	
	public double getProbEscape() {
		return probEscape;
	}
	
	// the training is basically the generation of the trie
	public void train(List<List<Event<?>>> sequences) {
		trie = new Trie<Event<?>>();
		
		for(List<Event<?>> sequence : sequences) {
			List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive copy
			currentSequence.add(0, Event.STARTEVENT);
			currentSequence.add(Event.ENDEVENT);
			
			trie.train(currentSequence, maxOrder);
		}
	}
	
	/*private void addToTrie(List<Event<?>> sequence) {
		if( knownSymbols==null ) {
			knownSymbols = new LinkedHashSet<Event<?>>();
		}
		IncompleteMemory<Event<?>> latestActions = new IncompleteMemory<Event<?>>(maxOrder);
		int i=0;
		for(Event<?> currentEvent : sequence) {
			latestActions.add(currentEvent);
			knownSymbols.add(currentEvent);
			i++;
			if( i>=maxOrder ) {
				trie.add(latestActions.getLast(maxOrder));
			}
		}
		int sequenceLength = sequence.size();
		for( int j=maxOrder-1 ; j>0 ; j-- ) {
			trie.add(sequence.subList(sequenceLength-j, sequenceLength));
		}
	}*/
	
	public List<? extends Event<?>> randomSequence() {
		List<Event<?>> sequence = new LinkedList<Event<?>>();
		
		IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(maxOrder-1);
		context.add(Event.STARTEVENT);
		
		Event<?> currentState = Event.STARTEVENT;
		
		boolean endFound = false;
		
		while(!endFound) {
			double randVal = r.nextDouble();
			double probSum = 0.0;
			List<Event<?>> currentContext = context.getLast(maxOrder);
			for( Event<?> symbol : trie.getKnownSymbols() ) {
				probSum += getProbability(currentContext, symbol);
				if( probSum>=randVal ) {
					endFound = (symbol==Event.ENDEVENT);
					if( !(symbol==Event.STARTEVENT || symbol==Event.ENDEVENT) ) {
						// only add the symbol the sequence if it is not START or END
						context.add(symbol);
						currentState = symbol;
						sequence.add(currentState);
					}
					break;
				}
			}
		}
		return sequence;
	}
		
	private double getProbability(List<Event<?>> context, Event<?> symbol) {
		double result = 0.0d;
		double resultCurrentContex = 0.0d;
		double resultShorterContex = 0.0d;
		
		List<Event<?>> contextCopy = new LinkedList<Event<?>>(context); // defensive copy

	
		List<Event<?>> followers = trie.getFollowingSymbols(contextCopy); // \Sigma'
		int sumCountFollowers = 0; // N(s\sigma')
		for( Event<?> follower : followers ) {
			sumCountFollowers += trie.getCount(contextCopy, follower);
		}
		
		int countSymbol = trie.getCount(contextCopy, symbol); // N(s\sigma)
		if( contextCopy.size()==0 ) {
			resultCurrentContex = ((double) countSymbol) / sumCountFollowers;
		} else {
			resultCurrentContex = ((double) countSymbol / sumCountFollowers)*(1-probEscape);
			contextCopy.remove(0); 
			double probSuffix = getProbability(contextCopy, symbol);
			if( followers.size()==0 ) {
				resultShorterContex = probSuffix;
			} else {
				resultShorterContex = probEscape*probSuffix;
			}
		}
		result = resultCurrentContex+resultShorterContex;
		
		return result;
	}
	
	@Override
	public String toString() {
		return trie.toString();
	}
	
	/*
	public void testStuff() {
		// basically an inline unit test without assertions but manual observation
		List<String> list = new ArrayList<String>();
		list.add(initialSymbol);
		list.add("a");
		list.add("b");
		list.add("r");
		list.add("a");
		list.add("c");
		list.add("a");
		list.add("d");
		list.add("a");
		list.add("b");
		list.add("r");
		list.add("a");
		list.add(endSymbol);
		
		PredictionByPartialMatch model = new PredictionByPartialMatch();
		model.trie = new Trie<String>();
		model.trainStringTrie(list);
		model.trie.display();
		
		List<String> context = new ArrayList<String>();
		String symbol = "a";
		// expected: 5
		Console.traceln(""+model.trie.getCount(context, symbol));
		
		// expected: 0
		context.add("b");
		Console.traceln(""+model.trie.getCount(context, symbol));
		
		// expected: 2
		context.add("r");
		Console.traceln(""+model.trie.getCount(context, symbol));
		
		// exptected: [b, r]
		context = new ArrayList<String>();
		context.add("a");
		context.add("b");
		context.add("r");
		Console.traceln(model.trie.getContextSuffix(context).toString());
		
		// exptected: []
		context = new ArrayList<String>();
		context.add("e");
		Console.traceln(model.trie.getContextSuffix(context).toString());
		
		// exptected: {a, b, c, d, r}
		context = new ArrayList<String>();
		Console.traceln(model.trie.getFollowingSymbols(context).toString());
		
		// exptected: {b, c, d}
		context = new ArrayList<String>();
		context.add("a");
		Console.traceln(model.trie.getFollowingSymbols(context).toString());
		
		// exptected: []
		context = new ArrayList<String>();
		context.add("a");
		context.add("b");
		context.add("r");
		Console.traceln(model.trie.getFollowingSymbols(context).toString());
		
		// exptected: 0.0d
		context = new ArrayList<String>();
		context.add("a");
		Console.traceln(""+model.getProbability(context, "z"));
	}*/
}
