Skip to content

Commit 5b89b56

Browse files
Fix ReactantPythonCallExt.jl (#419)
* Fix ReactantPythonCallExt.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 4cc000c commit 5b89b56

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

ext/ReactantPythonCallExt.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@ using PythonCall
88

99
const jaxptr = Ref{Py}()
1010

11+
const NUMPY_SIMPLE_TYPES = (
12+
("bool_", Bool),
13+
("int8", Int8),
14+
("int16", Int16),
15+
("int32", Int32),
16+
("int64", Int64),
17+
("uint8", UInt8),
18+
("uint16", UInt16),
19+
("uint32", UInt32),
20+
("uint64", UInt64),
21+
("float16", Float16),
22+
("float32", Float32),
23+
("float64", Float64),
24+
("complex32", ComplexF16),
25+
("complex64", ComplexF32),
26+
("complex128", ComplexF64),
27+
)
28+
1129
function PythonCall.pycall(
1230
f::Py, arg0::Reactant.TracedRArray, argNs::Reactant.TracedRArray...; kwargs...
1331
)
@@ -16,7 +34,7 @@ function PythonCall.pycall(
1634
inputs = map((arg0, argNs...)) do arg
1735
JT = eltype(arg)
1836
PT = nothing
19-
for (CPT, CJT) in PythonCall.Convert.NUMPY_SIMPLE_TYPES
37+
for (CPT, CJT) in NUMPY_SIMPLE_TYPES
2038
if JT == CJT
2139
PT = CPT
2240
break

0 commit comments

Comments
 (0)