package de.ugoe.cs.eventbench.ppm; import java.util.LinkedHashSet; import java.util.List; import java.util.Random; import java.util.Set; import de.ugoe.cs.eventbench.data.Event; import de.ugoe.cs.eventbench.markov.IncompleteMemory; import de.ugoe.cs.util.console.Console; public class PredictionByPartialMatch { private String initialSymbol = "GLOBALSTARTSTATE"; private String endSymbol = "GLOBALENDSTATE"; private int maxOrder = 3; private Trie trie; private Set knownSymbols; // the training is basically the generation of the trie public void train(List>> sequences) { trie = new Trie(); knownSymbols = new LinkedHashSet(); knownSymbols.add(initialSymbol); knownSymbols.add(endSymbol); for(List> sequence : sequences) { IncompleteMemory latestActions = new IncompleteMemory(maxOrder); // TODO need to check if it should be maxOrder+1 latestActions.add(initialSymbol); for(Event currentAction : sequence) { String currentId = currentAction.getStandardId(); latestActions.add(currentId); knownSymbols.add(currentId); if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder trie.add(latestActions.getLast(maxOrder)); } } latestActions.add(endSymbol); if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder trie.add(latestActions.getLast(maxOrder)); } } } public void printRandomWalk(Random r) { IncompleteMemory context = new IncompleteMemory(maxOrder-1); context.add(initialSymbol); String currentState = initialSymbol; Console.println(currentState); while(!endSymbol.equals(currentState)) { double randVal = r.nextDouble(); double probSum = 0.0; List currentContext = context.getLast(maxOrder); for( String symbol : knownSymbols ) { probSum += getProbability(currentContext, symbol); if( probSum>=randVal ) { currentContext.add(symbol); currentState = symbol; Console.println(currentState); break; } } } } private double getProbability(List context, String symbol) { double result = 0.0; int countContextSymbol = 0; List contextSuffix = trie.getContextSuffix(context); if( contextSuffix.isEmpty() ) { result = 1.0d / knownSymbols.size(); } else { countContextSymbol = trie.getCount(contextSuffix, symbol); List followers = trie.getFollowingSymbols(contextSuffix); int countContextFollowers = 0; for( String follower : followers ) { countContextFollowers += trie.getCount(contextSuffix, follower); } if( followers.isEmpty() ) { throw new AssertionError("Invalid return value of getContextSuffix!"); } if( countContextSymbol!=0 ) { result = ((double) countContextSymbol) / (followers.size()+countContextFollowers); } else { // escape double probEscape = ((double) followers.size()) / (followers.size()+countContextFollowers); contextSuffix.remove(0); double probSuffix = getProbability(contextSuffix, symbol); result = probEscape*probSuffix; } } return result; } @Override public String toString() { return trie.toString(); } }