Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] pir dy2st unittest verification - Part 9 #59232

Merged
merged 11 commits into from
Nov 29, 2023
25 changes: 12 additions & 13 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,19 +1738,18 @@ def clear_gradient(self):
>>> import numpy as np

>>> x = np.ones([2, 2], np.float32)
>>> with base.dygraph.guard():
... inputs2 = []
... for _ in range(10):
... tmp = base.dygraph.base.to_variable(x)
... tmp.stop_gradient=False
... inputs2.append(tmp)
... ret2 = paddle.add_n(inputs2)
... loss2 = paddle.sum(ret2)
... loss2.retain_grads()
... loss2.backward()
... print(loss2.gradient())
... loss2.clear_gradient()
... print("After clear {}".format(loss2.gradient()))
>>> inputs2 = []
>>> for _ in range(10):
>>> tmp = base.dygraph.base.to_variable(x)
>>> tmp.stop_gradient=False
>>> inputs2.append(tmp)
>>> ret2 = paddle.add_n(inputs2)
>>> loss2 = paddle.sum(ret2)
>>> loss2.retain_grads()
>>> loss2.backward()
>>> print(loss2.gradient())
>>> loss2.clear_gradient()
>>> print("After clear {}".format(loss2.gradient()))
1.0
After clear 0.0
"""
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@


def convert_attr(x, attr):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif 好像没必要存在?直接删掉即可

这里应该是为了对齐动态图 x.size 和老 IR x.size() 动静不统一的问题的,在 PIR 下动静统一,所以 convert_attr 以及 AttributeJstTransformer 应该都是可以清理了的

我觉得我们可以记一个 TODO(cleanup-legacy-ir),说明下老 IR 退场时是可以直接删掉相关 convert 和 Transformer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已直接修改

if isinstance(x, (Variable, OpResult)) and attr == "size":
if isinstance(x, Variable) and attr == "size":
return x.size()
elif isinstance(x, OpResult) and attr == "size":
return x.size
else:
return getattr(x, attr)

Expand Down
54 changes: 54 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import warnings

from paddle.base.libpaddle import DataType
from paddle.base.wrapped_decorator import wrap_decorator

from . import OpResult

Expand All @@ -31,6 +32,21 @@
]


def _fake_interface_only_(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

framework.py中有一个相同的_fake_interface_only_函数,能直接用framework.py中的那个函数吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会循环引用emmm,得解了才能复用

def __impl__(*args, **kwargs):
raise AssertionError(
f"'{func.__name__}' only can be called by `paddle.Tensor` in dynamic graph mode. Suggestions:\n"
" 1. If you are in static graph mode, you can switch to dynamic graph mode by turning off `paddle.enable_static()` or calling `paddle.disable_static()`.\n"
" 2. If you are using `@paddle.jit.to_static`, you can call `paddle.jit.enable_to_static(False)`. "
f"If you have to translate dynamic graph to static graph, please use other API to replace '{func.__name__}'."
)

return __impl__


fake_interface_only = wrap_decorator(_fake_interface_only_)


def create_tensor_with_batchsize(ref_var, value, dtype):
assert isinstance(ref_var, OpResult)
value = float(value)
Expand Down Expand Up @@ -356,6 +372,43 @@ def clone(self):
"""
return paddle.assign(self)

@fake_interface_only
def clear_gradient(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: 等 CI 过了让震哥来 review API 变动

"""
**Notes**:
**1. This API is ONLY available in Dygraph mode**

**2. Use it only OpResult has gradient, normally we use this for Parameters since other temporal OpResult will be deleted by Python's GC**

Clear (set to ``0`` ) the Gradient of Current OpResult

Returns: None

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.base as base
>>> import numpy as np

>>> x = np.ones([2, 2], np.float32)
>>> inputs2 = []
>>> for _ in range(10):
>>> tmp = base.dygraph.base.to_variable(x)
>>> tmp.stop_gradient=False
>>> inputs2.append(tmp)
>>> ret2 = paddle.add_n(inputs2)
>>> loss2 = paddle.sum(ret2)
>>> loss2.retain_grads()
>>> loss2.backward()
>>> print(loss2.gradient())
>>> loss2.clear_gradient()
>>> print("After clear {}".format(loss2.gradient()))
1.0
After clear 0.0
"""
pass

import paddle

opresult_methods = [
Expand All @@ -367,6 +420,7 @@ def clone(self):
('astype', astype),
('size', _size_),
('clone', clone),
('clear_gradient', clear_gradient),
(
'__add__',
_binary_creator_('__add__', paddle.tensor.add, False, _scalar_add_),
Expand Down
121 changes: 60 additions & 61 deletions test/dygraph_to_static/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@

SEED = 2020

if paddle.base.is_compiled_with_cuda():
paddle.base.set_flags({'FLAGS_cudnn_deterministic': True})
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': True})


class SimpleImgConvPool(paddle.nn.Layer):
Expand Down Expand Up @@ -135,9 +135,9 @@ def setUp(self):
self.epoch_num = 1
self.batch_size = 64
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.train_reader = paddle.batch(
paddle.dataset.mnist.train(),
Expand Down Expand Up @@ -177,11 +177,11 @@ def test_mnist_to_static(self):

def test_mnist_declarative_cpu_vs_mkldnn(self):
dygraph_loss_cpu = self.train_dygraph()
base.set_flags({'FLAGS_use_mkldnn': True})
paddle.set_flags({'FLAGS_use_mkldnn': True})
try:
dygraph_loss_mkldnn = self.train_dygraph()
finally:
base.set_flags({'FLAGS_use_mkldnn': False})
paddle.set_flags({'FLAGS_use_mkldnn': False})
np.testing.assert_allclose(
dygraph_loss_cpu,
dygraph_loss_mkldnn,
Expand All @@ -193,62 +193,61 @@ def test_mnist_declarative_cpu_vs_mkldnn(self):

def train(self, to_static=False):
loss_data = []
with base.dygraph.guard(self.place):
base.default_main_program().random_seed = SEED
base.default_startup_program().random_seed = SEED
mnist = MNIST()
if to_static:
mnist = paddle.jit.to_static(mnist, full_graph=True)
adam = Adam(learning_rate=0.001, parameters=mnist.parameters())

for epoch in range(self.epoch_num):
start = time()
for batch_id, data in enumerate(self.train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]
).astype('float32')
y_data = (
np.array([x[1] for x in data])
.astype('int64')
.reshape(-1, 1)
)

img = to_variable(dy_x_data)
label = to_variable(y_data)

label.stop_gradient = True
prediction, acc, avg_loss = mnist(img, label=label)
avg_loss.backward()
base.default_main_program().random_seed = SEED
base.default_startup_program().random_seed = SEED
mnist = MNIST()
if to_static:
mnist = paddle.jit.to_static(mnist, full_graph=True)
adam = Adam(learning_rate=0.001, parameters=mnist.parameters())

for epoch in range(self.epoch_num):
start = time()
for batch_id, data in enumerate(self.train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]
).astype('float32')
y_data = (
np.array([x[1] for x in data])
.astype('int64')
.reshape(-1, 1)
)

adam.minimize(avg_loss)
loss_data.append(float(avg_loss))
# save checkpoint
mnist.clear_gradients()
if batch_id % 10 == 0:
print(
"Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}".format(
epoch,
batch_id,
avg_loss.numpy(),
acc.numpy(),
time() - start,
)
)
start = time()
if batch_id == 50:
mnist.eval()
prediction, acc, avg_loss = mnist(img, label)
loss_data.append(float(avg_loss))
# new save load check
self.check_jit_save_load(
mnist,
[dy_x_data],
[img, label],
to_static,
prediction,
[img.name],
img = to_variable(dy_x_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续遇到类似to_variable这种旧的API,可否替换为paddle.to_tensor?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的,已经有清理一些了,这应该是没注意到的

label = to_variable(y_data)

label.stop_gradient = True
prediction, acc, avg_loss = mnist(img, label=label)
avg_loss.backward()

adam.minimize(avg_loss)
loss_data.append(float(avg_loss))
# save checkpoint
mnist.clear_gradients()
if batch_id % 10 == 0:
print(
"Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}".format(
epoch,
batch_id,
avg_loss.numpy(),
acc.numpy(),
time() - start,
)
break
)
start = time()
if batch_id == 50:
mnist.eval()
prediction, acc, avg_loss = mnist(img, label)
loss_data.append(float(avg_loss))
# new save load check
self.check_jit_save_load(
mnist,
[dy_x_data],
[img, label],
to_static,
prediction,
[img.name],
)
break
return loss_data

def check_jit_save_load(
Expand Down
19 changes: 9 additions & 10 deletions test/dygraph_to_static/test_tensor_memcpy_on_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,18 @@
import paddle


@paddle.jit.to_static
def tensor_copy_to_cpu(x):
x = paddle.to_tensor(x)
y = x.cpu()
return y


@paddle.jit.to_static
def tensor_copy_to_cuda(x):
x = paddle.to_tensor(x)
y = x.cuda()
return y


@paddle.jit.to_static
def tensor_copy_to_cuda_with_warning(x, device_id=None, blocking=True):
x = paddle.to_tensor(x)
y = x.cuda(device_id, blocking)
Expand All @@ -46,7 +43,7 @@ class TestTensorCopyToCpuOnDefaultGPU(Dy2StTestBase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
x1 = paddle.ones([1, 2, 3])
x2 = tensor_copy_to_cpu(x1)
x2 = paddle.jit.to_static(tensor_copy_to_cpu)(x1)
return x1.place, x2.place, x2.numpy()

@test_legacy_and_pir
Expand All @@ -73,12 +70,12 @@ class TestTensorCopyToCUDAOnDefaultGPU(Dy2StTestBase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
x1 = paddle.ones([1, 2, 3])
x2 = tensor_copy_to_cuda(x1)
x2 = paddle.jit.to_static(tensor_copy_to_cuda)(x1)
return x1.place, x2.place, x2.numpy()

@test_legacy_and_pir
def test_tensor_cuda_on_default_gpu(self):
if paddle.base.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(
int(os.environ.get('FLAGS_selected_gpus', 0))
)
Expand All @@ -100,7 +97,9 @@ class TestTensorCopyToCUDAWithWarningOnGPU(unittest.TestCase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
x1 = paddle.ones([1, 2, 3])
x2 = tensor_copy_to_cuda_with_warning(x1, device_id=1, blocking=False)
x2 = paddle.jit.to_static(tensor_copy_to_cuda_with_warning)(
x1, device_id=1, blocking=False
)
return x1.place, x2.place, x2.numpy()

def test_with_warning_on_gpu(self):
Expand All @@ -114,19 +113,19 @@ def test_with_warning_on_gpu(self):

x1 = paddle.ones([1, 2, 3])
with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x2 = paddle.jit.to_static(tensor_copy_to_cuda_with_warning)(
x1, device_id=1, blocking=True
)
self.assertIn('math_op_patch.py', cm.filename)

with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x2 = paddle.jit.to_static(tensor_copy_to_cuda_with_warning)(
x1, device_id=None, blocking=False
)
self.assertIn('math_op_patch.py', cm.filename)

with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x2 = paddle.jit.to_static(tensor_copy_to_cuda_with_warning)(
x1, device_id=2, blocking=False
)
self.assertIn('math_op_patch.py', cm.filename)
Expand Down
Loading