Skip to content

Commit

Permalink
Add MPS enum ad struct wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Dec 10, 2024
1 parent bea48d8 commit 5ebc3be
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 37 deletions.
37 changes: 37 additions & 0 deletions res/wrap/libmps.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[general]
library_name = "libmpss"
output_file_path = "lib/mps/libmps.jl"
prologue_file_path = "res/wrap/libmps_prologue.jl"
#no_audit = true

printer_blacklist = [
"mt_macCatalyst",
"mt_ios",
"mt_macos",
"CF.*",
"MTL.*",
"NS.*",
"BOOL"
]

[codegen]
use_ccall_macro = true
always_NUL_terminated_string = true

[codegen.macro]
# it's highly recommended to set this entry to "basic".
# if you'd like to skip all of the macros, please set this entry to "disable".
# if you'd like to translate function-like macros to Julia, please set this entry to "aggressive".
macro_mode = "disable"

[api.MPSSize]
constructor = "MPSSize(w=1.0, h=1.0, d=1.0) = new(w, h, d)"

[api.MPSRegion]
constructor = "MPSRegion(origin=MPSOrigin(), size=MPSSize()) = new(origin, size)"

[api.MPSOrigin]
constructor = "MPSOrigin(x=0.0, y=0.0, z=0.0) = new(x, y, z)"

[api.MPSOffset]
constructor = "MPSOffset(x=0, y=0, z=0) = new(x, y, z)"
76 changes: 39 additions & 37 deletions res/wrap/wrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,53 @@ SDK_PATH = `xcrun --show-sdk-path` |> open |> readchomp |> String
# Hack to prevent printing of functions for now
Generators.skip_check(dag::Generators.ExprDAG, node::Generators.ExprNode{Generators.FunctionProto}) = true

function main(name="all"; sdk_path=SDK_PATH)
main(name::AbstractString; kwargs...) = main([name]; kwargs...)
function main(names=["all"]; sdk_path=SDK_PATH)
path_to_framework(framework) = joinpath(sdk_path, "System/Library/Frameworks/",framework*".framework","Headers")

mtl = path_to_framework("Metal")
# mps = path_to_framework("MetalPerformanceShaders")
# foundation = path_to_framework("Foundation")
# cf = path_to_framework("CoreFoundation")
path_to_mps_framework(framework) = joinpath(sdk_path, "System/Library/Frameworks/","MetalPerformanceShaders.framework","Frameworks",framework*".framework","Headers")

defines = []

ctxs = []

if name == "all" || name == "libmtl" || name == "mtl"
tctx = wrap("libmtl", joinpath(mtl, "Metal.h"); targets=glob("*.h", mtl), defines,
include_dirs=[mtl])
push!(ctxs, tctx)
if "all" in names
names = ["all"]
end

# if name == "all" || name == "libmps" || name == "mps"
# tctx = wrap("libmps", joinpath(mps, "MetalPerformanceShaders.h"); targets=[glob("../Frameworks/*/Headers/*.h", mps)...,glob("*.h", mps)...], defines,
# include_dirs=[mps])
# push!(ctxs, tctx)
# end
for name in names
if name == "all" || name == "libmtl" || name == "mtl"
fwpath = path_to_framework("Metal")
tctx = wrap("libmtl", joinpath(fwpath, "Metal.h"); targets=glob("*.h", fwpath), defines,
include_dirs=[fwpath])
push!(ctxs, tctx)
end

# if name == "libfoundation" || name == "foundation"
# tctx = wrap("libfoundation", joinpath(foundation, "Foundation.h"); targets=glob("*.h", foundation), defines=["__builtin_va_list"],
# include_dirs=[foundation])
# push!(ctxs, tctx)
# end
# if name == "libcf" || name == "cf"
# tctx = wrap("libfoundation", joinpath(cf, "CoreFoundation.h"); targets=glob("*.h", cf), defines=["__builtin_va_list"],
# include_dirs=[cf])
# push!(ctxs, tctx)
# end
if name == "all" || name == "libmps" || name == "mps"
mpsframeworks = ["MPSCore", "MPSImage", "MPSMatrix", "MPSNDArray", "MPSNeuralNetwork", "MPSRayIntersector"]
fwpaths = [path_to_framework("MetalPerformanceShaders")]
fwpaths = append!(fwpaths, path_to_mps_framework.(mpsframeworks))

getheaderfname(path) = Sys.splitext(Sys.splitpath(path)[end-1])[1] * ".h"
headers = joinpath.(fwpaths, getheaderfname.(fwpaths))

tctx = wrap("libmps", headers; defines,
include_dirs=fwpaths)
push!(ctxs, tctx)
end

if name == "all" || name == "libfoundation" || name == "foundation"
fwpath = path_to_framework("Foundation")
tctx = wrap("libfoundation", joinpath(foundation, "Foundation.h"); targets=glob("*.h", fwpath), defines=["__builtin_va_list"],
include_dirs=[fwpath])
push!(ctxs, tctx)
end
# if name == "all" || name == "libcf" || name == "cf"
# fwpath = path_to_framework("CoreFoundation")
# tctx = wrap("libfoundation", joinpath(fwpath, "CoreFoundation.h"); targets=glob("*.h", fwpath), defines=["__builtin_va_list"],
# include_dirs=[fwpath])
# push!(ctxs, tctx)
# end
end
return ctxs
end

Expand Down Expand Up @@ -86,17 +99,6 @@ function wrap(name, headers; targets=headers, defines=[], include_dirs=[], preco
@info "Building no printing"
build!(ctx, BUILDSTAGE_NO_PRINTING)

replace!(get_nodes(ctx.dag)) do node
path = normpath(Clang.get_filename(node.cursor))
should_wrap = any(targets) do target
occursin(target, path)
end
if !should_wrap
return ExprNode(node.id, Generators.Skip(), node.cursor, Expr[], node.adj)
end
return node
end

rewriter!(ctx, options)

@info "Building only printing"
Expand All @@ -113,7 +115,7 @@ function create_objc_context(headers::Vector, args::Vector=String[], options::Di
"/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain"
]

# Since The framework we're wrapping is a system header,
# Since the framework we're wrapping is a system header,
# find all dependent headers, then remove all but the relevant ones
# also temporarily disable logging
Base.CoreLogging._min_enabled_level[] = Logging.Info+1
Expand Down

0 comments on commit 5ebc3be

Please sign in to comment.