-
Notifications
You must be signed in to change notification settings - Fork 0
/
ClassifierFactory.cs
80 lines (73 loc) · 1.94 KB
/
ClassifierFactory.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
using System;
using System.Threading;
public static class ClassifierFactory
{
public enum Config
{
// An example configuration for detecting gestures from 3D coordinate patterns
Gesture
}
private class TrainingArgs
{
public readonly Classifier Classifier;
public readonly double[][] TrainData;
public readonly int[] Targets;
public readonly double TrainRate;
public readonly double TargetError;
public readonly int MaxEpochs;
public readonly int MaxRestarts;
public readonly Action<Classifier> Callback;
public TrainingArgs(
Classifier classifier,
double[][] trainData,
int[] targets,
double trainRate,
double targetError,
int maxEpochs,
int maxRestarts,
Action<Classifier> callback)
{
Classifier = classifier;
TrainData = trainData;
Targets = targets;
TrainRate = trainRate;
TargetError = targetError;
MaxEpochs = maxEpochs;
MaxRestarts = maxRestarts;
Callback = callback;
}
}
public static void CreateGestureClassifier(Config config, double[][] trainData, int[] targets, Action<Classifier> callback)
{
TrainingArgs trainingArgs;
switch (config)
{
case Config.Gesture:
trainingArgs = new TrainingArgs(
classifier: new Classifier(numInputs: 33, numHiddenNeurons: 11),
trainData: trainData,
targets: targets,
trainRate: 0.5,
targetError: 0.05,
maxEpochs: 3000,
maxRestarts: 20,
callback: callback
);
break;
default:
throw new ArgumentException("Unhandled config: " + config);
}
new Thread(Create).Start(trainingArgs);
}
private static void Create(object obj)
{
if (!(obj is TrainingArgs))
{
throw new ArgumentException("Object is not of type " + typeof(TrainingArgs).FullName);
}
TrainingArgs args = (TrainingArgs)obj;
Classifier classifier = args.Classifier;
classifier.Train(args.TrainData, args.Targets, args.TrainRate, args.TargetError, args.MaxEpochs, args.MaxRestarts);
args.Callback(classifier);
}
}