source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/models/FirstOrderMarkovModel.java @ 16

Last change on this file since 16 was 16, checked in by sherbold, 13 years ago

+ added class de.ugoe.cs.eventbench.models.FirstOrderMarkovModel? to replace de.ugoe.cs.eventbench.markov.MarkovModel?

  • Property svn:mime-type set to text/plain
File size: 4.5 KB
Line 
1package de.ugoe.cs.eventbench.models;
2
3import java.util.ArrayList;
4import java.util.List;
5import java.util.Random;
6
7import de.ugoe.cs.eventbench.data.Event;
8import de.ugoe.cs.eventbench.markov.DotPrinter;
9import de.ugoe.cs.util.console.Console;
10import edu.uci.ics.jung.graph.Graph;
11import edu.uci.ics.jung.graph.SparseMultigraph;
12import edu.uci.ics.jung.graph.util.EdgeType;
13
14import Jama.Matrix;
15
16public class FirstOrderMarkovModel extends HighOrderMarkovModel implements DotPrinter {
17
18        final static int MAX_STATDIST_ITERATIONS = 1000;
19       
20        public FirstOrderMarkovModel(Random r) {
21                super(1, r);
22        }
23       
24        private Matrix getTransmissionMatrix() {
25                List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
26                int numStates = knownSymbols.size();
27                Matrix transmissionMatrix = new Matrix(numStates, numStates);
28               
29                for( int i=0 ; i<numStates ; i++ ) {
30                        Event<?> currentSymbol = knownSymbols.get(i);
31                        List<Event<?>> context = new ArrayList<Event<?>>();
32                        context.add(currentSymbol);
33                        for( int j=0 ; j<numStates ; j++ ) {
34                                Event<?> follower = knownSymbols.get(j);
35                                double prob = getProbability(context, follower);
36                                transmissionMatrix.set(i, j, prob);
37                        }
38                }
39                return transmissionMatrix;
40        }
41       
42        public void printDot() {
43                Console.println("digraph model {");
44
45                List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
46               
47                for( Event<?> symbol : knownSymbols) {
48                        final String thisSaneId = symbol.getShortId().replace("\"", "\\\"").replaceAll("[\r\n]","");
49                        Console.println(" " + symbol.hashCode() + " [label=\""+thisSaneId+"\"];");
50                        List<Event<?>> context = new ArrayList<Event<?>>();
51                        context.add(symbol);
52                        List<Event<?>> followers = trie.getFollowingSymbols(context);
53                        for( Event<?> follower : followers ) {
54                                System.out.print(" "+symbol.hashCode()+" -> " + follower.hashCode() + " ");
55                                System.out.println("[label=\"" + getProbability(context, follower) + "\"];");
56                        }
57                }
58                System.out.println('}');
59        }
60       
61        public Graph<String, MarkovEdge> getGraph() {
62                Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
63               
64                List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
65               
66                for( Event<?> symbol : knownSymbols) {
67                        String from = symbol.getShortId();
68                        List<Event<?>> context = new ArrayList<Event<?>>();
69                        context.add(symbol);
70                       
71                        List<Event<?>> followers = trie.getFollowingSymbols(context);
72                       
73                        for( Event<?> follower : followers ) {
74                                String to = follower.getShortId();
75                                MarkovEdge prob = new MarkovEdge(getProbability(context, follower));
76                                graph.addEdge(prob, from, to, EdgeType.DIRECTED);
77                        }
78                }
79                return graph;
80        }
81       
82        static public class MarkovEdge {
83                double weight;
84                MarkovEdge(double weight) { this.weight = weight; }
85                public String toString() { return ""+weight; }
86        }
87       
88        public double calcEntropy() {
89                Matrix transmissionMatrix = getTransmissionMatrix();
90                List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
91                int numStates = knownSymbols.size();
92               
93                int startStateIndex = knownSymbols.indexOf(Event.STARTEVENT);
94                int endStateIndex = knownSymbols.indexOf(Event.ENDEVENT);
95                if( startStateIndex==-1 ) {
96                        Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
97                        return Double.NaN;
98                }
99                if( endStateIndex==-1 ) {
100                        Console.printerrln("Error calculating entropy. End state of markov chain not found.");
101                        return Double.NaN;
102                }
103                transmissionMatrix.set(endStateIndex, startStateIndex, 1);
104               
105                // Calculate stationary distribution by raising the power of the transmission matrix.
106                // The rank of the matrix should fall to 1 and each two should be the vector of the
107                // stationory distribution.
108                int iter = 0;
109                int rank = transmissionMatrix.rank();
110                Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
111                while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
112                        stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
113                        rank = stationaryMatrix.rank();
114                        iter++;
115                }
116               
117                if( rank!=1 ) {
118                        Console.traceln("rank: " + rank);
119                        Console.printerrln("Unable to calculate stationary distribution.");
120                        return Double.NaN;
121                }
122               
123                double entropy = 0.0;
124                for( int i=0 ; i<numStates ; i++ ) {
125                        for( int j=0 ; j<numStates ; j++ ) {
126                                if( transmissionMatrix.get(i,j)!=0 ) {
127                                        double tmp = stationaryMatrix.get(i, 0);
128                                        tmp *= transmissionMatrix.get(i, j);
129                                        tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
130                                        entropy -= tmp;
131                                }
132                        }
133                }
134                return entropy;
135        }
136
137}
Note: See TracBrowser for help on using the repository browser.