source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/markov/MarkovModel.java @ 2

Last change on this file since 2 was 2, checked in by sherbold, 13 years ago

+ added toString() to Event

File size: 6.8 KB
Line 
1package de.ugoe.cs.eventbench.markov;
2
3import java.util.ArrayList;
4import java.util.LinkedList;
5import java.util.List;
6import java.util.Random;
7
8import Jama.Matrix;
9
10import de.ugoe.cs.eventbench.data.Event;
11import de.ugoe.cs.util.console.Console;
12import edu.uci.ics.jung.graph.Graph;
13import edu.uci.ics.jung.graph.SparseMultigraph;
14import edu.uci.ics.jung.graph.util.EdgeType;
15
16public class MarkovModel implements DotPrinter {
17
18        private State initialState;
19        private State endState;
20       
21        private List<State> states;
22        private List<String> stateIdList;
23       
24        private Random r;
25       
26        final static int MAX_STATDIST_ITERATIONS = 1000;
27       
28        /**
29         * <p>
30         * Default constructor. Creates a new random number generator.
31         * </p>
32         */
33        public MarkovModel() {
34                this(new Random());
35        }
36       
37        /**
38         * <p>
39         * Creates a new {@link MarkovModel} with a predefined random number generator.
40         * </p>
41         *
42         * @param r random number generator
43         */
44        public MarkovModel(Random r) {
45                this.r = r; // defensive copy would be better, but constructor Random(r) does not seem to exist.
46        }
47       
48        public List<? extends Event<?>> randomSequence() {
49                List<Event<?>> sequence = new LinkedList<Event<?>>();
50                State currentState = initialState;
51                if( currentState.getAction()!=null ) {
52                        sequence.add(currentState.getAction());
53                }
54                while(!currentState.equals(endState)) {
55                        currentState = currentState.getNextState();
56                        if( currentState.getAction()!=null ) {
57                                sequence.add(currentState.getAction());
58                        }
59                }
60                return sequence;
61        }
62       
63        public void printDot() {
64                int numUnprintableStates = 0;
65                System.out.println("digraph model {");
66                for( State state : states ) {
67                        if( state instanceof DotPrinter ) {
68                                ((DotPrinter) state).printDot();
69                        } else {
70                                numUnprintableStates++;
71                        }
72                }
73                System.out.println('}');
74                if( numUnprintableStates>0 ) {
75                        Console.println("" + numUnprintableStates + "/" + states.size() + "were unprintable!");
76                }
77        }
78       
79        public Graph<String, MarkovEdge> getGraph() {
80                Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
81               
82                for( State state : states) {
83                        try {
84                                SimpleState simpleState = (SimpleState) state;
85                                String from = simpleState.getShortId();
86                                for( int i=0 ; i<simpleState.toStates.size() ; i++ ) {
87                                        SimpleState toState = (SimpleState) simpleState.toStates.get(i);
88                                        String to = toState.getShortId();
89                                        MarkovEdge prob = new MarkovEdge(simpleState.transitionProbs.get(i));
90                                        graph.addEdge(prob, from, to, EdgeType.DIRECTED);
91                                }
92                        } catch (ClassCastException e) {
93                                // TODO: handle exception
94                        }
95                }
96               
97                return graph;
98        }
99       
100        static public class MarkovEdge {
101                double weight;
102                MarkovEdge(double weight) { this.weight = weight; }
103                public String toString() { return ""+weight; }
104        }
105       
106        /////////////////////////////////////////////////////////////////////////////////////
107        // Code to learn type1 model: states are wndid.action and transitions are unlabled //
108        /////////////////////////////////////////////////////////////////////////////////////
109       
110        public void train(List<List<Event<?>>> sequences) {
111                Event<?> fromElement = null;
112                Event<?> toElement = null;
113                SimpleState fromState;
114                SimpleState toState;
115               
116                states = new ArrayList<State>();
117                stateIdList = new ArrayList<String>();
118                initialState = new SimpleState("GLOBALSTARTSTATE", null);
119                initialState.setRandom(r);
120                states.add(initialState);
121                stateIdList.add("GLOBALSTARTSTATE");
122                endState = new SimpleState("GLOBALENDSTATE", null);
123                endState.setRandom(r);
124                states.add(endState);
125                stateIdList.add("GLOBALENDSTATE");
126                for( List<Event<?>> sequence : sequences ) {
127                        for( int i=0; i<sequence.size() ; i++ ) {
128                                if( i==0 ) {
129                                        fromState = (SimpleState) initialState;
130                                } else {
131                                        fromElement = sequence.get(i-1);
132                                        fromState = findOrCreateSimpleState(fromElement);
133                                }
134                               
135                                toElement = sequence.get(i);
136                                toState = findOrCreateSimpleState(toElement);
137                               
138                                fromState.incTransTo(toState);
139                               
140                                if( i==sequence.size()-1 ) {
141                                        toState.incTransTo(endState);
142                                }
143                        }
144                }
145        }
146       
147        private SimpleState findOrCreateSimpleState(Event<?> action) {
148                SimpleState state = null;
149                String id = action.getStandardId();
150                String idShort = action.getShortId();
151                int index = stateIdList.indexOf(id);
152                if( index!=-1 ) {
153                        state = (SimpleState) states.get(index);
154                } else {
155                        state = new SimpleState(id, action, idShort);
156                        state.setRandom(r);
157                        states.add(state);
158                        stateIdList.add(id);
159                }
160                return state;
161        }
162       
163        ///////////////////////////////////////////////////////////
164       
165        // states must be SimpleState, this functions will throw bad cast exceptions
166        public double calcEntropy() {
167                int numStates = states.size();
168                // create transmission matrix
169                Matrix transmissionMatrix = new Matrix(numStates, numStates);
170                for( int i=0 ; i<numStates ; i++ ) {
171                        State tmpState = states.get(i);
172                        if( SimpleState.class.isInstance(tmpState) ) {
173                                SimpleState currentState = (SimpleState) tmpState;
174                                for( int j=0 ; j<numStates ; j++ ) {
175                                        double prob = currentState.getProb(states.get(j));
176                                        transmissionMatrix.set(i, j, prob);
177                                }
178                        } else {
179                                Console.printerr("Error calculating entropy. Only allowed for first-order markov models.");
180                                return Double.NaN;
181                        }
182                }
183               
184                // Add transition from endState to startState. This makes the markov chain irreducible and recurrent.
185                int startStateIndex = states.indexOf(initialState);
186                int endStateIndex = states.indexOf(endState);
187                if( startStateIndex==-1 ) {
188                        Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
189                        return Double.NaN;
190                }
191                if( endStateIndex==-1 ) {
192                        Console.printerrln("Error calculating entropy. End state of markov chain not found.");
193                        return Double.NaN;
194                }
195                transmissionMatrix.set(endStateIndex, startStateIndex, 1);
196               
197                // Calculate stationary distribution by raising the power of the transmission matrix.
198                // The rank of the matrix should fall to 1 and each two should be the vector of the
199                // stationory distribution.
200                int iter = 0;
201                int rank = transmissionMatrix.rank();
202                Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
203                while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
204                        stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
205                        rank = stationaryMatrix.rank();
206                        iter++;
207                }
208               
209                if( rank!=1 ) {
210                        Console.traceln("rank: " + rank);
211                        Console.printerrln("Unable to calculate stationary distribution.");
212                        return Double.NaN;
213                }
214               
215                double entropy = 0.0;
216                for( int i=0 ; i<numStates ; i++ ) {
217                        for( int j=0 ; j<numStates ; j++ ) {
218                                if( transmissionMatrix.get(i,j)!=0 ) {
219                                        double tmp = stationaryMatrix.get(i, 0);
220                                        tmp *= transmissionMatrix.get(i, j);
221                                        tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
222                                        entropy -= tmp;
223                                }
224                        }
225                }
226                return entropy;
227        }
228}
Note: See TracBrowser for help on using the repository browser.