package de.ugoe.cs.eventbench.models;

import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;

import de.ugoe.cs.eventbench.data.Event;
import de.ugoe.cs.eventbench.models.Trie.Edge;
import de.ugoe.cs.eventbench.models.Trie.TrieVertex;
import edu.uci.ics.jung.graph.Tree;

public abstract class TrieBasedModel implements IStochasticProcess {

	/**
	 * Id for object serialization.
	 */
	private static final long serialVersionUID = 1L;

	protected int trieOrder;

	protected Trie<Event<?>> trie;
	protected final Random r;

	
	public TrieBasedModel(int markovOrder, Random r) {
		super();
		this.trieOrder = markovOrder+1;
		this.r = r;
	}

	public void train(List<List<Event<?>>> sequences) {
		trie = new Trie<Event<?>>();
		
		for(List<Event<?>> sequence : sequences) {
			List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive copy
			currentSequence.add(0, Event.STARTEVENT);
			currentSequence.add(Event.ENDEVENT);
			
			trie.train(currentSequence, trieOrder);
		}
	}

	/* (non-Javadoc)
	 * @see de.ugoe.cs.eventbench.models.IStochasticProcess#randomSequence()
	 */
	@Override
	public List<? extends Event<?>> randomSequence() {
		List<Event<?>> sequence = new LinkedList<Event<?>>();
		
		IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(trieOrder-1);
		context.add(Event.STARTEVENT);
		
		Event<?> currentState = Event.STARTEVENT;
		
		boolean endFound = false;
		
		while(!endFound) {
			double randVal = r.nextDouble();
			double probSum = 0.0;
			List<Event<?>> currentContext = context.getLast(trieOrder);
			for( Event<?> symbol : trie.getKnownSymbols() ) {
				probSum += getProbability(currentContext, symbol);
				if( probSum>=randVal ) {
					endFound = (symbol==Event.ENDEVENT);
					if( !(symbol==Event.STARTEVENT || symbol==Event.ENDEVENT) ) {
						// only add the symbol the sequence if it is not START or END
						context.add(symbol);
						currentState = symbol;
						sequence.add(currentState);
					}
					break;
				}
			}
		}
		return sequence;
	}
	
	public String getTrieDotRepresentation() {
		return trie.getDotRepresentation();
	}
	
	public Tree<TrieVertex, Edge> getTrieGraph() {
		return trie.getGraph();
	}

	@Override
	public String toString() {
		return trie.toString();
	}
	
	public int getNumStates() {
		return trie.getNumSymbols();
	}
	
	public String[] getStateStrings() {
		String[] stateStrings = new String[getNumStates()];
		int i=0;
		for( Event<?> symbol : trie.getKnownSymbols() ) {
			stateStrings[i] = symbol.toString();
			i++;
		}
		return stateStrings;
	}
	
	public Set<Event<?>> getEvents() {
		return trie.getKnownSymbols();
	}

}