diff --git a/paddle/fluid/eager/nan_inf_utils.cc b/paddle/fluid/eager/nan_inf_utils.cc index d1c5983a3702f..0ed1a198c916d 100644 --- a/paddle/fluid/eager/nan_inf_utils.cc +++ b/paddle/fluid/eager/nan_inf_utils.cc @@ -114,6 +114,7 @@ void CheckTensorHasNanOrInf(const std::string& api_name, const TupleOfTensorAndVector& tensors) { CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); } } // namespace egr diff --git a/paddle/fluid/eager/nan_inf_utils.h b/paddle/fluid/eager/nan_inf_utils.h index a411504fa4900..815e3bd6cd14f 100644 --- a/paddle/fluid/eager/nan_inf_utils.h +++ b/paddle/fluid/eager/nan_inf_utils.h @@ -31,7 +31,8 @@ using TupleOfFourTensors = std::tuple; using TupleOfFiveTensors = std::tuple; using TupleOfSixTensors = std::tuple; -using TupleOfTensorAndVector = std::tuple>; +using TupleOfTensorAndVector = + std::tuple, std::vector>; void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor); diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 7fc19d6913f83..95f841f7797b9 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -41,6 +41,10 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { .AsExtra() .AsIntermediate(); + AddOutput("XShape", "(Tensor), The cache of the x_shape of: A and B.") + .AsDuplicable() + .AsExtra() + .AsIntermediate(); AddAttr("equation", "(string) A einsum equation. such as `ij,jk->ik`" "There must have `->` and the number of operands in " @@ -59,8 +63,8 @@ class EinsumGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { auto x_name = "Operands"; auto x_grad_name = framework::GradVarName(x_name); - ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name)); - ctx->ShareAllLoD(x_name, x_grad_name); + ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim("Operands")); + ctx->ShareAllLoD("Operands", x_grad_name); } protected: @@ -79,8 +83,15 @@ class EinsumGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr retv) const override { retv->SetType("einsum_grad"); - retv->SetInput("Operands", this->Input("Operands")); - retv->SetInput("InnerCache", this->Output("InnerCache")); + if (this->HasOutput("InnerCache")) { + retv->SetInput("InnerCache", this->Output("InnerCache")); + } + if (this->HasOutput("XShape")) { + // add if for compatibility. + retv->SetInput("Operands", this->Output("XShape")); // for memory save. + } else { + retv->SetInput("Operands", this->Input("Operands")); + } retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetAttrMap(this->Attrs()); retv->SetOutput(framework::GradVarName("Operands"), diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 0beb7223f212a..018a2a6b50f11 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -402,7 +402,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, MetaTensor* out, - std::vector inner_cache) { + std::vector inner_cache, + std::vector xshape) { // collect the following informations to prepare einsum. LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); @@ -439,6 +440,12 @@ void EinsumInferMeta(const std::vector& inputs, VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); out->set_dims(make_ddim(output_dims)); out->set_dtype(inputs[0]->dtype()); + for (size_t i = 0; i < xshape.size(); ++i) { + if (xshape[i] != nullptr) { + xshape[i]->set_dims(inputs[i]->dims()); + xshape[i]->set_dtype(inputs[i]->dtype()); + } + } } void ExpandInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a288b9371016f..52818e32c3720 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -83,7 +83,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, MetaTensor* out, - std::vector inner_cache); + std::vector inner_cache, + std::vector xshape); void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, diff --git a/paddle/phi/kernels/einsum_kernel.h b/paddle/phi/kernels/einsum_kernel.h index 87df2b1c64a4a..569cf7a55afd4 100644 --- a/paddle/phi/kernels/einsum_kernel.h +++ b/paddle/phi/kernels/einsum_kernel.h @@ -29,6 +29,7 @@ void EinsumKernelRaw(const Context& dev_ctx, const std::vector& inputs, const std::string& equation, DenseTensor* out, - std::vector cache); + std::vector inner_cache, + std::vector xshape); } // namespace phi diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index a72db326807f8..a04185a0c53ed 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -177,7 +177,6 @@ void EinsumGradKernel(const Context& dev_ctx, cache[0].ShareBufferWith(*(inner_cache[0])); cache[1].ShareBufferWith(*(inner_cache[1])); } - EinsumKernelImpl(dev_ctx, all_labels, operands_for_A, diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index f3521c81ce46b..43b2760b404f9 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -459,7 +459,7 @@ DenseTensor PerformContraction( } // reduction DenseTensor trans_t; - if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr && + if (use_cache && cache[operand_idx] != nullptr && cache[operand_idx]->IsInitialized()) { trans_t.ShareBufferWith(*(cache[operand_idx])); VLOG(5) << "Cache Used!"; @@ -468,7 +468,7 @@ DenseTensor PerformContraction( dev_ctx, t, perm, all_labels, ellipsis, label2type); trans_t = PerformTranspose( dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); - if (FLAGS_einsum_opt && cache[operand_idx] != nullptr) + if (cache[operand_idx] != nullptr) cache[operand_idx]->ShareBufferWith(trans_t); } auto mul_dims = GetShapeByType(all_labels, @@ -599,6 +599,11 @@ void EinsumKernelImpl(const Context& dev_ctx, out); // Reshape Procedure } else if (inputs.size() == 1) { + if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if + // loading the program from v2.3.0 + (*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward + // we can only see cached tensor. + } auto reduce_A = PerformReduction(dev_ctx, *inputs[0], label2perms[0], @@ -627,7 +632,8 @@ void EinsumKernelRaw(const Context& dev_ctx, const std::vector& inputs, const std::string& equation, DenseTensor* out, - std::vector cache) { + std::vector cache, + std::vector xshape) { std::vector tmp; // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output // may have nullptr and the cache.size() is not equal to inputs.size(). refer diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index 5e45bcf97ce0e..4fd31c1a2d842 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -18,7 +18,7 @@ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature( - "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"}); + "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"}); } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { diff --git a/python/paddle/fluid/tests/unittests/test_einsum_op.py b/python/paddle/fluid/tests/unittests/test_einsum_op.py index c36950b6922fe..e34d04be927cc 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_op.py @@ -39,7 +39,9 @@ def setUp(self): 'Out': out, "InnerCache": [('cache_' + str(i), np.array([1.0])) - for i in range(len(self.operands))] + for i in range(len(self.operands))], + "XShape": [('xshape_' + str(i), np.array([1.0])) + for i in range(len(self.operands))], } def init_input(self): @@ -48,14 +50,13 @@ def init_input(self): self.inputs.append(np.random.random(s).astype(t)) def set_mandatory(self): - self.disable = False self.shapes = [(10, 10, 20), (20, 6)] self.types = [np.float64, np.float64] self.equation = "mij,jk->ki" def test_check_output(self): if not self.disable: - self.check_output(no_check_set=["InnerCache"]) + self.check_output(no_check_set=["InnerCache", "XShape"]) def test_grad(self): if not self.disable: diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 0cdced2cf9b84..34a1ead2cb497 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -807,9 +807,9 @@ def gen_einsum_op(equation, *operands): if _in_legacy_dygraph(): # dygraph - return _C_ops.einsum(operands, len(operands), 'equation', equation)[0] + return _C_ops.einsum(operands, len(operands), len(operands), 'equation', + equation)[0] - # static graph for inp in operands: check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') check_type(equation, 'equation', str, 'einsum') @@ -821,11 +821,16 @@ def gen_einsum_op(equation, *operands): helper.create_variable_for_type_inference(dtype=operands[0].dtype) for i in range(len(operands)) ] + xshape = [ + helper.create_variable_for_type_inference(dtype=operands[0].dtype) + for i in range(len(operands)) + ] helper.append_op(type='einsum', inputs={'Operands': operands}, outputs={ 'Out': out, - "InnerCache": caches + "InnerCache": caches, + "XShape": xshape }, attrs=attrs) return out diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index fd000567c507b..27929c593be70 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -602,7 +602,7 @@ - api : einsum args : (Tensor[] x, str equation) - output : Tensor, Tensor[]{x.size()} + output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} infer_meta : func : EinsumInferMeta param : [x, equation] diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 81641ac19f7b5..b3fc33961f252 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1,3 +1,14 @@ +#- backward_api : einsum_grad + + #forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache) + #args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation) + #output : Tensor[](x_grad){x.size()} + #infer_meta : + #func : UnchangedMultiInferMeta + #param : [x] + #kernel : + #func : einsum_grad + - backward_api : abs_double_grad forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x) args : (Tensor x, Tensor grad_x_grad) @@ -611,12 +622,12 @@ skip_transform : out_w, out_w_grad - backward_api : einsum_grad - forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache) - args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation) + forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape) + args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation) output : Tensor[](x_grad){x.size()} infer_meta : func : UnchangedMultiInferMeta - param : [x] + param : [x_shape] kernel : func : einsum_grad