forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gorgonia_test.go
66 lines (54 loc) · 1.94 KB
/
gorgonia_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package gorgonia
import (
"testing"
nd "github.com/chewxy/gorgonia/tensor"
"github.com/chewxy/hm"
"github.com/stretchr/testify/assert"
)
func TestNewConstant(t *testing.T) {
assert := assert.New(t)
var expectedType hm.Type
t.Log("Testing New Constant Tensors")
backing := nd.Random(Float64, 9)
T := nd.New(nd.WithBacking(backing), nd.WithShape(3, 3))
ct := NewConstant(T)
expectedTT := newTensorType(2, Float64)
expectedType = expectedTT
assert.Equal(nd.Shape{3, 3}, ct.shape)
assert.Equal(expectedType, ct.t)
ct = NewConstant(T, WithName("From TensorValue"))
assert.Equal(nd.Shape{3, 3}, ct.shape)
assert.Equal(expectedType, ct.t)
assert.Equal("From TensorValue", ct.name)
t.Log("Testing Constant Scalars")
cs := NewConstant(3.14)
expectedType = Float64
assert.Equal(scalarShape, cs.shape)
assert.Equal(expectedType, cs.t)
}
var anyNodeTest = []struct {
name string
any interface{}
correctType hm.Type
correctShape nd.Shape
}{
{"float32", float32(3.14), Float32, scalarShape},
{"float64", float64(3.14), Float64, scalarShape},
{"int", int(3), Int, scalarShape},
{"bool", true, Bool, scalarShape},
{"nd.Tensor", nd.New(nd.Of(nd.Float64), nd.WithShape(2, 3, 4)), &TensorType{Dims: 3, Of: Float64}, nd.Shape{2, 3, 4}},
{"nd.Tensor", nd.New(nd.Of(nd.Float32), nd.WithShape(2, 3, 4)), &TensorType{Dims: 3, Of: Float32}, nd.Shape{2, 3, 4}},
{"ScalarValue", newF64(3.14), Float64, scalarShape},
{"TensorValue", nd.New(nd.Of(nd.Float64), nd.WithShape(2, 3)), &TensorType{Dims: 2, Of: Float64}, nd.Shape{2, 3}},
}
func TestNodeFromAny(t *testing.T) {
assert := assert.New(t)
g := NewGraph()
for _, a := range anyNodeTest {
n := NodeFromAny(g, a.any, WithName(a.name))
assert.Equal(a.name, n.name)
assert.Equal(g, n.g)
assert.True(a.correctType.Eq(n.t), "%v type error: Want %v. Got %v", a.name, a.correctType, n.t)
assert.True(a.correctShape.Eq(n.shape), "%v shape error: Want %v. Got %v", a.name, a.correctShape, n.shape)
}
}