@@ -409,63 +409,75 @@ function isend(
409409 return request # we return a TracedRNumber, converted to TracedRequest in Overrides.jl
410410end
411411
412- function recv! (
413- recvbuf:: TracedRArray ,
414- tag:: TracedRNumber ,
415- src:: TracedRNumber ;
416- location= mlir_stacktrace (" mpi.recv" , @__FILE__ , @__LINE__ ),
417- )
418- T = Reactant. unwrapped_eltype (recvbuf)
419- mpi_datatype = MPI. Datatype (T)
420- mpi_datatype_name = inject_mpi_datatype! (mpi_datatype)
412+ # function recv!(
413+ # recvbuf::TracedRArray,
414+ # tag::TracedRNumber,
415+ # src::TracedRNumber;
416+ # location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__),
417+ # )
418+ # T = Reactant.unwrapped_eltype(recvbuf)
419+ # mpi_datatype = MPI.Datatype(T)
420+ # mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
421421
422- sym_name = " enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name) "
423- sym_attr = IR. FlatSymbolRefAttribute (sym_name)
422+ # sym_name = "enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name)"
423+ # sym_attr = IR.FlatSymbolRefAttribute(sym_name)
424424
425- IR. inject! (" MPI_COMM_WORLD" , " llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr" )
426- IR. inject! (
427- " MPI_STATUS_IGNORE" , " llvm.mlir.global constant @MPI_STATUS_IGNORE() : !llvm.ptr"
428- )
429- IR. inject! (
430- " MPI_Recv" ,
431- " llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32" ,
432- )
425+ # IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr")
426+ # IR.inject!(
427+ # "MPI_STATUS_IGNORE", "llvm.mlir.global constant @MPI_STATUS_IGNORE() : !llvm.ptr"
428+ # )
429+ # IR.inject!(
430+ # "MPI_Recv",
431+ # "llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32",
432+ # )
433433
434- # ! format: off
435- IR. inject! (sym_name, """
436- func.func @$sym_name (%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () {
437- %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
438- %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr
439- %status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr
440- %count = llvm.load %count_ptr : !llvm.ptr -> i32
441- %source = llvm.load %source_ptr : !llvm.ptr -> i32
442- %tag = llvm.load %tag_ptr : !llvm.ptr -> i32
443- llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32)
444- func.return
445- }
446- """ )
447- # ! format: on
434+ # #! format: off
435+ # IR.inject!(sym_name, """
436+ # func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () {
437+ # %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
438+ # %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr
439+ # %status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr
440+ # %count = llvm.load %count_ptr : !llvm.ptr -> i32
441+ # %source = llvm.load %source_ptr : !llvm.ptr -> i32
442+ # %tag = llvm.load %tag_ptr : !llvm.ptr -> i32
443+ # llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32)
444+ # func.return
445+ # }
446+ # """)
447+ # #! format: on
448448
449- count = Reactant. Ops. constant (Int32 (length (recvbuf)))
449+ # count = Reactant.Ops.constant(Int32(length(recvbuf)))
450450
451- output_operand_aliases = IR. Attribute ([
452- IR. Attribute (
453- MLIR. API. stablehloOutputOperandAliasGet (
454- MLIR. IR. context (), 0 , C_NULL , 0 , 0 , C_NULL
455- ),
456- ),
457- ])
451+ # output_operand_aliases = IR.Attribute([
452+ # IR.Attribute(
453+ # MLIR.API.stablehloOutputOperandAliasGet(
454+ # MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL
455+ # ),
456+ # ),
457+ # ])
458458
459- ret = enzymexla. jit_call (
460- IR. Value[recvbuf. mlir_data, count. mlir_data, src. mlir_data, tag. mlir_data];
461- fn= sym_attr,
462- result_0= [mlir_type (recvbuf)],
463- output_operand_aliases,
464- location,
465- )
459+ # ret = enzymexla.jit_call(
460+ # IR.Value[recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data];
461+ # fn=sym_attr,
462+ # result_0=[mlir_type(recvbuf)],
463+ # output_operand_aliases,
464+ # location,
465+ # )
466466
467- recvbuf. mlir_data = IR. result (ret)
467+ # recvbuf.mlir_data = IR.result(ret)
468+
469+ # return recvbuf
470+ # end
468471
472+ @noinline function recv! (
473+ recvbuf:: TracedRArray ,
474+ tag:: TracedRNumber ,
475+ src:: TracedRNumber ;
476+ location= mlir_stacktrace (" mpi.recv" , @__FILE__ , @__LINE__ ),
477+ )
478+ count = Reactant. Ops. constant (Int32 (length (recvbuf)))
479+ ret = enzymexla. recv (count. mlir_data, src. mlir_data, tag. mlir_data; buf= mlir_type (recvbuf), location)
480+ recvbuf. mlir_data = IR. result (ret)
469481 return recvbuf
470482end
471483
0 commit comments