Skip to content

Commit 28c28c0

Browse files
committed
Rename to GetComputationClientOrDie().
1 parent a19bb88 commit 28c28c0

16 files changed

+113
-120
lines changed

test/cpp/cpp_test_util.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,14 @@ void WithAllDevices(
225225
std::vector<torch::lazy::BackendDevice> devices;
226226
std::vector<torch::lazy::BackendDevice> all_devices;
227227
for (const auto& device_str :
228-
torch_xla::runtime::GetComputationClient()->GetLocalDevices()) {
228+
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) {
229229
torch::lazy::BackendDevice device = ParseDeviceString(device_str);
230230
if (device.type() == device_type.type) {
231231
devices.push_back(device);
232232
}
233233
}
234234
for (const auto& device_str :
235-
torch_xla::runtime::GetComputationClient()->GetAllDevices()) {
235+
torch_xla::runtime::GetComputationClientOrDie()->GetAllDevices()) {
236236
torch::lazy::BackendDevice device = ParseDeviceString(device_str);
237237
if (device.type() == device_type.type) {
238238
all_devices.push_back(device);
@@ -283,17 +283,17 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
283283
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
284284
instances.push_back(
285285
{std::move(computation), device.toString(),
286-
torch_xla::runtime::GetComputationClient()->GetCompilationDevices(
286+
torch_xla::runtime::GetComputationClientOrDie()->GetCompilationDevices(
287287
device.toString(), {}),
288288
&shape});
289289

290290
std::vector<
291291
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
292-
computations = torch_xla::runtime::GetComputationClient()->Compile(
292+
computations = torch_xla::runtime::GetComputationClientOrDie()->Compile(
293293
std::move(instances));
294294

295295
torch_xla::runtime::ComputationClient::ExecuteComputationOptions options;
296-
return torch_xla::runtime::GetComputationClient()->ExecuteComputation(
296+
return torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
297297
*computations.front(), UnwrapXlaData(lowering_ctx.GetParametersData()),
298298
device.toString(), options);
299299
}
@@ -302,7 +302,7 @@ std::vector<at::Tensor> Fetch(
302302
absl::Span<const torch_xla::runtime::ComputationClient::DataPtr>
303303
device_data) {
304304
std::vector<xla::Literal> literals =
305-
torch_xla::runtime::GetComputationClient()->TransferFromDevice(
305+
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
306306
device_data);
307307
std::vector<at::Tensor> tensors;
308308
for (auto& literal : literals) {

test/cpp/test_replication.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void TestSingleReplication(
4848
}
4949
std::vector<torch_xla::runtime::ComputationClient::ComputationPtr>
5050
compiled_computations =
51-
torch_xla::runtime::GetComputationClient()->Compile(
51+
torch_xla::runtime::GetComputationClientOrDie()->Compile(
5252
std::move(instances));
5353

5454
std::vector<at::Tensor> tensors;
@@ -65,7 +65,7 @@ void TestSingleReplication(
6565
for (size_t i = 0; i < device_strings.size(); ++i) {
6666
auto executor = [&, i]() {
6767
results[i] =
68-
torch_xla::runtime::GetComputationClient()->ExecuteComputation(
68+
torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
6969
*compiled_computations[i],
7070
{std::dynamic_pointer_cast<
7171
torch_xla::runtime::ComputationClient::Data>(
@@ -79,7 +79,7 @@ void TestSingleReplication(
7979

8080
for (size_t i = 0; i < results.size(); ++i) {
8181
std::vector<xla::Literal> literals =
82-
torch_xla::runtime::GetComputationClient()->TransferFromDevice(
82+
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
8383
results[i]);
8484
ASSERT_EQ(literals.size(), 1);
8585

test/cpp/test_xla_sharding.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,14 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
332332
CreateTensorsData(tensors, shardings, devices);
333333

334334
int64_t n_devices =
335-
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
335+
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices().size();
336336
if (n_devices > 1) {
337337
// null sharding is treated as replicated.
338338
auto xla_data =
339339
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
340340
tensors_data[0]);
341341
std::vector<torch_xla::runtime::ComputationClient::DataPtr> shards =
342-
torch_xla::runtime::GetComputationClient()->GetDataShards(xla_data);
342+
torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(xla_data);
343343
EXPECT_EQ(shards.size(), n_devices);
344344
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(xla_data->shape(),
345345
shards[0]->shape()));
@@ -349,7 +349,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
349349
auto sharded_xla_data =
350350
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
351351
tensors_data[1]);
352-
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
352+
shards = torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(
353353
sharded_xla_data);
354354
EXPECT_EQ(shards.size(), n_devices);
355355
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
@@ -360,7 +360,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
360360
sharded_xla_data =
361361
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
362362
tensors_data[2]);
363-
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
363+
shards = torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(
364364
sharded_xla_data);
365365
EXPECT_EQ(shards.size(), n_devices);
366366
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
@@ -372,7 +372,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
372372
TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
373373
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4});
374374
int64_t n_devices =
375-
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
375+
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices().size();
376376
xla::Array<int64_t> tile_assignment({1, n_devices});
377377
tile_assignment.FillIota(0);
378378
xla::OpSharding tiled = xla::HloSharding::Tile(tile_assignment).ToProto();
@@ -395,15 +395,15 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
395395

396396
std::vector<
397397
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
398-
computations = torch_xla::runtime::GetComputationClient()->Compile(
398+
computations = torch_xla::runtime::GetComputationClientOrDie()->Compile(
399399
std::move(instances));
400400
torch_xla::runtime::ComputationClient::ComputationPtr computation =
401401
std::make_shared<torch_xla::runtime::ComputationClient::Computation>(
402402
"add", std::move(computations[0]->move_computation()));
403403

404404
// Prepare output sharding propagation, expect a sharded output placeholder.
405405
std::vector<XLATensorPtr> tensors{XLATensor::Create(
406-
torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder(
406+
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
407407
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
408408
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
409409
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;

torch_xla/csrc/aten_fallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ bool UseOpenXLAFallbackOnCUDA(const c10::OperatorHandle& op) {
7777
// support running OpenXLA fallback operations on CUDA if the current
7878
// PyTorch/XLA DeviceType is not CUDA.
7979
bool device_is_cuda =
80-
runtime::GetComputationClient()->GetDeviceType().getType() ==
80+
runtime::GetComputationClientOrDie()->GetDeviceType().getType() ==
8181
XlaDeviceType::CUDA;
8282

8383
// 3. PyTorch must have been compiled with CUDA support. Otherwise, our

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class AtenXlaDeviceMapper {
5656
devices_ordinals_[devices_.back()] = 0;
5757
} else {
5858
for (auto& device_str :
59-
torch_xla::runtime::GetComputationClient()->GetLocalDevices()) {
59+
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) {
6060
devices_.emplace_back(ParseDeviceString(device_str));
6161
devices_ordinals_[devices_.back()] = devices_.size() - 1;
6262
}
@@ -367,7 +367,7 @@ std::string ToXlaString(const c10::Device& device) {
367367
const torch::lazy::BackendDevice* GetDefaultDevice() {
368368
static std::string default_device_spec =
369369
UseVirtualDevice() ? "SPMD:0"
370-
: runtime::GetComputationClient()->GetDefaultDevice();
370+
: runtime::GetComputationClientOrDie()->GetDefaultDevice();
371371
XLA_CHECK(!default_device_spec.empty());
372372
static const torch::lazy::BackendDevice default_device =
373373
ParseDeviceString(default_device_spec);

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ at::Tensor all_to_all_single(const at::Tensor& input,
332332
bool pin_layout = false;
333333
const torch::lazy::Value& token =
334334
GetAllReduceToken(bridge::GetCurrentDevice());
335-
int64_t split_count = runtime::GetComputationClient()->GetAllDevices().size();
335+
int64_t split_count = runtime::GetComputationClientOrDie()->GetAllDevices().size();
336336
std::vector<int64_t> all_groups(split_count);
337337
std::iota(all_groups.begin(), all_groups.end(), 0);
338338

torch_xla/csrc/dl_convertor.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
122122
<< "Could not extract a valid data handle from the input tensor";
123123

124124
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer =
125-
runtime::GetComputationClient()->GetPjRtBuffer(handle);
125+
runtime::GetComputationClientOrDie()->GetPjRtBuffer(handle);
126126
XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer";
127127

128128
XLA_CHECK(!pjrt_buffer->IsTuple())
@@ -168,14 +168,14 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
168168
absl::StatusOr<xla::PjRtDevice*> DeviceForDLDevice(const DLDevice& context) {
169169
switch (context.device_type) {
170170
case DLDeviceType::kDLCPU:
171-
XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(),
171+
XLA_CHECK_EQ(runtime::GetComputationClientOrDie()->GetPlatformID(),
172172
xla::CpuId());
173-
return runtime::GetComputationClient()->LookupAddressableDevice(
173+
return runtime::GetComputationClientOrDie()->LookupAddressableDevice(
174174
context.device_id);
175175
case DLDeviceType::kDLCUDA:
176-
XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(),
176+
XLA_CHECK_EQ(runtime::GetComputationClientOrDie()->GetPlatformID(),
177177
xla::CudaId());
178-
return runtime::GetComputationClient()->LookupAddressableDevice(
178+
return runtime::GetComputationClientOrDie()->LookupAddressableDevice(
179179
context.device_id);
180180
default:
181181
return tsl::errors::InvalidArgument(
@@ -335,7 +335,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
335335

336336
runtime::ComputationClient::DataPtr data =
337337
runtime::PjRtComputationClient::CreateData(
338-
runtime::GetComputationClient()->PjRtDeviceToString(device), shape,
338+
runtime::GetComputationClientOrDie()->PjRtDeviceToString(device), shape,
339339
std::move(pjrt_buffer.value()));
340340

341341
at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype);

0 commit comments

Comments
 (0)