-
Notifications
You must be signed in to change notification settings - Fork 88
/
InterfaceDynamicExpressions.jl
318 lines (277 loc) · 11 KB
/
InterfaceDynamicExpressions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
module InterfaceDynamicExpressionsModule
using Printf: @sprintf
using DynamicExpressions: DynamicExpressions
using DynamicExpressions:
OperatorEnum, GenericOperatorEnum, AbstractExpressionNode, Node, GraphNode
using DynamicExpressions.StringsModule: needs_brackets
using DynamicQuantities: dimension, ustrip
using ..CoreModule: Options
using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
using ..UtilsModule: subscriptify
import DynamicExpressions:
eval_tree_array,
eval_diff_tree_array,
eval_grad_tree_array,
print_tree,
string_tree,
differentiable_eval_tree_array
import ..deprecate_varmap
"""
eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...)
Evaluate a binary tree (equation) over a given input data matrix. The
operators contain all of the operators used. This function fuses doublets
and triplets of operations for lower memory usage.
This function can be represented by the following pseudocode:
```
function eval(current_node)
if current_node is leaf
return current_node.value
elif current_node is degree 1
return current_node.operator(eval(current_node.left_child))
else
return current_node.operator(eval(current_node.left_child), eval(current_node.right_child))
```
The bulk of the code is for optimizations and pre-emptive NaN/Inf checks,
which speed up evaluation significantly.
# Arguments
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
- `X::AbstractArray`: The input data to evaluate the tree on.
- `options::Options`: Options used to define the operators used in the tree.
# Returns
- `(output, complete)::Tuple{AbstractVector, Bool}`: the result,
which is a 1D array, as well as if the evaluation completed
successfully (true/false). A `false` complete means an infinity
or nan was encountered, and a large loss should be assigned
to the equation.
"""
function eval_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
)
A = expected_array_type(X)
return eval_tree_array(
tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...
)::Tuple{A,Bool}
end
# Improve type inference by telling Julia the expected array returned
function expected_array_type(X::AbstractArray)
return typeof(similar(X, axes(X, 2)))
end
"""
eval_diff_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int)
Compute the forward derivative of an expression, using a similar
structure and optimization to eval_tree_array. `direction` is the index of a particular
variable in the expression. e.g., `direction=1` would indicate derivative with
respect to `x1`.
# Arguments
- `tree::AbstractExpressionNode`: The expression tree to evaluate.
- `X::AbstractArray`: The data matrix, with each column being a data point.
- `options::Options`: The options containing the operators used to create the `tree`.
- `direction::Int`: The index of the variable to take the derivative with respect to.
# Returns
- `(evaluation, derivative, complete)::Tuple{AbstractVector, AbstractVector, Bool}`: the normal evaluation,
the derivative, and whether the evaluation completed as normal (or encountered a nan or inf).
"""
function eval_diff_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int
)
A = expected_array_type(X)
return eval_diff_tree_array(tree, X, options.operators, direction)::Tuple{A,A,Bool}
end
"""
eval_grad_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options; variable::Bool=false)
Compute the forward-mode derivative of an expression, using a similar
structure and optimization to eval_tree_array. `variable` specifies whether
we should take derivatives with respect to features (i.e., `X`), or with respect
to every constant in the expression.
# Arguments
- `tree::AbstractExpressionNode`: The expression tree to evaluate.
- `X::AbstractArray`: The data matrix, with each column being a data point.
- `options::Options`: The options containing the operators used to create the `tree`.
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
or with respect to every constant in the expression (`variable=false`).
# Returns
- `(evaluation, gradient, complete)::Tuple{AbstractVector, AbstractArray, Bool}`: the normal evaluation,
the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
"""
function eval_grad_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
)
A = expected_array_type(X)
M = typeof(X) # TODO: This won't work with StaticArrays!
return eval_grad_tree_array(tree, X, options.operators; kws...)::Tuple{A,M,Bool}
end
"""
differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options)
Evaluate an expression tree in a way that can be auto-differentiated.
"""
function differentiable_eval_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
)
A = expected_array_type(X)
return differentiable_eval_tree_array(tree, X, options.operators; kws...)::Tuple{A,Bool}
end
const WILDCARD_UNIT_STRING = "[?]"
"""
string_tree(tree::AbstractExpressionNode, options::Options; kws...)
Convert an equation to a string.
# Arguments
- `tree::AbstractExpressionNode`: The equation to convert to a string.
- `options::Options`: The options holding the definition of operators.
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
to print for each feature.
"""
@inline function string_tree(
tree::AbstractExpressionNode,
options::Options;
raw::Bool=true,
X_sym_units=nothing,
y_sym_units=nothing,
variable_names=nothing,
display_variable_names=variable_names,
varMap=nothing,
kws...,
)
variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
if raw
tree = tree isa GraphNode ? convert(Node, tree) : tree
return string_tree(
tree, options.operators; f_variable=string_variable_raw, variable_names
)
end
vprecision = vals[options.print_precision]
if X_sym_units !== nothing || y_sym_units !== nothing
return string_tree(
tree,
options.operators;
f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units),
f_constant=let
unit_placeholder =
options.dimensionless_constants_only ? "" : WILDCARD_UNIT_STRING
(val,) -> string_constant(val, vprecision, unit_placeholder)
end,
variable_names=display_variable_names,
kws...,
)
else
return string_tree(
tree,
options.operators;
f_variable=string_variable,
f_constant=(val,) -> string_constant(val, vprecision, ""),
variable_names=display_variable_names,
kws...,
)
end
end
const vals = ntuple(Val, 8192)
function string_variable_raw(feature, variable_names)
if variable_names === nothing || feature > length(variable_names)
return "x" * string(feature)
else
return variable_names[feature]
end
end
function string_variable(feature, variable_names, variable_units=nothing)
base = if variable_names === nothing || feature > length(variable_names)
"x" * subscriptify(feature)
else
variable_names[feature]
end
if variable_units !== nothing
base *= format_dimensions(variable_units[feature])
end
return base
end
function string_constant(val, ::Val{precision}, unit_placeholder) where {precision}
if typeof(val) <: Real
return sprint_precision(val, Val(precision)) * unit_placeholder
else
return "(" * string(val) * ")" * unit_placeholder
end
end
function format_dimensions(::Nothing)
return ""
end
function format_dimensions(u)
if isone(ustrip(u))
dim = dimension(u)
if iszero(dim)
return ""
else
return "[" * string(dim) * "]"
end
else
return "[" * string(u) * "]"
end
end
@generated function sprint_precision(x, ::Val{precision}) where {precision}
fmt_string = "%.$(precision)g"
return :(@sprintf($fmt_string, x))
end
"""
print_tree(tree::AbstractExpressionNode, options::Options; kws...)
Print an equation
# Arguments
- `tree::AbstractExpressionNode`: The equation to convert to a string.
- `options::Options`: The options holding the definition of operators.
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
to print for each feature.
"""
function print_tree(tree::AbstractExpressionNode, options::Options; kws...)
return print_tree(tree, options.operators; kws...)
end
function print_tree(io::IO, tree::AbstractExpressionNode, options::Options; kws...)
return print_tree(io, tree, options.operators; kws...)
end
"""
convert(::Type{<:AbstractExpressionNode{T}}, tree::AbstractExpressionNode, options::Options; kws...) where {T}
Convert an equation to a different base type `T`.
"""
function Base.convert(
::Type{N}, tree::AbstractExpressionNode, options::Options
) where {T,N<:AbstractExpressionNode{T}}
return convert(N, tree, options.operators)
end
"""
@extend_operators options
Extends all operators defined in this options object to work on the
`AbstractExpressionNode` type. While by default this is already done for operators defined
in `Base` when you create an options and pass `define_helper_functions=true`,
this does not apply to the user-defined operators. Thus, to do so, you must
apply this macro to the operator enum in the same module you have the operators
defined.
"""
macro extend_operators(options)
operators = :($(options).operators)
type_requirements = Options
@gensym alias_operators
return quote
if !isa($(options), $type_requirements)
error("You must pass an options type to `@extend_operators`.")
end
$alias_operators = $define_alias_operators($operators)
$(DynamicExpressions).@extend_operators $alias_operators
end |> esc
end
function define_alias_operators(operators)
# We undo some of the aliases so that the user doesn't need to use, e.g.,
# `safe_pow(x1, 1.5)`. They can use `x1 ^ 1.5` instead.
constructor = isa(operators, OperatorEnum) ? OperatorEnum : GenericOperatorEnum
return constructor(;
binary_operators=inverse_binopmap.(operators.binops),
unary_operators=inverse_unaopmap.(operators.unaops),
define_helper_functions=false,
empty_old_operators=false,
)
end
function (tree::AbstractExpressionNode)(X, options::Options; kws...)
return tree(X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...)
end
function DynamicExpressions.EvaluationHelpersModule._grad_evaluator(
tree::AbstractExpressionNode, X, options::Options; kws...
)
return DynamicExpressions.EvaluationHelpersModule._grad_evaluator(
tree, X, options.operators; turbo=options.turbo, kws...
)
end
end