-
Notifications
You must be signed in to change notification settings - Fork 152
Closed
Labels
Description
using StaticArrays: @SArray
@SArray Float32[1, 2]gives
ERROR: UndefVarError: SArray not defined
indeed:
julia> macroexpand(Main, :(@SArray [1, 2]))
:((StaticArrays.SArray{Tuple{2},T,N,L} where L where N where T)((1, 2)))
julia> macroexpand(Main, :(@SArray Float32[1, 2]))
:(SArray{Tuple{2}, Float32}((1, 2)))The definition
Lines 82 to 235 in a3bca35
| macro SArray(ex) | |
| if !isa(ex, Expr) | |
| error("Bad input for @SArray") | |
| end | |
| if ex.head == :vect # vector | |
| return esc(Expr(:call, SArray{Tuple{length(ex.args)}}, Expr(:tuple, ex.args...))) | |
| elseif ex.head == :ref # typed, vector | |
| return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | |
| elseif ex.head == :hcat # 1 x n | |
| s1 = 1 | |
| s2 = length(ex.args) | |
| return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, ex.args...))) | |
| elseif ex.head == :typed_hcat # typed, 1 x n | |
| s1 = 1 | |
| s2 = length(ex.args) - 1 | |
| return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | |
| elseif ex.head == :vcat | |
| if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m | |
| # Validate | |
| s1 = length(ex.args) | |
| s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1) | |
| s2 = minimum(s2s) | |
| if maximum(s2s) != s2 | |
| error("Rows must be of matching lengths") | |
| end | |
| exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2] | |
| return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, exprs...))) | |
| else # n x 1 | |
| return esc(Expr(:call, SArray{Tuple{length(ex.args), 1}}, Expr(:tuple, ex.args...))) | |
| end | |
| elseif ex.head == :typed_vcat | |
| if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m | |
| # Validate | |
| s1 = length(ex.args) - 1 | |
| s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1) | |
| s2 = minimum(s2s) | |
| if maximum(s2s) != s2 | |
| error("Rows must be of matching lengths") | |
| end | |
| exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2] | |
| return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, exprs...))) | |
| else # typed, n x 1 | |
| return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1, 1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | |
| end | |
| elseif isa(ex, Expr) && ex.head == :comprehension | |
| if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator | |
| error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]") | |
| end | |
| ex = ex.args[1] | |
| n_rng = length(ex.args) - 1 | |
| rng_args = [ex.args[i+1].args[1] for i = 1:n_rng] | |
| rngs = Any[Core.eval(__module__, ex.args[i+1].args[2]) for i = 1:n_rng] | |
| rng_lengths = map(length, rngs) | |
| f = gensym() | |
| f_expr = :($f = ($(Expr(:tuple, rng_args...)) -> $(ex.args[1]))) | |
| # TODO figure out a generic way of doing this... | |
| if n_rng == 1 | |
| exprs = [:($f($j1)) for j1 in rngs[1]] | |
| elseif n_rng == 2 | |
| exprs = [:($f($j1, $j2)) for j1 in rngs[1], j2 in rngs[2]] | |
| elseif n_rng == 3 | |
| exprs = [:($f($j1, $j2, $j3)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3]] | |
| elseif n_rng == 4 | |
| exprs = [:($f($j1, $j2, $j3, $j4)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4]] | |
| elseif n_rng == 5 | |
| exprs = [:($f($j1, $j2, $j3, $j4, $j5)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5]] | |
| elseif n_rng == 6 | |
| exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6]] | |
| elseif n_rng == 7 | |
| exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7]] | |
| elseif n_rng == 8 | |
| exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7, $j8)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7], j8 in rngs[8]] | |
| else | |
| error("@SArray only supports up to 8-dimensional comprehensions") | |
| end | |
| return quote | |
| $(esc(f_expr)) | |
| $(esc(Expr(:call, Expr(:curly, :SArray, Tuple{rng_lengths...}), Expr(:tuple, exprs...)))) | |
| end | |
| elseif isa(ex, Expr) && ex.head == :typed_comprehension | |
| if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator | |
| error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]") | |
| end | |
| T = ex.args[1] | |
| ex = ex.args[2] | |
| n_rng = length(ex.args) - 1 | |
| rng_args = [ex.args[i+1].args[1] for i = 1:n_rng] | |
| rngs = [Core.eval(__module__, ex.args[i+1].args[2]) for i = 1:n_rng] | |
| rng_lengths = map(length, rngs) | |
| f = gensym() | |
| f_expr = :($f = ($(Expr(:tuple, rng_args...)) -> $(ex.args[1]))) | |
| # TODO figure out a generic way of doing this... | |
| if n_rng == 1 | |
| exprs = [:($f($j1)) for j1 in rngs[1]] | |
| elseif n_rng == 2 | |
| exprs = [:($f($j1, $j2)) for j1 in rngs[1], j2 in rngs[2]] | |
| elseif n_rng == 3 | |
| exprs = [:($f($j1, $j2, $j3)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3]] | |
| elseif n_rng == 4 | |
| exprs = [:($f($j1, $j2, $j3, $j4)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4]] | |
| elseif n_rng == 5 | |
| exprs = [:($f($j1, $j2, $j3, $j4, $j5)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5]] | |
| elseif n_rng == 6 | |
| exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6]] | |
| elseif n_rng == 7 | |
| exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7]] | |
| elseif n_rng == 8 | |
| exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7, $j8)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7], j8 in rngs[8]] | |
| else | |
| error("@SArray only supports up to 8-dimensional comprehensions") | |
| end | |
| return quote | |
| $(esc(f_expr)) | |
| $(esc(Expr(:call, Expr(:curly, :SArray, Tuple{rng_lengths...}, T), Expr(:tuple, exprs...)))) | |
| end | |
| elseif isa(ex, Expr) && ex.head == :call | |
| if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp | |
| if length(ex.args) == 1 | |
| error("@SArray got bad expression: $(ex.args[1])()") | |
| else | |
| return quote | |
| if isa($(esc(ex.args[2])), DataType) | |
| $(ex.args[1])($(esc(Expr(:curly, SArray, Expr(:curly, Tuple, ex.args[3:end]...), ex.args[2])))) | |
| else | |
| $(ex.args[1])($(esc(Expr(:curly, SArray, Expr(:curly, Tuple, ex.args[2:end]...))))) | |
| end | |
| end | |
| end | |
| elseif ex.args[1] == :fill | |
| if length(ex.args) == 1 | |
| error("@SArray got bad expression: $(ex.args[1])()") | |
| elseif length(ex.args) == 2 | |
| error("@SArray got bad expression: $(ex.args[1])($(ex.args[2]))") | |
| else | |
| return quote | |
| $(esc(ex.args[1]))($(esc(ex.args[2])), SArray{$(esc(Expr(:curly, Tuple, ex.args[3:end]...)))}) | |
| end | |
| end | |
| else | |
| error("@SArray only supports the zeros(), ones(), rand(), randn(), and randexp() functions.") | |
| end | |
| else | |
| error("Bad input for @SArray") | |
| end | |
| end |