Skip to content

Commit

Permalink
supplet several interface of static Variable to consistent with dygra…
Browse files Browse the repository at this point in the history
…ph Tensor (PaddlePaddle#33330)

As the title
  • Loading branch information
CtfGo authored and Aurelius84 committed Jul 26, 2021
1 parent dbc54d2 commit ecec416
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 26 deletions.
72 changes: 72 additions & 0 deletions paddle/fluid/operators/share_data_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/share_data_op.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

class ShareDataOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShareData");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShareData");
auto in_type = ctx->GetInputsVarType("X")[0];
auto out_type = ctx->GetOutputsVarType("Out")[0];

PADDLE_ENFORCE_EQ(
in_type == framework::proto::VarType::LOD_TENSOR ||
in_type == framework::proto::VarType::SELECTED_ROWS,
true, platform::errors::InvalidArgument(
"Type of Variable[X] must be LoDTensor or SelectedRows!"));
PADDLE_ENFORCE_EQ(
in_type, out_type,
platform::errors::InvalidArgument(
"The type of input (X) and output (Out) are inconsistent."));

ctx->ShareDim("X", "Out");
}
};

class ShareDataOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of share_data op");
AddOutput("Out", "(Tensor), The output tensor of share_data op");
AddComment(R"DOC(
ShareData Operator.
Return a tensor $Out$ that shares data with the input tensor $X$ and without tensor copy.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(
share_data, ops::ShareDataOp, ops::ShareDataOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(share_data, ops::ShareDataKernel<bool>,
ops::ShareDataKernel<int>, ops::ShareDataKernel<int8_t>,
ops::ShareDataKernel<uint8_t>,
ops::ShareDataKernel<paddle::platform::float16>,
ops::ShareDataKernel<int64_t>,
ops::ShareDataKernel<float>,
ops::ShareDataKernel<double>)
25 changes: 25 additions & 0 deletions paddle/fluid/operators/share_data_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/share_data_op.h"

REGISTER_OP_CUDA_KERNEL(
share_data, paddle::operators::ShareDataKernel<bool>,
paddle::operators::ShareDataKernel<int>,
paddle::operators::ShareDataKernel<int8_t>,
paddle::operators::ShareDataKernel<uint8_t>,
paddle::operators::ShareDataKernel<paddle::platform::float16>,
paddle::operators::ShareDataKernel<int64_t>,
paddle::operators::ShareDataKernel<float>,
paddle::operators::ShareDataKernel<double>);
41 changes: 41 additions & 0 deletions paddle/fluid/operators/share_data_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

template <typename T>
class ShareDataKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *in_var = ctx.InputVar("X");
auto *out_var = ctx.OutputVar("Out");
if (in_var->IsType<framework::LoDTensor>()) {
const auto &origin_tensor = in_var->Get<framework::LoDTensor>();
auto *detach_tensor = out_var->GetMutable<framework::LoDTensor>();
detach_tensor->ShareDataWith(origin_tensor);
} else {
const auto &origin_selected_rows = in_var->Get<framework::SelectedRows>();
auto *detach_selected_rows =
out_var->GetMutable<framework::SelectedRows>();
detach_selected_rows->mutable_value()->ShareDataWith(
origin_selected_rows.value());
}
}
};
} // namespace operators
} // namespace paddle
69 changes: 53 additions & 16 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,35 +942,43 @@ def __init__(self,
self._stop_gradient = stop_gradient
self.is_data = is_data

@fake_interface_only
def detach(self):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**
Returns a new Variable, detached from the current graph.
It will share data with origin Variable and without tensor copy.
In addition, the detached Variable doesn't provide gradient propagation.
Returns:
( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import Linear
import numpy as np
import paddle
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with fluid.dygraph.guard():
linear = Linear(32, 64)
data = to_variable(data)
x = linear(data)
y = x.detach()
paddle.enable_static()
# create a static Variable
x = paddle.static.data(name='x', shape=[3, 2, 1])
# create a detached Variable
y = x.detach()
"""
pass

assert self.type == core.VarDesc.VarType.SELECTED_ROWS or \
self.type == core.VarDesc.VarType.LOD_TENSOR, \
"only support a variable with SELECTED_ROWS or LOD_TENSOR to be detached"

output = self.block.create_var(
name=unique_name.generate_with_ignorable_key("detach_" + self.name),
dtype=self.dtype,
type=self.type,
persistable=self.persistable,
stop_gradient=True)

self.block.append_op(
type='share_data', inputs={'X': [self]}, outputs={'Out': [output]})
return output

@fake_interface_only
def numpy(self):
Expand Down Expand Up @@ -1805,6 +1813,35 @@ def set_value(self, value, scope=None):

t.set(value, place)

def size(self):
"""
Returns the number of elements for current Variable, which is a int64 Variable with shape [1]
Returns:
Variable: the number of elements for current Variable
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
# create a static Variable
x = paddle.static.data(name='x', shape=[3, 2, 1])
# get the number of elements of the Variable
y = x.size()
"""

output = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name + "_size"),
dtype=core.VarDesc.VarType.INT64)

self.block.append_op(
type='size', inputs={'Input': [self]}, outputs={'Out': [output]})
return output


def get_all_op_protos():
"""
Expand Down
34 changes: 31 additions & 3 deletions python/paddle/fluid/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"__rpow__": "A **= B",
"__floordiv__": "A //B",
"__mod__": "A % B",
"__matmul__": "A @ B",
"__eq__": "A == B",
"__ne__": "A != B",
"__lt__": "A < B",
Expand Down Expand Up @@ -197,6 +198,28 @@ def _scalar_op_(var, scale, bias):
def _neg_(var):
return _scalar_op_(var, -1.0, 0.0)

@property
def _ndim_(self):
"""
Returns the dimension of current Variable
Returns:
the dimension
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
# create a static Variable
x = paddle.static.data(name='x', shape=[3, 2, 1])
# print the dimension of the Variable
print(x.ndim)
"""
return len(self.shape)

def _scalar_add_(var, value):
return _scalar_op_(var, 1.0, value)

Expand Down Expand Up @@ -233,9 +256,9 @@ def __impl__(self, other_var):
other_var = float(other_var)
# division is a special case
# NOTE(chenweihang): because we cast tensor to float32 instead float64,
# the division result can only guarantee the numerical accuracy of 6 digits
# after the decimal point. The result of numpy calculation is of float64 type,
# so the calculation result here and the calculation result of numpy are
# the division result can only guarantee the numerical accuracy of 6 digits
# after the decimal point. The result of numpy calculation is of float64 type,
# so the calculation result here and the calculation result of numpy are
# different after 6 decimal point. If necessary, we can also use float64 here.
# torch's behavior here is consistent with ours
if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_:
Expand Down Expand Up @@ -323,6 +346,9 @@ def __impl__(self, other_var):
# b=-a
('__neg__', _neg_),
('astype', astype),
('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)),
('ndim', _ndim_),
('__add__', _binary_creator_('__add__', 'elementwise_add', False,
_scalar_add_)),
# a+b == b+a. Do not need to reverse explicitly
Expand Down Expand Up @@ -353,6 +379,8 @@ def __impl__(self, other_var):
'elementwise_floordiv', False, None)),
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
None)),
('__matmul__', _binary_creator_('__matmul__', "matmul_v2", False,
None)),
# for logical compare
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
Expand Down
6 changes: 0 additions & 6 deletions python/paddle/fluid/tests/unittests/test_detach.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,6 @@ def test_NoDetachSingle_DetachMulti(self):
array_detach_multi = self.detach_multi()
assert np.array_equal(array_no_detach_single, array_detach_multi)

def test_detach_exception(self):
x = fluid.layers.data(name="a", shape=[3, 4], dtype='float32')
y = fluid.layers.fc(input=x, size=10, bias_attr=True)
with self.assertRaises(AssertionError):
y_detach = y.detach()


class TestInplace(unittest.TestCase):
def test_forward_version(self):
Expand Down
Loading

0 comments on commit ecec416

Please sign in to comment.