Skip to content

Commit

Permalink
disable apply iterate handling
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Apr 17, 2020
1 parent f50866a commit 4e1e635
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 22 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ julia:
- 1.1
- 1.2
- 1.3
- 1.4
- nightly
matrix:
allow_failures:
Expand Down
6 changes: 0 additions & 6 deletions src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,6 @@ macro context(_Ctx)
@inline Cassette.overdub(::C, ::Typeof(Tag), ::Type{N}, ::Type{X}) where {C<:$Ctx,N,X} = Tag(N, X, tagtype(C))

@inline Cassette.overdub(ctx::$Ctx, ::typeof(Core._apply), f, args...) = Core._apply(overdub, (ctx, f), args...)
if VERSION >= v"1.4.0-DEV.304"
@inline function Cassette.overdub(ctx::$Ctx, ::typeof(Core._apply_iterate), f, args...)
new_args = ((_args...) -> overdub(ctx, args[1], _args...), Base.tail(args)...)
Core._apply_iterate((args...)->overdub(ctx, f, args...), new_args...)
end
end

# TODO: There are certain non-`Core.Builtin` functions which the compiler often
# relies upon constant propagation/tfuncs to infer, instead of specializing on
Expand Down
6 changes: 0 additions & 6 deletions src/overdub.jl
Original file line number Diff line number Diff line change
Expand Up @@ -522,12 +522,6 @@ function overdub end
function recurse end

recurse(ctx::Context, ::typeof(Core._apply), f, args...) = Core._apply(recurse, (ctx, f), args...)
if VERSION >= v"1.4.0-DEV.304"
function recurse(ctx::Context, ::typeof(Core._apply_iterate), f, args...)
new_args = ((_args...) -> overdub(ctx, args[1], _args...), Base.tail(args)...)
Core._apply_iterate((args...)->recurse(ctx, f, args...), new_args...)
end
end

function overdub_definition(line, file)
return quote
Expand Down
18 changes: 14 additions & 4 deletions test/misctaggingtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,15 @@ Cassette.overdub(ctx::DiffCtx, ::typeof(+), x, y, z) = Cassette.overdub(ctx, +,
@test D(sin, 1) === cos(1)
@test D(x -> D(sin, x), 1) === -sin(1)
@test D(x -> sin(x) * cos(x), 1) === cos(1)^2 - sin(1)^2
@test D(x -> x * D(y -> x * y, 1), 2) === 4
@test D(x -> x * D(y -> x * y, 2), 1) === 2
@test D(x -> x * D(y -> 5*x*y, 3), 2) === 20
if VERSION < v"1.4"
@test D(x -> x * D(y -> x * y, 1), 2) === 4
@test D(x -> x * D(y -> x * y, 2), 1) === 2
@test D(x -> x * D(y -> 5*x*y, 3), 2) === 20
else
@test_broken D(x -> x * D(y -> x * y, 1), 2) === 4
@test_broken D(x -> x * D(y -> x * y, 2), 1) === 2
@test_broken D(x -> x * D(y -> 5*x*y, 3), 2) === 20
end
@test D(x -> x * foo_bar_identity(x), 1) === 2.0

x = rand()
Expand All @@ -493,6 +499,10 @@ ctx = enabletagging(ArrayIndexCtx(), matrixliteral)
result = overdub(ctx, matrixliteral, tag(1, ctx, "hi"))

@test untag(result, ctx) == matrixliteral(1)
@test metameta(result, ctx) == fill(Cassette.Meta("hi", Cassette.NoMetaMeta()), 2, 2)
if VERSION < v"1.4"
@test metameta(result, ctx) == fill(Cassette.Meta("hi", Cassette.NoMetaMeta()), 2, 2)
else
@test_broken metameta(result, ctx) == fill(Cassette.Meta("hi", Cassette.NoMetaMeta()), 2, 2)
end

println("done (took ", time() - before_time, " seconds)")
21 changes: 15 additions & 6 deletions test/misctests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ end
tracekw = Any[]
@overdub(TraceCtx(metadata = tracekw), trkwtest(x, _y = y, _z = z)) == trtest(x, y, z)
subtracekw = first(Iterators.filter(t -> t[1] === (Core.kwfunc(trkwtest), (_y = y, _z = z), trkwtest, x), tracekw))[2]
@test subtracekw == trace
if VERSION < v"1.4"
@test subtracekw == trace
else
@test_broken subtracekw == trace
end

function enter!(t::HookTrace, args...)
pair = args => Any[]
Expand Down Expand Up @@ -509,8 +513,13 @@ if VERSION >= v"1.1-"
ff73(2, 50) # warm up
fff73(2, 50) # warm up

if VERSION < v"1.4"
@test @allocated(f73(2, 50)) == 0
@test @allocated(ff73(2, 50)) == 0
else
@test_broken @allocated(f73(2, 50)) == 0
@test_broken @allocated(ff73(2, 50)) == 0
end
@test_broken @allocated(fff73(2, 50)) == 0

println("done (took ", time() - before_time, " seconds)")
Expand Down Expand Up @@ -699,11 +708,11 @@ end

launch(s::Silo) = (s...,)

if VERSION >= v"1.4.0-DEV.304"
@test Cassette.overdub(NukeContext(), launch, Silo()) === ()
else
# if VERSION >= v"1.4.0-DEV.304"
# @test Cassette.overdub(NukeContext(), launch, Silo()) === ()
# else
@test_broken Cassette.overdub(NukeContext(), launch, Silo()) === ()
end
# end

if VERSION >= v"1.4.0-DEV.304"
Cassette.@context ApplyIterateCtx;
Expand All @@ -716,5 +725,5 @@ if VERSION >= v"1.4.0-DEV.304"
end

Cassette.overdub(ApplyIterateCtx(), ()->pi*2.0)
@test instructions[end] === (Core.Intrinsics.mul_float, Float64, Float64)
@test_broken instructions[end] === (Core.Intrinsics.mul_float, Float64, Float64)
end

0 comments on commit 4e1e635

Please sign in to comment.