Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codegen for Tanh #3724

Merged
merged 1 commit into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1533,11 +1533,6 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
XLATensor::clamp(bridge::GetXlaTensor(self), min_val, max_val));
}

at::Tensor XLANativeFunctions::tanh(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::tanh(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Scalar& min_val,
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ namespace torch_xla {
std::move(lower_fn)); \
}

PTXLA_UNARY_OP(Tanh, at::aten::tanh, xla::Tanh);
PTXLA_UNARY_OP(Neg, at::aten::neg, xla::Neg);
PTXLA_UNARY_OP(Exp, at::aten::exp, xla::Exp);
PTXLA_UNARY_OP(Expm1, at::aten::expm1, xla::Expm1);
Expand Down Expand Up @@ -867,7 +866,9 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) {
torch::lazy::NodePtr one = ScalarOp(1, shape);
torch::lazy::NodePtr half = ScalarOp(0.5, shape);
torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three));
return half * input * (one + Tanh(inner));
return half * input *
(one + torch::lazy::MakeNode<Tanh>(inner,
std::vector<torch::lazy::Shape>()));
}

torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
Expand All @@ -883,7 +884,8 @@ torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
torch::lazy::NodePtr three = ScalarOp(3, shape);
torch::lazy::NodePtr half = ScalarOp(0.5, shape);
torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three));
torch::lazy::NodePtr tanh_inner = Tanh(inner);
torch::lazy::NodePtr tanh_inner =
torch::lazy::MakeNode<Tanh>(inner, std::vector<torch::lazy::Shape>());

torch::lazy::NodePtr left = half * input;
torch::lazy::NodePtr right = one + tanh_inner;
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ torch::lazy::NodePtr Atan2(const torch::lazy::Value& input,

torch::lazy::NodePtr Tan(const torch::lazy::Value& input);

torch::lazy::NodePtr Tanh(const torch::lazy::Value& input);

torch::lazy::NodePtr Neg(const torch::lazy::Value& input);

torch::lazy::NodePtr SgnOp(const torch::lazy::Value& input);
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,9 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Tan(xla_input), loctx);
}

torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Tanh(xla_input), loctx);
}

} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,8 @@ xla::Shape TanOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape TanhOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ xla::Shape SinhOutputShape(const torch::lazy::Value& input);

xla::Shape TanOutputShape(const torch::lazy::Value& input);

xla::Shape TanhOutputShape(const torch::lazy::Value& input);

} // namespace torch_xla
2 changes: 0 additions & 2 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1169,8 +1169,6 @@ class XLATensor : public c10::intrusive_ptr_target {
static XLATensorPtr take(const XLATensorPtr& input,
const XLATensorPtr& index);

static XLATensorPtr tanh(const XLATensorPtr& input);

static XLATensorPtr tanh_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& output);

Expand Down
11 changes: 4 additions & 7 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1924,9 +1924,10 @@ void XLATensor::min_out(XLATensorPtr& min, XLATensorPtr& min_indices,
}

XLATensorPtr XLATensor::mish(const XLATensorPtr& input) {
return input->CreateFrom(
input->GetIrValue() *
Tanh(tensor_ops::Softplus(input, 1, 20)->GetIrValue()));
return input->CreateFrom(input->GetIrValue() *
torch::lazy::MakeNode<Tanh>(
tensor_ops::Softplus(input, 1, 20)->GetIrValue(),
std::vector<torch::lazy::Shape>()));
}

XLATensorPtr XLATensor::mm(const XLATensorPtr& input,
Expand Down Expand Up @@ -2776,10 +2777,6 @@ XLATensorPtr XLATensor::take(const XLATensorPtr& input,
return input->CreateFrom(Take(input->GetIrValue(), index->GetIrValue()));
}

XLATensorPtr XLATensor::tanh(const XLATensorPtr& input) {
return input->CreateFrom(Tanh(input->GetIrValue()));
}

XLATensorPtr XLATensor::tanh_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& output) {
return XLATensor::mul(grad_output,
Expand Down
2 changes: 1 addition & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ full_codegen:
- sin
- sinh
- tan
- tanh
supported:
- __ilshift__.Scalar
- __ilshift__.Tensor
Expand Down Expand Up @@ -302,7 +303,6 @@ supported:
- t
- t_
- take
- tanh
- tanh_backward
- threshold
- threshold_backward
Expand Down