source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm/PredictionByPartialMatch.java @ 7

Last change on this file since 7 was 7, checked in by sherbold, 13 years ago
  • changed PPM to work with Event<?> instead of String
File size: 5.9 KB
Line 
1package de.ugoe.cs.eventbench.ppm;
2
3import java.util.LinkedHashSet;
4import java.util.LinkedList;
5import java.util.List;
6import java.util.Random;
7import java.util.Set;
8
9import de.ugoe.cs.eventbench.data.Event;
10import de.ugoe.cs.eventbench.markov.IncompleteMemory;
11
12public class PredictionByPartialMatch {
13       
14        private int maxOrder = 3;
15       
16        private Trie<Event<?>> trie;
17       
18        private Set<Event<?>> knownSymbols;
19       
20        private double probEscape;
21       
22        private final Random r;
23       
24        public PredictionByPartialMatch(Random r) {
25                this(r, 0.1);
26        }
27       
28        public PredictionByPartialMatch(Random r, double probEscape) {
29                this.r = r; // TODO defensive copy instead?
30                this.probEscape = probEscape;
31        }
32       
33        public void setProbEscape(double probEscape) {
34                this.probEscape = probEscape;
35        }
36       
37        public double getProbEscape() {
38                return probEscape;
39        }
40       
41        // the training is basically the generation of the trie
42        public void train(List<List<Event<?>>> sequences) {
43                trie = new Trie<Event<?>>();
44                knownSymbols = new LinkedHashSet<Event<?>>();
45                knownSymbols.add(Event.STARTEVENT);
46                knownSymbols.add(Event.ENDEVENT);
47               
48                for(List<Event<?>> sequence : sequences) {
49                        List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive copy
50                        currentSequence.add(0, Event.STARTEVENT);
51                        currentSequence.add(Event.ENDEVENT);
52                       
53                        addToTrie(currentSequence);
54                }
55        }
56       
57        private void addToTrie(List<Event<?>> sequence) {
58                if( knownSymbols==null ) {
59                        knownSymbols = new LinkedHashSet<Event<?>>();
60                }
61                IncompleteMemory<Event<?>> latestActions = new IncompleteMemory<Event<?>>(maxOrder);
62                int i=0;
63                for(Event<?> currentEvent : sequence) {
64                        latestActions.add(currentEvent);
65                        knownSymbols.add(currentEvent);
66                        i++;
67                        if( i>=maxOrder ) {
68                                trie.add(latestActions.getLast(maxOrder));
69                        }
70                }
71                int sequenceLength = sequence.size();
72                for( int j=maxOrder-1 ; j>0 ; j-- ) {
73                        trie.add(sequence.subList(sequenceLength-j, sequenceLength));
74                }
75        }
76       
77        public List<? extends Event<?>> randomSequence() {
78                List<Event<?>> sequence = new LinkedList<Event<?>>();
79               
80                IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(maxOrder-1);
81                context.add(Event.STARTEVENT);
82               
83                Event<?> currentState = Event.STARTEVENT;
84               
85                boolean endFound = false;
86               
87                while(!endFound) {
88                        double randVal = r.nextDouble();
89                        double probSum = 0.0;
90                        List<Event<?>> currentContext = context.getLast(maxOrder);
91                        for( Event<?> symbol : knownSymbols ) {
92                                probSum += getProbability(currentContext, symbol);
93                                if( probSum>=randVal ) {
94                                        endFound = (symbol==Event.ENDEVENT);
95                                        if( !(symbol==Event.STARTEVENT || symbol==Event.ENDEVENT) ) {
96                                                // only add the symbol the sequence if it is not START or END
97                                                context.add(symbol);
98                                                currentState = symbol;
99                                                sequence.add(currentState);
100                                        }
101                                        break;
102                                }
103                        }
104                }
105                return sequence;
106        }
107               
108        private double getProbability(List<Event<?>> context, Event<?> symbol) {
109                double result = 0.0d;
110                double resultCurrentContex = 0.0d;
111                double resultShorterContex = 0.0d;
112               
113                List<Event<?>> contextCopy = new LinkedList<Event<?>>(context); // defensive copy
114
115       
116                List<Event<?>> followers = trie.getFollowingSymbols(contextCopy); // \Sigma'
117                int sumCountFollowers = 0; // N(s\sigma')
118                for( Event<?> follower : followers ) {
119                        sumCountFollowers += trie.getCount(contextCopy, follower);
120                }
121               
122                int countSymbol = trie.getCount(contextCopy, symbol); // N(s\sigma)
123                if( contextCopy.size()==0 ) {
124                        resultCurrentContex = ((double) countSymbol) / sumCountFollowers;
125                } else {
126                        resultCurrentContex = ((double) countSymbol / sumCountFollowers)*(1-probEscape);
127                        contextCopy.remove(0);
128                        double probSuffix = getProbability(contextCopy, symbol);
129                        if( followers.size()==0 ) {
130                                resultShorterContex = probSuffix;
131                        } else {
132                                resultShorterContex = probEscape*probSuffix;
133                        }
134                }
135                result = resultCurrentContex+resultShorterContex;
136               
137                return result;
138        }
139       
140        @Override
141        public String toString() {
142                return trie.toString();
143        }
144       
145        /*
146        public void testStuff() {
147                // basically an inline unit test without assertions but manual observation
148                List<String> list = new ArrayList<String>();
149                list.add(initialSymbol);
150                list.add("a");
151                list.add("b");
152                list.add("r");
153                list.add("a");
154                list.add("c");
155                list.add("a");
156                list.add("d");
157                list.add("a");
158                list.add("b");
159                list.add("r");
160                list.add("a");
161                list.add(endSymbol);
162               
163                PredictionByPartialMatch model = new PredictionByPartialMatch();
164                model.trie = new Trie<String>();
165                model.trainStringTrie(list);
166                model.trie.display();
167               
168                List<String> context = new ArrayList<String>();
169                String symbol = "a";
170                // expected: 5
171                Console.traceln(""+model.trie.getCount(context, symbol));
172               
173                // expected: 0
174                context.add("b");
175                Console.traceln(""+model.trie.getCount(context, symbol));
176               
177                // expected: 2
178                context.add("r");
179                Console.traceln(""+model.trie.getCount(context, symbol));
180               
181                // exptected: [b, r]
182                context = new ArrayList<String>();
183                context.add("a");
184                context.add("b");
185                context.add("r");
186                Console.traceln(model.trie.getContextSuffix(context).toString());
187               
188                // exptected: []
189                context = new ArrayList<String>();
190                context.add("e");
191                Console.traceln(model.trie.getContextSuffix(context).toString());
192               
193                // exptected: {a, b, c, d, r}
194                context = new ArrayList<String>();
195                Console.traceln(model.trie.getFollowingSymbols(context).toString());
196               
197                // exptected: {b, c, d}
198                context = new ArrayList<String>();
199                context.add("a");
200                Console.traceln(model.trie.getFollowingSymbols(context).toString());
201               
202                // exptected: []
203                context = new ArrayList<String>();
204                context.add("a");
205                context.add("b");
206                context.add("r");
207                Console.traceln(model.trie.getFollowingSymbols(context).toString());
208               
209                // exptected: 0.0d
210                context = new ArrayList<String>();
211                context.add("a");
212                Console.traceln(""+model.getProbability(context, "z"));
213        }*/
214}
Note: See TracBrowser for help on using the repository browser.