Skip to content

Commit

Permalink
Add an easy abstract type interface
Browse files Browse the repository at this point in the history
Can be used in place of the macro to give a pass CassetteoOverlay
behavior. The abstract type works by reading the appropriate binding
from the type parameter.

Requires JuliaLang/julia#47749
  • Loading branch information
Keno committed Dec 2, 2022
1 parent 4129df6 commit e5df711
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
27 changes: 25 additions & 2 deletions src/CassetteOverlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,15 @@ macro overlaypass(args...)

nonoverlaytype = typeof(CassetteOverlay.nonoverlay)

if method_table !== :nothing
mthd_tbl = :($CassetteOverlay.method_table(::Type{$PassName}) = $(esc(method_table)))
else
mthd_tbl = nothing
end

blk = quote
$decl_pass

$CassetteOverlay.method_table(::Type{$PassName}) = $(esc(method_table))
$mthd_tbl

@inline function (::$PassName)(f::Union{Core.Builtin,Core.IntrinsicFunction}, args...)
@nospecialize f args
Expand All @@ -187,6 +192,10 @@ macro overlaypass(args...)
@nospecialize args
return f(args...)
end
@inline function (self::$PassName)(::typeof(Core._apply_iterate), iterate, f, args...)
@nospecialize args
return Core.Compiler._apply_iterate(iterate, self, (f,), args...)
end

@generated function (pass::$PassName)($(esc(:fargs))...)
src = $overlay_generator(pass, fargs)
Expand Down Expand Up @@ -222,4 +231,18 @@ macro overlaypass(args...)
return Expr(:toplevel, blk.args...)
end

abstract type AbstractBindingOverlay{M, S} <: OverlayPass; end
function method_table(::Type{<:AbstractBindingOverlay{M, S}}) where {M, S}
@assert isconst(M, S)
return getglobal(M, S)::Core.MethodTable
end
@overlaypass AbstractBindingOverlay nothing

struct Overlay{M, S} <: AbstractBindingOverlay{M, S}; end
function Overlay(mt::Core.MethodTable)
@assert isconst(mt.module, mt.name)
@assert getglobal(mt.module, mt.name) === mt
return Overlay{mt.module, mt.name}()
end

end # module CassetteOverlay
23 changes: 23 additions & 0 deletions test/abstract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module abstract

using CassetteOverlay, Test
@MethodTable SinTable
mutable struct CosCounter <: CassetteOverlay.AbstractBindingOverlay{@__MODULE__, :SinTable}
ncos::Int
end

function (c::CosCounter)(::typeof(cos), args...)
c.ncos += 1
return cos(args...)
end

@overlay SinTable sin(x::Union{Float32,Float64}) = cos(x);

let pass! = CosCounter(0)
pass!(42) do a
sin(a) * cos(a)
end
@test pass!.ncos == 2
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ using Test
@testset "simple" include("simple.jl")
@testset "math" include("math.jl")
@testset "misc" include("misc.jl")
@testset "abstract" include("abstract.jl")
end

0 comments on commit e5df711

Please sign in to comment.