-
Notifications
You must be signed in to change notification settings - Fork 2
/
infer.go
203 lines (182 loc) · 4.73 KB
/
infer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
package shapes
import (
"fmt"
"sort"
"github.com/pkg/errors"
)
// ConstraintExpr is a tuple of a list of constraints and an expression.
type ConstraintsExpr struct {
cs constraints
e Expr
st SubjectTo
}
func (ce ConstraintsExpr) Format(f fmt.State, r rune) {
fmt.Fprintf(f, "%v | %v", ce.cs, ce.e)
}
// App applys an expression to a function/Arrow expression.
// This function will aggressively perform alpha renaming on the expression.
//
// Example. Given an application of the following:
//
// ((a,b) → (b, c) → (a, c)) @ (2, a)
//
// The variables in the latter will be aggressively renamed, to become:
//
// ((a,b) → (b, c) → (a, c)) @ (2, d)
//
// Normally this wouldn't be a concern, as you would be passing in concrete shapes, something like:
//
// ((a,b) → (b, c) → (a, c)) @ (2, 3)
//
// which will then yield:
//
// (3, c) → (2, c)
func App(ar Expr, b Expr) ConstraintsExpr {
var a Arrow
var st SubjectTo
switch at := ar.(type) {
case Arrow:
a = at
case Compound:
var ok bool
if a, ok = at.Expr.(Arrow); !ok {
panic(fmt.Sprintf("Unhandled type at.Expr %v of %T", at.Expr, at.Expr))
}
st = at.SubjectTo
default:
panic(fmt.Sprintf("Unhandled type ar %v of %T", ar, ar))
}
fv := a.freevars()
// rename all the free variables in b
b = alpha(fv, b)
// add those new free variables to the set of free variables
fv = append(fv, b.freevars()...)
fv = unique(fv)
// get a fresh variable given the set already used
fr := fresh(fv)
cs := constraints{{a, Arrow{b, fr}}}
return ConstraintsExpr{cs, fr, st}
}
func Infer(ce ConstraintsExpr) (Expr, error) {
if ce.e == nil {
return nil, errors.Errorf("No expression found in ConstraintExpr %v", ce)
}
subs, err := solve(ce.cs, nil)
if err != nil {
return nil, errors.Wrapf(err, "Failed to solve %v", ce)
}
retVal := ce.e.apply(subs).(Expr)
if retVal, err = recursiveResolve(retVal); err != nil {
return retVal, err
}
if ce.st.A != nil && ce.st.B != nil {
st := ce.st.apply(subs).(SubjectTo)
if len(st.freevars()) > 0 {
// don't try to resolve the st yet.
return Compound{Expr: retVal, SubjectTo: st}, nil
}
ok, err := st.resolveBool()
if err != nil {
return nil, errors.Errorf("Failed to resolve SubjectTo %v. Error %v", st, err)
}
if !ok {
return nil, errors.Errorf("SubjectTo %v resolved to false. Cannot continue", st)
}
return retVal, nil
}
return retVal, nil
}
func InferApp(a Expr, others ...Expr) (retVal Expr, err error) {
if len(others) == 0 {
return nil, errors.New("Expected at least one other shape expression in order to InferApp")
}
fst := a
for _, e := range others {
if fst, err = Infer(App(fst, e)); err != nil {
return nil, err
}
}
return fst, nil
}
func ToShape(a Expr) (Shape, error) {
switch at := a.(type) {
case Shape:
return at, nil
case Abstract:
sh, ok := at.ToShape()
if !ok {
return nil, errors.Errorf("Unable to concretize %v of %T", a, a)
}
return sh, nil
default:
return nil, errors.Errorf("Unable to concretize %v of %T", a, a)
}
}
func fresh(set varset) Var {
sort.Sort(set)
if len(set) == 0 {
return 'a'
}
return set[len(set)-1] + 1
}
func alpha(set varset, a Expr) Expr {
fv := a.freevars()
var subs substitutions
for _, v := range fv {
if set.Contains(v) {
fr := fresh(set)
set = append(set, fr)
subs = append(subs, substitution{Sub: fr, For: v})
}
}
a2 := a.apply(subs).(Expr)
return a2
}
func recursiveResolve(a Expr) (Expr, error) {
switch at := a.(type) {
case Abstract:
// even though Abstract implements `resolver`,
// due to the recursive nature of recursiveResolve,
// this will cause an infinite loop
retVal, err := at.resolve()
return retVal, err
case Arrow:
A, err := recursiveResolve(at.A)
if err != nil {
return a, nil // if there's an error, don't continue or return errors.
}
B, err := recursiveResolve(at.B)
if err != nil {
return a, nil // if there's an error, don't continue or return errors.
}
return Arrow{A, B}, nil
case SliceOf:
A, err := recursiveResolve(at.A)
if err != nil {
return nil, errors.Wrapf(err, "Unable to resolve %v in SliceOf", at.A)
}
s := SliceOf{at.Slice, A}
return s.resolve()
case sizeOp:
if !at.isValid() {
return nil, errors.Errorf("Expression %v is not a valid Operation", at)
}
sz, err := at.resolveSize()
if err != nil {
return a, errors.Wrapf(err, "Cannot resolve final expresison. But it may still be used.")
}
return Shape{int(sz)}, nil
case resolver:
retVal, err := at.resolve()
if _, ok := err.(NoOpError); ok {
return retVal, nil
}
if err != nil {
return nil, errors.Wrapf(err, "Failed to recursively resolve %v", at)
}
return recursiveResolve(retVal)
default:
// nothing else can be resolved. return the identity
return a, nil
}
}