@@ -478,9 +478,9 @@ function compile_mlir(f, args; client=nothing, kwargs...)
478478 @ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
479479
480480 if client != = nothing
481- backend = XLA. ClientGetPlatformName (client)
481+ backend = XLA. platform_name (client)
482482 else
483- backend = XLA. ClientGetPlatformName (XLA. default_backend[])
483+ backend = XLA. platform_name (XLA. default_backend[])
484484 end
485485 if backend == " CUDA"
486486 backend = " GPU"
@@ -1076,8 +1076,8 @@ function codegen_flatten!(
10761076
10771077 if is_sharded
10781078 carg = inv_seen_args[arg]
1079- condensed_op_sharding = Reactant . Sharding . XLA . CondensedOpSharding (
1080- linear_parameter_shardings[i]
1079+ condensed_op_sharding = convert (
1080+ Reactant . Sharding . XLA . CondensedOpSharding, linear_parameter_shardings[i]
10811081 )
10821082 if Reactant. Sharding. is_sharded (carg)
10831083 # Currently disabling the error since we roundtrip from MHLO to generate
@@ -1102,17 +1102,16 @@ function codegen_flatten!(
11021102 device_ids = vec (mesh)
11031103 for j in 1 : length (mesh)
11041104 buf = Symbol (:buf_ , i, :_ , j)
1105- device_id = device_ids[j]
1105+ local_device_id = device_ids[j]
11061106 slice = device_to_array_slices[j]
11071107 push! (
11081108 flatten_code,
11091109 :($ buf = XLA. synced_buffer (only ($ usbuf[$ (slice). .. ]. data))),
11101110 )
1111- device_ordinal = XLA. device_ordinal (client, device_id)
11121111 sbuf = Symbol (:sbuf_ , i, :_ , j)
1113- device = XLA. ClientGetAddressableDevice (client, device_ordinal )
1112+ device = XLA. get_addressable_device (client, local_device_id )
11141113 push! (flatten_names, sbuf)
1115- push! (flatten_code, :($ sbuf = XLA. CopyBufferToDevice ($ buf, $ device)))
1114+ push! (flatten_code, :($ sbuf = XLA. copy_buffer_to_device ($ buf, $ device)))
11161115 end
11171116 end
11181117 else
@@ -1308,12 +1307,17 @@ Generate Julia code to call the XLA executable.
13081307- `nresults`: The number of results to expect.
13091308"""
13101309function codegen_xla_call (
1311- exec, device, flatten_names, donated_args_mask, nresults, is_sharded:: Bool , mesh_ids
1310+ exec,
1311+ device,
1312+ flatten_names,
1313+ donated_args_mask,
1314+ nresults,
1315+ is_sharded:: Bool ,
1316+ ndevices:: Int ,
13121317)
13131318 flatten_buffer_refs = map (n -> :($ n. buffer), flatten_names)
13141319
1315- base_symbol_name =
1316- is_sharded ? Symbol (:result_buffer_m , length (mesh_ids), :_ ) : :result_buffer_
1320+ base_symbol_name = is_sharded ? Symbol (:result_buffer_m , ndevices, :_ ) : :result_buffer_
13171321 concretized_res_names = Symbol[Symbol (base_symbol_name, i) for i in 1 : nresults]
13181322 concretized_res_code = map (enumerate (concretized_res_names)) do (i, varname)
13191323 :($ varname = linearized_results[$ i])
@@ -1325,21 +1329,20 @@ function codegen_xla_call(
13251329 if is_sharded
13261330 quote
13271331 GC. @preserve $ (flatten_names... ) begin
1328- linearized_results = XLA. ExecutableCall (
1332+ linearized_results = XLA. execute (
13291333 $ exec,
1330- $ (mesh_ids),
13311334 ($ (flatten_buffer_refs... ),),
13321335 $ (Tuple (donated_args_mask)),
13331336 Val ($ nresults),
1334- Val ($ ( length (mesh_ids)) ),
1337+ Val ($ ndevices ),
13351338 )
13361339 end
13371340 $ (concretized_res_code... )
13381341 end
13391342 else
13401343 quote
13411344 GC. @preserve $ (flatten_names... ) begin
1342- linearized_results = XLA. ExecutableCallSharded (
1345+ linearized_results = XLA. execute_sharded (
13431346 $ exec,
13441347 $ (device),
13451348 ($ (flatten_buffer_refs... ),),
@@ -1393,7 +1396,7 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
13931396 if ! allequal (devices_list)
13941397 msg = " Expected all arguments to be on the same device, got:\n "
13951398 for (i, device) in enumerate (devices_list)
1396- msg *= " Device $(i) : $(XLA . DeviceToString (device)) \n "
1399+ msg *= " Device $(i) : $(string (device)) \n "
13971400 end
13981401 throw (ArgumentError (msg))
13991402 end
@@ -1407,17 +1410,13 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
14071410 client = XLA. client (device)
14081411 else
14091412 client = XLA. default_backend[]
1410- device = XLA. ClientGetAddressableDevice (
1411- client, XLA. device_ordinal (client, XLA. default_device_idx[])
1412- )
1413+ device = XLA. get_addressable_device (client, XLA. default_device_idx[])
14131414 end
14141415 else
14151416 if device != = nothing
14161417 @assert client == XLA. client (device) " client ($(client) ) and XLA.client(device) ($(XLA. client (device)) ) must be the same"
14171418 else
1418- device = XLA. ClientGetAddressableDevice (
1419- client, XLA. device_ordinal (client, XLA. default_device_idx[])
1420- )
1419+ device = XLA. get_addressable_device (client, XLA. default_device_idx[])
14211420 end
14221421 end
14231422
@@ -1431,9 +1430,9 @@ function compile_xla(f, args; client=nothing, kwargs...)
14311430 @ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
14321431
14331432 if client != = nothing
1434- backend = XLA. ClientGetPlatformName (client)
1433+ backend = XLA. platform_name (client)
14351434 else
1436- backend = XLA. ClientGetPlatformName (XLA. default_backend[])
1435+ backend = XLA. platform_name (XLA. default_backend[])
14371436 end
14381437 if backend == " CUDA"
14391438 backend = " GPU"
@@ -1464,7 +1463,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
14641463 device_ids = mlir_fn_res. is_sharded ? vec (mlir_fn_res. sharding_mesh) : Int64[]
14651464 mlir_fn_res. is_sharded && (device = nothing )
14661465
1467- exec = XLA. Compile (
1466+ exec = XLA. compile (
14681467 client,
14691468 device,
14701469 mod;
@@ -1514,7 +1513,7 @@ function compile(f, args; sync=false, kwargs...)
15141513 donated_args_mask,
15151514 length (linear_results),
15161515 mlir_fn_res. is_sharded,
1517- mlir_fn_res. is_sharded ? vec (mlir_fn_res. sharding_mesh) : Int64[] ,
1516+ mlir_fn_res. is_sharded ? length (mlir_fn_res. sharding_mesh) : 1 ,
15181517 )
15191518
15201519 linear_result_shard_info = if mlir_fn_res. is_sharded
0 commit comments