source: trunk/EventBenchCore/src/de/ugoe/cs/eventbench/markov/MarkovModel.java @ 1

Last change on this file since 1 was 1, checked in by sherbold, 13 years ago
File size: 7.4 KB
Line 
1package de.ugoe.cs.eventbench.markov;
2
3import java.util.ArrayList;
4import java.util.LinkedList;
5import java.util.List;
6import java.util.Random;
7
8import Jama.Matrix;
9
10import de.ugoe.cs.eventbench.data.Event;
11import de.ugoe.cs.util.console.Console;
12import edu.uci.ics.jung.graph.Graph;
13import edu.uci.ics.jung.graph.SparseMultigraph;
14import edu.uci.ics.jung.graph.util.EdgeType;
15
16public class MarkovModel implements DotPrinter {
17
18        private State initialState;
19        private State endState;
20       
21        private List<State> states;
22        private List<String> stateIdList;
23       
24        private Random r;
25       
26        final static int MAX_STATDIST_ITERATIONS = 1000;
27       
28        /**
29         * <p>
30         * Default constructor. Creates a new random number generator.
31         * </p>
32         */
33        public MarkovModel() {
34                this(new Random());
35        }
36       
37        /**
38         * <p>
39         * Creates a new {@link MarkovModel} with a predefined random number generator.
40         * </p>
41         *
42         * @param r random number generator
43         */
44        public MarkovModel(Random r) {
45                this.r = r; // defensive copy would be better, but constructor Random(r) does not seem to exist.
46        }
47       
48        public void printRandomWalk() {
49                State currentState = initialState;
50                IMemory<State> history = new IncompleteMemory<State>(5); // this is NOT used here, just for testing ...
51                history.add(currentState);
52                Console.println(currentState.getId());
53                while(!currentState.equals(endState)) {
54                        currentState = currentState.getNextState();
55                        Console.println(currentState.getId());
56                        history.add(currentState);
57                }
58        }
59       
60        public List<? extends Event<?>> randomSequence() {
61                List<Event<?>> sequence = new LinkedList<Event<?>>();
62                State currentState = initialState;
63                if( currentState.getAction()!=null ) {
64                        sequence.add(currentState.getAction());
65                }
66                System.out.println(currentState.getId());
67                while(!currentState.equals(endState)) {
68                        currentState = currentState.getNextState();
69                        if( currentState.getAction()!=null ) {
70                                sequence.add(currentState.getAction());
71                        }
72                        System.out.println(currentState.getId());
73                }
74                return sequence;
75        }
76       
77        public void printDot() {
78                int numUnprintableStates = 0;
79                System.out.println("digraph model {");
80                for( State state : states ) {
81                        if( state instanceof DotPrinter ) {
82                                ((DotPrinter) state).printDot();
83                        } else {
84                                numUnprintableStates++;
85                        }
86                }
87                System.out.println('}');
88                if( numUnprintableStates>0 ) {
89                        Console.println("" + numUnprintableStates + "/" + states.size() + "were unprintable!");
90                }
91        }
92       
93        public Graph<String, MarkovEdge> getGraph() {
94                Graph<String, MarkovEdge> graph = new SparseMultigraph<String, MarkovEdge>();
95               
96                for( State state : states) {
97                        try {
98                                SimpleState simpleState = (SimpleState) state;
99                                String from = simpleState.getShortId();
100                                for( int i=0 ; i<simpleState.toStates.size() ; i++ ) {
101                                        SimpleState toState = (SimpleState) simpleState.toStates.get(i);
102                                        String to = toState.getShortId();
103                                        MarkovEdge prob = new MarkovEdge(simpleState.transitionProbs.get(i));
104                                        graph.addEdge(prob, from, to, EdgeType.DIRECTED);
105                                }
106                        } catch (ClassCastException e) {
107                                // TODO: handle exception
108                        }
109                }
110               
111                return graph;
112        }
113       
114        static public class MarkovEdge {
115                double weight;
116                MarkovEdge(double weight) { this.weight = weight; }
117                public String toString() { return ""+weight; }
118        }
119       
120        /////////////////////////////////////////////////////////////////////////////////////
121        // Code to learn type1 model: states are wndid.action and transitions are unlabled //
122        /////////////////////////////////////////////////////////////////////////////////////
123       
124        public void train(List<List<Event<?>>> sequences) {
125                Event<?> fromElement = null;
126                Event<?> toElement = null;
127                SimpleState fromState;
128                SimpleState toState;
129               
130                states = new ArrayList<State>();
131                stateIdList = new ArrayList<String>();
132                initialState = new SimpleState("GLOBALSTARTSTATE", null);
133                initialState.setRandom(r);
134                states.add(initialState);
135                stateIdList.add("GLOBALSTARTSTATE");
136                endState = new SimpleState("GLOBALENDSTATE", null);
137                endState.setRandom(r);
138                states.add(endState);
139                stateIdList.add("GLOBALENDSTATE");
140                for( List<Event<?>> sequence : sequences ) {
141                        for( int i=0; i<sequence.size() ; i++ ) {
142                                if( i==0 ) {
143                                        fromState = (SimpleState) initialState;
144                                } else {
145                                        fromElement = sequence.get(i-1);
146                                        fromState = findOrCreateSimpleState(fromElement);
147                                }
148                               
149                                toElement = sequence.get(i);
150                                toState = findOrCreateSimpleState(toElement);
151                               
152                                fromState.incTransTo(toState);
153                               
154                                if( i==sequence.size()-1 ) {
155                                        toState.incTransTo(endState);
156                                }
157                        }
158                }
159        }
160       
161        private SimpleState findOrCreateSimpleState(Event<?> action) {
162                SimpleState state = null;
163                String id = action.getStandardId();
164                String idShort = action.getShortId();
165                int index = stateIdList.indexOf(id);
166                if( index!=-1 ) {
167                        state = (SimpleState) states.get(index);
168                } else {
169                        state = new SimpleState(id, action, idShort);
170                        state.setRandom(r);
171                        states.add(state);
172                        stateIdList.add(id);
173                }
174                return state;
175        }
176       
177        ///////////////////////////////////////////////////////////
178       
179        // states must be SimpleState, this functions will throw bad cast exceptions
180        public double calcEntropy() {
181                int numStates = states.size();
182                // create transmission matrix
183                Matrix transmissionMatrix = new Matrix(numStates, numStates);
184                for( int i=0 ; i<numStates ; i++ ) {
185                        State tmpState = states.get(i);
186                        if( SimpleState.class.isInstance(tmpState) ) {
187                                SimpleState currentState = (SimpleState) tmpState;
188                                for( int j=0 ; j<numStates ; j++ ) {
189                                        double prob = currentState.getProb(states.get(j));
190                                        transmissionMatrix.set(i, j, prob);
191                                }
192                        } else {
193                                Console.printerr("Error calculating entropy. Only allowed for first-order markov models.");
194                                return Double.NaN;
195                        }
196                }
197               
198                // Add transition from endState to startState. This makes the markov chain irreducible and recurrent.
199                int startStateIndex = states.indexOf(initialState);
200                int endStateIndex = states.indexOf(endState);
201                if( startStateIndex==-1 ) {
202                        Console.printerrln("Error calculating entropy. Initial state of markov chain not found.");
203                        return Double.NaN;
204                }
205                if( endStateIndex==-1 ) {
206                        Console.printerrln("Error calculating entropy. End state of markov chain not found.");
207                        return Double.NaN;
208                }
209                transmissionMatrix.set(endStateIndex, startStateIndex, 1);
210               
211                // Calculate stationary distribution by raising the power of the transmission matrix.
212                // The rank of the matrix should fall to 1 and each two should be the vector of the
213                // stationory distribution.
214                int iter = 0;
215                int rank = transmissionMatrix.rank();
216                Matrix stationaryMatrix = (Matrix) transmissionMatrix.clone();
217                while( iter<MAX_STATDIST_ITERATIONS && rank>1 ) {
218                        stationaryMatrix = stationaryMatrix.times(stationaryMatrix);
219                        rank = stationaryMatrix.rank();
220                        iter++;
221                }
222               
223                if( rank!=1 ) {
224                        Console.traceln("rank: " + rank);
225                        Console.printerrln("Unable to calculate stationary distribution.");
226                        return Double.NaN;
227                }
228               
229                double entropy = 0.0;
230                for( int i=0 ; i<numStates ; i++ ) {
231                        for( int j=0 ; j<numStates ; j++ ) {
232                                if( transmissionMatrix.get(i,j)!=0 ) {
233                                        double tmp = stationaryMatrix.get(i, 0);
234                                        tmp *= transmissionMatrix.get(i, j);
235                                        tmp *= Math.log(transmissionMatrix.get(i,j))/Math.log(2);
236                                        entropy -= tmp;
237                                }
238                        }
239                }
240                return entropy;
241        }
242}
Note: See TracBrowser for help on using the repository browser.