-
Notifications
You must be signed in to change notification settings - Fork 2
/
abstract_test.go
139 lines (120 loc) · 4.15 KB
/
abstract_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package shapes
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAbstract_T(t *testing.T) {
assert := assert.New(t)
abstract := Abstract{Size(1), BinOp{Add, Size(1), Size(2)}}
// noop
a2, err := abstract.T(0, 1)
if err == nil {
t.Errorf("Expected a noop error")
}
if _, ok := err.(NoOpError); !ok {
t.Errorf("Expected a noop error. Got %v instead", err)
}
assert.Equal(a2, abstract)
a2, err = abstract.T(1, 0)
if err != nil {
t.Fatal(err)
}
correct := Abstract{abstract[1], abstract[0]}
assert.Equal(correct, a2)
}
var absSliceTests = []struct {
name string
a Abstract
sli []Slice
expected Shapelike
err bool
}{
{"all vars", Gen(2), nil, Gen(2), false},
{"all vars", Gen(2), []Slice{nil, S(1)},
Abstract{Var('a'), sizelikeSliceOf{SliceOf{*S(1), Var('b')}}}, false},
{"all sizes (vector)", Abstract{Size(2)}, []Slice{S(0)}, ScalarShape(), false},
{"all sizes (vector) - bad slice range", Abstract{Size(2)}, []Slice{S(3)}, nil, true},
{"Mixed sizes and var", Abstract{Var('a'), Size(2)}, []Slice{S(2), S(0, 2)},
Abstract{sizelikeSliceOf{SliceOf{*S(2), Var('a')}}, Size(2)},
false,
},
{"Mixed",
Abstract{Var('a'), BinOp{Add, Var('a'), Var('b')}, UnaryOp{Dims, Var('b')}},
[]Slice{S(1, 5, 2), S(1, 5), S(1, 5)},
Abstract{
sizelikeSliceOf{SliceOf{*S(1, 5, 2), Var('a')}},
sizelikeSliceOf{SliceOf{*S(1, 5), E2{BinOp{Add, Var('a'), Var('b')}}}},
sizelikeSliceOf{SliceOf{*S(1, 5), UnaryOp{Dims, Var('b')}}},
},
false,
},
}
func TestAbstract_S(t *testing.T) {
assert := assert.New(t)
for i, c := range absSliceTests {
newShapelike, err := c.a.S(c.sli...)
if checkErr(t, c.err, err, "Abs slice", i) {
continue
}
assert.Equal(c.expected, newShapelike, c.name)
}
}
func ExampleSlice_s() {
param0 := Abstract{Var('a'), Var('b')}
param1 := Abstract{Var('a'), Var('b'), BinOp{Add, Var('a'), Var('b')}, UnaryOp{Const, Var('b')}}
expected, err := param1.S(S(1, 5), S(1, 5), S(1, 5), S(2, 5))
if err != nil {
fmt.Printf("Err %v\n", err)
return
}
expr := MakeArrow(param0, param1, expected.(Expr))
fmt.Printf("expr: %v\n", expr)
fst := Shape{10, 20}
result, err := InferApp(expr, fst)
if err != nil {
fmt.Printf("Err %v\n", err)
return
}
fmt.Printf("%v @ %v ↠ %v\n", expr, fst, result)
snd := Shape{10, 20, 30, 20}
result2, err := InferApp(result, snd)
if err != nil {
fmt.Printf("Err %v\n", err)
return
}
fmt.Printf("%v @ %v ↠ %v", result, snd, result2)
// Output:
// expr: (a, b) → (a, b, a + b, K b) → (a[1:5], b[1:5], a + b[1:5], K b[2:5])
// (a, b) → (a, b, a + b, K b) → (a[1:5], b[1:5], a + b[1:5], K b[2:5]) @ (10, 20) ↠ (10, 20, 30, 20) → (4, 4, 4, 3)
// (10, 20, 30, 20) → (4, 4, 4, 3) @ (10, 20, 30, 20) ↠ (4, 4, 4, 3)
}
var absRepeatTests = []struct {
name string
a Abstract
repeats []int
axis Axis
expected Shapelike
expectedRepeats []int
expectedSize int
err bool
}{
{"vector repeat on axis 0", Abstract{Var('a')}, []int{3}, 0, Abstract{BinOp{Mul, Var('a'), Size(3)}}, []int{3}, -1, false},
{"vector repeat on axis 1", Abstract{Var('a')}, []int{3}, 1, Abstract{Var('a'), Size(3)}, []int{3}, 1, false},
{"var matrix repeat on axis 0", Abstract{Var('a'), Var('b')}, []int{1, 3}, 0, Abstract{Size(4), Var('b')}, nil, -1, false},
{"var matrix repeat on axis 1", Abstract{Var('a'), Var('b')}, []int{1, 3}, 1, Abstract{Var('a'), Size(4)}, nil, -1, false},
{"var matrix generic repeat on axis 0", Abstract{Var('a'), Var('b')}, []int{3}, 0, Abstract{BinOp{Mul, Var('a'), Size(3)}, Var('b')}, []int{3}, -1, false},
{"var matrix generic repeat on axis 1", Abstract{Var('a'), Var('b')}, []int{3}, 1, Abstract{Var('a'), BinOp{Mul, Var('b'), Size(3)}}, []int{3}, -1, false},
}
func TestAbs_Repeat(t *testing.T) {
assert := assert.New(t)
for i, c := range absRepeatTests {
newShape, reps, size, err := c.a.Repeat(c.axis, c.repeats...)
if checkErr(t, c.err, err, c.name, i) {
continue
}
assert.Equal(c.expected, newShape, "Test %v - Shape like not the same", c.name)
assert.Equal(c.expectedRepeats, reps, "Test %v - Repeats not the same", c.name)
assert.Equal(c.expectedSize, size, "Test %v - Size not the same", c.name)
}
}