source: trunk/EventBenchConsole/src/de/ugoe/cs/eventbench/commands/AbstractTrainCommand.java @ 206

Last change on this file since 206 was 203, checked in by sherbold, 13 years ago
  • Changed data type for handling of sequence-sets. Before, List<List<Event<?>>> was used, now Collection<List<Event<?>>> is used.
  • Property svn:mime-type set to text/plain
File size: 2.9 KB
Line 
1package de.ugoe.cs.eventbench.commands;
2
3import java.security.InvalidParameterException;
4import java.util.Collection;
5import java.util.List;
6
7import de.ugoe.cs.eventbench.data.Event;
8import de.ugoe.cs.eventbench.data.GlobalDataContainer;
9import de.ugoe.cs.eventbench.models.TrieBasedModel;
10import de.ugoe.cs.util.console.Command;
11import de.ugoe.cs.util.console.Console;
12
13/**
14 * <p>
15 * Abstract class for commands to train {@link TrieBasedModel}s.
16 * </p>
17 *
18 * @author Steffen Herbold
19 * @version 1.0
20 */
21public abstract class AbstractTrainCommand implements Command {
22
23        /**
24         * <p>
25         * Handling of additional parameters.
26         * </p>
27         *
28         * @param parameters
29         *            same as the parameters passed to {@link #run(List)}.
30         * @throws Exception
31         *             thrown, if there is an error parsing the parameters
32         */
33        abstract void handleAdditionalParameters(List<Object> parameters)
34                        throws Exception;
35
36        /**
37         * <p>
38         * Returns a concrete instance of {@link TrieBasedModel} to be trained. This
39         * is a factory method.
40         * </p>
41         *
42         * @return instance of {@link TrieBasedModel}
43         */
44        abstract TrieBasedModel createModel();
45
46        /**
47         * <p>
48         * The command is implemented as a template method. The general structure of
49         * the command is always the same, only the parameters of the command and
50         * the creation of the {@link TrieBasedModel} instance. The former is
51         * handled by {@link #handleOptionalParameters(List)}, the latter by
52         * {@link #createModel()}.
53         * </p>
54         *
55         * @see de.ugoe.cs.util.console.Command#run(java.util.List)
56         */
57        @SuppressWarnings("unchecked")
58        @Override
59        public void run(List<Object> parameters) {
60                String modelname;
61                String sequencesName;
62
63                try {
64                        modelname = (String) parameters.get(0);
65                        sequencesName = (String) parameters.get(1);
66                        handleAdditionalParameters(parameters);
67                } catch (Exception e) {
68                        throw new InvalidParameterException();
69                }
70
71                Collection<List<Event<?>>> sequences = null;
72                Object dataObject = GlobalDataContainer.getInstance().getData(
73                                sequencesName);
74                if (dataObject == null) {
75                        Console.println("Object " + sequencesName
76                                        + " not found in storage.");
77                        return;
78                }
79                try {
80                        sequences = (Collection<List<Event<?>>>) dataObject;
81                } catch (ClassCastException e) {
82                        Console.println("Object " + sequencesName
83                                        + "not of type Collection<List<Event<?>>>.");
84                        return;
85                }
86                /* TODO implement better type check
87                if (sequences.size() == 0 || !(sequences.get(0).get(0) instanceof Event) ) {
88                        Console.println("Object " + sequencesName
89                                        + "not of type Collection<List<Event<?>>>.");
90                        return;
91                }
92                */
93
94                TrieBasedModel model = createModel();
95                model.train(sequences);
96                if (GlobalDataContainer.getInstance().addData(modelname,
97                                model)) {
98                        Console.traceln("Old data \"" + modelname
99                                        + "\" overwritten");
100                }
101               
102        }
103
104}
Note: See TracBrowser for help on using the repository browser.