Skip to content

Commit

Permalink
Remove Ger lowering as it is not needed (#3855)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Aug 10, 2022
1 parent 45689b9 commit 2d880d1
Show file tree
Hide file tree
Showing 9 changed files with 0 additions and 48 deletions.
1 change: 0 additions & 1 deletion test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3604,7 +3604,6 @@ TEST_F(AtenXlaTensorTest, TestGer) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::ger", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestMv) {
Expand Down
7 changes: 0 additions & 7 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1451,13 +1451,6 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self), approximate));
}

at::Tensor XLANativeFunctions::ger(const at::Tensor& self,
const at::Tensor& vec2) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::ger(bridge::GetXlaTensor(self), bridge::GetXlaTensor(vec2)));
}

at::Tensor XLANativeFunctions::gt(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
Expand Down
21 changes: 0 additions & 21 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,27 +238,6 @@ torch::lazy::NodePtr Celu(const torch::lazy::Value& input,
GetXlaShape(input), std::move(lower_fn));
}

torch::lazy::NodePtr Ger(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(BuildGer(xla_input, xla_other), loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildGer(operands[0], operands[1]);
};
return GenericOp(torch::lazy::OpKind(at::aten::ger), {input, other},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr AddMatMulOp(const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& bias) {
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ torch::lazy::NodePtr Celu(const torch::lazy::Value& input,

torch::lazy::NodePtr FracOp(const torch::lazy::Value& input);

torch::lazy::NodePtr Ger(const torch::lazy::Value& input,
const torch::lazy::Value& other);

torch::lazy::NodePtr AddMatMulOp(const torch::lazy::Value& input,
const torch::lazy::Value& weight,
const torch::lazy::Value& bias);
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,6 @@ class XLATensor : public c10::intrusive_ptr_target {
const XLATensorPtr& input,
const c10::string_view approximate);

static XLATensorPtr ger(const XLATensorPtr& input, const XLATensorPtr& vec2);

static XLATensorPtr gt(const XLATensorPtr& input, const at::Scalar& other);

static XLATensorPtr gt(const XLATensorPtr& input, const XLATensorPtr& other);
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,11 +1398,6 @@ XLATensorPtr XLATensor::gelu_backward(const XLATensorPtr& grad,
}
}

XLATensorPtr XLATensor::ger(const XLATensorPtr& input,
const XLATensorPtr& vec2) {
return input->CreateFrom(Ger(input->GetIrValue(), vec2->GetIrValue()));
}

XLATensorPtr XLATensor::gt(const XLATensorPtr& input, const at::Scalar& other) {
return DispatchComparisonOp(at::aten::gt, input, other);
}
Expand Down
6 changes: 0 additions & 6 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,6 @@ xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs) {
<< rhs_shape << ")";
}

xla::XlaOp BuildGer(xla::XlaOp lhs, xla::XlaOp rhs) {
xla::XlaOp lhs_reshaped = BuildUnsqueeze(lhs, 1);
xla::XlaOp rhs_reshaped = BuildUnsqueeze(rhs, 0);
return BuildDot(lhs_reshaped, rhs_reshaped);
}

xla::XlaOp BuildMatMul(xla::XlaOp lhs, xla::XlaOp rhs, xla::XlaOp bias) {
xla::XlaOp dot = BuildDot(lhs, rhs);
const xla::Shape& dot_shape = XlaHelpers::ShapeOfXlaOp(dot);
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, int64_t k, int64_t dim,

xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs);

xla::XlaOp BuildGer(xla::XlaOp lhs, xla::XlaOp rhs);

xla::XlaOp BuildMatMul(xla::XlaOp lhs, xla::XlaOp rhs, xla::XlaOp bias);

xla::XlaOp BuildMatMulWithMultiplier(xla::XlaOp lhs, xla::XlaOp rhs,
Expand Down
1 change: 0 additions & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ supported:
- ge.Tensor
- gelu
- gelu_backward
- ger
- gt.Scalar
- gt.Tensor
- hardshrink
Expand Down

0 comments on commit 2d880d1

Please sign in to comment.