Skip to content

Commit

Permalink
Polish UT
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Aug 29, 2022
1 parent c39566f commit 75a2451
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/paddle/fluid/tests/unittests/test_arg_min_max_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_static(self):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]
feat = fc(x)

out = self.call_func(feat)

Expand Down Expand Up @@ -293,8 +293,8 @@ def test_static(self):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]

feat = fc(x)
feat = paddle.cast(feat, 'int32')
out = self.call_func(feat)

sgd = paddle.optimizer.SGD()
Expand All @@ -307,19 +307,19 @@ def test_static(self):
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
gt = np.argmin(res[0], 1)
np.testing.assert_allclose(res[1], gt)
np.testing.assert_allclose(np.squeeze(res[1]), gt)

# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = np.argmin(infer_outs[0], 1)
np.testing.assert_allclose(infer_outs[1], gt)
np.testing.assert_allclose(np.squeeze(infer_outs[1]), gt)

def path_prefix(self):
return 'argmin_tensor_axis'

def call_func(self, x):
axis = paddle.assign(1)
out = paddle.argmin(x, axis)
out = paddle.argmin(x, axis, keepdim=True)
return out


Expand Down

0 comments on commit 75a2451

Please sign in to comment.