@@ -8,14 +8,14 @@ struct Mesh{D,ND}
88 shape:: Dims{D}
99 axis_names:: NTuple{D,Symbol}
1010
11- function Mesh (devices:: AbstractArray{XLA.AbstractDevice} , axis_names)
12- return Mesh (XLA. get_local_device_id .(devices), axis_names)
11+ function Mesh (devices:: AbstractArray{<: XLA.AbstractDevice} , axis_names)
12+ return Mesh (XLA. device_ordinal .(devices), axis_names)
1313 end
1414
1515 function Mesh (
16- devices:: NTuple{D,XLA.AbstractDevice} , shape:: Dims{D} , axis_names
16+ devices:: NTuple{D,<: XLA.AbstractDevice} , shape:: Dims{D} , axis_names
1717 ) where {D}
18- return Mesh (XLA. get_local_device_id .(devices), shape, axis_names)
18+ return Mesh (XLA. device_ordinal .(devices), shape, axis_names)
1919 end
2020
2121 function Mesh (
@@ -114,7 +114,7 @@ function (sharding::NamedSharding)(
114114 XLA. PJRT. AsyncBuffer (
115115 client,
116116 x[device_to_array_slices[i]. .. ],
117- XLA. get_addressable_device (client, mesh. sorted_device_ids[i]),
117+ XLA. get_device (client, mesh. sorted_device_ids[i]),
118118 )
119119 end
120120
@@ -199,7 +199,7 @@ function (sharding::LazySharding)(
199199 client:: XLA.PJRT.Client , :: Nothing , x:: Union{AbstractArray,Number}
200200)
201201 data = XLA. PJRT. AsyncBuffer (
202- client, x, XLA. get_addressable_device (client, vec (sharding. sharding. mesh)[1 ])
202+ client, x, XLA. get_device (client, vec (sharding. sharding. mesh)[1 ])
203203 )
204204
205205 return (data,), ShardInfo (sharding, (ntuple (i -> 1 : size (x, i), ndims (x)),))
0 commit comments