forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement implicit broadcast for binary operation of dynamic shapes.
- This cl instructs dynamic padder to insert implicit broadcasts into the graph when a binary operation is performed on two dynamic tensors. - Optimization #1: The implicit broadcast is only inserted when we can't proof two dynamic dimensions are the same. - Optimization #2: Added a simplification pass that allows us to simplify operations on dynamic dimensions, this opens up more opportunities for optimization #1 PiperOrigin-RevId: 355539597 Change-Id: I7753550a6057155c3f436c6b51b356cb48c945e6
- Loading branch information
1 parent
60f2d95
commit ab8d6cc
Showing
6 changed files
with
628 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
214 changes: 214 additions & 0 deletions
214
tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h" | ||
|
||
#include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||
#include "tensorflow/compiler/xla/service/hlo_opcode.h" | ||
#include "tensorflow/compiler/xla/status_macros.h" | ||
|
||
namespace xla { | ||
namespace { | ||
|
||
// Concat(Concat(A, B), C) => Concat(A, B, C) | ||
StatusOr<bool> ConcatForwarding(HloInstruction* concat) { | ||
if (concat->opcode() != HloOpcode::kConcatenate) { | ||
return false; | ||
} | ||
bool changed = false; | ||
|
||
auto parent = concat->parent(); | ||
std::vector<HloInstruction*> new_operands; | ||
for (HloInstruction* operand : concat->operands()) { | ||
if (operand->opcode() != HloOpcode::kConcatenate || | ||
operand->concatenate_dimension() != concat->concatenate_dimension()) { | ||
new_operands.push_back(operand); | ||
} else { | ||
changed = true; | ||
for (HloInstruction* operand_operand : operand->operands()) { | ||
new_operands.push_back(operand_operand); | ||
} | ||
} | ||
} | ||
if (changed) { | ||
auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate( | ||
concat->shape(), new_operands, concat->concatenate_dimension())); | ||
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat)); | ||
} | ||
return changed; | ||
} | ||
|
||
// Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An | ||
StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) { | ||
if (slice->opcode() != HloOpcode::kSlice) { | ||
return false; | ||
} | ||
auto concat = slice->mutable_operand(0); | ||
if (concat->opcode() != HloOpcode::kConcatenate) { | ||
return false; | ||
} | ||
|
||
if (slice->shape().rank() != 1) { | ||
// Slice concat forwarding only work for size 1 tensor. | ||
return false; | ||
} | ||
|
||
int64 concat_dim = concat->concatenate_dimension(); | ||
|
||
std::vector<HloInstruction*> new_operands; | ||
int64 size_so_far = 0; | ||
int64 slice_size = slice->shape().dimensions(concat_dim); | ||
if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) { | ||
return false; | ||
} | ||
if (slice->slice_strides(0) != 1) { | ||
return false; | ||
} | ||
for (HloInstruction* operand : concat->operands()) { | ||
if (size_so_far == slice->slice_starts(0) && | ||
operand->shape().dimensions(0) == slice_size) { | ||
// Found an operand that can be forwarded. | ||
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand)); | ||
return true; | ||
} | ||
size_so_far += operand->shape().dimensions(concat_dim); | ||
} | ||
|
||
return false; | ||
} | ||
|
||
// Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A | ||
StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) { | ||
if (reshape->opcode() != HloOpcode::kReshape) { | ||
return false; | ||
} | ||
auto broadcast = reshape->mutable_operand(0); | ||
if (broadcast->opcode() != HloOpcode::kBroadcast) { | ||
return false; | ||
} | ||
|
||
if (reshape->shape().rank() != 0) { | ||
return false; | ||
} | ||
|
||
if (broadcast->shape().rank() != 1) { | ||
return false; | ||
} | ||
|
||
if (broadcast->mutable_operand(0)->shape().rank() != 0) { | ||
return false; | ||
} | ||
|
||
TF_RETURN_IF_ERROR( | ||
reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0))); | ||
|
||
return true; | ||
} | ||
|
||
// Reshape(Reshape(A, []->[1]), [1]->[]) ==> A | ||
StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) { | ||
if (reshape->opcode() != HloOpcode::kReshape) { | ||
return false; | ||
} | ||
auto reshape_2 = reshape->mutable_operand(0); | ||
if (reshape_2->opcode() != HloOpcode::kReshape) { | ||
return false; | ||
} | ||
|
||
if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) { | ||
return false; | ||
} | ||
TF_RETURN_IF_ERROR( | ||
reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0))); | ||
|
||
return true; | ||
} | ||
|
||
// Convert(A, T->T) ==> A | ||
StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) { | ||
if (convert->opcode() != HloOpcode::kConvert) { | ||
return false; | ||
} | ||
auto operand = convert->mutable_operand(0); | ||
if (Shape::Equal()(convert->shape(), operand->shape())) { | ||
TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand)); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
// Reshape(A, S->S) ==> A | ||
StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) { | ||
if (reshape->opcode() != HloOpcode::kReshape) { | ||
return false; | ||
} | ||
auto operand = reshape->mutable_operand(0); | ||
if (Shape::Equal()(reshape->shape(), operand->shape())) { | ||
TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand)); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
} // namespace | ||
|
||
StatusOr<bool> DynamicDimensionSimplifier::Run(HloModule* module) { | ||
XLA_VLOG_LINES( | ||
2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString()); | ||
bool changed = false; | ||
|
||
for (auto* comp : module->MakeNonfusionComputations()) { | ||
for (auto* inst : comp->MakeInstructionPostOrder()) { | ||
TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst)); | ||
changed |= local_changed; | ||
} | ||
} | ||
|
||
for (auto* comp : module->MakeNonfusionComputations()) { | ||
for (auto* inst : comp->MakeInstructionPostOrder()) { | ||
TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst)); | ||
changed |= local_changed; | ||
} | ||
} | ||
|
||
for (auto* comp : module->MakeNonfusionComputations()) { | ||
for (auto* inst : comp->MakeInstructionPostOrder()) { | ||
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst)); | ||
changed |= local_changed; | ||
} | ||
} | ||
for (auto* comp : module->MakeNonfusionComputations()) { | ||
for (auto* inst : comp->MakeInstructionPostOrder()) { | ||
TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst)); | ||
changed |= local_changed; | ||
} | ||
} | ||
for (auto* comp : module->MakeNonfusionComputations()) { | ||
for (auto* inst : comp->MakeInstructionPostOrder()) { | ||
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst)); | ||
changed |= local_changed; | ||
} | ||
} | ||
for (auto* comp : module->MakeNonfusionComputations()) { | ||
for (auto* inst : comp->MakeInstructionPostOrder()) { | ||
TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst)); | ||
changed |= local_changed; | ||
} | ||
} | ||
XLA_VLOG_LINES( | ||
2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString()); | ||
return changed; | ||
} | ||
} // namespace xla |
37 changes: 37 additions & 0 deletions
37
tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ | ||
#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ | ||
|
||
#include <utility> | ||
|
||
#include "tensorflow/compiler/xla/service/hlo_module.h" | ||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" | ||
|
||
namespace xla { | ||
|
||
// This pass simplifies operations on dynamic dimension sizes so that it can be | ||
// easily analyzed by later passes. | ||
class DynamicDimensionSimplifier : public HloModulePass { | ||
public: | ||
absl::string_view name() const override { | ||
return "dynamic dimension simplifier"; | ||
} | ||
|
||
StatusOr<bool> Run(HloModule* module) override; | ||
}; | ||
} // namespace xla | ||
|
||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ |
Oops, something went wrong.