From 05fa2aa6ada872ab980e3fcbcac53644e4b09cf8 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Wed, 22 Jun 2022 17:04:05 -0700 Subject: [PATCH] Full codegen erf, erfc, erfinv, and exp (#3659) --- torch_xla/csrc/aten_xla_type.cpp | 21 --------------------- torch_xla/csrc/ops/ops_lower_fn.cpp | 20 ++++++++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 16 ++++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.h | 8 ++++++++ torch_xla/csrc/tensor.h | 6 ------ torch_xla/csrc/tensor_methods.cpp | 12 ------------ xla_native_functions.yaml | 8 ++++---- 7 files changed, 48 insertions(+), 43 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index abce622af48..d1b6fb048b1 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1322,27 +1322,6 @@ at::Tensor XLANativeFunctions::eq(const at::Tensor& self, XLATensor::eq(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor XLANativeFunctions::erf(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::erf(bridge::GetXlaTensor(self))); -} - -at::Tensor XLANativeFunctions::erfc(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::erfc(bridge::GetXlaTensor(self))); -} - -at::Tensor XLANativeFunctions::erfinv(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor( - XLATensor::erfinv(bridge::GetXlaTensor(self))); -} - -at::Tensor XLANativeFunctions::exp(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::exp(bridge::GetXlaTensor(self))); -} - at::Tensor XLANativeFunctions::expand(const at::Tensor& self, at::IntArrayRef size, bool implicit) { XLA_FN_COUNTER("xla::"); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 0793d629084..ffeed40049f 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -53,6 +53,26 @@ torch_xla::XlaOpVector Cosh::Lower(LoweringContext* loctx) const { return ReturnOp(xla::Cosh(xla_input), loctx); } +torch_xla::XlaOpVector Erf::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::Erf(xla_input), loctx); +} + +torch_xla::XlaOpVector Erfc::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::Erfc(xla_input), loctx); +} + +torch_xla::XlaOpVector Erfinv::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::ErfInv(xla_input), loctx); +} + +torch_xla::XlaOpVector Exp::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::Exp(xla_input), loctx); +} + torch_xla::XlaOpVector Floor::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(xla::Floor(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index ff4703c66f7..bdd74a9112c 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -41,6 +41,22 @@ xla::Shape CoshOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } +xla::Shape ErfOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + +xla::Shape ErfcOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + +xla::Shape ErfinvOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + +xla::Shape ExpOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + xla::Shape FloorOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 98ee86a2361..679875404d4 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -21,6 +21,14 @@ xla::Shape CosOutputShape(const torch::lazy::Value& input); xla::Shape CoshOutputShape(const torch::lazy::Value& input); +xla::Shape ErfOutputShape(const torch::lazy::Value& input); + +xla::Shape ErfcOutputShape(const torch::lazy::Value& input); + +xla::Shape ErfinvOutputShape(const torch::lazy::Value& input); + +xla::Shape ExpOutputShape(const torch::lazy::Value& input); + xla::Shape FloorOutputShape(const torch::lazy::Value& input); xla::Shape InverseOutputShape(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 26496d7d649..b206a40b3de 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -555,12 +555,6 @@ class XLATensor : public c10::intrusive_ptr_target { static XLATensor eq(const XLATensor& input, const XLATensor& other); - static XLATensor erf(const XLATensor& input); - - static XLATensor erfc(const XLATensor& input); - - static XLATensor erfinv(const XLATensor& input); - static XLATensor exp(const XLATensor& input); static XLATensor expand(const XLATensor& input, std::vector size); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 36be8cd3dee..401d263a7ba 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1271,18 +1271,6 @@ XLATensor XLATensor::embedding_dense_backward(const XLATensor& grad_output, padding_idx, scale_grad_by_freq); } -XLATensor XLATensor::erf(const XLATensor& input) { - return input.CreateFrom(Erf(input.GetIrValue())); -} - -XLATensor XLATensor::erfc(const XLATensor& input) { - return input.CreateFrom(Erfc(input.GetIrValue())); -} - -XLATensor XLATensor::erfinv(const XLATensor& input) { - return input.CreateFrom(Erfinv(input.GetIrValue())); -} - XLATensor XLATensor::exp(const XLATensor& input) { return input.CreateFrom(Exp(input.GetIrValue())); } diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index cbdeec34c0b..4aa60658a01 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -10,6 +10,10 @@ full_codegen: - atanh - cos - cosh + - erf + - erfc + - erfinv + - exp - floor - inverse - logdet @@ -121,10 +125,6 @@ supported: - empty_strided - eq.Scalar - eq.Tensor - - erf - - erfc - - erfinv - - exp - expand - expm1 - exponential_