1 | package de.ugoe.cs.eventbench.ppm;
|
---|
2 |
|
---|
3 | import java.util.LinkedHashSet;
|
---|
4 | import java.util.List;
|
---|
5 | import java.util.Random;
|
---|
6 | import java.util.Set;
|
---|
7 |
|
---|
8 | import de.ugoe.cs.eventbench.data.Event;
|
---|
9 | import de.ugoe.cs.eventbench.markov.IncompleteMemory;
|
---|
10 | import de.ugoe.cs.util.console.Console;
|
---|
11 |
|
---|
12 | public class PredictionByPartialMatch {
|
---|
13 |
|
---|
14 | private String initialSymbol = "GLOBALSTARTSTATE";
|
---|
15 | private String endSymbol = "GLOBALENDSTATE";
|
---|
16 |
|
---|
17 | private int maxOrder = 3;
|
---|
18 |
|
---|
19 | private Trie<String> trie;
|
---|
20 |
|
---|
21 | private Set<String> knownSymbols;
|
---|
22 |
|
---|
23 | // the training is basically the generation of the trie
|
---|
24 | public void train(List<List<Event<?>>> sequences) {
|
---|
25 | trie = new Trie<String>();
|
---|
26 | knownSymbols = new LinkedHashSet<String>();
|
---|
27 | knownSymbols.add(initialSymbol);
|
---|
28 | knownSymbols.add(endSymbol);
|
---|
29 |
|
---|
30 | for(List<Event<?>> sequence : sequences) {
|
---|
31 | IncompleteMemory<String> latestActions = new IncompleteMemory<String>(maxOrder); // TODO need to check if it should be maxOrder+1
|
---|
32 | latestActions.add(initialSymbol);
|
---|
33 | for(Event<?> currentAction : sequence) {
|
---|
34 | String currentId = currentAction.getStandardId();
|
---|
35 | latestActions.add(currentId);
|
---|
36 | knownSymbols.add(currentId);
|
---|
37 | if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder
|
---|
38 | trie.add(latestActions.getLast(maxOrder));
|
---|
39 | }
|
---|
40 | }
|
---|
41 | latestActions.add(endSymbol);
|
---|
42 | if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder
|
---|
43 | trie.add(latestActions.getLast(maxOrder));
|
---|
44 | }
|
---|
45 | }
|
---|
46 | }
|
---|
47 |
|
---|
48 | public void printRandomWalk(Random r) {
|
---|
49 | IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1);
|
---|
50 |
|
---|
51 | context.add(initialSymbol);
|
---|
52 |
|
---|
53 | String currentState = initialSymbol;
|
---|
54 |
|
---|
55 | Console.println(currentState);
|
---|
56 | while(!endSymbol.equals(currentState)) {
|
---|
57 | double randVal = r.nextDouble();
|
---|
58 | double probSum = 0.0;
|
---|
59 | List<String> currentContext = context.getLast(maxOrder);
|
---|
60 | for( String symbol : knownSymbols ) {
|
---|
61 | probSum += getProbability(currentContext, symbol);
|
---|
62 | if( probSum>=randVal ) {
|
---|
63 | currentContext.add(symbol);
|
---|
64 | currentState = symbol;
|
---|
65 | Console.println(currentState);
|
---|
66 | break;
|
---|
67 | }
|
---|
68 | }
|
---|
69 | }
|
---|
70 | }
|
---|
71 |
|
---|
72 | private double getProbability(List<String> context, String symbol) {
|
---|
73 | double result = 0.0;
|
---|
74 | int countContextSymbol = 0;
|
---|
75 | List<String> contextSuffix = trie.getContextSuffix(context);
|
---|
76 | if( contextSuffix.isEmpty() ) {
|
---|
77 | result = 1.0d / knownSymbols.size();
|
---|
78 | } else {
|
---|
79 | countContextSymbol = trie.getCount(contextSuffix, symbol);
|
---|
80 | List<String> followers = trie.getFollowingSymbols(contextSuffix);
|
---|
81 | int countContextFollowers = 0;
|
---|
82 | for( String follower : followers ) {
|
---|
83 | countContextFollowers += trie.getCount(contextSuffix, follower);
|
---|
84 | }
|
---|
85 |
|
---|
86 | if( followers.isEmpty() ) {
|
---|
87 | throw new AssertionError("Invalid return value of getContextSuffix!");
|
---|
88 | }
|
---|
89 | if( countContextSymbol!=0 ) {
|
---|
90 | result = ((double) countContextSymbol) / (followers.size()+countContextFollowers);
|
---|
91 | } else { // escape
|
---|
92 | double probEscape = ((double) followers.size()) / (followers.size()+countContextFollowers);
|
---|
93 | contextSuffix.remove(0);
|
---|
94 | double probSuffix = getProbability(contextSuffix, symbol);
|
---|
95 | result = probEscape*probSuffix;
|
---|
96 | }
|
---|
97 | }
|
---|
98 |
|
---|
99 | return result;
|
---|
100 | }
|
---|
101 |
|
---|
102 | @Override
|
---|
103 | public String toString() {
|
---|
104 | return trie.toString();
|
---|
105 | }
|
---|
106 | }
|
---|