Ignore:
Timestamp:
04/11/11 15:42:42 (14 years ago)
Author:
sherbold
Message:
  • Cleanup of PPM and Trie
Location:
trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm/PredictionByPartialMatch.java

    r5 r6  
    6363                        trie.add(sequence.subList(sequenceLength-j, sequenceLength)); 
    6464                } 
    65         } 
    66          
    67         public void testStuff() { 
    68                 // basically an inline unit test without assertions but manual observation 
    69                 List<String> list = new ArrayList<String>(); 
    70                 list.add(initialSymbol); 
    71                 list.add("a"); 
    72                 list.add("b"); 
    73                 list.add("r"); 
    74                 list.add("a"); 
    75                 list.add("c"); 
    76                 list.add("a"); 
    77                 list.add("d"); 
    78                 list.add("a"); 
    79                 list.add("b"); 
    80                 list.add("r"); 
    81                 list.add("a"); 
    82                 list.add(endSymbol); 
    83                  
    84                 PredictionByPartialMatch model = new PredictionByPartialMatch(); 
    85                 model.trie = new Trie<String>(); 
    86                 model.trainStringTrie(list); 
    87                 model.trie.display(); 
    88                 Console.println("------------------------"); 
    89                 model.randomSequence();/* 
    90                 Console.println("------------------------"); 
    91                 model.randomSequence(); 
    92                 Console.println("------------------------"); 
    93                 model.randomSequence(); 
    94                 Console.println("------------------------");*/ 
    95                  
    96                 List<String> context = new ArrayList<String>(); 
    97                 String symbol = "a"; 
    98                 // expected: 5 
    99                 Console.traceln(""+model.trie.getCount(context, symbol)); 
    100                  
    101                 // expected: 0 
    102                 context.add("b"); 
    103                 Console.traceln(""+model.trie.getCount(context, symbol)); 
    104                  
    105                 // expected: 2 
    106                 context.add("r"); 
    107                 Console.traceln(""+model.trie.getCount(context, symbol)); 
    108                  
    109                 // exptected: [b, r] 
    110                 context = new ArrayList<String>(); 
    111                 context.add("a"); 
    112                 context.add("b"); 
    113                 context.add("r"); 
    114                 Console.traceln(model.trie.getContextSuffix(context).toString()); 
    115                  
    116                 // exptected: [] 
    117                 context = new ArrayList<String>(); 
    118                 context.add("e"); 
    119                 Console.traceln(model.trie.getContextSuffix(context).toString()); 
    120                  
    121                 // exptected: {a, b, c, d, r} 
    122                 context = new ArrayList<String>(); 
    123                 Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
    124                  
    125                 // exptected: {b, c, d} 
    126                 context = new ArrayList<String>(); 
    127                 context.add("a"); 
    128                 Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
    129                  
    130                 // exptected: [] 
    131                 context = new ArrayList<String>(); 
    132                 context.add("a"); 
    133                 context.add("b"); 
    134                 context.add("r"); 
    135                 Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
    13665        } 
    13766         
     
    16392                return sequence; 
    16493        } 
    165          
    166         /*public void printRandomWalk(Random r) { 
    167                 IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1); 
    168                  
    169                 context.add(initialSymbol); 
    170                  
    171                 String currentState = initialSymbol; 
    172                  
    173                 Console.println(currentState); 
    174                 while(!endSymbol.equals(currentState)) { 
    175                         double randVal = r.nextDouble(); 
    176                         double probSum = 0.0; 
    177                         List<String> currentContext = context.getLast(maxOrder); 
    178                         // DEBUG // 
    179                         Console.traceln("Context: " + currentContext.toString()); 
    180                         double tmpSum = 0.0d; 
    181                         for( String symbol : knownSymbols ) { 
    182                                 double prob = getProbability(currentContext, symbol); 
    183                                 tmpSum += prob; 
    184                                 Console.traceln(symbol + ": " + prob); 
    185                         } 
    186                         Console.traceln("Sum: " + tmpSum); 
    187                         // DEBUG-END // 
    188                         for( String symbol : knownSymbols ) { 
    189                                 probSum += getProbability(currentContext, symbol); 
    190                                 if( probSum>=randVal-0.3 ) { 
    191                                         context.add(symbol); 
    192                                         currentState = symbol; 
    193                                         Console.println(currentState); 
    194                                         break; 
    195                                 } 
    196                         } 
    197                 } 
    198         }*/ 
    199          
     94                 
    20095        private double getProbability(List<String> context, String symbol) { 
    201                 // FIXME needs exception handling for unknown symbols 
    202                 // if the symbol is not contained in the trie, context.remove(0) will fail 
    20396                double result = 0.0d; 
    20497                double resultCurrentContex = 0.0d; 
     
    232125        } 
    233126         
    234         /* 
    235         private double getProbability(List<String> context, String symbol) { 
    236                 double result = 0.0;  
    237                 int countContextSymbol = 0; 
    238                 List<String> contextSuffix = trie.getContextSuffix(context); 
    239                 if( contextSuffix.isEmpty() ) { 
    240                         // unobserved context! everything is possible... assuming identical distribution 
    241                         result = 1.0d / knownSymbols.size(); // why 1.0 and not N(symbol) 
    242                 } else { 
    243                         countContextSymbol = trie.getCount(contextSuffix, symbol); 
    244                         List<String> followers = trie.getFollowingSymbols(contextSuffix); 
    245                         int countContextFollowers = 0; 
    246                         for( String follower : followers ) { 
    247                                 countContextFollowers += trie.getCount(contextSuffix, follower); 
    248                         } 
    249                          
    250                         if( followers.isEmpty() ) { 
    251                                 throw new AssertionError("Invalid return value of trie.getContextSuffix()!"); 
    252                         } 
    253                         if( countContextSymbol!=0 ) { 
    254                                 result = ((double) countContextSymbol) / (followers.size()+countContextFollowers); 
    255                         } else { // escape 
    256                                 double probEscape = ((double) followers.size()) / (followers.size()+countContextFollowers); 
    257                                 contextSuffix.remove(0);  
    258                                 double probSuffix = getProbability(contextSuffix, symbol); 
    259                                 result = probEscape*probSuffix; 
    260                         } 
    261                 } 
    262  
    263                 return result; 
    264         }*/ 
    265          
    266127        @Override 
    267128        public String toString() { 
    268129                return trie.toString(); 
    269130        } 
     131         
     132        public void testStuff() { 
     133                // basically an inline unit test without assertions but manual observation 
     134                List<String> list = new ArrayList<String>(); 
     135                list.add(initialSymbol); 
     136                list.add("a"); 
     137                list.add("b"); 
     138                list.add("r"); 
     139                list.add("a"); 
     140                list.add("c"); 
     141                list.add("a"); 
     142                list.add("d"); 
     143                list.add("a"); 
     144                list.add("b"); 
     145                list.add("r"); 
     146                list.add("a"); 
     147                list.add(endSymbol); 
     148                 
     149                PredictionByPartialMatch model = new PredictionByPartialMatch(); 
     150                model.trie = new Trie<String>(); 
     151                model.trainStringTrie(list); 
     152                model.trie.display(); 
     153                Console.println("------------------------"); 
     154                model.randomSequence();/* 
     155                Console.println("------------------------"); 
     156                model.randomSequence(); 
     157                Console.println("------------------------"); 
     158                model.randomSequence(); 
     159                Console.println("------------------------");*/ 
     160                 
     161                List<String> context = new ArrayList<String>(); 
     162                String symbol = "a"; 
     163                // expected: 5 
     164                Console.traceln(""+model.trie.getCount(context, symbol)); 
     165                 
     166                // expected: 0 
     167                context.add("b"); 
     168                Console.traceln(""+model.trie.getCount(context, symbol)); 
     169                 
     170                // expected: 2 
     171                context.add("r"); 
     172                Console.traceln(""+model.trie.getCount(context, symbol)); 
     173                 
     174                // exptected: [b, r] 
     175                context = new ArrayList<String>(); 
     176                context.add("a"); 
     177                context.add("b"); 
     178                context.add("r"); 
     179                Console.traceln(model.trie.getContextSuffix(context).toString()); 
     180                 
     181                // exptected: [] 
     182                context = new ArrayList<String>(); 
     183                context.add("e"); 
     184                Console.traceln(model.trie.getContextSuffix(context).toString()); 
     185                 
     186                // exptected: {a, b, c, d, r} 
     187                context = new ArrayList<String>(); 
     188                Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
     189                 
     190                // exptected: {b, c, d} 
     191                context = new ArrayList<String>(); 
     192                context.add("a"); 
     193                Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
     194                 
     195                // exptected: [] 
     196                context = new ArrayList<String>(); 
     197                context.add("a"); 
     198                context.add("b"); 
     199                context.add("r"); 
     200                Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
     201                 
     202                // exptected: 0.0d 
     203                context = new ArrayList<String>(); 
     204                context.add("a"); 
     205                Console.traceln(""+model.getProbability(context, "z")); 
     206        } 
    270207} 
  • trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm/Trie.java

    r5 r6  
    2121public class Trie<T> { 
    2222         
    23         // Children of the Trie root 
    24         // should contain counts of all elements 
    25         private List<TrieNode<T>> children = new LinkedList<TrieNode<T>>(); 
     23        private final TrieNode<T> rootNode; 
     24         
     25        public Trie() { 
     26                rootNode = new TrieNode<T>(); 
     27        } 
    2628         
    2729 
     
    3032                        subsequence = new LinkedList<T>(subsequence);  // defensive copy! 
    3133                        T firstSymbol = subsequence.get(0); 
    32                         getChildCreate(firstSymbol).add(subsequence); 
     34                        TrieNode<T> node = getChildCreate(firstSymbol); 
     35                        node.add(subsequence); 
    3336                } 
    3437        } 
    3538 
    36         // FIXME clones of TrieNode.getChildCreate 
    3739        protected TrieNode<T>  getChildCreate(T symbol) { 
    38                 TrieNode<T> node = getChild(symbol); 
    39                 if( node==null ) { 
    40                         node = new TrieNode<T>(symbol); 
    41                         children.add(node); 
    42                 } 
    43                 return node; 
     40                return rootNode.getChildCreate(symbol); 
    4441        } 
    4542         
    46         // FIXME clones of TrieNode.getChild 
    4743        protected TrieNode<T> getChild(T symbol) { 
    48                 for( TrieNode<T> child : children ) { 
    49                         if( child.getSymbol().equals(symbol) ) { 
    50                                 return child; 
    51                         } 
    52                 } 
    53                 return null; 
     44                return rootNode.getChild(symbol); 
    5445        } 
    5546 
     
    7364         
    7465        public TrieNode<T> find(List<T> sequence) { 
     66                if( sequence==null || sequence.isEmpty() ) { 
     67                        return rootNode; 
     68                } 
    7569                List<T> sequenceCopy = new LinkedList<T>(sequence); 
    7670                TrieNode<T> result = null; 
    77                 if( !sequenceCopy.isEmpty() ) { 
    78                         TrieNode<T> node = getChild(sequenceCopy.get(0)); 
    79                         if( node!=null ) { 
    80                                 sequenceCopy.remove(0); 
    81                                 result = node.find(sequenceCopy); 
    82                         } 
     71                TrieNode<T> node = getChild(sequenceCopy.get(0)); 
     72                if( node!=null ) { 
     73                        sequenceCopy.remove(0); 
     74                        result = node.find(sequenceCopy); 
    8375                } 
    8476                return result; 
     
    8880        public List<T> getFollowingSymbols(List<T> sequence) { 
    8981                List<T> result = new LinkedList<T>(); 
    90                 if( sequence==null || sequence.isEmpty() ) { 
    91                         for( TrieNode<T> child : children ) { 
    92                                 result.add(child.getSymbol()); 
    93                         } 
    94                 } else { 
    95                         TrieNode<T> node = find(sequence); 
    96                         if( node!=null ) { 
    97                                 result = node.getFollowingSymbols(); 
    98                         } 
     82                TrieNode<T> node = find(sequence); 
     83                if( node!=null ) { 
     84                        result = node.getFollowingSymbols(); 
    9985                } 
    10086                return result; 
     
    10288         
    10389        // longest suffix of context, that is contained in the tree and whose children are leaves 
     90        // possibly already deprecated 
    10491        public List<T> getContextSuffix(List<T> context) { 
    10592                List<T> contextSuffix = new LinkedList<T>(context); // defensive copy 
     
    140127        private Tree<TrieVertex, Edge> getGraph() { 
    141128                DelegateTree<TrieVertex, Edge> graph = new DelegateTree<TrieVertex, Edge>(); 
    142                  
    143                 TrieVertex root = new TrieVertex("root"); 
    144                 graph.addVertex(root); 
    145                                  
    146                 for( TrieNode<T> node : children ) { 
    147                         node.getGraph(root, graph); 
    148                 } 
    149                  
     129                rootNode.getGraph(null, graph); 
    150130                return graph; 
    151131        } 
     
    181161        @Override 
    182162        public String toString() { 
    183                 return children.toString(); 
     163                return rootNode.toString(); 
    184164        } 
    185165} 
  • trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm/TrieNode.java

    r5 r6  
    1717         
    1818        private List<TrieNode<T>> children; 
     19         
     20        TrieNode() { 
     21                this.symbol = null; 
     22                count = 0; 
     23                children = new LinkedList<TrieNode<T>>(); 
     24        } 
    1925         
    2026        public TrieNode(T symbol) { 
     
    102108 
    103109        public void getGraph(TrieVertex parent, DelegateTree<TrieVertex, Edge> graph) { 
    104                 TrieVertex vertex = new TrieVertex(getSymbol().toString()+"#"+getCount()); 
    105                 graph.addChild( new Edge() , parent, vertex ); 
     110                TrieVertex currentVertex; 
     111                if( symbol==null ){ 
     112                        currentVertex = new TrieVertex("root"); 
     113                        graph.addVertex(currentVertex); 
     114                } else { 
     115                        currentVertex = new TrieVertex(getSymbol().toString()+"#"+getCount()); 
     116                        graph.addChild( new Edge() , parent, currentVertex ); 
     117                } 
    106118                for( TrieNode<T> node : children ) { 
    107                         node.getGraph(vertex, graph); 
     119                        node.getGraph(currentVertex, graph); 
    108120                }                
    109121        } 
Note: See TracChangeset for help on using the changeset viewer.