Skip to content

Commit

Permalink
Codegen for Tanh (#3724)
Browse files Browse the repository at this point in the history
  • Loading branch information
steventk-g authored Jul 19, 2022
1 parent 75ac08b commit 1262dd4
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 20 deletions.
5 changes: 0 additions & 5 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1533,11 +1533,6 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
XLATensor::clamp(bridge::GetXlaTensor(self), min_val, max_val));
}

at::Tensor XLANativeFunctions::tanh(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::tanh(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Scalar& min_val,
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ namespace torch_xla {
std::move(lower_fn)); \
}

PTXLA_UNARY_OP(Tanh, at::aten::tanh, xla::Tanh);
PTXLA_UNARY_OP(Neg, at::aten::neg, xla::Neg);
PTXLA_UNARY_OP(Exp, at::aten::exp, xla::Exp);
PTXLA_UNARY_OP(Expm1, at::aten::expm1, xla::Expm1);
Expand Down Expand Up @@ -866,7 +865,9 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) {
torch::lazy::NodePtr one = ScalarOp(1, shape);
torch::lazy::NodePtr half = ScalarOp(0.5, shape);
torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three));
return half * input * (one + Tanh(inner));
return half * input *
(one + torch::lazy::MakeNode<Tanh>(inner,
std::vector<torch::lazy::Shape>()));
}

torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
Expand All @@ -882,7 +883,8 @@ torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
torch::lazy::NodePtr three = ScalarOp(3, shape);
torch::lazy::NodePtr half = ScalarOp(0.5, shape);
torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three));
torch::lazy::NodePtr tanh_inner = Tanh(inner);
torch::lazy::NodePtr tanh_inner =
torch::lazy::MakeNode<Tanh>(inner, std::vector<torch::lazy::Shape>());

torch::lazy::NodePtr left = half * input;
torch::lazy::NodePtr right = one + tanh_inner;
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ torch::lazy::NodePtr Atan2(const torch::lazy::Value& input,

torch::lazy::NodePtr Tan(const torch::lazy::Value& input);

torch::lazy::NodePtr Tanh(const torch::lazy::Value& input);

torch::lazy::NodePtr Neg(const torch::lazy::Value& input);

torch::lazy::NodePtr SgnOp(const torch::lazy::Value& input);
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,9 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Tan(xla_input), loctx);
}

torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Tanh(xla_input), loctx);
}

} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,8 @@ xla::Shape TanOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape TanhOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,6 @@ xla::Shape SinhOutputShape(const torch::lazy::Value& input);

xla::Shape TanOutputShape(const torch::lazy::Value& input);

xla::Shape TanhOutputShape(const torch::lazy::Value& input);

} // namespace torch_xla
2 changes: 0 additions & 2 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1167,8 +1167,6 @@ class XLATensor : public c10::intrusive_ptr_target {
static XLATensorPtr take(const XLATensorPtr& input,
const XLATensorPtr& index);

static XLATensorPtr tanh(const XLATensorPtr& input);

static XLATensorPtr tanh_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& output);

Expand Down
11 changes: 4 additions & 7 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1924,9 +1924,10 @@ void XLATensor::min_out(XLATensorPtr& min, XLATensorPtr& min_indices,
}

XLATensorPtr XLATensor::mish(const XLATensorPtr& input) {
return input->CreateFrom(
input->GetIrValue() *
Tanh(tensor_ops::Softplus(input, 1, 20)->GetIrValue()));
return input->CreateFrom(input->GetIrValue() *
torch::lazy::MakeNode<Tanh>(
tensor_ops::Softplus(input, 1, 20)->GetIrValue(),
std::vector<torch::lazy::Shape>()));
}

XLATensorPtr XLATensor::mm(const XLATensorPtr& input,
Expand Down Expand Up @@ -2772,10 +2773,6 @@ XLATensorPtr XLATensor::take(const XLATensorPtr& input,
return input->CreateFrom(Take(input->GetIrValue(), index->GetIrValue()));
}

XLATensorPtr XLATensor::tanh(const XLATensorPtr& input) {
return input->CreateFrom(Tanh(input->GetIrValue()));
}

XLATensorPtr XLATensor::tanh_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& output) {
return XLATensor::mul(grad_output,
Expand Down
2 changes: 1 addition & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ full_codegen:
- sin
- sinh
- tan
- tanh
supported:
- __ilshift__.Scalar
- __ilshift__.Tensor
Expand Down Expand Up @@ -302,7 +303,6 @@ supported:
- t
- t_
- take
- tanh
- tanh_backward
- threshold
- threshold_backward
Expand Down

0 comments on commit 1262dd4

Please sign in to comment.