Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Julia: add binding for runtime feature detection (#13992)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored Mar 12, 2019
1 parent 89bebd1 commit 73b29fa
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 7 deletions.
7 changes: 4 additions & 3 deletions julia/deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ elseif Sys.islinux()
end

if Sys.isunix()
try
push!(CUDAPATHS, replace(strip(read(`which nvcc`, String)), "bin/nvcc", "lib64"))
catch
nvcc_path = Sys.which("nvcc")
if nvcc_path nothing
@info "Found nvcc: $nvcc_path"
push!(CUDAPATHS, replace(nvcc_path, "bin/nvcc", "lib64"))
end
end

Expand Down
1 change: 1 addition & 0 deletions julia/src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ export to_graphviz

include("base.jl")

include("runtime.jl")
include("context.jl")
include("util.jl")

Expand Down
7 changes: 3 additions & 4 deletions julia/src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,11 @@ function mx_get_last_error()
end

"Utility macro to call MXNet API functions"
macro mxcall(fv, argtypes, args...)
f = eval(fv)
macro mxcall(f, argtypes, args...)
args = map(esc, args)
quote
_mxret = ccall(($(QuoteNode(f)), $MXNET_LIB),
Cint, $argtypes, $(args...))
_mxret = ccall(($f, $MXNET_LIB),
Cint, $(esc(argtypes)), $(args...))
if _mxret != 0
err_msg = mx_get_last_error()
throw(MXError(err_msg))
Expand Down
76 changes: 76 additions & 0 deletions julia/src/runtime.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# License); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# runtime detection of compile time features in the native library

module MXRuntime

using ..mx

export LibFeature
export feature_list, isenabled

# defined in include/mxnet/c_api.h
struct LibFeature
_name::Ptr{Cchar}
enabled::Bool
end

function Base.getproperty(x::LibFeature, p::Symbol)
(p == :name) && return unsafe_string(getfield(x, :_name))
getfield(x, p)
end

Base.show(io::IO, x::LibFeature) =
print(io, ifelse(x.enabled, "", ""), " ", x.name)

"""
feature_list()
Check the library for compile-time features.
The list of features are maintained in libinfo.h and libinfo.cc
"""
function feature_list()
ref = Ref{Ptr{LibFeature}}(C_NULL)
s = Ref{Csize_t}(C_NULL)
@mx.mxcall(:MXLibInfoFeatures, (Ref{Ptr{LibFeature}}, Ref{Csize_t}), ref, s)
unsafe_wrap(Array, ref[], s[])
end

"""
isenabled(x::Symbol)::Bool
Returns the given runtime feature is enabled or not.
```julia-repl
julia> mx.isenabled(:CUDA)
false
julia> mx.isenabled(:CPU_SSE)
true
```
See also `mx.feature_list()`.
"""
isenabled(x::Symbol) =
any(feature_list()) do i
Symbol(i.name) == x && i.enabled
end

end # module MXRuntime

using .MXRuntime

0 comments on commit 73b29fa

Please sign in to comment.