Skip to content

Commit

Permalink
[CPU-PSLIB] Add consistency insepection of use_var_list and data_gene…
Browse files Browse the repository at this point in the history
…rator data, test=develop (#34463)
  • Loading branch information
WorgenZhang authored Aug 18, 2021
1 parent 8967a66 commit 209075a
Show file tree
Hide file tree
Showing 2 changed files with 471 additions and 0 deletions.
65 changes: 65 additions & 0 deletions python/paddle/distributed/fleet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,71 @@ def _dynamic_adjust_before_train(self, thread_num):
def _dynamic_adjust_after_train(self):
pass

def _check_use_var_with_data_generator(self, var_list, data_generator_class,
test_file):
"""
Var consistency insepection of use_var_list and data_generator data.
Examples:
.. code-block:: python
# required: skiptest
import paddle
from dataset_generator import CTRDataset
dataset = paddle.distributed.fleet.DatasetBase()
generator_class = CTRDataset()
dataset._check_use_var_with_data_generator([data, label], generator_class, "data/part-00000")
Args:
var_list(list): variable list
data_generator_class(class): data_generator class
test_file(str): local test file path
"""

f = open(test_file, "r")
var_len = len(var_list)

while True:
line = f.readline()
if line:
line_iter = data_generator_class.generate_sample(line)
for user_parsed_line in line_iter():
data_gen_len = len(user_parsed_line)
if var_len != data_gen_len:
raise ValueError(
"var length mismatch error: var_list = %s vs data_generator = %s"
% (var_len, data_gen_len))

for i, ele in enumerate(user_parsed_line):
if len(ele[1]) == 0:
raise ValueError(
"var length error: var %s's length in data_generator is 0"
% ele[0])

if var_list[
i].dtype == core.VarDesc.VarType.FP32 and not all(
isinstance(ele, float) for ele in ele[1]):
raise TypeError(
"var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-float value, which is %s \n"
"Please check if order of var_list and data_generator are aligned. \n"
"Please check if var's type in data_generator is correct."
% (ele[0], "float", ele[1]))

if (var_list[i].dtype == core.VarDesc.VarType.INT64 or
var_list[i].dtype == core.VarDesc.VarType.INT32
) and not all(
isinstance(ele, int) for ele in ele[1]):
raise TypeError(
"var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-int value, which is %s \n"
"Please check if order of var_list and data_generator are aligned. \n"
"Please check if var's type in data_generator is correct."
% (ele[0], "int", ele[1]))

else:
break

f.close()


class InMemoryDataset(DatasetBase):
"""
Expand Down
Loading

0 comments on commit 209075a

Please sign in to comment.