diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 302db7e21927..f9e974847289 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -58,6 +58,7 @@ set(TORCH_XLA_TEST_SOURCES test_tensor.cpp test_xla_util_cache.cpp torch_xla_test.cpp + test_xla_backend_intf.cpp ) add_executable(test_ptxla ${TORCH_XLA_TEST_SOURCES}) diff --git a/test/cpp/test_xla_backend_intf.cpp b/test/cpp/test_xla_backend_intf.cpp new file mode 100644 index 000000000000..25cf16b66ec5 --- /dev/null +++ b/test/cpp/test_xla_backend_intf.cpp @@ -0,0 +1,86 @@ +#include + +#include "cpp_test_util.h" +#include "torch_xla/csrc/tensor_util.h" + +namespace torch_xla { +namespace cpp_test { + +TEST(XLABackendTest, TestTensorTransfer) { + torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl(); + at::Tensor input = at::randint(std::numeric_limits::min(), + std::numeric_limits::max(), {2, 2}, + at::TensorOptions(at::kByte)); + ForEachDevice([&](const torch::lazy::BackendDevice& device) { + torch::lazy::BackendDataPtr data = impl->MakeComputationDataFromTensor( + input, torch::lazy::Shape(input.scalar_type(), input.sizes()), device); + at::Tensor res = impl->MakeTensorFromComputationData(data, at::kByte); + AllClose(input, res); + }); +} + +TEST(XLABackendTest, TestScalarTransfer) { + torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl(); + at::Scalar input = at::Scalar(int64_t(1)); + ForEachDevice([&](const torch::lazy::BackendDevice& device) { + torch::lazy::BackendDataPtr data = + impl->MakeComputationDataFromScalar(input, device); + at::Tensor res = impl->MakeTensorFromComputationData(data, at::kByte); + AllClose(at::ones({}, at::TensorOptions(at::kByte)), res); + }); +} + +TEST(XLABackendTest, TestPlaceholder) { + torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl(); + torch::lazy::Shape shape(at::kFloat, {10, 10}); + ForEachDevice([&](const torch::lazy::BackendDevice& device) { + torch::lazy::BackendDataPtr data = + impl->CreateDataPlaceholder(device, shape); + xla::ComputationClient::DataPtr computation_data = UnwrapXlaData(data); + EXPECT_EQ(computation_data->device(), device.toString()); + EXPECT_EQ(computation_data->shape(), + MakeXlaShapeFromLazyShape(shape, device)); + }); +} + +xla::XlaComputation CreateAddComputation(const xla::Shape& shape) { + xla::XlaBuilder builder("AddComputation"); + xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x"); + xla::XlaOp y = xla::Parameter(&builder, 1, shape, "y"); + xla::XlaOp sum = xla::Add(x, y); + return ConsumeValue(builder.Build()); +} + +TEST(XLABackendTest, TestE2E) { + torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl(); + xla::Shape input_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {8, 8}); + at::Tensor one = at::ones({8, 8}, at::TensorOptions(at::kFloat)); + std::vector tensors = {one, one}; + + ForEachDevice([&](const torch::lazy::BackendDevice& device) { + xla::XlaComputation xla_computation = CreateAddComputation(input_shape); + torch::lazy::ComputationPtr computation = + std::make_shared( + "test", std::move(xla_computation), device); + std::vector compiled_programs = + impl->Compile({computation}); + EXPECT_EQ(compiled_programs.size(), 1); + + std::vector parameters; + for (auto& tensor : tensors) { + parameters.push_back(impl->MakeComputationDataFromTensor( + tensor, torch::lazy::Shape(tensor.scalar_type(), tensor.sizes()), + device)); + } + std::vector res = + impl->ExecuteComputation(compiled_programs[0], parameters, device); + EXPECT_EQ(res.size(), 1); + at::Tensor res_tensor = + impl->MakeTensorFromComputationData(res[0], at::kFloat); + AllClose(one + one, res_tensor); + }); +} + +} // namespace cpp_test +} // namespace torch_xla diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index 92953949f389..60b3eac929b5 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -105,6 +105,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { // c10::ArrayRef instead of // c10::ArrayRef since c10::ArrayRef already // provided const for its member. + XLA_ERROR() << "Need to handle post_order"; return std::make_unique(name, device); } @@ -131,15 +132,16 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { std::vector res; std::vector compile_instances; torch::lazy::BackendDevice current_device = GetCurrentDevice(); + std::vector output_shapes; for (const torch::lazy::ComputationPtr instance : instances) { // TODO(JackCaoG): device is missing in instance, use CurrentDevice for // now const Computation* torch_xla_computation = dynamic_cast(instance.get()); - xla::Shape shape = MakeShapeWithDeviceLayout( + output_shapes.push_back(MakeShapeWithDeviceLayout( torch_xla_computation->program_shape().result(), - static_cast(current_device.type())); + static_cast(current_device.type()))); // Call GetCompilationDevices and passes all device here if needed. // Currently on TPU we always have 1 replica per device and one process @@ -152,7 +154,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { compile_instances.push_back(xla::ComputationClient::CompileInstance( torch_xla_computation->move_computation(), torch_xla_computation->get_device_string(), - {current_device.toString()}, &shape)); + {current_device.toString()}, &output_shapes.back())); } std::vector> client_computations = xla::ComputationClient::Get()->Compile( @@ -238,4 +240,4 @@ void InitXlaBackend() { std::make_unique(GetXlaBackendImpl()); }; -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla diff --git a/torch_xla/csrc/xla_backend_impl.h b/torch_xla/csrc/xla_backend_impl.h index 139833815928..250226d86ce5 100644 --- a/torch_xla/csrc/xla_backend_impl.h +++ b/torch_xla/csrc/xla_backend_impl.h @@ -39,8 +39,8 @@ class XLAData : public torch::lazy::BackendData { xla::ComputationClient::DataPtr xla_data_; }; -// torch::lazy::BackendImplInterface* GetXlaBackendImpl(); +torch::lazy::BackendImplInterface* GetXlaBackendImpl(); void InitXlaBackend(); -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla