source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/models/TrieBasedModel.java

Last change on this file was 400, checked in by sherbold, 12 years ago
  • fixed bug that disallowed to generate sequences of a predefined maxLength and a valid end with the method de.ugoe.cs.eventbench.models.TrieBasedModel?.randomSequence(int maxLength, boolean validEnd).
  • Property svn:mime-type set to text/plain
File size: 10.5 KB
Line 
1package de.ugoe.cs.eventbench.models;
2
3import java.security.InvalidParameterException;
4import java.util.ArrayList;
5import java.util.Collection;
6import java.util.HashSet;
7import java.util.LinkedHashSet;
8import java.util.LinkedList;
9import java.util.List;
10import java.util.Random;
11import java.util.Set;
12
13import de.ugoe.cs.eventbench.data.Event;
14import de.ugoe.cs.eventbench.models.Trie.Edge;
15import de.ugoe.cs.eventbench.models.Trie.TrieVertex;
16import edu.uci.ics.jung.graph.Tree;
17
18/**
19 * <p>
20 * Implements a skeleton for stochastic processes that can calculate
21 * probabilities based on a trie. The skeleton provides all functionalities of
22 * {@link IStochasticProcess} except
23 * {@link IStochasticProcess#getProbability(List, Event)}.
24 * </p>
25 *
26 * @author Steffen Herbold
27 * @version 1.0
28 */
29public abstract class TrieBasedModel implements IStochasticProcess {
30
31        /**
32         * <p>
33         * Id for object serialization.
34         * </p>
35         */
36        private static final long serialVersionUID = 1L;
37
38        /**
39         * <p>
40         * The order of the trie, i.e., the maximum length of subsequences stored in
41         * the trie.
42         * </p>
43         */
44        protected int trieOrder;
45
46        /**
47         * <p>
48         * Trie on which the probability calculations are based.
49         * </p>
50         */
51        protected Trie<Event<?>> trie = null;
52
53        /**
54         * <p>
55         * Random number generator used by probabilistic sequence generation
56         * methods.
57         * </p>
58         */
59        protected final Random r;
60
61        /**
62         * <p>
63         * Constructor. Creates a new TrieBasedModel that can be used for stochastic
64         * processes with a Markov order less than or equal to {@code markovOrder}.
65         * </p>
66         *
67         * @param markovOrder
68         *            Markov order of the model
69         * @param r
70         *            random number generator used by probabilistic methods of the
71         *            class
72         * @throws InvalidParameterException
73         *             thrown if markovOrder is less than 0 or the random number
74         *             generator r is null
75         */
76        public TrieBasedModel(int markovOrder, Random r) {
77                super();
78                if (markovOrder < 0) {
79                        throw new InvalidParameterException(
80                                        "markov order must not be less than 0");
81                }
82                if (r == null) {
83                        throw new InvalidParameterException(
84                                        "random number generator r must not be null");
85                }
86                this.trieOrder = markovOrder + 1;
87                this.r = r;
88        }
89
90        /**
91         * <p>
92         * Trains the model by generating a trie from which probabilities are
93         * calculated. The trie is newly generated based solely on the passed
94         * sequences. If an existing model should only be updated, use
95         * {@link #update(Collection)} instead.
96         * </p>
97         *
98         * @param sequences
99         *            training data
100         * @throws InvalidParameterException
101         *             thrown is sequences is null
102         */
103        public void train(Collection<List<? extends Event<?>>> sequences) {
104                trie = null;
105                update(sequences);
106        }
107
108        /**
109         * <p>
110         * Trains the model by updating the trie from which the probabilities are
111         * calculated. This function updates an existing trie. In case no trie
112         * exists yet, a new trie is generated and the function behaves like
113         * {@link #train(Collection)}.
114         * </p>
115         *
116         * @param sequences
117         *            training data
118         * @throws InvalidParameterException
119         *             thrown is sequences is null
120         */
121        public void update(Collection<List<? extends Event<?>>> sequences) {
122                if (sequences == null) {
123                        throw new InvalidParameterException("sequences must not be null");
124                }
125                if (trie == null) {
126                        trie = new Trie<Event<?>>();
127                }
128                for (List<? extends Event<?>> sequence : sequences) {
129                        List<Event<?>> currentSequence = new LinkedList<Event<?>>(sequence); // defensive
130                                                                                                                                                                        // copy
131                        currentSequence.add(0, Event.STARTEVENT);
132                        currentSequence.add(Event.ENDEVENT);
133
134                        trie.train(currentSequence, trieOrder);
135                }
136        }
137
138        /*
139         * (non-Javadoc)
140         *
141         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#randomSequence()
142         */
143        @Override
144        public List<? extends Event<?>> randomSequence() {
145                return randomSequence(Integer.MAX_VALUE, true);
146        }
147
148        /*
149         * (non-Javadoc)
150         *
151         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#randomSequence()
152         */
153        @Override
154        public List<? extends Event<?>> randomSequence(int maxLength,
155                        boolean validEnd) {
156                List<Event<?>> sequence = new LinkedList<Event<?>>();
157                if (trie != null) {
158                        boolean endFound = false;
159                        while (!endFound) { // outer loop for length checking
160                                sequence = new LinkedList<Event<?>>();
161                                IncompleteMemory<Event<?>> context = new IncompleteMemory<Event<?>>(
162                                                trieOrder - 1);
163                                context.add(Event.STARTEVENT);
164
165                                while (!endFound && sequence.size() <= maxLength) {
166                                        double randVal = r.nextDouble();
167                                        double probSum = 0.0;
168                                        List<Event<?>> currentContext = context.getLast(trieOrder);
169                                        for (Event<?> symbol : trie.getKnownSymbols()) {
170                                                probSum += getProbability(currentContext, symbol);
171                                                if (probSum >= randVal) {
172                                                        if (!(Event.STARTEVENT.equals(symbol) || Event.ENDEVENT
173                                                                        .equals(symbol))) {
174                                                                // only add the symbol the sequence if it is not
175                                                                // START or END
176                                                                context.add(symbol);
177                                                                sequence.add(symbol);
178                                                        }
179                                                        endFound = (Event.ENDEVENT.equals(symbol))
180                                                                        || (!validEnd && sequence.size() == maxLength);
181                                                        break;
182                                                }
183                                        }
184                                }
185                        }
186                }
187                return sequence;
188        }
189
190        /**
191         * <p>
192         * Returns a Dot representation of the internal trie.
193         * </p>
194         *
195         * @return dot representation of the internal trie
196         */
197        public String getTrieDotRepresentation() {
198                if (trie == null) {
199                        return "";
200                } else {
201                        return trie.getDotRepresentation();
202                }
203        }
204
205        /**
206         * <p>
207         * Returns a {@link Tree} of the internal trie that can be used for
208         * visualization.
209         * </p>
210         *
211         * @return {@link Tree} depicting the internal trie
212         */
213        public Tree<TrieVertex, Edge> getTrieGraph() {
214                if (trie == null) {
215                        return null;
216                } else {
217                        return trie.getGraph();
218                }
219        }
220
221        /**
222         * <p>
223         * The string representation of the model is {@link Trie#toString()} of
224         * {@link #trie}.
225         * </p>
226         *
227         * @see java.lang.Object#toString()
228         */
229        @Override
230        public String toString() {
231                if (trie == null) {
232                        return "";
233                } else {
234                        return trie.toString();
235                }
236        }
237
238        /*
239         * (non-Javadoc)
240         *
241         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getNumStates()
242         */
243        @Override
244        public int getNumSymbols() {
245                if (trie == null) {
246                        return 0;
247                } else {
248                        return trie.getNumSymbols();
249                }
250        }
251
252        /*
253         * (non-Javadoc)
254         *
255         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getStateStrings()
256         */
257        @Override
258        public String[] getSymbolStrings() {
259                if (trie == null) {
260                        return new String[0];
261                }
262                String[] stateStrings = new String[getNumSymbols()];
263                int i = 0;
264                for (Event<?> symbol : trie.getKnownSymbols()) {
265                        if (symbol.toString() == null) {
266                                stateStrings[i] = "null";
267                        } else {
268                                stateStrings[i] = symbol.toString();
269                        }
270                        i++;
271                }
272                return stateStrings;
273        }
274
275        /*
276         * (non-Javadoc)
277         *
278         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getEvents()
279         */
280        @Override
281        public Collection<? extends Event<?>> getEvents() {
282                if (trie == null) {
283                        return new HashSet<Event<?>>();
284                } else {
285                        return trie.getKnownSymbols();
286                }
287        }
288
289        /*
290         * (non-Javadoc)
291         *
292         * @see
293         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateSequences(int)
294         */
295        @Override
296        public Collection<List<? extends Event<?>>> generateSequences(int length) {
297                return generateSequences(length, false);
298        }
299
300        /*
301         * (non-Javadoc)
302         *
303         * @see
304         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateSequences(int,
305         * boolean)
306         */
307        @Override
308        public Set<List<? extends Event<?>>> generateSequences(int length,
309                        boolean fromStart) {
310                Set<List<? extends Event<?>>> sequenceSet = new LinkedHashSet<List<? extends Event<?>>>();
311                if (length < 1) {
312                        throw new InvalidParameterException(
313                                        "Length of generated subsequences must be at least 1.");
314                }
315                if (length == 1) {
316                        if (fromStart) {
317                                List<Event<?>> subSeq = new LinkedList<Event<?>>();
318                                subSeq.add(Event.STARTEVENT);
319                                sequenceSet.add(subSeq);
320                        } else {
321                                for (Event<?> event : getEvents()) {
322                                        List<Event<?>> subSeq = new LinkedList<Event<?>>();
323                                        subSeq.add(event);
324                                        sequenceSet.add(subSeq);
325                                }
326                        }
327                        return sequenceSet;
328                }
329                Collection<? extends Event<?>> events = getEvents();
330                Collection<List<? extends Event<?>>> seqsShorter = generateSequences(
331                                length - 1, fromStart);
332                for (Event<?> event : events) {
333                        for (List<? extends Event<?>> seqShorter : seqsShorter) {
334                                Event<?> lastEvent = event;
335                                if (getProbability(seqShorter, lastEvent) > 0.0) {
336                                        List<Event<?>> subSeq = new ArrayList<Event<?>>(seqShorter);
337                                        subSeq.add(lastEvent);
338                                        sequenceSet.add(subSeq);
339                                }
340                        }
341                }
342                return sequenceSet;
343        }
344
345        /*
346         * (non-Javadoc)
347         *
348         * @see
349         * de.ugoe.cs.eventbench.models.IStochasticProcess#generateValidSequences
350         * (int)
351         */
352        @Override
353        public Collection<List<? extends Event<?>>> generateValidSequences(
354                        int length) {
355                // check for min-length implicitly done by generateSequences
356                Collection<List<? extends Event<?>>> allSequences = generateSequences(
357                                length, true);
358                Collection<List<? extends Event<?>>> validSequences = new LinkedHashSet<List<? extends Event<?>>>();
359                for (List<? extends Event<?>> sequence : allSequences) {
360                        if (sequence.size() == length
361                                        && Event.ENDEVENT.equals(sequence.get(sequence.size() - 1))) {
362                                validSequences.add(sequence);
363                        }
364                }
365                return validSequences;
366        }
367
368        /*
369         * (non-Javadoc)
370         *
371         * @see
372         * de.ugoe.cs.eventbench.models.IStochasticProcess#getProbability(java.util
373         * .List)
374         */
375        @Override
376        public double getProbability(List<? extends Event<?>> sequence) {
377                if (sequence == null) {
378                        throw new InvalidParameterException("sequence must not be null");
379                }
380                double prob = 1.0;
381                List<Event<?>> context = new LinkedList<Event<?>>();
382                for (Event<?> event : sequence) {
383                        prob *= getProbability(context, event);
384                        context.add(event);
385                }
386                return prob;
387        }
388
389        /*
390         * (non-Javadoc)
391         *
392         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getNumFOMStates()
393         */
394        @Override
395        public int getNumFOMStates() {
396                if (trie == null) {
397                        return 0;
398                } else {
399                        return trie.getNumLeafAncestors();
400                }
401        }
402
403        /*
404         * (non-Javadoc)
405         *
406         * @see de.ugoe.cs.eventbench.models.IStochasticProcess#getNumTransitions()
407         */
408        @Override
409        public int getNumTransitions() {
410                if (trie == null) {
411                        return 0;
412                } else {
413                        return trie.getNumLeafs();
414                }
415        }
416}
Note: See TracBrowser for help on using the repository browser.