Skip to content

Commit 35e2864

Browse files
committed
fix: check for name and module in function
1 parent eb28a00 commit 35e2864

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

src/utils.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,16 @@ function should_rewrite_ft(@nospecialize(ft))
9595
return false
9696
end
9797
if ft <: Core.Function
98-
mod = ft.name.module
99-
# Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions
100-
if has_ancestor(mod, Reactant.Ops) ||
101-
has_ancestor(mod, Reactant.TracedUtils) ||
102-
has_ancestor(mod, Reactant.MLIR) ||
103-
has_ancestor(mod, Reactant.TracedRandom)
104-
return false
98+
# We need this for closures to work
99+
if hasfield(typeof(ft), :name) && hasfield(typeof(ft.name), :module)
100+
mod = ft.name.module
101+
# Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions
102+
if has_ancestor(mod, Reactant.Ops) ||
103+
has_ancestor(mod, Reactant.TracedUtils) ||
104+
has_ancestor(mod, Reactant.MLIR) ||
105+
has_ancestor(mod, Reactant.TracedRandom)
106+
return false
107+
end
105108
end
106109
end
107110
# Don't rewrite Val

test/compile.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,22 @@ end
127127
@test !occursin("subtract", repr(hlo))
128128
@test !occursin("add", repr(hlo))
129129
end
130+
131+
# While a bit specific, the following is used to check for a bug in `should_rewrite_ft`
132+
function sinusoidal_embedding(
133+
x::AbstractArray{T,4}, min_freq, max_freq, embedding_dims::Int
134+
) where {T}
135+
if size(x)[1:3] != (1, 1, 1)
136+
throw(DimensionMismatch("Input shape must be (1, 1, 1, batch)"))
137+
end
138+
139+
lower, upper = log(T(min_freq)), log(T(max_freq))
140+
n = embedding_dims ÷ 2
141+
x_ = 2 .* x .* exp.(reshape(range(lower, upper; length=n), 1, 1, n, 1))
142+
return cat(sinpi.(x_), cospi.(x_); dims=Val(3))
143+
end
144+
145+
@testset "sinusoidal_embedding" begin
146+
x_ra = Reactant.to_rarray(rand(Float32, 1, 1, 1, 4))
147+
hlo = @code_hlo sinusoidal_embedding(x_ra, 0.1, 10.0, 4)
148+
end

0 commit comments

Comments
 (0)