1 | package de.ugoe.cs.eventbench.models;
|
---|
2 |
|
---|
3 | import java.util.ArrayList;
|
---|
4 | import java.util.List;
|
---|
5 | import java.util.Random;
|
---|
6 |
|
---|
7 | import de.ugoe.cs.eventbench.data.Event;
|
---|
8 | import de.ugoe.cs.eventbench.markov.DotPrinter;
|
---|
9 | import de.ugoe.cs.util.console.Console;
|
---|
10 | import edu.uci.ics.jung.graph.Graph;
|
---|
11 | import edu.uci.ics.jung.graph.SparseMultigraph;
|
---|
12 | import edu.uci.ics.jung.graph.util.EdgeType;
|
---|
13 |
|
---|
14 | import Jama.Matrix;
|
---|
15 |
|
---|
16 | public class FirstOrderMarkovModel extends HighOrderMarkovModel implements DotPrinter {
|
---|
17 |
|
---|
18 | final static int MAX_STATDIST_ITERATIONS = 1000;
|
---|
19 |
|
---|
20 | public FirstOrderMarkovModel(Random r) {
|
---|
21 | super(1, r);
|
---|
22 | }
|
---|
23 |
|
---|
24 | private Matrix getTransmissionMatrix() {
|
---|
25 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
26 | int numStates = knownSymbols.size();
|
---|
27 | Matrix transmissionMatrix = new Matrix(numStates, numStates);
|
---|
28 |
|
---|
29 | for( int i=0 ; i<numStates ; i++ ) {
|
---|
30 | Event<?> currentSymbol = knownSymbols.get(i);
|
---|
31 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
32 | context.add(currentSymbol);
|
---|
33 | for( int j=0 ; j<numStates ; j++ ) {
|
---|
34 | Event<?> follower = knownSymbols.get(j);
|
---|
35 | double prob = getProbability(context, follower);
|
---|
36 | transmissionMatrix.set(i, j, prob);
|
---|
37 | }
|
---|
38 | }
|
---|
39 | return transmissionMatrix;
|
---|
40 | }
|
---|
41 |
|
---|
42 | public void printDot() {
|
---|
43 | Console.println("digraph model {");
|
---|
44 |
|
---|
45 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
46 |
|
---|
47 | for( Event<?> symbol : knownSymbols) {
|
---|
48 | final String thisSaneId = symbol.getShortId().replace("\"", "\\\"").replaceAll("[\r\n]","");
|
---|
49 | Console.println(" " + symbol.hashCode() + " [label=\""+thisSaneId+"\"];");
|
---|
50 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
51 | context.add(symbol);
|
---|
52 | List<Event<?>> followers = trie.getFollowingSymbols(context);
|
---|
53 | for( Event<?> follower : followers ) {
|
---|
54 | System.out.print(" "+symbol.hashCode()+" -> " + follower.hashCode() + " ");
|
---|
55 | System.out.println("[label=\"" + getProbability(context, follower) + "\"];");
|
---|
56 | }
|
---|
57 | }
|
---|
58 | System.out.println('}');
|
---|
59 | }
|
---|
60 |
|
---|
61 | public Graph<String, MarkovEdge> getGraph() {
|
---|
62 | Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
|
---|
63 |
|
---|
64 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
65 |
|
---|
66 | for( Event<?> symbol : knownSymbols) {
|
---|
67 | String from = symbol.getShortId();
|
---|
68 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
69 | context.add(symbol);
|
---|
70 |
|
---|
71 | List<Event<?>> followers = trie.getFollowingSymbols(context);
|
---|
72 |
|
---|
73 | for( Event<?> follower : followers ) {
|
---|
74 | String to = follower.getShortId();
|
---|
75 | MarkovEdge prob = new MarkovEdge(getProbability(context, follower));
|
---|
76 | graph.addEdge(prob, from, to, EdgeType.DIRECTED);
|
---|
77 | }
|
---|
78 | }
|
---|
79 | return graph;
|
---|
80 | }
|
---|
81 |
|
---|
82 | static public class MarkovEdge {
|
---|
83 | double weight;
|
---|
84 | MarkovEdge(double weight) { this.weight = weight; }
|
---|
85 | public String toString() { return ""+weight; }
|
---|
86 | }
|
---|
87 |
|
---|
88 | public double calcEntropy() {
|
---|
89 | Matrix transmissionMatrix = getTransmissionMatrix();
|
---|
90 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
91 | int numStates = knownSymbols.size();
|
---|
92 |
|
---|
93 | int startStateIndex = knownSymbols.indexOf(Event.STARTEVENT);
|
---|
94 | int endStateIndex = knownSymbols.indexOf(Event.ENDEVENT);
|
---|
95 | if( startStateIndex==-1 ) {
|
---|
96 | Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
|
---|
97 | return Double.NaN;
|
---|
98 | }
|
---|
99 | if( endStateIndex==-1 ) {
|
---|
100 | Console.printerrln("Error calculating entropy. End state of markov chain not found.");
|
---|
101 | return Double.NaN;
|
---|
102 | }
|
---|
103 | transmissionMatrix.set(endStateIndex, startStateIndex, 1);
|
---|
104 |
|
---|
105 | // Calculate stationary distribution by raising the power of the transmission matrix.
|
---|
106 | // The rank of the matrix should fall to 1 and each two should be the vector of the
|
---|
107 | // stationory distribution.
|
---|
108 | int iter = 0;
|
---|
109 | int rank = transmissionMatrix.rank();
|
---|
110 | Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
|
---|
111 | while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
|
---|
112 | stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
|
---|
113 | rank = stationaryMatrix.rank();
|
---|
114 | iter++;
|
---|
115 | }
|
---|
116 |
|
---|
117 | if( rank!=1 ) {
|
---|
118 | Console.traceln("rank: " + rank);
|
---|
119 | Console.printerrln("Unable to calculate stationary distribution.");
|
---|
120 | return Double.NaN;
|
---|
121 | }
|
---|
122 |
|
---|
123 | double entropy = 0.0;
|
---|
124 | for( int i=0 ; i<numStates ; i++ ) {
|
---|
125 | for( int j=0 ; j<numStates ; j++ ) {
|
---|
126 | if( transmissionMatrix.get(i,j)!=0 ) {
|
---|
127 | double tmp = stationaryMatrix.get(i, 0);
|
---|
128 | tmp *= transmissionMatrix.get(i, j);
|
---|
129 | tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
|
---|
130 | entropy -= tmp;
|
---|
131 | }
|
---|
132 | }
|
---|
133 | }
|
---|
134 | return entropy;
|
---|
135 | }
|
---|
136 |
|
---|
137 | }
|
---|