diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 8e93098abe03..7c44d6317b0d 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1533,11 +1533,6 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self, XLATensor::clamp(bridge::GetXlaTensor(self), min_val, max_val)); } -at::Tensor XLANativeFunctions::tanh(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::tanh(bridge::GetXlaTensor(self))); -} - at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& min_val, diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 06d614f8eef1..790cdf62343b 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -68,7 +68,6 @@ namespace torch_xla { std::move(lower_fn)); \ } -PTXLA_UNARY_OP(Tanh, at::aten::tanh, xla::Tanh); PTXLA_UNARY_OP(Neg, at::aten::neg, xla::Neg); PTXLA_UNARY_OP(Exp, at::aten::exp, xla::Exp); PTXLA_UNARY_OP(Expm1, at::aten::expm1, xla::Expm1); @@ -866,7 +865,9 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) { torch::lazy::NodePtr one = ScalarOp(1, shape); torch::lazy::NodePtr half = ScalarOp(0.5, shape); torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three)); - return half * input * (one + Tanh(inner)); + return half * input * + (one + torch::lazy::MakeNode(inner, + std::vector())); } torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad, @@ -882,7 +883,8 @@ torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad, torch::lazy::NodePtr three = ScalarOp(3, shape); torch::lazy::NodePtr half = ScalarOp(0.5, shape); torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three)); - torch::lazy::NodePtr tanh_inner = Tanh(inner); + torch::lazy::NodePtr tanh_inner = + torch::lazy::MakeNode(inner, std::vector()); torch::lazy::NodePtr left = half * input; torch::lazy::NodePtr right = one + tanh_inner; diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 703c656bcd23..8e12578ae8fb 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -68,8 +68,6 @@ torch::lazy::NodePtr Atan2(const torch::lazy::Value& input, torch::lazy::NodePtr Tan(const torch::lazy::Value& input); -torch::lazy::NodePtr Tanh(const torch::lazy::Value& input); - torch::lazy::NodePtr Neg(const torch::lazy::Value& input); torch::lazy::NodePtr SgnOp(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index acd7badbe3ea..88874c547e4d 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -144,4 +144,9 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const { return ReturnOp(xla::Tan(xla_input), loctx); } +torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::Tanh(xla_input), loctx); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 027d559d3ea5..49c140567094 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -135,4 +135,8 @@ xla::Shape TanOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } +xla::Shape TanhOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 53e495c1b16d..1a9f1c285652 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -58,4 +58,6 @@ xla::Shape SinhOutputShape(const torch::lazy::Value& input); xla::Shape TanOutputShape(const torch::lazy::Value& input); +xla::Shape TanhOutputShape(const torch::lazy::Value& input); + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 014540f95545..d1f4727e7a0b 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -1167,8 +1167,6 @@ class XLATensor : public c10::intrusive_ptr_target { static XLATensorPtr take(const XLATensorPtr& input, const XLATensorPtr& index); - static XLATensorPtr tanh(const XLATensorPtr& input); - static XLATensorPtr tanh_backward(const XLATensorPtr& grad_output, const XLATensorPtr& output); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 5fa56cd14e9d..c08c1fb2fdf5 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1924,9 +1924,10 @@ void XLATensor::min_out(XLATensorPtr& min, XLATensorPtr& min_indices, } XLATensorPtr XLATensor::mish(const XLATensorPtr& input) { - return input->CreateFrom( - input->GetIrValue() * - Tanh(tensor_ops::Softplus(input, 1, 20)->GetIrValue())); + return input->CreateFrom(input->GetIrValue() * + torch::lazy::MakeNode( + tensor_ops::Softplus(input, 1, 20)->GetIrValue(), + std::vector())); } XLATensorPtr XLATensor::mm(const XLATensorPtr& input, @@ -2772,10 +2773,6 @@ XLATensorPtr XLATensor::take(const XLATensorPtr& input, return input->CreateFrom(Take(input->GetIrValue(), index->GetIrValue())); } -XLATensorPtr XLATensor::tanh(const XLATensorPtr& input) { - return input->CreateFrom(Tanh(input->GetIrValue())); -} - XLATensorPtr XLATensor::tanh_backward(const XLATensorPtr& grad_output, const XLATensorPtr& output) { return XLATensor::mul(grad_output, diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index a18faee015da..6bc2b4da29d5 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -26,6 +26,7 @@ full_codegen: - sin - sinh - tan + - tanh supported: - __ilshift__.Scalar - __ilshift__.Tensor @@ -302,7 +303,6 @@ supported: - t - t_ - take - - tanh - tanh_backward - threshold - threshold_backward