Skip to content

Commit

Permalink
【Fix PIR Unittest No.276,342,395,377,435,445】Fix gather_nd/tree/histo…
Browse files Browse the repository at this point in the history
…gramdd UT (#65799)

* [PIR]Fix gather_nd/tree/hisgoram UT

* fix test_gru_rnn_op

* fix linalg

* fix import

* fix import

* fix CMakeList.txt

* fix cmake

* fix TypeError
  • Loading branch information
Aurelius84 authored Jul 11, 2024
1 parent 9b79cde commit 09f6b0b
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 10 deletions.
7 changes: 6 additions & 1 deletion python/paddle/nn/functional/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from paddle import _C_ops, tensor
from paddle.utils import deprecated

from ...base.data_feeder import check_type, check_variable_and_dtype
from ...base.data_feeder import (
check_dtype,
check_type,
check_variable_and_dtype,
)
from ...base.layer_helper import LayerHelper
from ...common_ops_import import Variable
from ...framework import (
Expand Down Expand Up @@ -221,6 +225,7 @@ def gather_tree(ids: Tensor, parents: Tensor) -> Tensor:
raise ValueError("The ids's shape must be the same as parents' shape. ")

if in_dynamic_or_pir_mode():
check_dtype(parents.dtype, "parents", ['int32', 'int64'], 'gather_tree')
return _C_ops.gather_tree(ids, parents)
else:
helper = LayerHelper('gather_tree', **locals())
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,12 @@ def fill_constant(
out = _C_ops.full(shape, value, dtype, place)
out.stop_gradient = True
return out
_C_ops.full_(out, shape, value, dtype, place)

if out.dtype != dtype:
raise TypeError(
"Required out.dtype == dtype if specifying out, but recevied f{out.dtype} != f{dtype}"
)
out = _C_ops.full_(out, shape, value, dtype, place)
out.stop_gradient = True
return out

Expand Down
3 changes: 2 additions & 1 deletion python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5372,6 +5372,7 @@ def __check_ranges(D, ranges):
e = paddle.linspace(r[0], r[1], bins[idx] + 1, x.dtype)
edges.append(e)
dedges.append(e.diff())
hist_shape.append(bins[idx] + 2)
elif isinstance(
bins, tuple
): # tuple with D tensors for each innermost dimension
Expand All @@ -5380,9 +5381,9 @@ def __check_ranges(D, ranges):
bin = paddle.to_tensor(bin)
edges.append(bin)
dedges.append(bin.diff())
hist_shape.append(bin.shape[0] + 1)
else:
raise ValueError("Input bins must be Tensor[], int[], or int.")
hist_shape = [edge.shape[0] + 1 for edge in edges]
index_list = []
# edges shape: [D, linspaced]
# index_list shape: [D, N]
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5279,6 +5279,7 @@ def gather_nd(x: Tensor, index: Tensor, name: str | None = None) -> Tensor:
"""
if in_dynamic_or_pir_mode():
check_dtype(index.dtype, "index", ['int32', 'int64'], 'gather_nd')
return _C_ops.gather_nd(x, index)
else:
check_variable_and_dtype(
Expand All @@ -5294,10 +5295,10 @@ def gather_nd(x: Tensor, index: Tensor, name: str | None = None) -> Tensor:
'int32',
'int64',
],
'gather_np',
'gather_nd',
)
check_variable_and_dtype(
index, 'index', ['int32', 'int64'], 'gather_np'
index, 'index', ['int32', 'int64'], 'gather_nd'
)
helper = LayerHelper('gather_nd', **locals())
dtype = helper.input_dtype()
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6346,6 +6346,8 @@ def _diff_handler(x, n=1, axis=-1, prepend=None, append=None, name=None):
attrs_2 = ()

dim_len = new_input.shape[axis]
if dim_len < 0:
dim_len = paddle.shape(new_input)[axis]

starts_1 = [0]
attrs_1 += ('starts', starts_1)
Expand Down
5 changes: 1 addition & 4 deletions test/deprecated/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,7 @@ endif()

# Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC test_gather_nd_op test_slice_op)
set(TEST_OPS_WITH_GC test_gather_nd_op test_slice_op_deprecated)
set(TEST_OPS_WITH_GC test_slice_op test_slice_op_deprecated)

foreach(TEST_OP ${TEST_OPS_WITH_GC})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
Expand Down Expand Up @@ -646,7 +645,6 @@ endif()
set_tests_properties(test_imperative_selected_rows_to_lod_tensor
PROPERTIES TIMEOUT 200)
set_tests_properties(test_argsort_op_deprecated PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_nd_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_masked_select_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_sigmoid_cross_entropy_with_logits_op
PROPERTIES TIMEOUT 120)
Expand Down Expand Up @@ -686,7 +684,6 @@ set(TEST_CINN_OPS
test_softmax_op
test_slice_op
test_slice_op_deprecated
test_gather_nd_op
test_scale_op
test_layer_norm_op_deprecated
test_where_op
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ set(TEST_OPS_WITH_GC
test_elementwise_add_op
test_elementwise_sub_op
test_gather_op
test_gather_nd_op
test_mean_op
test_lod_reset_op)

Expand Down Expand Up @@ -772,6 +773,7 @@ set_tests_properties(test_isin PROPERTIES TIMEOUT 30)
set_tests_properties(test_binomial_op PROPERTIES TIMEOUT 30)
set_tests_properties(test_run PROPERTIES TIMEOUT 120)
set_tests_properties(test_sync_batch_norm_op PROPERTIES TIMEOUT 180)
set_tests_properties(test_gather_nd_op PROPERTIES TIMEOUT 120)

set_tests_properties(test_profiler PROPERTIES TIMEOUT 120)
set_tests_properties(test_cross_entropy_loss PROPERTIES TIMEOUT 180)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import paddle
from paddle.base import core

sys.path.append("../../rnn")
sys.path.append("../deprecated/rnn")
from convert import get_params_for_net

sys.path.append("../rnn")
from rnn_numpy import GRU

random.seed(2)
Expand Down
File renamed without changes.

0 comments on commit 09f6b0b

Please sign in to comment.