source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm/PredictionByPartialMatch.java @ 5

Last change on this file since 5 was 5, checked in by sherbold, 13 years ago
  • major debugging of PPM and Trie. Results are now correct, but both classes need major refactorings
File size: 8.2 KB
Line 
1package de.ugoe.cs.eventbench.ppm;
2
3import java.util.ArrayList;
4import java.util.LinkedHashSet;
5import java.util.LinkedList;
6import java.util.List;
7import java.util.Random;
8import java.util.Set;
9
10import de.ugoe.cs.eventbench.data.Event;
11import de.ugoe.cs.eventbench.markov.IncompleteMemory;
12import de.ugoe.cs.util.console.Console;
13
14public class PredictionByPartialMatch {
15       
16        private String initialSymbol = "GS";
17        private String endSymbol = "GE";
18       
19        private int maxOrder = 3;
20       
21        private Trie<String> trie;
22       
23        private Set<String> knownSymbols;
24       
25        private double probEscape = 0.2d; // TODO getter/setter - steering parameter!
26       
27        private Random r = new Random(); // TODO should be defined in the constructor
28       
29        // the training is basically the generation of the trie
30        public void train(List<List<Event<?>>> sequences) {
31                trie = new Trie<String>();
32                knownSymbols = new LinkedHashSet<String>();
33                knownSymbols.add(initialSymbol);
34                knownSymbols.add(endSymbol);
35               
36                for(List<Event<?>> sequence : sequences) {
37                        List<String> stringSequence = new LinkedList<String>();
38                        stringSequence.add(initialSymbol);
39                        for( Event<?> event : sequence ) {
40                                stringSequence.add(event.getStandardId());
41                        }
42                        stringSequence.add(endSymbol);
43                       
44                        trainStringTrie(stringSequence);
45                }
46        }
47       
48        private void trainStringTrie(List<String> sequence) {
49                knownSymbols = new LinkedHashSet<String>();             
50                IncompleteMemory<String> latestActions = new IncompleteMemory<String>(maxOrder);
51                int i=0;
52                for(String currentAction : sequence) {
53                        String currentId = currentAction;
54                        latestActions.add(currentId);
55                        knownSymbols.add(currentId);
56                        i++;
57                        if( i>=maxOrder ) {
58                                trie.add(latestActions.getLast(maxOrder));
59                        }
60                }
61                int sequenceLength = sequence.size();
62                for( int j=maxOrder-1 ; j>0 ; j-- ) {
63                        trie.add(sequence.subList(sequenceLength-j, sequenceLength));
64                }
65        }
66       
67        public void testStuff() {
68                // basically an inline unit test without assertions but manual observation
69                List<String> list = new ArrayList<String>();
70                list.add(initialSymbol);
71                list.add("a");
72                list.add("b");
73                list.add("r");
74                list.add("a");
75                list.add("c");
76                list.add("a");
77                list.add("d");
78                list.add("a");
79                list.add("b");
80                list.add("r");
81                list.add("a");
82                list.add(endSymbol);
83               
84                PredictionByPartialMatch model = new PredictionByPartialMatch();
85                model.trie = new Trie<String>();
86                model.trainStringTrie(list);
87                model.trie.display();
88                Console.println("------------------------");
89                model.randomSequence();/*
90                Console.println("------------------------");
91                model.randomSequence();
92                Console.println("------------------------");
93                model.randomSequence();
94                Console.println("------------------------");*/
95               
96                List<String> context = new ArrayList<String>();
97                String symbol = "a";
98                // expected: 5
99                Console.traceln(""+model.trie.getCount(context, symbol));
100               
101                // expected: 0
102                context.add("b");
103                Console.traceln(""+model.trie.getCount(context, symbol));
104               
105                // expected: 2
106                context.add("r");
107                Console.traceln(""+model.trie.getCount(context, symbol));
108               
109                // exptected: [b, r]
110                context = new ArrayList<String>();
111                context.add("a");
112                context.add("b");
113                context.add("r");
114                Console.traceln(model.trie.getContextSuffix(context).toString());
115               
116                // exptected: []
117                context = new ArrayList<String>();
118                context.add("e");
119                Console.traceln(model.trie.getContextSuffix(context).toString());
120               
121                // exptected: {a, b, c, d, r}
122                context = new ArrayList<String>();
123                Console.traceln(model.trie.getFollowingSymbols(context).toString());
124               
125                // exptected: {b, c, d}
126                context = new ArrayList<String>();
127                context.add("a");
128                Console.traceln(model.trie.getFollowingSymbols(context).toString());
129               
130                // exptected: []
131                context = new ArrayList<String>();
132                context.add("a");
133                context.add("b");
134                context.add("r");
135                Console.traceln(model.trie.getFollowingSymbols(context).toString());
136        }
137       
138        // TODO needs to be changed from String to <? extends Event>
139        public List<String> randomSequence() {
140                List<String> sequence = new LinkedList<String>();
141               
142                IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1);
143                context.add(initialSymbol);
144                sequence.add(initialSymbol);
145               
146                String currentState = initialSymbol;
147               
148                Console.println(currentState);
149                while(!endSymbol.equals(currentState)) {
150                        double randVal = r.nextDouble();
151                        double probSum = 0.0;
152                        List<String> currentContext = context.getLast(maxOrder);
153                        for( String symbol : knownSymbols ) {
154                                probSum += getProbability(currentContext, symbol);
155                                if( probSum>=randVal ) {
156                                        context.add(symbol);
157                                        currentState = symbol;
158                                        sequence.add(currentState);
159                                        break;
160                                }
161                        }
162                }
163                return sequence;
164        }
165       
166        /*public void printRandomWalk(Random r) {
167                IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1);
168               
169                context.add(initialSymbol);
170               
171                String currentState = initialSymbol;
172               
173                Console.println(currentState);
174                while(!endSymbol.equals(currentState)) {
175                        double randVal = r.nextDouble();
176                        double probSum = 0.0;
177                        List<String> currentContext = context.getLast(maxOrder);
178                        // DEBUG //
179                        Console.traceln("Context: " + currentContext.toString());
180                        double tmpSum = 0.0d;
181                        for( String symbol : knownSymbols ) {
182                                double prob = getProbability(currentContext, symbol);
183                                tmpSum += prob;
184                                Console.traceln(symbol + ": " + prob);
185                        }
186                        Console.traceln("Sum: " + tmpSum);
187                        // DEBUG-END //
188                        for( String symbol : knownSymbols ) {
189                                probSum += getProbability(currentContext, symbol);
190                                if( probSum>=randVal-0.3 ) {
191                                        context.add(symbol);
192                                        currentState = symbol;
193                                        Console.println(currentState);
194                                        break;
195                                }
196                        }
197                }
198        }*/
199       
200        private double getProbability(List<String> context, String symbol) {
201                // FIXME needs exception handling for unknown symbols
202                // if the symbol is not contained in the trie, context.remove(0) will fail
203                double result = 0.0d;
204                double resultCurrentContex = 0.0d;
205                double resultShorterContex = 0.0d;
206               
207                List<String> contextCopy = new LinkedList<String>(context); // defensive copy
208
209       
210                List<String> followers = trie.getFollowingSymbols(contextCopy); // \Sigma'
211                int sumCountFollowers = 0; // N(s\sigma')
212                for( String follower : followers ) {
213                        sumCountFollowers += trie.getCount(contextCopy, follower);
214                }
215               
216                int countSymbol = trie.getCount(contextCopy, symbol); // N(s\sigma)
217                if( contextCopy.size()==0 ) {
218                        resultCurrentContex = ((double) countSymbol) / sumCountFollowers;
219                } else {
220                        resultCurrentContex = ((double) countSymbol / sumCountFollowers)*(1-probEscape);
221                        contextCopy.remove(0);
222                        double probSuffix = getProbability(contextCopy, symbol);
223                        if( followers.size()==0 ) {
224                                resultShorterContex = probSuffix;
225                        } else {
226                                resultShorterContex = probEscape*probSuffix;
227                        }
228                }
229                result = resultCurrentContex+resultShorterContex;
230               
231                return result;
232        }
233       
234        /*
235        private double getProbability(List<String> context, String symbol) {
236                double result = 0.0;
237                int countContextSymbol = 0;
238                List<String> contextSuffix = trie.getContextSuffix(context);
239                if( contextSuffix.isEmpty() ) {
240                        // unobserved context! everything is possible... assuming identical distribution
241                        result = 1.0d / knownSymbols.size(); // why 1.0 and not N(symbol)
242                } else {
243                        countContextSymbol = trie.getCount(contextSuffix, symbol);
244                        List<String> followers = trie.getFollowingSymbols(contextSuffix);
245                        int countContextFollowers = 0;
246                        for( String follower : followers ) {
247                                countContextFollowers += trie.getCount(contextSuffix, follower);
248                        }
249                       
250                        if( followers.isEmpty() ) {
251                                throw new AssertionError("Invalid return value of trie.getContextSuffix()!");
252                        }
253                        if( countContextSymbol!=0 ) {
254                                result = ((double) countContextSymbol) / (followers.size()+countContextFollowers);
255                        } else { // escape
256                                double probEscape = ((double) followers.size()) / (followers.size()+countContextFollowers);
257                                contextSuffix.remove(0);
258                                double probSuffix = getProbability(contextSuffix, symbol);
259                                result = probEscape*probSuffix;
260                        }
261                }
262
263                return result;
264        }*/
265       
266        @Override
267        public String toString() {
268                return trie.toString();
269        }
270}
Note: See TracBrowser for help on using the repository browser.