@@ -34,23 +34,33 @@ class FunctionalBatchTransform:
3434 >>> # and modify the weights in weight_dict by dividing each weight by 10.
3535 >>> # Finally, it returns the transformed data, labels, and weights as a tuple.
3636 >>> import ppsci
37- >>> def transform_func(data_dict, label_dict, weight_dict):
38- ... for key in data_dict:
39- ... data_dict[key] = data_dict[key] * 2
40- ... for key in label_dict:
41- ... label_dict[key] = label_dict[key] + 1.0
42- ... for key in weight_dict:
43- ... weight_dict[key] = weight_dict[key] / 10
44- ... return data_dict, label_dict, weight_dict
45- >>> transform = ppsci.data.transform.FunctionalTransform(transform_func)
37+ >>> from typing import Tuple, Dict, Optional
38+ >>> def batch_transform_func(
39+ ... data_list: List[
40+ ... Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]
41+ ... ],
42+ ... ) -> List[Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]]:
43+ ... input_dicts, label_dicts, weight_dicts = zip(*data_list)
44+ ...
45+ ... for input_dict in input_dicts:
46+ ... for key in input_dict:
47+ ... input_dict[key] = input_dict[key] * 2
48+ ...
49+ ... for label_dict in label_dicts:
50+ ... for key in label_dict:
51+ ... label_dict[key] = label_dict[key] + 1.0
52+ ...
53+ ... return list(zip(input_dicts, label_dicts, weight_dicts))
54+ ...
55+ >>> # Create a FunctionalBatchTransform object with the batch_transform_func function
56+ >>> transform = ppsci.data.batch_transform.FunctionalBatchTransform(batch_transform_func)
4657 >>> # Define some sample data, labels, and weights
47- >>> data = {'feature1': np.array([1, 2, 3]), 'feature2': np.array([4, 5, 6])}
48- >>> label = {'class': 0.0, 'instance': 0.1}
49- >>> weight = {'weight1': 0.5, 'weight2': 0.5}
50- >>> # Apply the transform function to the data, labels, and weights using the FunctionalTransform instance
51- >>> transformed_data = transform(data, label, weight)
52- >>> print(transformed_data)
53- ({'feature1': array([2, 4, 6]), 'feature2': array([ 8, 10, 12])}, {'class': 1.0, 'instance': 1.1}, {'weight1': 0.05, 'weight2': 0.05})
58+ >>> data = [({'x': 1}, {'y': 2}, None), ({'x': 11}, {'y': 22}, None)]
59+ >>> transformed_data = transform(data)
60+ >>> for tuple in transformed_data:
61+ ... print(tuple)
62+ ({'x': 2}, {'y': 3.0}, None)
63+ ({'x': 22}, {'y': 23.0}, None)
5464 """
5565
5666 def __init__ (
@@ -62,6 +72,5 @@ def __init__(
6272 def __call__ (
6373 self ,
6474 list_data : List [List [Dict [str , np .ndarray ]]],
65- # [{'u': arr, 'y': arr}, {'u': arr, 'y': arr}, {'u': arr, 'y': arr}], [{'s': arr}, {'s': arr}, {'s': arr}], [{}, {}, {}]
6675 ) -> List [Dict [str , np .ndarray ]]:
6776 return self .transform_func (list_data )
0 commit comments