Skip to content

Commit

Permalink
Merge pull request #8 from feifei-111/zzz1
Browse files Browse the repository at this point in the history
fix tests
  • Loading branch information
2742195759 authored Jun 16, 2023
2 parents 8b8623a + 0778ade commit 84f5645
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 7 deletions.
6 changes: 3 additions & 3 deletions test/dygraph_to_static/seq2seq_dygraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def forward(self, inputs):
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
)
loss = paddle.squeeze(loss, axes=[2])
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
tar_mask = paddle.static.nn.sequence_lod.sequence_mask(
tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
Expand Down Expand Up @@ -835,13 +835,13 @@ def forward(self, inputs):
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
)
loss = paddle.squeeze(loss, axes=[2])
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
tar_mask = paddle.static.nn.sequence_lod.sequence_mask(
tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
)
loss = loss * tar_mask
loss = paddle.mean(loss, axis=[0])
loss = fluid.layers.reduce_sum(loss)
loss = paddle.sum(loss)

return loss
2 changes: 2 additions & 0 deletions test/dygraph_to_static/test_se_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -561,6 +562,7 @@ def verify_predict(self):
),
)

@ast_only_test
def test_check_result(self):
pred_1, loss_1, acc1_1, acc5_1 = self.train(
self.train_reader, to_static=False
Expand Down
5 changes: 4 additions & 1 deletion test/dygraph_to_static/test_spec_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import unittest

from dygraph_to_static_util import enable_fallback_guard

import paddle
from paddle.nn import Layer

Expand Down Expand Up @@ -101,4 +103,5 @@ def to_idx(name):


if __name__ == '__main__':
unittest.main()
with enable_fallback_guard("False"):
unittest.main()
14 changes: 14 additions & 0 deletions test/dygraph_to_static/test_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy
from dygraph_to_static_util import ast_only_test, sot_only_test

import paddle
from paddle.fluid import core
Expand Down Expand Up @@ -148,6 +149,7 @@ def test_to_tensor_badreturn(self):
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))

@ast_only_test
def test_to_tensor_err_log(self):
paddle.disable_static()
x = paddle.to_tensor([3])
Expand All @@ -159,6 +161,18 @@ def test_to_tensor_err_log(self):
in str(e)
)

@sot_only_test
def test_to_tensor_err_log_sot(self):
paddle.disable_static()
x = paddle.to_tensor([3])
try:
a = paddle.jit.to_static(case8)(x)
except Exception as e:
self.assertTrue(
"Can't constructs a 'paddle.Tensor' with data type <class 'dict'>"
in str(e)
)


class TestStatic(unittest.TestCase):
def test_static(self):
Expand Down
4 changes: 3 additions & 1 deletion test/dygraph_to_static/test_train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from functools import partial

import numpy as np
from dygraph_to_static_util import enable_fallback_guard

import paddle

Expand Down Expand Up @@ -433,4 +434,5 @@ def setUp(self):


if __name__ == "__main__":
unittest.main()
with enable_fallback_guard("False"):
unittest.main()
4 changes: 3 additions & 1 deletion test/dygraph_to_static/test_train_step_resnet18_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import platform
import unittest

from dygraph_to_static_util import enable_fallback_guard
from test_train_step import (
TestTrainStepTinyModel,
loss_fn_tiny_model,
Expand All @@ -40,4 +41,5 @@ def setUp(self):


if __name__ == "__main__":
unittest.main()
with enable_fallback_guard("False"):
unittest.main()
4 changes: 3 additions & 1 deletion test/dygraph_to_static/test_train_step_resnet18_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import platform
import unittest

from dygraph_to_static_util import enable_fallback_guard
from test_train_step import (
TestTrainStepTinyModel,
loss_fn_tiny_model,
Expand All @@ -40,4 +41,5 @@ def setUp(self):


if __name__ == "__main__":
unittest.main()
with enable_fallback_guard("False"):
unittest.main()

0 comments on commit 84f5645

Please sign in to comment.