Ignore:
Timestamp:
04/11/11 14:43:03 (13 years ago)
Author:
sherbold
Message:
  • major debugging of PPM and Trie. Results are now correct, but both classes need major refactorings
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/EventBenchCore/src/de/ugoe/cs/eventbench/ppm/PredictionByPartialMatch.java

    r1 r5  
    11package de.ugoe.cs.eventbench.ppm; 
    22 
     3import java.util.ArrayList; 
    34import java.util.LinkedHashSet; 
     5import java.util.LinkedList; 
    46import java.util.List; 
    57import java.util.Random; 
     
    1214public class PredictionByPartialMatch { 
    1315         
    14         private String initialSymbol = "GLOBALSTARTSTATE"; 
    15         private String endSymbol = "GLOBALENDSTATE"; 
     16        private String initialSymbol = "GS"; 
     17        private String endSymbol = "GE"; 
    1618         
    1719        private int maxOrder = 3; 
     
    2022         
    2123        private Set<String> knownSymbols; 
     24         
     25        private double probEscape = 0.2d; // TODO getter/setter - steering parameter! 
     26         
     27        private Random r = new Random(); // TODO should be defined in the constructor 
    2228         
    2329        // the training is basically the generation of the trie 
     
    2935                 
    3036                for(List<Event<?>> sequence : sequences) { 
    31                         IncompleteMemory<String> latestActions = new IncompleteMemory<String>(maxOrder); // TODO need to check if it should be maxOrder+1 
    32                         latestActions.add(initialSymbol); 
    33                         for(Event<?> currentAction : sequence) { 
    34                                 String currentId = currentAction.getStandardId(); 
    35                                 latestActions.add(currentId); 
    36                                 knownSymbols.add(currentId); 
    37                                 if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder 
    38                                         trie.add(latestActions.getLast(maxOrder)); 
    39                                 } 
    40                         } 
    41                         latestActions.add(endSymbol); 
    42                         if( latestActions.getLength()==maxOrder ) { // FIXME needs special case for sequences shorter than maxOrder 
     37                        List<String> stringSequence = new LinkedList<String>(); 
     38                        stringSequence.add(initialSymbol); 
     39                        for( Event<?> event : sequence ) { 
     40                                stringSequence.add(event.getStandardId()); 
     41                        } 
     42                        stringSequence.add(endSymbol); 
     43                         
     44                        trainStringTrie(stringSequence); 
     45                } 
     46        } 
     47         
     48        private void trainStringTrie(List<String> sequence) { 
     49                knownSymbols = new LinkedHashSet<String>();              
     50                IncompleteMemory<String> latestActions = new IncompleteMemory<String>(maxOrder); 
     51                int i=0; 
     52                for(String currentAction : sequence) { 
     53                        String currentId = currentAction; 
     54                        latestActions.add(currentId); 
     55                        knownSymbols.add(currentId); 
     56                        i++; 
     57                        if( i>=maxOrder ) { 
    4358                                trie.add(latestActions.getLast(maxOrder)); 
    4459                        } 
    4560                } 
    46         } 
    47          
    48         public void printRandomWalk(Random r) { 
     61                int sequenceLength = sequence.size(); 
     62                for( int j=maxOrder-1 ; j>0 ; j-- ) { 
     63                        trie.add(sequence.subList(sequenceLength-j, sequenceLength)); 
     64                } 
     65        } 
     66         
     67        public void testStuff() { 
     68                // basically an inline unit test without assertions but manual observation 
     69                List<String> list = new ArrayList<String>(); 
     70                list.add(initialSymbol); 
     71                list.add("a"); 
     72                list.add("b"); 
     73                list.add("r"); 
     74                list.add("a"); 
     75                list.add("c"); 
     76                list.add("a"); 
     77                list.add("d"); 
     78                list.add("a"); 
     79                list.add("b"); 
     80                list.add("r"); 
     81                list.add("a"); 
     82                list.add(endSymbol); 
     83                 
     84                PredictionByPartialMatch model = new PredictionByPartialMatch(); 
     85                model.trie = new Trie<String>(); 
     86                model.trainStringTrie(list); 
     87                model.trie.display(); 
     88                Console.println("------------------------"); 
     89                model.randomSequence();/* 
     90                Console.println("------------------------"); 
     91                model.randomSequence(); 
     92                Console.println("------------------------"); 
     93                model.randomSequence(); 
     94                Console.println("------------------------");*/ 
     95                 
     96                List<String> context = new ArrayList<String>(); 
     97                String symbol = "a"; 
     98                // expected: 5 
     99                Console.traceln(""+model.trie.getCount(context, symbol)); 
     100                 
     101                // expected: 0 
     102                context.add("b"); 
     103                Console.traceln(""+model.trie.getCount(context, symbol)); 
     104                 
     105                // expected: 2 
     106                context.add("r"); 
     107                Console.traceln(""+model.trie.getCount(context, symbol)); 
     108                 
     109                // exptected: [b, r] 
     110                context = new ArrayList<String>(); 
     111                context.add("a"); 
     112                context.add("b"); 
     113                context.add("r"); 
     114                Console.traceln(model.trie.getContextSuffix(context).toString()); 
     115                 
     116                // exptected: [] 
     117                context = new ArrayList<String>(); 
     118                context.add("e"); 
     119                Console.traceln(model.trie.getContextSuffix(context).toString()); 
     120                 
     121                // exptected: {a, b, c, d, r} 
     122                context = new ArrayList<String>(); 
     123                Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
     124                 
     125                // exptected: {b, c, d} 
     126                context = new ArrayList<String>(); 
     127                context.add("a"); 
     128                Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
     129                 
     130                // exptected: [] 
     131                context = new ArrayList<String>(); 
     132                context.add("a"); 
     133                context.add("b"); 
     134                context.add("r"); 
     135                Console.traceln(model.trie.getFollowingSymbols(context).toString()); 
     136        } 
     137         
     138        // TODO needs to be changed from String to <? extends Event> 
     139        public List<String> randomSequence() { 
     140                List<String> sequence = new LinkedList<String>(); 
     141                 
    49142                IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1); 
    50                  
    51143                context.add(initialSymbol); 
     144                sequence.add(initialSymbol); 
    52145                 
    53146                String currentState = initialSymbol; 
     
    61154                                probSum += getProbability(currentContext, symbol); 
    62155                                if( probSum>=randVal ) { 
    63                                         currentContext.add(symbol); 
     156                                        context.add(symbol); 
     157                                        currentState = symbol; 
     158                                        sequence.add(currentState); 
     159                                        break; 
     160                                } 
     161                        } 
     162                } 
     163                return sequence; 
     164        } 
     165         
     166        /*public void printRandomWalk(Random r) { 
     167                IncompleteMemory<String> context = new IncompleteMemory<String>(maxOrder-1); 
     168                 
     169                context.add(initialSymbol); 
     170                 
     171                String currentState = initialSymbol; 
     172                 
     173                Console.println(currentState); 
     174                while(!endSymbol.equals(currentState)) { 
     175                        double randVal = r.nextDouble(); 
     176                        double probSum = 0.0; 
     177                        List<String> currentContext = context.getLast(maxOrder); 
     178                        // DEBUG // 
     179                        Console.traceln("Context: " + currentContext.toString()); 
     180                        double tmpSum = 0.0d; 
     181                        for( String symbol : knownSymbols ) { 
     182                                double prob = getProbability(currentContext, symbol); 
     183                                tmpSum += prob; 
     184                                Console.traceln(symbol + ": " + prob); 
     185                        } 
     186                        Console.traceln("Sum: " + tmpSum); 
     187                        // DEBUG-END // 
     188                        for( String symbol : knownSymbols ) { 
     189                                probSum += getProbability(currentContext, symbol); 
     190                                if( probSum>=randVal-0.3 ) { 
     191                                        context.add(symbol); 
    64192                                        currentState = symbol; 
    65193                                        Console.println(currentState); 
     
    68196                        } 
    69197                } 
    70         } 
    71          
     198        }*/ 
     199         
     200        private double getProbability(List<String> context, String symbol) { 
     201                // FIXME needs exception handling for unknown symbols 
     202                // if the symbol is not contained in the trie, context.remove(0) will fail 
     203                double result = 0.0d; 
     204                double resultCurrentContex = 0.0d; 
     205                double resultShorterContex = 0.0d; 
     206                 
     207                List<String> contextCopy = new LinkedList<String>(context); // defensive copy 
     208 
     209         
     210                List<String> followers = trie.getFollowingSymbols(contextCopy); // \Sigma' 
     211                int sumCountFollowers = 0; // N(s\sigma') 
     212                for( String follower : followers ) { 
     213                        sumCountFollowers += trie.getCount(contextCopy, follower); 
     214                } 
     215                 
     216                int countSymbol = trie.getCount(contextCopy, symbol); // N(s\sigma) 
     217                if( contextCopy.size()==0 ) { 
     218                        resultCurrentContex = ((double) countSymbol) / sumCountFollowers; 
     219                } else { 
     220                        resultCurrentContex = ((double) countSymbol / sumCountFollowers)*(1-probEscape); 
     221                        contextCopy.remove(0);  
     222                        double probSuffix = getProbability(contextCopy, symbol); 
     223                        if( followers.size()==0 ) { 
     224                                resultShorterContex = probSuffix; 
     225                        } else { 
     226                                resultShorterContex = probEscape*probSuffix; 
     227                        } 
     228                } 
     229                result = resultCurrentContex+resultShorterContex; 
     230                 
     231                return result; 
     232        } 
     233         
     234        /* 
    72235        private double getProbability(List<String> context, String symbol) { 
    73236                double result = 0.0;  
     
    75238                List<String> contextSuffix = trie.getContextSuffix(context); 
    76239                if( contextSuffix.isEmpty() ) { 
    77                         result = 1.0d / knownSymbols.size();  
     240                        // unobserved context! everything is possible... assuming identical distribution 
     241                        result = 1.0d / knownSymbols.size(); // why 1.0 and not N(symbol) 
    78242                } else { 
    79243                        countContextSymbol = trie.getCount(contextSuffix, symbol); 
     
    85249                         
    86250                        if( followers.isEmpty() ) { 
    87                                 throw new AssertionError("Invalid return value of getContextSuffix!"); 
     251                                throw new AssertionError("Invalid return value of trie.getContextSuffix()!"); 
    88252                        } 
    89253                        if( countContextSymbol!=0 ) { 
     
    91255                        } else { // escape 
    92256                                double probEscape = ((double) followers.size()) / (followers.size()+countContextFollowers); 
    93                                 contextSuffix.remove(0); 
     257                                contextSuffix.remove(0);  
    94258                                double probSuffix = getProbability(contextSuffix, symbol); 
    95259                                result = probEscape*probSuffix; 
     
    98262 
    99263                return result; 
    100         } 
     264        }*/ 
    101265         
    102266        @Override 
Note: See TracChangeset for help on using the changeset viewer.