Skip to content

Commit 97de9e9

Browse files
update example code of FunctionalBatchTransform
1 parent cfdb3f3 commit 97de9e9

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

ppsci/data/process/batch_transform/preprocess.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,33 @@ class FunctionalBatchTransform:
3434
>>> # and modify the weights in weight_dict by dividing each weight by 10.
3535
>>> # Finally, it returns the transformed data, labels, and weights as a tuple.
3636
>>> import ppsci
37-
>>> def transform_func(data_dict, label_dict, weight_dict):
38-
... for key in data_dict:
39-
... data_dict[key] = data_dict[key] * 2
40-
... for key in label_dict:
41-
... label_dict[key] = label_dict[key] + 1.0
42-
... for key in weight_dict:
43-
... weight_dict[key] = weight_dict[key] / 10
44-
... return data_dict, label_dict, weight_dict
45-
>>> transform = ppsci.data.transform.FunctionalTransform(transform_func)
37+
>>> from typing import Tuple, Dict, Optional
38+
>>> def batch_transform_func(
39+
... data_list: List[
40+
... Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]
41+
... ],
42+
... ) -> List[Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]]:
43+
... input_dicts, label_dicts, weight_dicts = zip(*data_list)
44+
...
45+
... for input_dict in input_dicts:
46+
... for key in input_dict:
47+
... input_dict[key] = input_dict[key] * 2
48+
...
49+
... for label_dict in label_dicts:
50+
... for key in label_dict:
51+
... label_dict[key] = label_dict[key] + 1.0
52+
...
53+
... return list(zip(input_dicts, label_dicts, weight_dicts))
54+
...
55+
>>> # Create a FunctionalBatchTransform object with the batch_transform_func function
56+
>>> transform = ppsci.data.batch_transform.FunctionalBatchTransform(batch_transform_func)
4657
>>> # Define some sample data, labels, and weights
47-
>>> data = {'feature1': np.array([1, 2, 3]), 'feature2': np.array([4, 5, 6])}
48-
>>> label = {'class': 0.0, 'instance': 0.1}
49-
>>> weight = {'weight1': 0.5, 'weight2': 0.5}
50-
>>> # Apply the transform function to the data, labels, and weights using the FunctionalTransform instance
51-
>>> transformed_data = transform(data, label, weight)
52-
>>> print(transformed_data)
53-
({'feature1': array([2, 4, 6]), 'feature2': array([ 8, 10, 12])}, {'class': 1.0, 'instance': 1.1}, {'weight1': 0.05, 'weight2': 0.05})
58+
>>> data = [({'x': 1}, {'y': 2}, None), ({'x': 11}, {'y': 22}, None)]
59+
>>> transformed_data = transform(data)
60+
>>> for tuple in transformed_data:
61+
... print(tuple)
62+
({'x': 2}, {'y': 3.0}, None)
63+
({'x': 22}, {'y': 23.0}, None)
5464
"""
5565

5666
def __init__(
@@ -62,6 +72,5 @@ def __init__(
6272
def __call__(
6373
self,
6474
list_data: List[List[Dict[str, np.ndarray]]],
65-
# [{'u': arr, 'y': arr}, {'u': arr, 'y': arr}, {'u': arr, 'y': arr}], [{'s': arr}, {'s': arr}, {'s': arr}], [{}, {}, {}]
6675
) -> List[Dict[str, np.ndarray]]:
6776
return self.transform_func(list_data)

0 commit comments

Comments
 (0)