Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an easy abstract type interface #22

Merged
merged 2 commits into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/CassetteOverlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,15 @@ macro overlaypass(args...)

nonoverlaytype = typeof(CassetteOverlay.nonoverlay)

if method_table !== :nothing
mthd_tbl = :($CassetteOverlay.method_table(::Type{$PassName}) = $(esc(method_table)))
else
mthd_tbl = nothing
end

blk = quote
$decl_pass

$CassetteOverlay.method_table(::Type{$PassName}) = $(esc(method_table))
$mthd_tbl

@inline function (::$PassName)(f::Union{Core.Builtin,Core.IntrinsicFunction}, args...)
@nospecialize f args
Expand All @@ -187,6 +192,10 @@ macro overlaypass(args...)
@nospecialize args
return f(args...)
end
@inline function (self::$PassName)(::typeof(Core._apply_iterate), iterate, f, args...)
@nospecialize args
return Core.Compiler._apply_iterate(iterate, self, (f,), args...)
end

@generated function (pass::$PassName)($(esc(:fargs))...)
src = $overlay_generator(pass, fargs)
Expand Down Expand Up @@ -222,4 +231,18 @@ macro overlaypass(args...)
return Expr(:toplevel, blk.args...)
end

abstract type AbstractBindingOverlay{M, S} <: OverlayPass; end
function method_table(::Type{<:AbstractBindingOverlay{M, S}}) where {M, S}
@assert isconst(M, S)
return getglobal(M, S)::Core.MethodTable
end
@overlaypass AbstractBindingOverlay nothing

struct Overlay{M, S} <: AbstractBindingOverlay{M, S}; end
function Overlay(mt::Core.MethodTable)
@assert isconst(mt.module, mt.name)
@assert getglobal(mt.module, mt.name) === mt
return Overlay{mt.module, mt.name}()
end

end # module CassetteOverlay
23 changes: 23 additions & 0 deletions test/abstract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module abstract

using CassetteOverlay, Test
@MethodTable SinTable
mutable struct CosCounter <: CassetteOverlay.AbstractBindingOverlay{@__MODULE__, :SinTable}
ncos::Int
end

function (c::CosCounter)(::typeof(cos), args...)
c.ncos += 1
return cos(args...)
end

@overlay SinTable sin(x::Union{Float32,Float64}) = cos(x);

let pass! = CosCounter(0)
pass!(42) do a
sin(a) * cos(a)
end
@test pass!.ncos == 2
end

end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ using Test
@testset "simple" include("simple.jl")
@testset "math" include("math.jl")
@testset "misc" include("misc.jl")
if VERSION >= v"1.10.0-DEV.90"
# This interface depends on julia#47749
@testset "abstract" include("abstract.jl")
end
end