forked from compintell/Mooncake.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paths2s_forward_mode_ad.jl
313 lines (281 loc) · 9.56 KB
/
s2s_forward_mode_ad.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
function build_frule(args...; debug_mode=false)
interp = get_interpreter()
sig = _typeof(TestUtils.__get_primals(args))
return build_frule(interp, sig; debug_mode)
end
function build_frule(
interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true
) where {C}
# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
throw(
ArgumentError(
"World age associated to interp is behind current world age. Please " *
"a new interpreter for the current world age.",
),
)
end
# If we're compiling in debug mode, let the user know by default.
if !silence_debug_messages && debug_mode
@info "Compiling rule for $sig_or_mi in debug mode. Disable for best performance."
end
# If we have a hand-coded rule, just use that.
sig = _get_sig(sig_or_mi)
is_primitive(C, sig) && return (debug_mode ? DebugFRule(frule!!) : frule!!)
# We don't have a hand-coded rule, so derived one.
lock(MOONCAKE_INFERENCE_LOCK)
try
# If we've already derived the OpaqueClosures and info, do not re-derive, just
# create a copy and pass in new shared data.
oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode))
# if haskey(interp.oc_cache, oc_cache_key)
# return interp.oc_cache[oc_cache_key]
# else
# Derive forward-pass IR, and shove in a `MistyClosure`.
dual_ir = generate_dual_ir(interp, sig_or_mi; debug_mode)
dual_oc = MistyClosure(dual_ir; do_compile=true)
raw_rule = DerivedFRule(dual_oc)
rule = debug_mode ? DebugFRule(raw_rule) : raw_rule
interp.oc_cache[oc_cache_key] = rule
return rule
# end
catch e
rethrow(e)
finally
unlock(MOONCAKE_INFERENCE_LOCK)
end
end
struct DerivedFRule{Tfwd_oc}
fwd_oc::Tfwd_oc
end
@inline function (fwd::DerivedFRule)(args::Vararg{Dual,N}) where {N}
return fwd.fwd_oc(args...)
end
function generate_dual_ir(
interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
)
# Reset id count. This ensures that the IDs generated are the same each time this
# function runs.
seed_id!()
# Grab code associated to the primal.
primal_ir, _ = lookup_ir(interp, sig_or_mi)
# Normalise the IR.
_, spnames = is_vararg_and_sparam_names(sig_or_mi)
primal_ir = normalise!(primal_ir, spnames)
# Keep a copy of the primal IR with the insertions
dual_ir = copy(primal_ir)
# Modify dual argument types:
# - add one for the rule in front
# - convert the rest to dual types
for (a, P) in enumerate(primal_ir.argtypes)
if P isa DataType
dual_ir.argtypes[a] = dual_type(P)
elseif P isa Core.Const
dual_ir.argtypes[a] = dual_type(_typeof(P.val))
end
end
pushfirst!(dual_ir.argtypes, Any)
# Modify dual IR incrementally
dual_ir_comp = CC.IncrementalCompact(dual_ir)
for ((_, i), inst) in dual_ir_comp
modify_fwd_ad_stmts!(dual_ir_comp, primal_ir, interp, inst, i; debug_mode)
end
dual_ir_comp = CC.finish(dual_ir_comp)
dual_ir_comp = CC.compact!(dual_ir_comp)
CC.verify_ir(dual_ir_comp)
# Optimize dual IR
opt_dual_ir = optimise_ir!(dual_ir_comp; do_inline) # TODO: toggle
# @info "Inferred dual IR"
# display(opt_dual_ir) # TODO: toggle
return opt_dual_ir
end
## Modification of IR nodes
function modify_fwd_ad_stmts!(
dual_ir::CC.IncrementalCompact,
primal_ir::IRCode,
::MooncakeInterpreter,
stmt::Nothing,
i::Integer;
kwargs...,
)
return nothing
end
function modify_fwd_ad_stmts!(
dual_ir::CC.IncrementalCompact,
primal_ir::IRCode,
::MooncakeInterpreter,
stmt::GlobalRef,
i::Integer;
kwargs...,
)
return nothing
end
function modify_fwd_ad_stmts!(
dual_ir::CC.IncrementalCompact,
primal_ir::IRCode,
::MooncakeInterpreter,
stmt::GotoNode,
i::Integer;
kwargs...,
)
return nothing
end
function modify_fwd_ad_stmts!(
dual_ir::CC.IncrementalCompact,
primal_ir::IRCode,
::MooncakeInterpreter,
stmt::Core.GotoIfNot,
i::Integer;
kwargs...,
)
# replace GotoIfNot with the call to primal
Mooncake.replace_call!(
dual_ir, CC.SSAValue(i), Expr(:call, _primal, inc_args(stmt).cond)
)
# reinsert the GotoIfNot right after the call to primal
# (incremental insertion cannot be done before "where we are")
new_gotoifnot_inst = CC.NewInstruction(
Core.GotoIfNot(CC.SSAValue(i), stmt.dest), #
Any,
CC.NoCallInfo(),
Int32(1), # meaningless
CC.IR_FLAG_REFINED,
)
# stick the new instruction in the previous CFG block
reverse_affinity = true
CC.insert_node_here!(dual_ir, new_gotoifnot_inst, reverse_affinity)
return nothing
end
function modify_fwd_ad_stmts!(
dual_ir::CC.IncrementalCompact,
primal_ir::IRCode,
::MooncakeInterpreter,
stmt::ReturnNode,
i::Integer;
kwargs...,
)
# make sure that we always return a Dual even when it's a constant
Mooncake.replace_call!(dual_ir, CC.SSAValue(i), Expr(:call, _dual, inc_args(stmt).val))
# return the result from the previous Dual conversion
new_return_inst = CC.NewInstruction(
Core.ReturnNode(CC.SSAValue(i)), Any, CC.NoCallInfo(), Int32(1), CC.IR_FLAG_REFINED
)
CC.insert_node_here!(dual_ir, new_return_inst, true)
return nothing
end
function modify_fwd_ad_stmts!(
dual_ir::CC.IncrementalCompact,
primal_ir::IRCode,
::MooncakeInterpreter,
stmt::PhiNode,
i::Integer;
kwargs...,
)
dual_ir[SSAValue(i)][:stmt] = inc_args(stmt) # TODO: translate constants into constant Duals
dual_ir[SSAValue(i)][:type] = Any
dual_ir[SSAValue(i)][:flag] = CC.IR_FLAG_REFINED
return nothing
end
function modify_fwd_ad_stmts!(
dual_ir::CC.IncrementalCompact,
primal_ir::IRCode,
::MooncakeInterpreter,
stmt::PiNode,
i::Integer;
kwargs...,
)
dual_ir[SSAValue(i)][:stmt] = inc_args(
PiNode(stmt.val, Dual{stmt.typ,tangent_type(stmt.typ)})
) # TODO: improve?
dual_ir[SSAValue(i)][:type] = Any
dual_ir[SSAValue(i)][:flag] = CC.IR_FLAG_REFINED
return nothing
end
## Modification of IR nodes - expressions
struct DualArguments{FR}
frule::FR
end
function Base.show(io::IO, da::DualArguments)
return print(io, "DualArguments($(da.frule))")
end
# TODO: wrapping in Dual must not be systematic (e.g. Argument or SSAValue)
function (da::DualArguments)(f::F, args::Vararg{Any,N}) where {F,N}
return da.frule(tuple_map(_dual, (f, args...))...)
end
struct DynamicFRule{V}
cache::V
debug_mode::Bool
end
DynamicFRule(debug_mode::Bool) = DynamicFRule(Dict{Any,Any}(), debug_mode)
_copy(x::P) where {P<:DynamicFRule} = P(Dict{Any,Any}(), x.debug_mode)
function (dynamic_rule::DynamicFRule)(args::Vararg{Any,N}) where {N}
args_dual = map(_dual, args) # TODO: don't turn everything into a Dual, be clever with Argument and SSAValue
sig = Tuple{map(_typeof ∘ primal, args_dual)...}
rule = get(dynamic_rule.cache, sig, nothing)
if rule === nothing
rule = build_frule(get_interpreter(), sig; debug_mode=dynamic_rule.debug_mode)
dynamic_rule.cache[sig] = rule
end
return rule(args_dual...)
end
function modify_fwd_ad_stmts!(
dual_ir::CC.IncrementalCompact,
primal_ir::IRCode,
interp::MooncakeInterpreter,
stmt::Expr,
i::Integer;
debug_mode,
)
if isexpr(stmt, :invoke) || isexpr(stmt, :call)
sig, mi = if isexpr(stmt, :invoke)
mi = stmt.args[1]::Core.MethodInstance
mi.specTypes, mi
else
sig_types = map(stmt.args) do a
get_forward_primal_type(primal_ir, a)
end
Tuple{sig_types...}, missing
end
shifted_args = if isexpr(stmt, :invoke)
inc_args(stmt).args[2:end] # first arg is method instance
else
inc_args(stmt).args
end
if is_primitive(context_type(interp), sig)
call_frule = Expr(:call, DualArguments(frule!!), shifted_args...)
replace_call!(dual_ir, SSAValue(i), call_frule)
else
if isexpr(stmt, :invoke)
rule = build_frule(interp, mi; debug_mode)
else
@assert isexpr(stmt, :call)
rule = DynamicFRule(debug_mode)
end
# TODO: could this insertion of a naked rule in the IR cause a memory leak?
call_rule = Expr(:call, DualArguments(rule), shifted_args...)
replace_call!(dual_ir, SSAValue(i), call_rule)
end
elseif isexpr(stmt, :boundscheck)
nothing
elseif isexpr(stmt, :code_coverage_effect)
replace_call!(dual_ir, SSAValue(i), nothing)
else
throw(
ArgumentError(
"Expressions of type `:$(stmt.head)` are not yet supported in forward mode"
),
)
end
end
get_forward_primal_type(ir::IRCode, a::Argument) = ir.argtypes[a.n]
get_forward_primal_type(ir::IRCode, ssa::SSAValue) = ir[ssa][:type]
get_forward_primal_type(::IRCode, x::QuoteNode) = _typeof(x.value)
get_forward_primal_type(::IRCode, x) = _typeof(x)
function get_forward_primal_type(::IRCode, x::GlobalRef)
return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty
end
function get_forward_primal_type(::IRCode, x::Expr)
x.head === :boundscheck && return Bool
return error("Unrecognised expression $x found in argument slot.")
end