Skip to content

Commit

Permalink
[PaddlePaddle Hackathon] add paddle.nn.ClipGradByGlobalNorm单测 (#277)
Browse files Browse the repository at this point in the history
* paddle.nn.PixelShuffle单测提交

* 提交paddle.nn.PixelShuffle单测案例

* add test of paddle.nn.ClipGradByGlobalNorm

* add test paddle.nn.ClipGradByNorm

* add test of paddle.nn.PixelShuffle

* add test of paddle.nn.ClipGradByGlobalNorm and paddle.nn.ClipGradByNorm

* remove useless obj and class in test_clip_grad_by_global_norm.py and test_clip_grad_by_norm.py, modify test_pixel_shuffle.py

* add test of paddle.nn.UpsampingBinlinear2D

* remove unused code in test_flip_grad_by_global_norm.py

* remove unused code in test_clip_grad_by_norm.py

* add code annotation in test_clip_grad_by_global_norm.py

* add code annotation in test_clip_grad_by_norm.py

* add code annotation in test_pixel_shuffle.py

* add annotation in test_upsampling_bilinear2D.py

* add paddle.ClipGradByGlobalNorm test case

* add paddle.nn.ClipGradByNorm test case

* add paddle.nn.PixelShuffle test case

* add paddle.nn.UpsamplingBilinear2D test case

* fix bug in test_clip_grad_by_norm.py

* remove 3 test casse

* fix annotation

* refine exception raise code

* add test of paddle.nn.ClipGradByGlobalNorm

* add test case of paddle.nn.ClipGradByGlobalNorm

Co-authored-by: Divano <dddivano@outlook.com>
  • Loading branch information
justld and DDDivano authored Oct 29, 2021
1 parent b90a152 commit ed6d4f1
Showing 1 changed file with 327 additions and 0 deletions.
327 changes: 327 additions & 0 deletions framework/api/nn/test_clip_grad_by_global_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
#!/bin/env python
# -*- coding: utf-8 -*-
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
"""
test_clip_grad_by_global_norm
"""

from apibase import randtool, compare
import paddle
import pytest
import numpy as np


def numpy_clip_grad_by_global_norm(test_data, clip_norm):
"""
ClipGradByGlobalNorm implemented by numpy.
"""
cliped_data = []
grad_data = []
for data, grad in test_data:
grad_data.append(grad)
global_norm = np.sqrt(np.sum(np.square(np.array(grad_data))))
if global_norm > clip_norm:
for data, grad in test_data:
grad = grad * clip_norm / global_norm
cliped_data.append((data, grad))
else:
cliped_data = test_data
return cliped_data


def generate_test_data(length, shape, dtype="float32", value=10):
"""
generate test data
"""
tensor_data = []
numpy_data = []
np.random.seed(100)
for i in range(length):
np_weight = randtool("float", -value, value, shape).astype(dtype)
np_weight_grad = randtool("float", -value, value, shape).astype(dtype)
numpy_data.append((np_weight, np_weight_grad))

tensor_weight = paddle.to_tensor(np_weight)
tensor_weight_grad = paddle.to_tensor(np_weight_grad)
tensor_data.append((tensor_weight, tensor_weight_grad))
return numpy_data, tensor_data


@pytest.mark.api_nn_ClipGradByGlobalNorm_vartype
def test_clip_grad_by_global_norm_base():
"""
Test base.
Test base config:
input grad shape = [10, 10]
input grad number = 4
input data dtype = 'float32'
clip_norm = 1.0
value range: [-10, 10]
Expected Results:
The output of ClipGradByGlobalNorm implemented by numpy and paddle should be equal.
"""
shape = [10, 10]
length = 4
clip_norm = 1.0
dtype = "float32"
np_data, paddle_data = generate_test_data(length, shape, dtype=dtype, value=10)
np_res = numpy_clip_grad_by_global_norm(np_data, clip_norm=clip_norm)

paddle_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
paddle_cliped_data = paddle_clip(paddle_data)
paddle_res = []
for w, g in paddle_cliped_data:
paddle_res.append((w.numpy(), g.numpy()))

# compare grad value computed by numpy and paddle
for res, p_res in zip(np_res, paddle_res):
compare(res[1], p_res[1])


@pytest.mark.api_nn_ClipGradByGlobalNorm_parameters
def test_clip_grad_by_global_norm1():
"""
Test ClipGradByGlobalNorm when input shape changes.
Test base config:
input grad shape = [10, 10]
input grad number = 4
input data dtype = 'float32'
clip_norm = 1.0
value range: [-10, 10]
Changes:
input grad shape: [10, 10] -> [9, 13, 11]
Expected Results:
The output of ClipGradByGlobalNorm implemented by numpy and paddle should be equal.
"""
shape = [9, 13, 11]
length = 4
clip_norm = 1.0
dtype = "float32"
np_data, paddle_data = generate_test_data(length, shape, dtype=dtype, value=10)
np_res = numpy_clip_grad_by_global_norm(np_data, clip_norm=clip_norm)

paddle_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
paddle_cliped_data = paddle_clip(paddle_data)
paddle_res = []
for w, g in paddle_cliped_data:
paddle_res.append((w.numpy(), g.numpy()))

# compare grad value computed by numpy and paddle
for res, p_res in zip(np_res, paddle_res):
compare(res[1], p_res[1])


@pytest.mark.api_nn_ClipGradByGlobalNorm_parameters
def test_clip_grad_by_global_norm2():
"""
Test ClipGradByGlobalNorm when input shape changes.
Test base config:
input grad shape = [10, 10]
input grad number = 4
input data dtype = 'float32'
clip_norm = 1.0
value range: [-10, 10]
Changes:
input grad shape: [10, 10] -> [10]
Expected Results:
The output of ClipGradByGlobalNorm implemented by numpy and paddle should be equal.
"""
shape = [10]
length = 4
clip_norm = 1.0
dtype = "float32"
np_data, paddle_data = generate_test_data(length, shape, dtype=dtype, value=10)
np_res = numpy_clip_grad_by_global_norm(np_data, clip_norm=clip_norm)

paddle_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
paddle_cliped_data = paddle_clip(paddle_data)
paddle_res = []
for w, g in paddle_cliped_data:
paddle_res.append((w.numpy(), g.numpy()))

# compare grad value computed by numpy and paddle
for res, p_res in zip(np_res, paddle_res):
compare(res[1], p_res[1])


@pytest.mark.api_nn_ClipGradByGlobalNorm_parameters
def test_clip_grad_by_global_norm3():
"""
Test ClipGradByGlobalNorm when clip_norm changes.
Test base config:
input grad shape = [10, 10]
input grad number = 4
input data dtype = 'float32'
clip_norm = 1.0
value range: [-10, 10]
Changes:
clip_norm: 1.0 -> -1.0
Expected Results:
The output of ClipGradByGlobalNorm implemented by numpy and paddle should be equal.
"""
shape = [10, 10]
length = 4
clip_norm = -1.0
dtype = "float32"
np_data, paddle_data = generate_test_data(length, shape, dtype=dtype, value=10)
np_res = numpy_clip_grad_by_global_norm(np_data, clip_norm=clip_norm)

paddle_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
paddle_cliped_data = paddle_clip(paddle_data)
paddle_res = []
for w, g in paddle_cliped_data:
paddle_res.append((w.numpy(), g.numpy()))

# compare grad value computed by numpy and paddle
for res, p_res in zip(np_res, paddle_res):
compare(res[1], p_res[1])


@pytest.mark.api_nn_ClipGradByGlobalNorm_parameters
def test_clip_grad_by_global_norm4():
"""
Test ClipGradByGlobalNorm when set group_name.
Test base config:
input grad shape = [10, 10]
input grad number = 4
input data dtype = 'float32'
clip_norm = 1.0
value range: [-10, 10]
Changes:
group_name: 'test_group'
Expected Results:
The output of ClipGradByGlobalNorm implemented by numpy and paddle should be equal.
"""
shape = [10, 10]
length = 4
clip_norm = 1.0
dtype = "float32"
np_data, paddle_data = generate_test_data(length, shape, dtype=dtype, value=10)
np_res = numpy_clip_grad_by_global_norm(np_data, clip_norm=clip_norm)

paddle_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm, group_name="test_group")
paddle_cliped_data = paddle_clip(paddle_data)
paddle_res = []
for w, g in paddle_cliped_data:
paddle_res.append((w.numpy(), g.numpy()))

# compare grad value computed by numpy and paddle
for res, p_res in zip(np_res, paddle_res):
compare(res[1], p_res[1])


@pytest.mark.api_nn_ClipGradByGlobalNorm_parameters
def test_clip_grad_by_global_norm5():
"""
Test ClipGradByGlobalNorm when value range changes.
Test base config:
input grad shape = [10, 10]
input grad number = 4
input data dtype = 'float32'
clip_norm = 1.0
value range: [-10, 10]
Changes:
value range: [-10, 10] -> [-255555, 255555]
Expected Results:
The output of ClipGradByGlobalNorm implemented by numpy and paddle should be equal.
"""
shape = [10, 10]
length = 4
clip_norm = 1.0
dtype = "float32"
np_data, paddle_data = generate_test_data(length, shape, dtype=dtype, value=255555)
np_res = numpy_clip_grad_by_global_norm(np_data, clip_norm=clip_norm)

paddle_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
paddle_cliped_data = paddle_clip(paddle_data)
paddle_res = []
for w, g in paddle_cliped_data:
paddle_res.append((w.numpy(), g.numpy()))

# compare grad value computed by numpy and paddle
for res, p_res in zip(np_res, paddle_res):
compare(res[1], p_res[1])


@pytest.mark.api_nn_ClipGradByGlobalNorm_vartype
def test_clip_grad_by_global_norm6():
"""
Test ClipGradByGlobalNorm when input data dtype changes.
Test base config:
input grad shape = [10, 10]
input grad number = 4
input data dtype = 'float32'
clip_norm = 1.0
value range: [-10, 10]
Changes:
input data dtype: float32 -> float64
Expected Results:
The output of ClipGradByGlobalNorm implemented by numpy and paddle should be equal.
"""
shape = [10, 10]
length = 4
clip_norm = 1.0
dtype = "float64"
np_data, paddle_data = generate_test_data(length, shape, dtype=dtype, value=10)
np_res = numpy_clip_grad_by_global_norm(np_data, clip_norm=clip_norm)

paddle_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
paddle_cliped_data = paddle_clip(paddle_data)
paddle_res = []
for w, g in paddle_cliped_data:
paddle_res.append((w.numpy(), g.numpy()))

# compare grad value computed by numpy and paddle
for res, p_res in zip(np_res, paddle_res):
compare(res[1], p_res[1])


@pytest.mark.api_nn_ClipGradByGlobalNorm_vartype
def test_clip_grad_by_global_norm7():
"""
Test ClipGradByGlobalNorm when input data dtype changes.
Test base config:
input grad shape = [10, 10]
input grad number = 4
input data dtype = 'float32'
clip_norm = 1.0
value range: [-10, 10]
Changes:
input data dtype: float32 -> ['int8', 'int16', 'int32', 'float16']
Expected Results:
paddle.nn.ClipGradByGlobalNorm cann't accept input data with 'float16', raise RuntimeError.
"""
shape = [10, 10]
length = 4
clip_norm = 1.0
unsupport_dtypes = ["int8", "int16", "int32", "float16"]
paddle_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
for dtype in unsupport_dtypes:
np_data, paddle_data = generate_test_data(length, shape, dtype=dtype, value=10)
try:
paddle_clip(paddle_data)
except RuntimeError:
pass

0 comments on commit ed6d4f1

Please sign in to comment.