-
Notifications
You must be signed in to change notification settings - Fork 125
/
Copy pathlazy_mnist.ml
executable file
·59 lines (48 loc) · 1.84 KB
/
lazy_mnist.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
(* This example demonstrates using lazy functor to train a model on mnist. *)
open Owl
module CPU_Engine = Owl_computation_cpu_engine.Make (Owl_algodiff_primal_ops.S)
module CGCompiler = Owl_neural_compiler.Make (CPU_Engine)
open CGCompiler.Neural
open CGCompiler.Neural.Graph
open CGCompiler.Neural.Algodiff
let make_network input_shape =
input input_shape
|> lambda (fun x -> Maths.(x / pack_flt 256.))
|> conv2d [|5;5;1;32|] [|1;1|] ~act_typ:Activation.Relu
|> max_pool2d [|2;2|] [|2;2|]
|> dropout 0.1
|> fully_connected 1024 ~act_typ:Activation.Relu
|> linear 10 ~act_typ:Activation.(Softmax 1)
|> get_network ~name:"mnist"
let pack x = CGCompiler.Engine.pack_arr x |> Algodiff.pack_arr
let unpack x = Algodiff.unpack_arr x |> CGCompiler.Engine.unpack_arr
let train network =
let x, _, y = Dataset.load_mnist_train_data_arr () in
let x = pack x in
let y = pack y in
let params = Params.config
~batch:(Batch.Mini 100) ~learning_rate:(Learning_Rate.Adagrad 0.005) 0.1
(* ~momentum:(Momentum.Standard 0.1) *)
in
CGCompiler.train ~params network x y
let test network =
let imgs, _, labels = Dataset.load_mnist_test_data () in
let m = Dense.Matrix.S.row_num imgs in
let imgs = Dense.Ndarray.S.reshape imgs [|m;28;28;1|] in
let eval = CGCompiler.model ~batch_size:100 network in
let mat2num x = Dense.Matrix.S.of_array (
x |> Dense.Matrix.Generic.max_rows
|> Array.map (fun (_,_,num) -> float_of_int num)
) 1 m
in
let result = unpack (eval (pack imgs)) in
let pred = mat2num result in
let fact = mat2num labels in
let accu = Dense.Matrix.S.(elt_equal pred fact |> sum') in
Owl_log.info "Accuracy on test set: %f" (accu /. (float_of_int m))
let () =
Owl_log.(set_level INFO);
let network = make_network [|28;28;1|] in
Graph.print network; flush_all ();
let _ = train network in
test network