Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry pick] Einsum memory optimization PR #43397 #43554

Merged
merged 2 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/eager/nan_inf_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
}

} // namespace egr
3 changes: 2 additions & 1 deletion paddle/fluid/eager/nan_inf_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfTensorAndVector = std::tuple<Tensor, std::vector<Tensor>>;
using TupleOfTensorAndVector =
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>>;

void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);

Expand Down
19 changes: 15 additions & 4 deletions paddle/fluid/operators/einsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsExtra()
.AsIntermediate();

AddOutput("XShape", "(Tensor), The cache of the x_shape of: A and B.")
.AsDuplicable()
.AsExtra()
.AsIntermediate();
AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in "
Expand All @@ -58,8 +62,8 @@ class EinsumGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
auto x_name = "Operands";
auto x_grad_name = framework::GradVarName(x_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name));
ctx->ShareAllLoD(x_name, x_grad_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim("Operands"));
ctx->ShareAllLoD("Operands", x_grad_name);
}

protected:
Expand All @@ -78,8 +82,15 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands"));
retv->SetInput("InnerCache", this->Output("InnerCache"));
if (this->HasOutput("InnerCache")) {
retv->SetInput("InnerCache", this->Output("InnerCache"));
}
if (this->HasOutput("XShape")) {
// add if for compatibility.
retv->SetInput("Operands", this->Output("XShape")); // for memory save.
} else {
retv->SetInput("Operands", this->Input("Operands"));
}
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"),
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache) {
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
Expand Down Expand Up @@ -439,6 +440,12 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype());
for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims());
xshape[i]->set_dtype(inputs[i]->dtype());
}
}
}

void ExpandInferMeta(const MetaTensor& x,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache);
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/einsum_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache);
std::vector<DenseTensor*> inner_cache,
std::vector<DenseTensor*> xshape);

} // namespace phi
1 change: 0 additions & 1 deletion paddle/phi/kernels/impl/einsum_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ void EinsumGradKernel(const Context& dev_ctx,
cache[0].ShareBufferWith(*(inner_cache[0]));
cache[1].ShareBufferWith(*(inner_cache[1]));
}

EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_A,
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/impl/einsum_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ DenseTensor PerformContraction(
}
// reduction
DenseTensor trans_t;
if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr &&
if (use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
Expand All @@ -467,7 +467,7 @@ DenseTensor PerformContraction(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (FLAGS_einsum_opt && cache[operand_idx] != nullptr)
if (cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
Expand Down Expand Up @@ -598,6 +598,11 @@ void EinsumKernelImpl(const Context& dev_ctx,
out);
// Reshape Procedure
} else if (inputs.size() == 1) {
if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if
// loading the program from v2.3.0
(*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward
// we can only see cached tensor.
}
auto reduce_A = PerformReduction<T, Context>(dev_ctx,
*inputs[0],
label2perms[0],
Expand Down Expand Up @@ -626,7 +631,8 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
std::vector<DenseTensor*> cache,
std::vector<DenseTensor*> xshape) {
std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/compat/einsum_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace phi {

KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"});
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"});
}

KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/fluid/tests/unittests/test_einsum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def setUp(self):
self.outputs = {
'Out': out,
"InnerCache": [('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))]
for i in range(len(self.operands))],
"XShape": [('xshape_' + str(i), np.array([1.0]))
for i in range(len(self.operands))],
}

def init_input(self):
Expand All @@ -46,14 +48,13 @@ def init_input(self):
self.inputs.append(np.random.random(s).astype(t))

def set_mandatory(self):
self.disable = False
self.shapes = [(10, 10, 20), (20, 6)]
self.types = [np.float64, np.float64]
self.equation = "mij,jk->ki"

def test_check_output(self):
if not self.disable:
self.check_output(no_check_set=["InnerCache"])
self.check_output(no_check_set=["InnerCache", "XShape"])

def test_grad(self):
if not self.disable:
Expand Down
12 changes: 9 additions & 3 deletions python/paddle/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,9 +802,10 @@ def gen_einsum_op(equation, *operands):

if _in_legacy_dygraph():
# dygraph
return _C_ops.einsum(operands, len(operands), 'equation', equation)[0]
return _C_ops.einsum(operands,
len(operands), len(operands), 'equation',
equation)[0]

# static graph
for inp in operands:
check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum')
check_type(equation, 'equation', str, 'einsum')
Expand All @@ -816,11 +817,16 @@ def gen_einsum_op(equation, *operands):
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
xshape = [
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
helper.append_op(
type='einsum',
inputs={'Operands': operands},
outputs={'Out': out,
"InnerCache": caches},
"InnerCache": caches,
"XShape": xshape},
attrs=attrs)
return out

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@

- api : einsum
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta :
func : EinsumInferMeta
param : [x, equation]
Expand Down
20 changes: 16 additions & 4 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
- backward_api : abs_double_grad
forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad)
output : Tensor(grad_out_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : abs_double_grad
data_transform:
skip_transform : grad_x_grad

- backward_api : abs_grad
forward : abs (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down Expand Up @@ -447,12 +459,12 @@
skip_transform : out_w, out_w_grad

- backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()}
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x_shape.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [x]
param : [x_shape]
kernel :
func : einsum_grad

Expand Down