1 | package de.ugoe.cs.eventbench.markov;
|
---|
2 |
|
---|
3 | import java.util.ArrayList;
|
---|
4 | import java.util.List;
|
---|
5 |
|
---|
6 | import de.ugoe.cs.eventbench.data.Event;
|
---|
7 |
|
---|
8 |
|
---|
9 | /**
|
---|
10 | * This class implements a simple first-order markov state.
|
---|
11 | * Transitions are unlabled.
|
---|
12 | * @author Steffen Herbold
|
---|
13 | */
|
---|
14 | public class SimpleState extends State implements DotPrinter {
|
---|
15 |
|
---|
16 | List<Double> transitionProbs;
|
---|
17 | List<State> toStates;
|
---|
18 |
|
---|
19 | // members for learning
|
---|
20 | private int transitionsObserved;
|
---|
21 |
|
---|
22 | private String idShort = null;
|
---|
23 |
|
---|
24 | public SimpleState(String id, Event<?> action) {
|
---|
25 | super(id, action);
|
---|
26 | transitionsObserved = 0;
|
---|
27 | toStates = new ArrayList<State>();
|
---|
28 | transitionProbs = new ArrayList<Double>();
|
---|
29 | }
|
---|
30 |
|
---|
31 | public SimpleState(String id, Event<?> action, String idShort) {
|
---|
32 | this(id, action);
|
---|
33 | this.idShort = idShort;
|
---|
34 | }
|
---|
35 |
|
---|
36 | @Override
|
---|
37 | public State getNextState() {
|
---|
38 | double randVal = rand.nextDouble();
|
---|
39 | double probSum = 0;
|
---|
40 | int index = 0;
|
---|
41 | while( index<transitionProbs.size() && probSum+transitionProbs.get(index) < randVal ) {
|
---|
42 | probSum += transitionProbs.get(index);
|
---|
43 | index++;
|
---|
44 | }
|
---|
45 | return toStates.get(index);
|
---|
46 | }
|
---|
47 |
|
---|
48 | public void incTransTo(State state) {
|
---|
49 | int index = toStates.indexOf(state);
|
---|
50 | if( index==-1 ) {
|
---|
51 | toStates.add(state);
|
---|
52 | transitionProbs.add(0.0);
|
---|
53 | index = toStates.size()-1;
|
---|
54 | }
|
---|
55 | // update trans probs
|
---|
56 | for( int i=0 ; i<toStates.size() ; i++ ) {
|
---|
57 | double currentProb = transitionProbs.get(i);
|
---|
58 | double newProb = 0.0;
|
---|
59 | if( i!=index ) {
|
---|
60 | newProb = (currentProb*transitionsObserved)/(transitionsObserved+1);
|
---|
61 | } else {
|
---|
62 | newProb = ((currentProb*transitionsObserved)+1)/(transitionsObserved+1);
|
---|
63 | }
|
---|
64 | transitionProbs.set(i, newProb);
|
---|
65 | }
|
---|
66 | transitionsObserved++;
|
---|
67 | }
|
---|
68 |
|
---|
69 | // get the transition probability to the given state
|
---|
70 | public double getProb(State state) {
|
---|
71 | double prob = 0.0;
|
---|
72 | int index = toStates.indexOf(state);
|
---|
73 | if( index>=0 ) {
|
---|
74 | prob = transitionProbs.get(index);
|
---|
75 | }
|
---|
76 | return prob;
|
---|
77 | }
|
---|
78 |
|
---|
79 | public String getShortId() {
|
---|
80 | String shortId;
|
---|
81 | if( idShort!=null ) {
|
---|
82 | shortId = idShort;
|
---|
83 | } else {
|
---|
84 | shortId = getId();
|
---|
85 | }
|
---|
86 | return shortId;
|
---|
87 | }
|
---|
88 |
|
---|
89 | @Override
|
---|
90 | public void printDot() {
|
---|
91 | final String thisSaneId = getShortId().replace("\"", "\\\"").replaceAll("[\r\n]","");
|
---|
92 | System.out.println(" " + hashCode() + " [label=\""+thisSaneId+"\"];");
|
---|
93 | for(int i=0 ; i<toStates.size() ; i++ ) {
|
---|
94 | System.out.print(" "+hashCode()+" -> " + toStates.get(i).hashCode() + " ");
|
---|
95 | System.out.println("[label=\"" + transitionProbs.get(i) + "\"];");
|
---|
96 | }
|
---|
97 |
|
---|
98 | }
|
---|
99 |
|
---|
100 | } |
---|