1 | package de.ugoe.cs.eventbench.models;
|
---|
2 |
|
---|
3 | import java.util.ArrayList;
|
---|
4 | import java.util.List;
|
---|
5 | import java.util.Random;
|
---|
6 |
|
---|
7 | import de.ugoe.cs.eventbench.data.Event;
|
---|
8 | import de.ugoe.cs.util.console.Console;
|
---|
9 | import edu.uci.ics.jung.graph.Graph;
|
---|
10 | import edu.uci.ics.jung.graph.SparseMultigraph;
|
---|
11 | import edu.uci.ics.jung.graph.util.EdgeType;
|
---|
12 |
|
---|
13 | import Jama.Matrix;
|
---|
14 |
|
---|
15 | public class FirstOrderMarkovModel extends HighOrderMarkovModel implements DotPrinter {
|
---|
16 |
|
---|
17 | final static int MAX_STATDIST_ITERATIONS = 1000;
|
---|
18 |
|
---|
19 | public FirstOrderMarkovModel(Random r) {
|
---|
20 | super(1, r);
|
---|
21 | }
|
---|
22 |
|
---|
23 | private Matrix getTransmissionMatrix() {
|
---|
24 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
25 | int numStates = knownSymbols.size();
|
---|
26 | Matrix transmissionMatrix = new Matrix(numStates, numStates);
|
---|
27 |
|
---|
28 | for( int i=0 ; i<numStates ; i++ ) {
|
---|
29 | Event<?> currentSymbol = knownSymbols.get(i);
|
---|
30 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
31 | context.add(currentSymbol);
|
---|
32 | for( int j=0 ; j<numStates ; j++ ) {
|
---|
33 | Event<?> follower = knownSymbols.get(j);
|
---|
34 | double prob = getProbability(context, follower);
|
---|
35 | transmissionMatrix.set(i, j, prob);
|
---|
36 | }
|
---|
37 | }
|
---|
38 | return transmissionMatrix;
|
---|
39 | }
|
---|
40 |
|
---|
41 | public void printDot() {
|
---|
42 | Console.println("digraph model {");
|
---|
43 |
|
---|
44 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
45 |
|
---|
46 | for( Event<?> symbol : knownSymbols) {
|
---|
47 | final String thisSaneId = symbol.getShortId().replace("\"", "\\\"").replaceAll("[\r\n]","");
|
---|
48 | Console.println(" " + symbol.hashCode() + " [label=\""+thisSaneId+"\"];");
|
---|
49 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
50 | context.add(symbol);
|
---|
51 | List<Event<?>> followers = trie.getFollowingSymbols(context);
|
---|
52 | for( Event<?> follower : followers ) {
|
---|
53 | System.out.print(" "+symbol.hashCode()+" -> " + follower.hashCode() + " ");
|
---|
54 | System.out.println("[label=\"" + getProbability(context, follower) + "\"];");
|
---|
55 | }
|
---|
56 | }
|
---|
57 | System.out.println('}');
|
---|
58 | }
|
---|
59 |
|
---|
60 | public Graph<String, MarkovEdge> getGraph() {
|
---|
61 | Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
|
---|
62 |
|
---|
63 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
64 |
|
---|
65 | for( Event<?> symbol : knownSymbols) {
|
---|
66 | String from = symbol.getShortId();
|
---|
67 | List<Event<?>> context = new ArrayList<Event<?>>();
|
---|
68 | context.add(symbol);
|
---|
69 |
|
---|
70 | List<Event<?>> followers = trie.getFollowingSymbols(context);
|
---|
71 |
|
---|
72 | for( Event<?> follower : followers ) {
|
---|
73 | String to = follower.getShortId();
|
---|
74 | MarkovEdge prob = new MarkovEdge(getProbability(context, follower));
|
---|
75 | graph.addEdge(prob, from, to, EdgeType.DIRECTED);
|
---|
76 | }
|
---|
77 | }
|
---|
78 | return graph;
|
---|
79 | }
|
---|
80 |
|
---|
81 | static public class MarkovEdge {
|
---|
82 | double weight;
|
---|
83 | MarkovEdge(double weight) { this.weight = weight; }
|
---|
84 | public String toString() { return ""+weight; }
|
---|
85 | }
|
---|
86 |
|
---|
87 | public double calcEntropy() {
|
---|
88 | Matrix transmissionMatrix = getTransmissionMatrix();
|
---|
89 | List<Event<?>> knownSymbols = new ArrayList<Event<?>>(trie.getKnownSymbols());
|
---|
90 | int numStates = knownSymbols.size();
|
---|
91 |
|
---|
92 | int startStateIndex = knownSymbols.indexOf(Event.STARTEVENT);
|
---|
93 | int endStateIndex = knownSymbols.indexOf(Event.ENDEVENT);
|
---|
94 | if( startStateIndex==-1 ) {
|
---|
95 | Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
|
---|
96 | return Double.NaN;
|
---|
97 | }
|
---|
98 | if( endStateIndex==-1 ) {
|
---|
99 | Console.printerrln("Error calculating entropy. End state of markov chain not found.");
|
---|
100 | return Double.NaN;
|
---|
101 | }
|
---|
102 | transmissionMatrix.set(endStateIndex, startStateIndex, 1);
|
---|
103 |
|
---|
104 | // Calculate stationary distribution by raising the power of the transmission matrix.
|
---|
105 | // The rank of the matrix should fall to 1 and each two should be the vector of the
|
---|
106 | // stationory distribution.
|
---|
107 | int iter = 0;
|
---|
108 | int rank = transmissionMatrix.rank();
|
---|
109 | Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
|
---|
110 | while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
|
---|
111 | stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
|
---|
112 | rank = stationaryMatrix.rank();
|
---|
113 | iter++;
|
---|
114 | }
|
---|
115 |
|
---|
116 | if( rank!=1 ) {
|
---|
117 | Console.traceln("rank: " + rank);
|
---|
118 | Console.printerrln("Unable to calculate stationary distribution.");
|
---|
119 | return Double.NaN;
|
---|
120 | }
|
---|
121 |
|
---|
122 | double entropy = 0.0;
|
---|
123 | for( int i=0 ; i<numStates ; i++ ) {
|
---|
124 | for( int j=0 ; j<numStates ; j++ ) {
|
---|
125 | if( transmissionMatrix.get(i,j)!=0 ) {
|
---|
126 | double tmp = stationaryMatrix.get(i, 0);
|
---|
127 | tmp *= transmissionMatrix.get(i, j);
|
---|
128 | tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
|
---|
129 | entropy -= tmp;
|
---|
130 | }
|
---|
131 | }
|
---|
132 | }
|
---|
133 | return entropy;
|
---|
134 | }
|
---|
135 |
|
---|
136 | }
|
---|