source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm/PredictionByPartialMatch.java @ 9

Last change on this file since 9 was 9, checked in by sherbold, 13 years ago
  • moved training logic for single sequences from PPM to Trie
File size: 5.8 KB
Line 
1package de.ugoe.cs.eventbench.ppm;
2
3import java.util.LinkedList;
4import java.util.List;
5import java.util.Random;
6
7import de.ugoe.cs.eventbench.data.Event;
8import de.ugoe.cs.eventbench.markov.IncompleteMemory;
9
10public class PredictionByPartialMatch {
11       
12        private int maxOrder;
13       
14        private Trie<Event<?>> trie;
15       
16        private double probEscape;
17       
18        private final Random r;
19       
20        public PredictionByPartialMatch(int maxOrder, Random r) {
21                this(maxOrder, r, 0.1);
22        }
23       
24        public PredictionByPartialMatch(int maxOrder, Random r, double probEscape) {
25                this.maxOrder = maxOrder;
26                this.r = r; // TODO defensive copy instead?
27                this.probEscape = probEscape;
28        }
29       
30        public void setProbEscape(double probEscape) {
31                this.probEscape = probEscape;
32        }
33       
34        public double getProbEscape() {
35                return probEscape;
36        }
37       
38        // the training is basically the generation of the trie
39        public void train(List<List<Event<?>>> sequences) {
40                trie = new Trie<Event<?>>();
41               
42                for(List<Event<?>> sequence : sequences) {
43                        List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive copy
44                        currentSequence.add(0, Event.STARTEVENT);
45                        currentSequence.add(Event.ENDEVENT);
46                       
47                        trie.train(currentSequence, maxOrder);
48                }
49        }
50       
51        /*private void addToTrie(List<Event<?>> sequence) {
52                if( knownSymbols==null ) {
53                        knownSymbols = new LinkedHashSet<Event<?>>();
54                }
55                IncompleteMemory<Event<?>> latestActions = new IncompleteMemory<Event<?>>(maxOrder);
56                int i=0;
57                for(Event<?> currentEvent : sequence) {
58                        latestActions.add(currentEvent);
59                        knownSymbols.add(currentEvent);
60                        i++;
61                        if( i>=maxOrder ) {
62                                trie.add(latestActions.getLast(maxOrder));
63                        }
64                }
65                int sequenceLength = sequence.size();
66                for( int j=maxOrder-1 ; j>0 ; j-- ) {
67                        trie.add(sequence.subList(sequenceLength-j, sequenceLength));
68                }
69        }*/
70       
71        public List<? extends Event<?>> randomSequence() {
72                List<Event<?>> sequence = new LinkedList<Event<?>>();
73               
74                IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(maxOrder-1);
75                context.add(Event.STARTEVENT);
76               
77                Event<?> currentState = Event.STARTEVENT;
78               
79                boolean endFound = false;
80               
81                while(!endFound) {
82                        double randVal = r.nextDouble();
83                        double probSum = 0.0;
84                        List<Event<?>> currentContext = context.getLast(maxOrder);
85                        for( Event<?> symbol : trie.getKnownSymbols() ) {
86                                probSum += getProbability(currentContext, symbol);
87                                if( probSum>=randVal ) {
88                                        endFound = (symbol==Event.ENDEVENT);
89                                        if( !(symbol==Event.STARTEVENT || symbol==Event.ENDEVENT) ) {
90                                                // only add the symbol the sequence if it is not START or END
91                                                context.add(symbol);
92                                                currentState = symbol;
93                                                sequence.add(currentState);
94                                        }
95                                        break;
96                                }
97                        }
98                }
99                return sequence;
100        }
101               
102        private double getProbability(List<Event<?>> context, Event<?> symbol) {
103                double result = 0.0d;
104                double resultCurrentContex = 0.0d;
105                double resultShorterContex = 0.0d;
106               
107                List<Event<?>> contextCopy = new LinkedList<Event<?>>(context); // defensive copy
108
109       
110                List<Event<?>> followers = trie.getFollowingSymbols(contextCopy); // \Sigma'
111                int sumCountFollowers = 0; // N(s\sigma')
112                for( Event<?> follower : followers ) {
113                        sumCountFollowers += trie.getCount(contextCopy, follower);
114                }
115               
116                int countSymbol = trie.getCount(contextCopy, symbol); // N(s\sigma)
117                if( contextCopy.size()==0 ) {
118                        resultCurrentContex = ((double) countSymbol) / sumCountFollowers;
119                } else {
120                        resultCurrentContex = ((double) countSymbol / sumCountFollowers)*(1-probEscape);
121                        contextCopy.remove(0);
122                        double probSuffix = getProbability(contextCopy, symbol);
123                        if( followers.size()==0 ) {
124                                resultShorterContex = probSuffix;
125                        } else {
126                                resultShorterContex = probEscape*probSuffix;
127                        }
128                }
129                result = resultCurrentContex+resultShorterContex;
130               
131                return result;
132        }
133       
134        @Override
135        public String toString() {
136                return trie.toString();
137        }
138       
139        /*
140        public void testStuff() {
141                // basically an inline unit test without assertions but manual observation
142                List<String> list = new ArrayList<String>();
143                list.add(initialSymbol);
144                list.add("a");
145                list.add("b");
146                list.add("r");
147                list.add("a");
148                list.add("c");
149                list.add("a");
150                list.add("d");
151                list.add("a");
152                list.add("b");
153                list.add("r");
154                list.add("a");
155                list.add(endSymbol);
156               
157                PredictionByPartialMatch model = new PredictionByPartialMatch();
158                model.trie = new Trie<String>();
159                model.trainStringTrie(list);
160                model.trie.display();
161               
162                List<String> context = new ArrayList<String>();
163                String symbol = "a";
164                // expected: 5
165                Console.traceln(""+model.trie.getCount(context, symbol));
166               
167                // expected: 0
168                context.add("b");
169                Console.traceln(""+model.trie.getCount(context, symbol));
170               
171                // expected: 2
172                context.add("r");
173                Console.traceln(""+model.trie.getCount(context, symbol));
174               
175                // exptected: [b, r]
176                context = new ArrayList<String>();
177                context.add("a");
178                context.add("b");
179                context.add("r");
180                Console.traceln(model.trie.getContextSuffix(context).toString());
181               
182                // exptected: []
183                context = new ArrayList<String>();
184                context.add("e");
185                Console.traceln(model.trie.getContextSuffix(context).toString());
186               
187                // exptected: {a, b, c, d, r}
188                context = new ArrayList<String>();
189                Console.traceln(model.trie.getFollowingSymbols(context).toString());
190               
191                // exptected: {b, c, d}
192                context = new ArrayList<String>();
193                context.add("a");
194                Console.traceln(model.trie.getFollowingSymbols(context).toString());
195               
196                // exptected: []
197                context = new ArrayList<String>();
198                context.add("a");
199                context.add("b");
200                context.add("r");
201                Console.traceln(model.trie.getFollowingSymbols(context).toString());
202               
203                // exptected: 0.0d
204                context = new ArrayList<String>();
205                context.add("a");
206                Console.traceln(""+model.getProbability(context, "z"));
207        }*/
208}
Note: See TracBrowser for help on using the repository browser.