Skip to content

Commit

Permalink
Codegen amin amax
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Aug 10, 2022
1 parent 2d880d1 commit c0193a8
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 190 deletions.
7 changes: 6 additions & 1 deletion scripts/gen_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BaseCType,
OptionalCType,
VectorCType,
boolT,
kernel_signature,
)
import pathlib
Expand All @@ -22,6 +23,10 @@
source_yaml = str(torch_xla_root / "xla_native_functions.yaml")


def is_boolean_dtype(lazy_type):
return lazy_type == BaseCType(boolT)


@dataclass(frozen=True)
class GenXlaLazyIR(GenLazyIR):

Expand All @@ -47,7 +52,7 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
shape_fn_inputs_list = [
f"{a.name}" for a in schema.positional_args
if (a.is_lazy_value or isinstance(a.lazy_type, VectorCType) or
a.name == 'reduction')
is_boolean_dtype(a.lazy_type) or a.name == 'reduction')
]
shape_fn_inputs = ", ".join(shape_fn_inputs_list)

Expand Down
16 changes: 0 additions & 16 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,22 +676,6 @@ at::Tensor XLANativeFunctions::all(const at::Tensor& self, int64_t dim,
XLATensor::all_dim(bridge::GetXlaTensor(self), {dim}, keepdim));
}

at::Tensor XLANativeFunctions::amax(const at::Tensor& self, at::IntArrayRef dim,
bool keepdim) {
XLA_FN_COUNTER("xla::");
auto xdim = XlaHelpers::I64List(dim);
return bridge::AtenFromXlaTensor(
XLATensor::amax(bridge::GetXlaTensor(self), std::move(xdim), keepdim));
}

at::Tensor XLANativeFunctions::amin(const at::Tensor& self, at::IntArrayRef dim,
bool keepdim) {
XLA_FN_COUNTER("xla::");
auto xdim = XlaHelpers::I64List(dim);
return bridge::AtenFromXlaTensor(
XLATensor::amin(bridge::GetXlaTensor(self), std::move(xdim), keepdim));
}

at::Tensor XLANativeFunctions::any(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
Expand Down
46 changes: 0 additions & 46 deletions torch_xla/csrc/ops/amax.cpp

This file was deleted.

27 changes: 0 additions & 27 deletions torch_xla/csrc/ops/amax.h

This file was deleted.

46 changes: 0 additions & 46 deletions torch_xla/csrc/ops/amin.cpp

This file was deleted.

27 changes: 0 additions & 27 deletions torch_xla/csrc/ops/amin.h

This file was deleted.

11 changes: 10 additions & 1 deletion torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "torch_xla/csrc/reduction.h"

namespace torch_xla {

torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildAbs(xla_input), loctx);
Expand Down Expand Up @@ -60,6 +59,16 @@ torch_xla::XlaOpVector All::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildAll(input, dimensions, false), loctx);
}

torch_xla::XlaOpVector Amax::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildMaxInDims(input, dim, keepdim), loctx);
}

torch_xla::XlaOpVector Amin::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildMinInDims(input, dim, keepdim), loctx);
}

torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Asin(xla_input), loctx);
Expand Down
15 changes: 15 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ xla::Shape AdaptiveAvgPool2dOutputShape(const torch::lazy::Value& input,
[output_size](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 1);
return BuildAdaptiveAvgPool2d(operands[0], output_size);
}
}
xla::Shape AmaxOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildMaxInDims(operands[0], dim, keepdim);
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}
Expand All @@ -64,6 +71,14 @@ xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
[output_size](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 1);
return BuildAdaptiveAvgPool3d(operands[0], output_size);
}
}

xla::Shape AminOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildMinInDims(operands[0], dim, keepdim);
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,

xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);

xla::Shape AmaxOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim);

xla::Shape AminOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim);

xla::Shape AllOutputShape(const torch::lazy::Value& input);

Expand Down
24 changes: 0 additions & 24 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
#include "torch_xla/csrc/ops/all_gather.h"
#include "torch_xla/csrc/ops/all_reduce.h"
#include "torch_xla/csrc/ops/all_to_all.h"
#include "torch_xla/csrc/ops/amax.h"
#include "torch_xla/csrc/ops/amin.h"
#include "torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h"
#include "torch_xla/csrc/ops/amp_update_scale.h"
#include "torch_xla/csrc/ops/any.h"
Expand Down Expand Up @@ -713,28 +711,6 @@ XLATensorPtr XLATensor::all_dim(const XLATensorPtr& input,
result_type);
}

XLATensorPtr XLATensor::amax(const XLATensorPtr& input,
std::vector<int64_t> dimensions,
bool keep_reduced_dimensions) {
return input->CreateFrom(
torch::lazy::MakeNode<Amax>(input->GetIrValue(),
torch::lazy::GetCanonicalDimensionIndices(
xla::util::ToVector<int64_t>(dimensions),
input->shape().get().rank()),
keep_reduced_dimensions));
}

XLATensorPtr XLATensor::amin(const XLATensorPtr& input,
std::vector<int64_t> dimensions,
bool keep_reduced_dimensions) {
return input->CreateFrom(
torch::lazy::MakeNode<Amin>(input->GetIrValue(),
torch::lazy::GetCanonicalDimensionIndices(
xla::util::ToVector<int64_t>(dimensions),
input->shape().get().rank()),
keep_reduced_dimensions));
}

XLATensorPtr XLATensor::any(const XLATensorPtr& input,
std::vector<int64_t> dimensions,
bool keep_reduced_dimensions) {
Expand Down
4 changes: 2 additions & 2 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ full_codegen:
- acosh
- abs
- all
- amax
- amin
- asin
- asinh
- atan
Expand Down Expand Up @@ -93,8 +95,6 @@ supported:
- addmm
- alias
- all.dim
- amax
- amin
- any
- any.dim
- arange.start_out
Expand Down

0 comments on commit c0193a8

Please sign in to comment.