Skip to content

Commit 48b4eff

Browse files
committed
Call the generated mpi recv mlir binding
1 parent f21707d commit 48b4eff

File tree

1 file changed

+61
-49
lines changed

1 file changed

+61
-49
lines changed

ext/ReactantMPIExt/Ops.jl

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -409,63 +409,75 @@ function isend(
409409
return request # we return a TracedRNumber, converted to TracedRequest in Overrides.jl
410410
end
411411

412-
function recv!(
413-
recvbuf::TracedRArray,
414-
tag::TracedRNumber,
415-
src::TracedRNumber;
416-
location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__),
417-
)
418-
T = Reactant.unwrapped_eltype(recvbuf)
419-
mpi_datatype = MPI.Datatype(T)
420-
mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
412+
#function recv!(
413+
# recvbuf::TracedRArray,
414+
# tag::TracedRNumber,
415+
# src::TracedRNumber;
416+
# location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__),
417+
#)
418+
# T = Reactant.unwrapped_eltype(recvbuf)
419+
# mpi_datatype = MPI.Datatype(T)
420+
# mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
421421

422-
sym_name = "enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name)"
423-
sym_attr = IR.FlatSymbolRefAttribute(sym_name)
422+
# sym_name = "enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name)"
423+
# sym_attr = IR.FlatSymbolRefAttribute(sym_name)
424424

425-
IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr")
426-
IR.inject!(
427-
"MPI_STATUS_IGNORE", "llvm.mlir.global constant @MPI_STATUS_IGNORE() : !llvm.ptr"
428-
)
429-
IR.inject!(
430-
"MPI_Recv",
431-
"llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32",
432-
)
425+
# IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr")
426+
# IR.inject!(
427+
# "MPI_STATUS_IGNORE", "llvm.mlir.global constant @MPI_STATUS_IGNORE() : !llvm.ptr"
428+
# )
429+
# IR.inject!(
430+
# "MPI_Recv",
431+
# "llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32",
432+
# )
433433

434-
#! format: off
435-
IR.inject!(sym_name, """
436-
func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () {
437-
%comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
438-
%datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr
439-
%status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr
440-
%count = llvm.load %count_ptr : !llvm.ptr -> i32
441-
%source = llvm.load %source_ptr : !llvm.ptr -> i32
442-
%tag = llvm.load %tag_ptr : !llvm.ptr -> i32
443-
llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32)
444-
func.return
445-
}
446-
""")
447-
#! format: on
434+
# #! format: off
435+
# IR.inject!(sym_name, """
436+
# func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () {
437+
# %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
438+
# %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr
439+
# %status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr
440+
# %count = llvm.load %count_ptr : !llvm.ptr -> i32
441+
# %source = llvm.load %source_ptr : !llvm.ptr -> i32
442+
# %tag = llvm.load %tag_ptr : !llvm.ptr -> i32
443+
# llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32)
444+
# func.return
445+
# }
446+
# """)
447+
# #! format: on
448448

449-
count = Reactant.Ops.constant(Int32(length(recvbuf)))
449+
# count = Reactant.Ops.constant(Int32(length(recvbuf)))
450450

451-
output_operand_aliases = IR.Attribute([
452-
IR.Attribute(
453-
MLIR.API.stablehloOutputOperandAliasGet(
454-
MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL
455-
),
456-
),
457-
])
451+
# output_operand_aliases = IR.Attribute([
452+
# IR.Attribute(
453+
# MLIR.API.stablehloOutputOperandAliasGet(
454+
# MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL
455+
# ),
456+
# ),
457+
# ])
458458

459-
ret = enzymexla.jit_call(
460-
IR.Value[recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data];
461-
fn=sym_attr,
462-
result_0=[mlir_type(recvbuf)],
463-
output_operand_aliases,
464-
location,
465-
)
459+
# ret = enzymexla.jit_call(
460+
# IR.Value[recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data];
461+
# fn=sym_attr,
462+
# result_0=[mlir_type(recvbuf)],
463+
# output_operand_aliases,
464+
# location,
465+
# )
466466

467-
recvbuf.mlir_data = IR.result(ret)
467+
# recvbuf.mlir_data = IR.result(ret)
468+
469+
# return recvbuf
470+
#end
468471

472+
@noinline function recv!(
473+
recvbuf::TracedRArray,
474+
tag::TracedRNumber,
475+
src::TracedRNumber;
476+
location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__),
477+
)
478+
count = Reactant.Ops.constant(Int32(length(recvbuf)))
479+
ret = enzymexla.recv(count.mlir_data, src.mlir_data, tag.mlir_data; buf=mlir_type(recvbuf), location)
480+
recvbuf.mlir_data = IR.result(ret)
469481
return recvbuf
470482
end
471483

0 commit comments

Comments
 (0)