Skip to content

Commit

Permalink
Add XLABackendIntf test (#3697)
Browse files Browse the repository at this point in the history
* Add tensor transfer test

* Scalar transfer test

* add e2e test for xlabckend intf

* initXlaBackend in cppTest
  • Loading branch information
JackCaoG authored Jul 19, 2022
1 parent b2f9d01 commit 75ac08b
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 6 deletions.
1 change: 1 addition & 0 deletions test/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ set(TORCH_XLA_TEST_SOURCES
test_tensor.cpp
test_xla_util_cache.cpp
torch_xla_test.cpp
test_xla_backend_intf.cpp
)

add_executable(test_ptxla ${TORCH_XLA_TEST_SOURCES})
Expand Down
86 changes: 86 additions & 0 deletions test/cpp/test_xla_backend_intf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <vector>

#include "cpp_test_util.h"
#include "torch_xla/csrc/tensor_util.h"

namespace torch_xla {
namespace cpp_test {

TEST(XLABackendTest, TestTensorTransfer) {
torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl();
at::Tensor input = at::randint(std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max(), {2, 2},
at::TensorOptions(at::kByte));
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
torch::lazy::BackendDataPtr data = impl->MakeComputationDataFromTensor(
input, torch::lazy::Shape(input.scalar_type(), input.sizes()), device);
at::Tensor res = impl->MakeTensorFromComputationData(data, at::kByte);
AllClose(input, res);
});
}

TEST(XLABackendTest, TestScalarTransfer) {
torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl();
at::Scalar input = at::Scalar(int64_t(1));
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
torch::lazy::BackendDataPtr data =
impl->MakeComputationDataFromScalar(input, device);
at::Tensor res = impl->MakeTensorFromComputationData(data, at::kByte);
AllClose(at::ones({}, at::TensorOptions(at::kByte)), res);
});
}

TEST(XLABackendTest, TestPlaceholder) {
torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl();
torch::lazy::Shape shape(at::kFloat, {10, 10});
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
torch::lazy::BackendDataPtr data =
impl->CreateDataPlaceholder(device, shape);
xla::ComputationClient::DataPtr computation_data = UnwrapXlaData(data);
EXPECT_EQ(computation_data->device(), device.toString());
EXPECT_EQ(computation_data->shape(),
MakeXlaShapeFromLazyShape(shape, device));
});
}

xla::XlaComputation CreateAddComputation(const xla::Shape& shape) {
xla::XlaBuilder builder("AddComputation");
xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x");
xla::XlaOp y = xla::Parameter(&builder, 1, shape, "y");
xla::XlaOp sum = xla::Add(x, y);
return ConsumeValue(builder.Build());
}

TEST(XLABackendTest, TestE2E) {
torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl();
xla::Shape input_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {8, 8});
at::Tensor one = at::ones({8, 8}, at::TensorOptions(at::kFloat));
std::vector<at::Tensor> tensors = {one, one};

ForEachDevice([&](const torch::lazy::BackendDevice& device) {
xla::XlaComputation xla_computation = CreateAddComputation(input_shape);
torch::lazy::ComputationPtr computation =
std::make_shared<torch_xla::Computation>(
"test", std::move(xla_computation), device);
std::vector<torch::lazy::ComputationPtr> compiled_programs =
impl->Compile({computation});
EXPECT_EQ(compiled_programs.size(), 1);

std::vector<torch::lazy::BackendDataPtr> parameters;
for (auto& tensor : tensors) {
parameters.push_back(impl->MakeComputationDataFromTensor(
tensor, torch::lazy::Shape(tensor.scalar_type(), tensor.sizes()),
device));
}
std::vector<torch::lazy::BackendDataPtr> res =
impl->ExecuteComputation(compiled_programs[0], parameters, device);
EXPECT_EQ(res.size(), 1);
at::Tensor res_tensor =
impl->MakeTensorFromComputationData(res[0], at::kFloat);
AllClose(one + one, res_tensor);
});
}

} // namespace cpp_test
} // namespace torch_xla
10 changes: 6 additions & 4 deletions torch_xla/csrc/xla_backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
// c10::ArrayRef<torch::lazy::Node*> instead of
// c10::ArrayRef<const torch::lazy::Node*> since c10::ArrayRef already
// provided const for its member.
XLA_ERROR() << "Need to handle post_order";
return std::make_unique<LoweringContext>(name, device);
}

Expand All @@ -131,15 +132,16 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
std::vector<torch::lazy::ComputationPtr> res;
std::vector<xla::ComputationClient::CompileInstance> compile_instances;
torch::lazy::BackendDevice current_device = GetCurrentDevice();
std::vector<xla::Shape> output_shapes;

for (const torch::lazy::ComputationPtr instance : instances) {
// TODO(JackCaoG): device is missing in instance, use CurrentDevice for
// now
const Computation* torch_xla_computation =
dynamic_cast<Computation*>(instance.get());
xla::Shape shape = MakeShapeWithDeviceLayout(
output_shapes.push_back(MakeShapeWithDeviceLayout(
torch_xla_computation->program_shape().result(),
static_cast<XlaDeviceType>(current_device.type()));
static_cast<XlaDeviceType>(current_device.type())));

// Call GetCompilationDevices and passes all device here if needed.
// Currently on TPU we always have 1 replica per device and one process
Expand All @@ -152,7 +154,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
compile_instances.push_back(xla::ComputationClient::CompileInstance(
torch_xla_computation->move_computation(),
torch_xla_computation->get_device_string(),
{current_device.toString()}, &shape));
{current_device.toString()}, &output_shapes.back()));
}
std::vector<std::shared_ptr<xla::ComputationClient::Computation>>
client_computations = xla::ComputationClient::Get()->Compile(
Expand Down Expand Up @@ -238,4 +240,4 @@ void InitXlaBackend() {
std::make_unique<torch::lazy::BackendRegistrar>(GetXlaBackendImpl());
};

} // namespace torch_xla
} // namespace torch_xla
4 changes: 2 additions & 2 deletions torch_xla/csrc/xla_backend_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class XLAData : public torch::lazy::BackendData {
xla::ComputationClient::DataPtr xla_data_;
};

// torch::lazy::BackendImplInterface* GetXlaBackendImpl();
torch::lazy::BackendImplInterface* GetXlaBackendImpl();

void InitXlaBackend();

} // namespace torch_xla
} // namespace torch_xla

0 comments on commit 75ac08b

Please sign in to comment.