package de.ugoe.cs.eventbench.markov;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

import Jama.Matrix;

import de.ugoe.cs.eventbench.data.Event;
import de.ugoe.cs.util.console.Console;
import edu.uci.ics.jung.graph.Graph;
import edu.uci.ics.jung.graph.SparseMultigraph;
import edu.uci.ics.jung.graph.util.EdgeType;

public class MarkovModel implements DotPrinter {

	private State initialState;
	private State endState;
	
	private List<State> states;
	private List<String> stateIdList;
	
	private Random r;
	
	final static int MAX_STATDIST_ITERATIONS = 1000;
	
	/**
	 * <p>
	 * Default constructor. Creates a new random number generator.
	 * </p>
	 */
	public MarkovModel() {
		this(new Random());
	}
	
	/**
	 * <p>
	 * Creates a new {@link MarkovModel} with a predefined random number generator.
	 * </p>
	 * 
	 * @param r random number generator
	 */
	public MarkovModel(Random r) {
		this.r = r; // defensive copy would be better, but constructor Random(r) does not seem to exist.
	}
	
	public void printRandomWalk() {
		State currentState = initialState;
		IMemory<State> history = new IncompleteMemory<State>(5); // this is NOT used here, just for testing ...
		history.add(currentState);
		Console.println(currentState.getId());
		while(!currentState.equals(endState)) {
			currentState = currentState.getNextState();
			Console.println(currentState.getId());
			history.add(currentState);
		}
	}
	
	public List<? extends Event<?>> randomSequence() {
		List<Event<?>> sequence = new LinkedList<Event<?>>();
		State currentState = initialState;
		if( currentState.getAction()!=null ) {
			sequence.add(currentState.getAction());
		}
		System.out.println(currentState.getId());
		while(!currentState.equals(endState)) {
			currentState = currentState.getNextState();
			if( currentState.getAction()!=null ) {
				sequence.add(currentState.getAction());
			}
			System.out.println(currentState.getId());
		}
		return sequence;
	}
	
	public void printDot() {
		int numUnprintableStates = 0;
		System.out.println("digraph model {");
		for( State state : states ) {
			if( state instanceof DotPrinter ) {
				((DotPrinter) state).printDot();
			} else {
				numUnprintableStates++;
			}
		}
		System.out.println('}');
		if( numUnprintableStates>0 ) {
			Console.println("" + numUnprintableStates + "/" + states.size() + "were unprintable!");
		}
	}
	
	public Graph<String, MarkovEdge> getGraph() {
		Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
		
		for( State state : states) {
			try {
				SimpleState simpleState = (SimpleState) state;
				String from = simpleState.getShortId();
				for( int i=0 ; i<simpleState.toStates.size() ; i++ ) {
					SimpleState toState = (SimpleState) simpleState.toStates.get(i);
					String to = toState.getShortId();
					MarkovEdge prob = new MarkovEdge(simpleState.transitionProbs.get(i));
					graph.addEdge(prob, from, to, EdgeType.DIRECTED);
				}
			} catch (ClassCastException e) {
				// TODO: handle exception
			}
		}
		
		return graph;
	}
	
	static public class MarkovEdge {
		double weight;
		MarkovEdge(double weight) { this.weight = weight; }
		public String toString() { return ""+weight; }
	}
	
	/////////////////////////////////////////////////////////////////////////////////////
	// Code to learn type1 model: states are wndid.action and transitions are unlabled //
	/////////////////////////////////////////////////////////////////////////////////////
	
	public void train(List<List<Event<?>>> sequences) {
		Event<?> fromElement = null;
		Event<?> toElement = null;
		SimpleState fromState;
		SimpleState toState;
		
		states = new ArrayList<State>();
		stateIdList = new ArrayList<String>();
		initialState = new SimpleState("GLOBALSTARTSTATE", null);
		initialState.setRandom(r);
		states.add(initialState);
		stateIdList.add("GLOBALSTARTSTATE");
		endState = new SimpleState("GLOBALENDSTATE", null);
		endState.setRandom(r);
		states.add(endState);
		stateIdList.add("GLOBALENDSTATE");
		for( List<Event<?>> sequence : sequences ) {
			for( int i=0; i<sequence.size() ; i++ ) {
				if( i==0 ) {
					fromState = (SimpleState) initialState;
				} else {
					fromElement = sequence.get(i-1);
					fromState = findOrCreateSimpleState(fromElement);
				}
				
				toElement = sequence.get(i);
				toState = findOrCreateSimpleState(toElement);
				
				fromState.incTransTo(toState);
				
				if( i==sequence.size()-1 ) {
					toState.incTransTo(endState);
				}
			}
		}
	}
	
	private SimpleState findOrCreateSimpleState(Event<?> action) {
		SimpleState state = null;
		String id = action.getStandardId();
		String idShort = action.getShortId();
		int index = stateIdList.indexOf(id);
		if( index!=-1 ) {
			state = (SimpleState) states.get(index);
		} else {
			state = new SimpleState(id, action, idShort);
			state.setRandom(r);
			states.add(state);
			stateIdList.add(id);
		}
		return state;
	}
	
	///////////////////////////////////////////////////////////
	
	// states must be SimpleState, this functions will throw bad cast exceptions
	public double calcEntropy() {
		int numStates = states.size();
		// create transmission matrix
		Matrix transmissionMatrix = new Matrix(numStates, numStates);
		for( int i=0 ; i<numStates ; i++ ) {
			State tmpState = states.get(i);
			if( SimpleState.class.isInstance(tmpState) ) {
				SimpleState currentState = (SimpleState) tmpState;
				for( int j=0 ; j<numStates ; j++ ) {
					double prob = currentState.getProb(states.get(j));
					transmissionMatrix.set(i, j, prob);
				}
			} else {
				Console.printerr("Error calculating entropy. Only allowed for first-order markov models.");
				return Double.NaN;
			}
		}
		
		// Add transition from endState to startState. This makes the markov chain irreducible and recurrent. 
		int startStateIndex = states.indexOf(initialState);
		int endStateIndex = states.indexOf(endState);
		if( startStateIndex==-1 ) {
			Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
			return Double.NaN;
		}
		if( endStateIndex==-1 ) {
			Console.printerrln("Error calculating entropy. End state of markov chain not found.");
			return Double.NaN;
		}
		transmissionMatrix.set(endStateIndex, startStateIndex, 1);
		
		// Calculate stationary distribution by raising the power of the transmission matrix.
		// The rank of the matrix should fall to 1 and each two should be the vector of the
		// stationory distribution. 
		int iter = 0;
		int rank = transmissionMatrix.rank();
		Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
		while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
			stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
			rank = stationaryMatrix.rank();
			iter++;
		}
		
		if( rank!=1 ) {
			Console.traceln("rank: " + rank);
			Console.printerrln("Unable to calculate stationary distribution.");
			return Double.NaN;
		}
		
		double entropy = 0.0;
		for( int i=0 ; i<numStates ; i++ ) {
			for( int j=0 ; j<numStates ; j++ ) {
				if( transmissionMatrix.get(i,j)!=0 ) {
					double tmp = stationaryMatrix.get(i, 0);
					tmp *= transmissionMatrix.get(i, j);
					tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
					entropy -= tmp;
				}
			}
		}
		return entropy;
	}
}
