Skip to content

Commit 556d014

Browse files
Pythoncall (#407)
* WIP: pythoncall * fix * fix * Update Project.toml * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent a9f8e15 commit 556d014

File tree

6 files changed

+64
-3
lines changed

6 files changed

+64
-3
lines changed

CondaPkg.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
jax = ""

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,17 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
2424
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
2525
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2626
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
27+
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
2728
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
2829
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2930
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
3031

31-
[sources]
32-
ReactantCore = {path = "lib/ReactantCore"}
33-
3432
[extensions]
3533
ReactantAbstractFFTsExt = "AbstractFFTs"
3634
ReactantArrayInterfaceExt = "ArrayInterface"
3735
ReactantCUDAExt = "CUDA"
3836
ReactantNNlibExt = "NNlib"
37+
ReactantPythonCallExt = "PythonCall"
3938
ReactantRandom123Ext = "Random123"
4039
ReactantStatisticsExt = "Statistics"
4140
ReactantYaoBlocksExt = "YaoBlocks"
@@ -68,3 +67,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
6867
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
6968
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
7069
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
70+
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
71+
72+
[sources.ReactantCore]
73+
path = "lib/ReactantCore"

ext/ReactantPythonCallExt.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
module ReactantPythonCallExt
2+
3+
using PythonCall
4+
using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
5+
using ReactantCore: @trace
6+
7+
using PythonCall
8+
9+
const jaxptr = Ref{Py}()
10+
11+
function PythonCall.pycall(
12+
f::Py, arg0::Reactant.TracedRArray, argNs::Reactant.TracedRArray...; kwargs...
13+
)
14+
jax = jaxptr[]
15+
numpy = jax.numpy
16+
inputs = map((arg0, argNs...)) do arg
17+
JT = eltype(arg)
18+
PT = nothing
19+
for (CPT, CJT) in PythonCall.Convert.NUMPY_SIMPLE_TYPES
20+
if JT == CJT
21+
PT = CPT
22+
break
23+
end
24+
end
25+
numpy.zeros(size(arg); dtype=getproperty(numpy, Symbol(PT)))
26+
end
27+
lowered = jax.jit(f).lower(inputs...)
28+
txt = pyconvert(String, lowered.as_text())
29+
res = Reactant.Ops.hlo_call(txt, arg0, argNs...)
30+
if length(res) == 0
31+
return nothing
32+
else
33+
return res[1]
34+
end
35+
end
36+
37+
function __init__()
38+
return jaxptr[] = pyimport("jax")
39+
end
40+
41+
end # module ReactantPythonCallExt

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1616
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1717
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1818
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
19+
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2021
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
2122
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"

test/python.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Reactant
2+
using Reactant: Ops
3+
4+
using Test
5+
using PythonCall
6+
7+
@testset "PythonCall" begin
8+
jax = pyimport("jax")
9+
10+
result = @jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3]))
11+
@test typeof(result) == ConcreteRNumber{Float32}
12+
@test result 6
13+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5656
@safetestset "Shortcuts to MLIR ops" include("ops.jl")
5757
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
5858
@safetestset "Control Flow" include("control_flow.jl")
59+
@safetestset "Python" include("python.jl")
5960
end
6061

6162
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"

0 commit comments

Comments
 (0)