@@ -332,14 +332,14 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
332332 CreateTensorsData (tensors, shardings, devices);
333333
334334 int64_t n_devices =
335- torch_xla::runtime::GetComputationClient ()->GetLocalDevices ().size ();
335+ torch_xla::runtime::GetComputationClientOrDie ()->GetLocalDevices ().size ();
336336 if (n_devices > 1 ) {
337337 // null sharding is treated as replicated.
338338 auto xla_data =
339339 std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
340340 tensors_data[0 ]);
341341 std::vector<torch_xla::runtime::ComputationClient::DataPtr> shards =
342- torch_xla::runtime::GetComputationClient ()->GetDataShards (xla_data);
342+ torch_xla::runtime::GetComputationClientOrDie ()->GetDataShards (xla_data);
343343 EXPECT_EQ (shards.size (), n_devices);
344344 EXPECT_TRUE (xla::Shape::Equal ().IgnoreLayout ()(xla_data->shape (),
345345 shards[0 ]->shape ()));
@@ -349,7 +349,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
349349 auto sharded_xla_data =
350350 std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
351351 tensors_data[1 ]);
352- shards = torch_xla::runtime::GetComputationClient ()->GetDataShards (
352+ shards = torch_xla::runtime::GetComputationClientOrDie ()->GetDataShards (
353353 sharded_xla_data);
354354 EXPECT_EQ (shards.size (), n_devices);
355355 EXPECT_TRUE (xla::Shape::Equal ().IgnoreLayout ()(sharded_xla_data->shape (),
@@ -360,7 +360,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
360360 sharded_xla_data =
361361 std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
362362 tensors_data[2 ]);
363- shards = torch_xla::runtime::GetComputationClient ()->GetDataShards (
363+ shards = torch_xla::runtime::GetComputationClientOrDie ()->GetDataShards (
364364 sharded_xla_data);
365365 EXPECT_EQ (shards.size (), n_devices);
366366 EXPECT_TRUE (xla::Shape::Equal ().IgnoreLayout ()(sharded_xla_data->shape (),
@@ -372,7 +372,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
372372TEST_F (XLAShardingTest, PrepareOutputShardingPropagation) {
373373 xla::Shape shape = xla::ShapeUtil::MakeShape (xla::PrimitiveType::F32, {4 , 4 });
374374 int64_t n_devices =
375- torch_xla::runtime::GetComputationClient ()->GetLocalDevices ().size ();
375+ torch_xla::runtime::GetComputationClientOrDie ()->GetLocalDevices ().size ();
376376 xla::Array<int64_t > tile_assignment ({1 , n_devices});
377377 tile_assignment.FillIota (0 );
378378 xla::OpSharding tiled = xla::HloSharding::Tile (tile_assignment).ToProto ();
@@ -395,15 +395,15 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
395395
396396 std::vector<
397397 std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
398- computations = torch_xla::runtime::GetComputationClient ()->Compile (
398+ computations = torch_xla::runtime::GetComputationClientOrDie ()->Compile (
399399 std::move (instances));
400400 torch_xla::runtime::ComputationClient::ComputationPtr computation =
401401 std::make_shared<torch_xla::runtime::ComputationClient::Computation>(
402402 " add" , std::move (computations[0 ]->move_computation ()));
403403
404404 // Prepare output sharding propagation, expect a sharded output placeholder.
405405 std::vector<XLATensorPtr> tensors{XLATensor::Create (
406- torch_xla::runtime::GetComputationClient ()->CreateDataPlaceholder (
406+ torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
407407 bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
408408 std::vector<torch::lazy::BackendDataPtr> data_placeholders;
409409 std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
0 commit comments