From 5b89b56d39fe395977369cb4bf7cc61abaf21132 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 21:23:33 -0500 Subject: [PATCH] Fix ReactantPythonCallExt.jl (#419) * Fix ReactantPythonCallExt.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantPythonCallExt.jl | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/ext/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt.jl index d42945018..be5b61fdd 100644 --- a/ext/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt.jl @@ -8,6 +8,24 @@ using PythonCall const jaxptr = Ref{Py}() +const NUMPY_SIMPLE_TYPES = ( + ("bool_", Bool), + ("int8", Int8), + ("int16", Int16), + ("int32", Int32), + ("int64", Int64), + ("uint8", UInt8), + ("uint16", UInt16), + ("uint32", UInt32), + ("uint64", UInt64), + ("float16", Float16), + ("float32", Float32), + ("float64", Float64), + ("complex32", ComplexF16), + ("complex64", ComplexF32), + ("complex128", ComplexF64), +) + function PythonCall.pycall( f::Py, arg0::Reactant.TracedRArray, argNs::Reactant.TracedRArray...; kwargs... ) @@ -16,7 +34,7 @@ function PythonCall.pycall( inputs = map((arg0, argNs...)) do arg JT = eltype(arg) PT = nothing - for (CPT, CJT) in PythonCall.Convert.NUMPY_SIMPLE_TYPES + for (CPT, CJT) in NUMPY_SIMPLE_TYPES if JT == CJT PT = CPT break