[1] | 1 | package de.ugoe.cs.eventbench.ppm;
|
---|
| 2 |
|
---|
| 3 | import java.util.LinkedHashSet;
|
---|
| 4 | import java.util.List;
|
---|
| 5 | import java.util.Random;
|
---|
| 6 | import java.util.Set;
|
---|
| 7 |
|
---|
| 8 | import de.ugoe.cs.eventbench.data.Event;
|
---|
| 9 | import de.ugoe.cs.eventbench.markov.IncompleteMemory;
|
---|
| 10 | import de.ugoe.cs.util.console.Console;
|
---|
| 11 |
|
---|
| 12 | public class PredictionByPartialMatch {
|
---|
| 13 |
|
---|
| 14 | private String initialSymbol = "GLOBALSTARTSTATE";
|
---|
| 15 | private String endSymbol = "GLOBALENDSTATE";
|
---|
| 16 |
|
---|
| 17 | private int maxOrder = 3;
|
---|
| 18 |
|
---|
| 19 | private Trie<String> trie;
|
---|
| 20 |
|
---|
| 21 | private Set<String> knownSymbols;
|
---|
| 22 |
|
---|
| 23 | // the training is basically the generation of the trie
|
---|
| 24 | public void train(List<List<Event<?>>> sequences) {
|
---|
| 25 | trie = new Trie<String>();
|
---|
| 26 | knownSymbols = new LinkedHashSet<String>();
|
---|
| 27 | knownSymbols.add(initialSymbol);
|
---|
| 28 | knownSymbols.add(endSymbol);
|
---|
| 29 |
|
---|
| 30 | for(List<Event<?>> sequence : sequences) {
|
---|
| 31 | IncompleteMemory<String> latestActions = new IncompleteMemory<String>(maxOrder); // TODO need to check if it should be maxOrder+1
|
---|
| 32 | latestActions.add(initialSymbol);
|
---|
| 33 | for(Event<?> currentAction : sequence) {
|
---|
| 34 | String currentId = currentAction.getStandardId();
|
---|
| 35 | latestActions.add(currentId);
|
---|
| 36 | knownSymbols.add(currentId);
|
---|
| 37 | if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder
|
---|
| 38 | trie.add(latestActions.getLast(maxOrder));
|
---|
| 39 | }
|
---|
| 40 | }
|
---|
| 41 | latestActions.add(endSymbol);
|
---|
| 42 | if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder
|
---|
| 43 | trie.add(latestActions.getLast(maxOrder));
|
---|
| 44 | }
|
---|
| 45 | }
|
---|
| 46 | }
|
---|
| 47 |
|
---|
| 48 | public void printRandomWalk(Random r) {
|
---|
| 49 | IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1);
|
---|
| 50 |
|
---|
| 51 | context.add(initialSymbol);
|
---|
| 52 |
|
---|
| 53 | String currentState = initialSymbol;
|
---|
| 54 |
|
---|
| 55 | Console.println(currentState);
|
---|
| 56 | while(!endSymbol.equals(currentState)) {
|
---|
| 57 | double randVal = r.nextDouble();
|
---|
| 58 | double probSum = 0.0;
|
---|
| 59 | List<String> currentContext = context.getLast(maxOrder);
|
---|
| 60 | for( String symbol : knownSymbols ) {
|
---|
| 61 | probSum += getProbability(currentContext, symbol);
|
---|
| 62 | if( probSum>=randVal ) {
|
---|
| 63 | currentContext.add(symbol);
|
---|
| 64 | currentState = symbol;
|
---|
| 65 | Console.println(currentState);
|
---|
| 66 | break;
|
---|
| 67 | }
|
---|
| 68 | }
|
---|
| 69 | }
|
---|
| 70 | }
|
---|
| 71 |
|
---|
| 72 | private double getProbability(List<String> context, String symbol) {
|
---|
| 73 | double result = 0.0;
|
---|
| 74 | int countContextSymbol = 0;
|
---|
| 75 | List<String> contextSuffix = trie.getContextSuffix(context);
|
---|
| 76 | if( contextSuffix.isEmpty() ) {
|
---|
| 77 | result = 1.0d / knownSymbols.size();
|
---|
| 78 | } else {
|
---|
| 79 | countContextSymbol = trie.getCount(contextSuffix, symbol);
|
---|
| 80 | List<String> followers = trie.getFollowingSymbols(contextSuffix);
|
---|
| 81 | int countContextFollowers = 0;
|
---|
| 82 | for( String follower : followers ) {
|
---|
| 83 | countContextFollowers += trie.getCount(contextSuffix, follower);
|
---|
| 84 | }
|
---|
| 85 |
|
---|
| 86 | if( followers.isEmpty() ) {
|
---|
| 87 | throw new AssertionError("Invalid return value of getContextSuffix!");
|
---|
| 88 | }
|
---|
| 89 | if( countContextSymbol!=0 ) {
|
---|
| 90 | result = ((double) countContextSymbol) / (followers.size()+countContextFollowers);
|
---|
| 91 | } else { // escape
|
---|
| 92 | double probEscape = ((double) followers.size()) / (followers.size()+countContextFollowers);
|
---|
| 93 | contextSuffix.remove(0);
|
---|
| 94 | double probSuffix = getProbability(contextSuffix, symbol);
|
---|
| 95 | result = probEscape*probSuffix;
|
---|
| 96 | }
|
---|
| 97 | }
|
---|
| 98 |
|
---|
| 99 | return result;
|
---|
| 100 | }
|
---|
| 101 |
|
---|
| 102 | @Override
|
---|
| 103 | public String toString() {
|
---|
| 104 | return trie.toString();
|
---|
| 105 | }
|
---|
| 106 | }
|
---|