Skip to content

Commit

Permalink
Implement implicit broadcast for binary operation of dynamic shapes.
Browse files Browse the repository at this point in the history
- This cl instructs dynamic padder to insert implicit broadcasts into the graph when a binary operation is performed on two dynamic tensors.
- Optimization #1: The implicit broadcast is only inserted when we can't proof two dynamic dimensions are the same.
- Optimization #2: Added a simplification pass that allows us to simplify operations on dynamic dimensions, this opens up more opportunities for optimization #1

PiperOrigin-RevId: 355539597
Change-Id: I7753550a6057155c3f436c6b51b356cb48c945e6
  • Loading branch information
yunxing authored and tensorflower-gardener committed Feb 4, 2021
1 parent 60f2d95 commit ab8d6cc
Show file tree
Hide file tree
Showing 6 changed files with 628 additions and 2 deletions.
41 changes: 41 additions & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2794,6 +2794,47 @@ cc_library(
],
)

cc_library(
name = "dynamic_dimension_simplifier",
srcs = ["dynamic_dimension_simplifier.cc"],
hdrs = ["dynamic_dimension_simplifier.h"],
deps = [
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:status_macros",
],
)

tf_cc_test(
name = "dynamic_dimension_simplifier_test",
srcs = ["dynamic_dimension_simplifier_test.cc"],
deps = [
":dynamic_dimension_simplifier",
":hlo",
":hlo_casting_utils",
":hlo_creation_utils",
":hlo_parser",
":hlo_pass",
":hlo_pass_pipeline",
":pattern_matcher",
":pattern_matcher_gmock",
":shape_inference",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)

cc_library(
name = "dynamic_padder",
srcs = ["dynamic_padder.cc"],
Expand Down
38 changes: 36 additions & 2 deletions tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate(
}

Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
HloInstruction*) {
HloInstruction* gds) {
// Dynamic dimension doesn't propagate through GetDimensionSize:
//
// Input: F32[x, y, z]
Expand All @@ -646,6 +646,24 @@ Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
// The returned value is a scalar, which doesn't have any dynamic dimension in
// the shape (although the value contains the real size of the dynamic
// dimension of the input).
int64 dim = gds->dimension();
HloInstruction* operand = gds->mutable_operand(0);
HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
HloComputation* computation = gds->parent();
if (dynamic_size != nullptr) {
TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
// The dependency between an instruction and its dynamic dimensions is not
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
// dynamic dimension inference that the instruction is being replaced.
parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
} else {
TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
int32 size = gds->operand(0)->shape().dimensions(dim);
HloInstruction* new_instr = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
}
return Status::OK();
}

Expand Down Expand Up @@ -794,7 +812,23 @@ Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {

Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
HloInstruction* hlo) {
return PassThroughDynamicDimension(hlo);
HloComputation* comp = hlo->parent();
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size) {
HloInstruction* existing_size =
parent_->GetDynamicSize(hlo, index, dimension);
if (existing_size == nullptr || existing_size == dynamic_size) {
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
} else {
HloInstruction* max =
comp->AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
dynamic_size, existing_size));
parent_->SetDynamicSize(hlo, index, dimension, max);
}
return Status::OK();
});
}

Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
Expand Down
214 changes: 214 additions & 0 deletions tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"

#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"

namespace xla {
namespace {

// Concat(Concat(A, B), C) => Concat(A, B, C)
StatusOr<bool> ConcatForwarding(HloInstruction* concat) {
if (concat->opcode() != HloOpcode::kConcatenate) {
return false;
}
bool changed = false;

auto parent = concat->parent();
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : concat->operands()) {
if (operand->opcode() != HloOpcode::kConcatenate ||
operand->concatenate_dimension() != concat->concatenate_dimension()) {
new_operands.push_back(operand);
} else {
changed = true;
for (HloInstruction* operand_operand : operand->operands()) {
new_operands.push_back(operand_operand);
}
}
}
if (changed) {
auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate(
concat->shape(), new_operands, concat->concatenate_dimension()));
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat));
}
return changed;
}

// Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An
StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) {
if (slice->opcode() != HloOpcode::kSlice) {
return false;
}
auto concat = slice->mutable_operand(0);
if (concat->opcode() != HloOpcode::kConcatenate) {
return false;
}

if (slice->shape().rank() != 1) {
// Slice concat forwarding only work for size 1 tensor.
return false;
}

int64 concat_dim = concat->concatenate_dimension();

std::vector<HloInstruction*> new_operands;
int64 size_so_far = 0;
int64 slice_size = slice->shape().dimensions(concat_dim);
if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) {
return false;
}
if (slice->slice_strides(0) != 1) {
return false;
}
for (HloInstruction* operand : concat->operands()) {
if (size_so_far == slice->slice_starts(0) &&
operand->shape().dimensions(0) == slice_size) {
// Found an operand that can be forwarded.
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand));
return true;
}
size_so_far += operand->shape().dimensions(concat_dim);
}

return false;
}

// Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A
StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto broadcast = reshape->mutable_operand(0);
if (broadcast->opcode() != HloOpcode::kBroadcast) {
return false;
}

if (reshape->shape().rank() != 0) {
return false;
}

if (broadcast->shape().rank() != 1) {
return false;
}

if (broadcast->mutable_operand(0)->shape().rank() != 0) {
return false;
}

TF_RETURN_IF_ERROR(
reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0)));

return true;
}

// Reshape(Reshape(A, []->[1]), [1]->[]) ==> A
StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto reshape_2 = reshape->mutable_operand(0);
if (reshape_2->opcode() != HloOpcode::kReshape) {
return false;
}

if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) {
return false;
}
TF_RETURN_IF_ERROR(
reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0)));

return true;
}

// Convert(A, T->T) ==> A
StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) {
if (convert->opcode() != HloOpcode::kConvert) {
return false;
}
auto operand = convert->mutable_operand(0);
if (Shape::Equal()(convert->shape(), operand->shape())) {
TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand));
return true;
}
return false;
}

// Reshape(A, S->S) ==> A
StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) {
if (reshape->opcode() != HloOpcode::kReshape) {
return false;
}
auto operand = reshape->mutable_operand(0);
if (Shape::Equal()(reshape->shape(), operand->shape())) {
TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand));
return true;
}
return false;
}

} // namespace

StatusOr<bool> DynamicDimensionSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(
2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString());
bool changed = false;

for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst));
changed |= local_changed;
}
}

for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst));
changed |= local_changed;
}
}

for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst));
changed |= local_changed;
}
}
for (auto* comp : module->MakeNonfusionComputations()) {
for (auto* inst : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst));
changed |= local_changed;
}
}
XLA_VLOG_LINES(
2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString());
return changed;
}
} // namespace xla
37 changes: 37 additions & 0 deletions tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_

#include <utility>

#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"

namespace xla {

// This pass simplifies operations on dynamic dimension sizes so that it can be
// easily analyzed by later passes.
class DynamicDimensionSimplifier : public HloModulePass {
public:
absl::string_view name() const override {
return "dynamic dimension simplifier";
}

StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla

#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
Loading

0 comments on commit ab8d6cc

Please sign in to comment.