source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/models/Trie.java @ 397

Last change on this file since 397 was 397, checked in by sherbold, 12 years ago
  • implemented equals and hashCode for de.ugoe.cs.eventbench.models.Trie and de.ugoe.cs.eventbench.models.TrieNode?
File size: 11.0 KB
Line 
1package de.ugoe.cs.eventbench.models;
2
3import java.io.Serializable;
4import java.security.InvalidParameterException;
5import java.util.Collection;
6import java.util.HashSet;
7import java.util.LinkedHashSet;
8import java.util.LinkedList;
9import java.util.List;
10import java.util.Set;
11
12import de.ugoe.cs.util.StringTools;
13
14import edu.uci.ics.jung.graph.DelegateTree;
15import edu.uci.ics.jung.graph.Graph;
16import edu.uci.ics.jung.graph.Tree;
17
18/**
19 * <p>
20 * This class implements a <it>trie</it>, i.e., a tree of sequences that the
21 * occurence of subsequences up to a predefined length. This length is the trie
22 * order.
23 * </p>
24 *
25 * @author Steffen Herbold
26 *
27 * @param <T>
28 *            Type of the symbols that are stored in the trie.
29 *
30 * @see TrieNode
31 */
32public class Trie<T> implements IDotCompatible, Serializable {
33
34        /**
35         * <p>
36         * Id for object serialization.
37         * </p>
38         */
39        private static final long serialVersionUID = 1L;
40
41        /**
42         * <p>
43         * Collection of all symbols occuring in the trie.
44         * </p>
45         */
46        private Collection<T> knownSymbols;
47
48        /**
49         * <p>
50         * Reference to the root of the trie.
51         * </p>
52         */
53        private final TrieNode<T> rootNode;
54
55        /**
56         * <p>
57         * Contructor. Creates a new Trie.
58         * </p>
59         */
60        public Trie() {
61                rootNode = new TrieNode<T>();
62                knownSymbols = new LinkedHashSet<T>();
63        }
64
65        /**
66         * <p>
67         * Copy-Constructor. Creates a new Trie as the copy of other. The other trie
68         * must noch be null.
69         * </p>
70         *
71         * @param other
72         *            Trie that is copied
73         */
74        public Trie(Trie<T> other) {
75                if (other == null) {
76                        throw new InvalidParameterException("other trie must not be null");
77                }
78                rootNode = new TrieNode<T>(other.rootNode);
79                knownSymbols = new LinkedHashSet<T>(other.knownSymbols);
80        }
81
82        /**
83         * <p>
84         * Returns a collection of all symbols occuring in the trie.
85         * </p>
86         *
87         * @return symbols occuring in the trie
88         */
89        public Collection<T> getKnownSymbols() {
90                return new LinkedHashSet<T>(knownSymbols);
91        }
92
93        /**
94         * <p>
95         * Trains the current trie using the given sequence and adds all subsequence
96         * of length {@code maxOrder}.
97         * </p>
98         *
99         * @param sequence
100         *            sequence whose subsequences are added to the trie
101         * @param maxOrder
102         *            maximum length of the subsequences added to the trie
103         */
104        public void train(List<T> sequence, int maxOrder) {
105                if (maxOrder < 1) {
106                        return;
107                }
108                IncompleteMemory<T> latestActions = new IncompleteMemory<T>(maxOrder);
109                int i = 0;
110                for (T currentEvent : sequence) {
111                        latestActions.add(currentEvent);
112                        knownSymbols.add(currentEvent);
113                        i++;
114                        if (i >= maxOrder) {
115                                add(latestActions.getLast(maxOrder));
116                        }
117                }
118                int sequenceLength = sequence.size();
119                for (int j = maxOrder - 1; j > 0; j--) {
120                        add(sequence.subList(sequenceLength - j, sequenceLength));
121                }
122        }
123
124        /**
125         * <p>
126         * Adds a given subsequence to the trie and increases the counters
127         * accordingly.
128         * </p>
129         *
130         * @param subsequence
131         *            subsequence whose counters are increased
132         * @see TrieNode#add(List)
133         */
134        protected void add(List<T> subsequence) {
135                if (subsequence != null && !subsequence.isEmpty()) {
136                        knownSymbols.addAll(subsequence);
137                        subsequence = new LinkedList<T>(subsequence); // defensive copy!
138                        T firstSymbol = subsequence.get(0);
139                        TrieNode<T> node = getChildCreate(firstSymbol);
140                        node.add(subsequence);
141                }
142        }
143
144        /**
145         * <p>
146         * Returns the child of the root node associated with the given symbol or
147         * creates it if it does not exist yet.
148         * </p>
149         *
150         * @param symbol
151         *            symbol whose node is required
152         * @return node associated with the symbol
153         * @see TrieNode#getChildCreate(Object)
154         */
155        protected TrieNode<T> getChildCreate(T symbol) {
156                return rootNode.getChildCreate(symbol);
157        }
158
159        /**
160         * <p>
161         * Returns the child of the root node associated with the given symbol or
162         * null if it does not exist.
163         * </p>
164         *
165         * @param symbol
166         *            symbol whose node is required
167         * @return node associated with the symbol; null if no such node exists
168         * @see TrieNode#getChild(Object)
169         */
170        protected TrieNode<T> getChild(T symbol) {
171                return rootNode.getChild(symbol);
172        }
173
174        /**
175         * <p>
176         * Returns the number of occurences of the given sequence.
177         * </p>
178         *
179         * @param sequence
180         *            sequence whose number of occurences is required
181         * @return number of occurences of the sequence
182         */
183        public int getCount(List<T> sequence) {
184                int count = 0;
185                TrieNode<T> node = find(sequence);
186                if (node != null) {
187                        count = node.getCount();
188                }
189                return count;
190        }
191
192        /**
193         * <p>
194         * Returns the number of occurences of the given prefix and a symbol that
195         * follows it.<br>
196         * Convenience function to simplify usage of {@link #getCount(List)}.
197         * </p>
198         *
199         * @param sequence
200         *            prefix of the sequence
201         * @param follower
202         *            suffix of the sequence
203         * @return number of occurences of the sequence
204         * @see #getCount(List)
205         */
206        public int getCount(List<T> sequence, T follower) {
207                List<T> tmpSequence = new LinkedList<T>(sequence);
208                tmpSequence.add(follower);
209                return getCount(tmpSequence);
210
211        }
212
213        /**
214         * <p>
215         * Searches the trie for a given sequence and returns the node associated
216         * with the sequence or null if no such node is found.
217         * </p>
218         *
219         * @param sequence
220         *            sequence that is searched for
221         * @return node associated with the sequence
222         * @see TrieNode#find(List)
223         */
224        public TrieNode<T> find(List<T> sequence) {
225                if (sequence == null || sequence.isEmpty()) {
226                        return rootNode;
227                }
228                List<T> sequenceCopy = new LinkedList<T>(sequence);
229                TrieNode<T> result = null;
230                TrieNode<T> node = getChild(sequenceCopy.get(0));
231                if (node != null) {
232                        sequenceCopy.remove(0);
233                        result = node.find(sequenceCopy);
234                }
235                return result;
236        }
237
238        /**
239         * <p>
240         * Returns a collection of all symbols that follow a given sequence in the
241         * trie. In case the sequence is not found or no symbols follow the sequence
242         * the result will be empty.
243         * </p>
244         *
245         * @param sequence
246         *            sequence whose followers are returned
247         * @return symbols following the given sequence
248         * @see TrieNode#getFollowingSymbols()
249         */
250        public Collection<T> getFollowingSymbols(List<T> sequence) {
251                Collection<T> result = new LinkedList<T>();
252                TrieNode<T> node = find(sequence);
253                if (node != null) {
254                        result = node.getFollowingSymbols();
255                }
256                return result;
257        }
258
259        /**
260         * <p>
261         * Returns the longest suffix of the given context that is contained in the
262         * tree and whose children are leaves.
263         * </p>
264         *
265         * @param context
266         *            context whose suffix is searched for
267         * @return longest suffix of the context
268         */
269        public List<T> getContextSuffix(List<T> context) {
270                List<T> contextSuffix;
271                if (context != null) {
272                        contextSuffix = new LinkedList<T>(context); // defensive copy
273                } else {
274                        contextSuffix = new LinkedList<T>();
275                }
276                boolean suffixFound = false;
277
278                while (!suffixFound) {
279                        if (contextSuffix.isEmpty()) {
280                                suffixFound = true; // suffix is the empty word
281                        } else {
282                                TrieNode<T> node = find(contextSuffix);
283                                if (node != null) {
284                                        if (!node.getFollowingSymbols().isEmpty()) {
285                                                suffixFound = true;
286                                        }
287                                }
288                                if (!suffixFound) {
289                                        contextSuffix.remove(0);
290                                }
291                        }
292                }
293
294                return contextSuffix;
295        }
296
297        /**
298         * <p>
299         * Helper class for graph visualization of a trie.
300         * </p>
301         *
302         * @author Steffen Herbold
303         * @version 1.0
304         */
305        static public class Edge {
306        }
307
308        /**
309         * <p>
310         * Helper class for graph visualization of a trie.
311         * </p>
312         *
313         * @author Steffen Herbold
314         * @version 1.0
315         */
316        static public class TrieVertex {
317
318                /**
319                 * <p>
320                 * Id of the vertex.
321                 * </p>
322                 */
323                private String id;
324
325                /**
326                 * <p>
327                 * Contructor. Creates a new TrieVertex.
328                 * </p>
329                 *
330                 * @param id
331                 *            id of the vertex
332                 */
333                protected TrieVertex(String id) {
334                        this.id = id;
335                }
336
337                /**
338                 * <p>
339                 * Returns the id of the vertex.
340                 * </p>
341                 *
342                 * @see java.lang.Object#toString()
343                 */
344                @Override
345                public String toString() {
346                        return id;
347                }
348        }
349
350        /**
351         * <p>
352         * Returns a {@link Graph} representation of the trie.
353         * </p>
354         *
355         * @return {@link Graph} representation of the trie
356         */
357        protected Tree<TrieVertex, Edge> getGraph() {
358                DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>();
359                rootNode.getGraph(null, graph);
360                return graph;
361        }
362
363        /*
364         * (non-Javadoc)
365         *
366         * @see de.ugoe.cs.eventbench.models.IDotCompatible#getDotRepresentation()
367         */
368        public String getDotRepresentation() {
369                StringBuilder stringBuilder = new StringBuilder();
370                stringBuilder.append("digraph model {" + StringTools.ENDLINE);
371                rootNode.appendDotRepresentation(stringBuilder);
372                stringBuilder.append('}' + StringTools.ENDLINE);
373                return stringBuilder.toString();
374        }
375
376        /**
377         * <p>
378         * Returns the string representation of the root node.
379         * </p>
380         *
381         * @see TrieNode#toString()
382         * @see java.lang.Object#toString()
383         */
384        @Override
385        public String toString() {
386                return rootNode.toString();
387        }
388
389        /**
390         * <p>
391         * Returns the number of symbols contained in the trie.
392         * </p>
393         *
394         * @return number of symbols contained in the trie
395         */
396        public int getNumSymbols() {
397                return knownSymbols.size();
398        }
399
400        /**
401         * <p>
402         * Returns the number of trie nodes that are ancestors of a leaf. This is
403         * the equivalent to the number of states a first-order markov model would
404         * have.
405         * <p>
406         *
407         * @return number of trie nodes that are ancestors of leafs.
408         */
409        public int getNumLeafAncestors() {
410                Set<TrieNode<T>> ancestors = new HashSet<TrieNode<T>>();
411                rootNode.getLeafAncestors(ancestors);
412                return ancestors.size();
413        }
414
415        /**
416         * <p>
417         * Returns the number of trie nodes that are leafs.
418         * </p>
419         *
420         * @return number of leafs in the trie
421         */
422        public int getNumLeafs() {
423                return rootNode.getNumLeafs();
424        }
425
426        /**
427         * <p>
428         * Updates the list of known symbols by replacing it with all symbols that
429         * are found in the child nodes of the root node. This should be the same as
430         * all symbols that are contained in the trie.
431         * </p>
432         */
433        public void updateKnownSymbols() {
434                knownSymbols = new HashSet<T>();
435                for (TrieNode<T> node : rootNode.getChildren()) {
436                        knownSymbols.add(node.getSymbol());
437                }
438        }
439
440        /**
441         * <p>
442         * Two Tries are defined as equal, if their {@link #rootNode} are equal.
443         * </p>
444         *
445         * @see java.lang.Object#equals(java.lang.Object)
446         */
447        @SuppressWarnings("rawtypes")
448        @Override
449        public boolean equals(Object other) {
450                if (other == this) {
451                        return true;
452                }
453                if (other instanceof Trie) {
454                        return rootNode.equals(((Trie) other).rootNode);
455                }
456                return false;
457        }
458       
459        /*
460         * (non-Javadoc)
461         *
462         * @see java.lang.Object#hashCode()
463         */
464        @Override
465        public int hashCode() {
466                int multiplier = 17;
467                int hash = 42;
468                if (rootNode != null) {
469                        hash = multiplier * hash + rootNode.hashCode();
470                }
471                return hash;
472        }
473}
Note: See TracBrowser for help on using the repository browser.