Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codegen amin amax #3771

Merged
merged 2 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion scripts/gen_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BaseCType,
OptionalCType,
VectorCType,
boolT,
kernel_signature,
)
import pathlib
Expand All @@ -22,6 +23,10 @@
source_yaml = str(torch_xla_root / "xla_native_functions.yaml")


def is_boolean_dtype(lazy_type):
return lazy_type == BaseCType(boolT)


@dataclass(frozen=True)
class GenXlaLazyIR(GenLazyIR):

Expand All @@ -47,7 +52,7 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
shape_fn_inputs_list = [
f"{a.name}" for a in schema.positional_args
if (a.is_lazy_value or isinstance(a.lazy_type, VectorCType) or
a.name == 'reduction')
is_boolean_dtype(a.lazy_type) or a.name == 'reduction')
]
shape_fn_inputs = ", ".join(shape_fn_inputs_list)

Expand Down
16 changes: 0 additions & 16 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,22 +676,6 @@ at::Tensor XLANativeFunctions::all(const at::Tensor& self, int64_t dim,
XLATensor::all_dim(bridge::GetXlaTensor(self), {dim}, keepdim));
}

at::Tensor XLANativeFunctions::amax(const at::Tensor& self, at::IntArrayRef dim,
bool keepdim) {
XLA_FN_COUNTER("xla::");
auto xdim = XlaHelpers::I64List(dim);
return bridge::AtenFromXlaTensor(
XLATensor::amax(bridge::GetXlaTensor(self), std::move(xdim), keepdim));
}

at::Tensor XLANativeFunctions::amin(const at::Tensor& self, at::IntArrayRef dim,
bool keepdim) {
XLA_FN_COUNTER("xla::");
auto xdim = XlaHelpers::I64List(dim);
return bridge::AtenFromXlaTensor(
XLATensor::amin(bridge::GetXlaTensor(self), std::move(xdim), keepdim));
}

at::Tensor XLANativeFunctions::any(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
Expand Down
46 changes: 0 additions & 46 deletions torch_xla/csrc/ops/amax.cpp

This file was deleted.

27 changes: 0 additions & 27 deletions torch_xla/csrc/ops/amax.h

This file was deleted.

46 changes: 0 additions & 46 deletions torch_xla/csrc/ops/amin.cpp

This file was deleted.

27 changes: 0 additions & 27 deletions torch_xla/csrc/ops/amin.h

This file was deleted.

11 changes: 10 additions & 1 deletion torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "torch_xla/csrc/reduction.h"

namespace torch_xla {

torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildAbs(xla_input), loctx);
Expand Down Expand Up @@ -60,6 +59,16 @@ torch_xla::XlaOpVector All::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildAll(input, dimensions, false), loctx);
}

torch_xla::XlaOpVector Amax::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildMaxInDims(input, dim, keepdim), loctx);
}

torch_xla::XlaOpVector Amin::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildMinInDims(input, dim, keepdim), loctx);
}

torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Asin(xla_input), loctx);
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape AmaxOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildMaxInDims(operands[0], dim, keepdim);
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape AminOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildMinInDims(operands[0], dim, keepdim);
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input) {
auto lower_for_shape_fn =
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);

xla::Shape AmaxOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim);

xla::Shape AminOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim);

xla::Shape AllOutputShape(const torch::lazy::Value& input);

xla::Shape AsinOutputShape(const torch::lazy::Value& input);
Expand Down
21 changes: 15 additions & 6 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "torch/csrc/lazy/core/helpers.h"
#include "torch/csrc/lazy/core/util.h"
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/helpers.h"
Expand Down Expand Up @@ -318,17 +319,20 @@ xla::XlaOp BuildMaxInDims(xla::XlaOp input,
bool keep_reduced_dimensions) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(shape.element_type());
std::vector<int64_t> canonical_dimensions =
torch::lazy::GetCanonicalDimensionIndices(
xla::util::ToVector<int64_t>(dimensions), shape.rank());
xla::XlaOp init_value = XlaHelpers::ScalarValue(
min_max.min, shape.element_type(), input.builder());
ReductionInfo rinfo =
GetReductionInfo(input, shape, dimensions, keep_reduced_dimensions);
ReductionInfo rinfo = GetReductionInfo(input, shape, canonical_dimensions,
keep_reduced_dimensions);
if (rinfo.element_count.scalar_size) {
// When can only assert this if dimensions are not dynamic.
XLA_CHECK_GT(*rinfo.element_count.scalar_size, 0);
}
xla::XlaOp result = xla::Reduce(
input, init_value, XlaHelpers::CreateMaxComputation(shape.element_type()),
dimensions);
canonical_dimensions);
if (keep_reduced_dimensions) {
result = XlaHelpers::DynamicReshape(result, rinfo.new_dimensions);
}
Expand All @@ -345,17 +349,22 @@ xla::XlaOp BuildMinInDims(xla::XlaOp input,
bool keep_reduced_dimensions) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(shape.element_type());

std::vector<int64_t> canonical_dimensions =
torch::lazy::GetCanonicalDimensionIndices(
xla::util::ToVector<int64_t>(dimensions), shape.rank());

xla::XlaOp init_value = XlaHelpers::ScalarValue(
min_max.max, shape.element_type(), input.builder());
ReductionInfo rinfo =
GetReductionInfo(input, shape, dimensions, keep_reduced_dimensions);
ReductionInfo rinfo = GetReductionInfo(input, shape, canonical_dimensions,
keep_reduced_dimensions);
if (rinfo.element_count.scalar_size) {
// When can only assert this if dimensions are not dynamic.
XLA_CHECK_GT(*rinfo.element_count.scalar_size, 0);
}
xla::XlaOp result = xla::Reduce(
input, init_value, XlaHelpers::CreateMinComputation(shape.element_type()),
dimensions);
canonical_dimensions);
if (keep_reduced_dimensions) {
result = XlaHelpers::DynamicReshape(result, rinfo.new_dimensions);
}
Expand Down
24 changes: 0 additions & 24 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
#include "torch_xla/csrc/ops/all_gather.h"
#include "torch_xla/csrc/ops/all_reduce.h"
#include "torch_xla/csrc/ops/all_to_all.h"
#include "torch_xla/csrc/ops/amax.h"
#include "torch_xla/csrc/ops/amin.h"
#include "torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h"
#include "torch_xla/csrc/ops/amp_update_scale.h"
#include "torch_xla/csrc/ops/any.h"
Expand Down Expand Up @@ -713,28 +711,6 @@ XLATensorPtr XLATensor::all_dim(const XLATensorPtr& input,
result_type);
}

XLATensorPtr XLATensor::amax(const XLATensorPtr& input,
std::vector<int64_t> dimensions,
bool keep_reduced_dimensions) {
return input->CreateFrom(
torch::lazy::MakeNode<Amax>(input->GetIrValue(),
torch::lazy::GetCanonicalDimensionIndices(
xla::util::ToVector<int64_t>(dimensions),
input->shape().get().rank()),
keep_reduced_dimensions));
}

XLATensorPtr XLATensor::amin(const XLATensorPtr& input,
std::vector<int64_t> dimensions,
bool keep_reduced_dimensions) {
return input->CreateFrom(
torch::lazy::MakeNode<Amin>(input->GetIrValue(),
torch::lazy::GetCanonicalDimensionIndices(
xla::util::ToVector<int64_t>(dimensions),
input->shape().get().rank()),
keep_reduced_dimensions));
}

XLATensorPtr XLATensor::any(const XLATensorPtr& input,
std::vector<int64_t> dimensions,
bool keep_reduced_dimensions) {
Expand Down
4 changes: 2 additions & 2 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ full_codegen:
- acosh
- abs
- all
- amax
- amin
- asin
- asinh
- atan
Expand Down Expand Up @@ -93,8 +95,6 @@ supported:
- addmm
- alias
- all.dim
- amax
- amin
- any
- any.dim
- arange.start_out
Expand Down