source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm/PredictionByPartialMatch.java @ 6

Last change on this file since 6 was 6, checked in by sherbold, 13 years ago
  • Cleanup of PPM and Trie
File size: 6.0 KB
Line 
1package de.ugoe.cs.eventbench.ppm;
2
3import java.util.ArrayList;
4import java.util.LinkedHashSet;
5import java.util.LinkedList;
6import java.util.List;
7import java.util.Random;
8import java.util.Set;
9
10import de.ugoe.cs.eventbench.data.Event;
11import de.ugoe.cs.eventbench.markov.IncompleteMemory;
12import de.ugoe.cs.util.console.Console;
13
14public class PredictionByPartialMatch {
15       
16        private String initialSymbol = "GS";
17        private String endSymbol = "GE";
18       
19        private int maxOrder = 3;
20       
21        private Trie<String> trie;
22       
23        private Set<String> knownSymbols;
24       
25        private double probEscape = 0.2d; // TODO getter/setter - steering parameter!
26       
27        private Random r = new Random(); // TODO should be defined in the constructor
28       
29        // the training is basically the generation of the trie
30        public void train(List<List<Event<?>>> sequences) {
31                trie = new Trie<String>();
32                knownSymbols = new LinkedHashSet<String>();
33                knownSymbols.add(initialSymbol);
34                knownSymbols.add(endSymbol);
35               
36                for(List<Event<?>> sequence : sequences) {
37                        List<String> stringSequence = new LinkedList<String>();
38                        stringSequence.add(initialSymbol);
39                        for( Event<?> event : sequence ) {
40                                stringSequence.add(event.getStandardId());
41                        }
42                        stringSequence.add(endSymbol);
43                       
44                        trainStringTrie(stringSequence);
45                }
46        }
47       
48        private void trainStringTrie(List<String> sequence) {
49                knownSymbols = new LinkedHashSet<String>();             
50                IncompleteMemory<String> latestActions = new IncompleteMemory<String>(maxOrder);
51                int i=0;
52                for(String currentAction : sequence) {
53                        String currentId = currentAction;
54                        latestActions.add(currentId);
55                        knownSymbols.add(currentId);
56                        i++;
57                        if( i>=maxOrder ) {
58                                trie.add(latestActions.getLast(maxOrder));
59                        }
60                }
61                int sequenceLength = sequence.size();
62                for( int j=maxOrder-1 ; j>0 ; j-- ) {
63                        trie.add(sequence.subList(sequenceLength-j, sequenceLength));
64                }
65        }
66       
67        // TODO needs to be changed from String to <? extends Event>
68        public List<String> randomSequence() {
69                List<String> sequence = new LinkedList<String>();
70               
71                IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1);
72                context.add(initialSymbol);
73                sequence.add(initialSymbol);
74               
75                String currentState = initialSymbol;
76               
77                Console.println(currentState);
78                while(!endSymbol.equals(currentState)) {
79                        double randVal = r.nextDouble();
80                        double probSum = 0.0;
81                        List<String> currentContext = context.getLast(maxOrder);
82                        for( String symbol : knownSymbols ) {
83                                probSum += getProbability(currentContext, symbol);
84                                if( probSum>=randVal ) {
85                                        context.add(symbol);
86                                        currentState = symbol;
87                                        sequence.add(currentState);
88                                        break;
89                                }
90                        }
91                }
92                return sequence;
93        }
94               
95        private double getProbability(List<String> context, String symbol) {
96                double result = 0.0d;
97                double resultCurrentContex = 0.0d;
98                double resultShorterContex = 0.0d;
99               
100                List<String> contextCopy = new LinkedList<String>(context); // defensive copy
101
102       
103                List<String> followers = trie.getFollowingSymbols(contextCopy); // \Sigma'
104                int sumCountFollowers = 0; // N(s\sigma')
105                for( String follower : followers ) {
106                        sumCountFollowers += trie.getCount(contextCopy, follower);
107                }
108               
109                int countSymbol = trie.getCount(contextCopy, symbol); // N(s\sigma)
110                if( contextCopy.size()==0 ) {
111                        resultCurrentContex = ((double) countSymbol) / sumCountFollowers;
112                } else {
113                        resultCurrentContex = ((double) countSymbol / sumCountFollowers)*(1-probEscape);
114                        contextCopy.remove(0);
115                        double probSuffix = getProbability(contextCopy, symbol);
116                        if( followers.size()==0 ) {
117                                resultShorterContex = probSuffix;
118                        } else {
119                                resultShorterContex = probEscape*probSuffix;
120                        }
121                }
122                result = resultCurrentContex+resultShorterContex;
123               
124                return result;
125        }
126       
127        @Override
128        public String toString() {
129                return trie.toString();
130        }
131       
132        public void testStuff() {
133                // basically an inline unit test without assertions but manual observation
134                List<String> list = new ArrayList<String>();
135                list.add(initialSymbol);
136                list.add("a");
137                list.add("b");
138                list.add("r");
139                list.add("a");
140                list.add("c");
141                list.add("a");
142                list.add("d");
143                list.add("a");
144                list.add("b");
145                list.add("r");
146                list.add("a");
147                list.add(endSymbol);
148               
149                PredictionByPartialMatch model = new PredictionByPartialMatch();
150                model.trie = new Trie<String>();
151                model.trainStringTrie(list);
152                model.trie.display();
153                Console.println("------------------------");
154                model.randomSequence();/*
155                Console.println("------------------------");
156                model.randomSequence();
157                Console.println("------------------------");
158                model.randomSequence();
159                Console.println("------------------------");*/
160               
161                List<String> context = new ArrayList<String>();
162                String symbol = "a";
163                // expected: 5
164                Console.traceln(""+model.trie.getCount(context, symbol));
165               
166                // expected: 0
167                context.add("b");
168                Console.traceln(""+model.trie.getCount(context, symbol));
169               
170                // expected: 2
171                context.add("r");
172                Console.traceln(""+model.trie.getCount(context, symbol));
173               
174                // exptected: [b, r]
175                context = new ArrayList<String>();
176                context.add("a");
177                context.add("b");
178                context.add("r");
179                Console.traceln(model.trie.getContextSuffix(context).toString());
180               
181                // exptected: []
182                context = new ArrayList<String>();
183                context.add("e");
184                Console.traceln(model.trie.getContextSuffix(context).toString());
185               
186                // exptected: {a, b, c, d, r}
187                context = new ArrayList<String>();
188                Console.traceln(model.trie.getFollowingSymbols(context).toString());
189               
190                // exptected: {b, c, d}
191                context = new ArrayList<String>();
192                context.add("a");
193                Console.traceln(model.trie.getFollowingSymbols(context).toString());
194               
195                // exptected: []
196                context = new ArrayList<String>();
197                context.add("a");
198                context.add("b");
199                context.add("r");
200                Console.traceln(model.trie.getFollowingSymbols(context).toString());
201               
202                // exptected: 0.0d
203                context = new ArrayList<String>();
204                context.add("a");
205                Console.traceln(""+model.getProbability(context, "z"));
206        }
207}
Note: See TracBrowser for help on using the repository browser.