Skip to content

Commit

Permalink
Fix ReactantPythonCallExt.jl (EnzymeAD#419)
Browse files Browse the repository at this point in the history
* Fix ReactantPythonCallExt.jl

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Dec 23, 2024
1 parent 4cc000c commit 5b89b56
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion ext/ReactantPythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ using PythonCall

const jaxptr = Ref{Py}()

const NUMPY_SIMPLE_TYPES = (
("bool_", Bool),
("int8", Int8),
("int16", Int16),
("int32", Int32),
("int64", Int64),
("uint8", UInt8),
("uint16", UInt16),
("uint32", UInt32),
("uint64", UInt64),
("float16", Float16),
("float32", Float32),
("float64", Float64),
("complex32", ComplexF16),
("complex64", ComplexF32),
("complex128", ComplexF64),
)

function PythonCall.pycall(
f::Py, arg0::Reactant.TracedRArray, argNs::Reactant.TracedRArray...; kwargs...
)
Expand All @@ -16,7 +34,7 @@ function PythonCall.pycall(
inputs = map((arg0, argNs...)) do arg
JT = eltype(arg)
PT = nothing
for (CPT, CJT) in PythonCall.Convert.NUMPY_SIMPLE_TYPES
for (CPT, CJT) in NUMPY_SIMPLE_TYPES
if JT == CJT
PT = CPT
break
Expand Down

0 comments on commit 5b89b56

Please sign in to comment.