-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_generator_test.py
108 lines (91 loc) · 4.81 KB
/
dataset_generator_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import tempfile
import pytest
from sklearn import datasets, model_selection
import const
import poisoning
from . import dataset_generator as dataset_generator
X, y = datasets.make_classification()
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y)
def get_dg(*, poisoning_generation_input: poisoning.PoisoningGenerationInput,
generate: bool = True) -> dataset_generator.DatasetGenerator:
dg = dataset_generator.DatasetGenerator.from_dataset_to_poison(
poisoning_generation_input=poisoning_generation_input,
X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
if generate:
dg.generate()
return dg
@pytest.mark.parametrize('poisoning_generation_input', [
(
poisoning.PoisoningGenerationInput(
perc_data_points=[10, 15, 20],
performer=poisoning.PerformerLabelFlippingMonoDirectional(),
selector=poisoning.SelectorRandom(),
perform_info_kwargs={'from_label': 1, 'to_label': 0},
perform_info_clazz=poisoning.PerformInfoMonoDirectional,
# selection_info_clazz=poisoning.SelectionInfoLabelMonoDirectional,
selection_info_clazz=poisoning.SelectionInfoEmpty,
selection_info_kwargs={}
# selection_info_kwargs={'from_label': 1, 'to_label': 0}
)
),
# (
# poisoning.PoisoningGenerationInput(
# perc_data_points=[10, 15],
# performer=poisoning.PerformerLabelFlippingMonoDirectional(),
# selector=
# )
# )
])
def test_dataset_generator_no_override(poisoning_generation_input: poisoning.PoisoningGenerationInput):
# modify the input to include column
poisoning_generation_input.columns = [f'{i}' for i in range(X.shape[1])]
dg = get_dg(poisoning_generation_input=poisoning_generation_input)
assert len(dg.all_datasets) == len(poisoning_generation_input.perc_data_points)
# now for each individual xr.DataArray check that it contains the expected coords
# and shape
expected_col = set(poisoning_generation_input.columns).union({const.COORD_POISONED, const.COORD_LABEL})
for individual_poisoned_dataset_name, poisoning_algo in zip(
dg.all_datasets.data_vars, dg.poisoning_algos):
# + 2 because we have y and the column indicating if it is poisoned.
assert dg.all_datasets[individual_poisoned_dataset_name].shape == (X_train.shape[0], X.shape[1] + 2)
# check correctness of columns
got_col = set(dg.all_datasets[individual_poisoned_dataset_name].coords['y'].values)
assert expected_col == got_col
# this is a bit hardcoded but ok.
assert dg.all_datasets[individual_poisoned_dataset_name].attrs == {const.KEY_ATTR_POISONED: poisoning_algo.perform_info.get_info_as_dict()}
@pytest.mark.parametrize('poisoning_generation_input_pre, poisoning_generation_input_post', [
(
poisoning.PoisoningGenerationInput(
perc_data_points=[10, 15, 20],
performer=poisoning.PerformerLabelFlippingMonoDirectional(),
selector=poisoning.SelectorRandom(),
perform_info_kwargs={'from_label': 1, 'to_label': 0},
perform_info_clazz=poisoning.PerformInfoMonoDirectional,
selection_info_clazz=poisoning.SelectionInfoEmpty,
selection_info_kwargs={}
),
poisoning.PoisoningGenerationInput(
perc_data_points=[10, 15],
performer=poisoning.PerformerLabelFlippingMonoDirectional(),
selector=poisoning.SelectorRandom(),
perform_info_kwargs={'from_label': 1, 'to_label': 0},
perform_info_clazz=poisoning.PerformInfoMonoDirectional,
selection_info_clazz=poisoning.SelectionInfoEmpty,
selection_info_kwargs={}
),
)
])
def test_import_with_smaller(poisoning_generation_input_pre: poisoning.PoisoningGenerationInput,
poisoning_generation_input_post: poisoning.PoisoningGenerationInput):
# modify the input to include column
poisoning_generation_input_pre.columns = [f'{i}' for i in range(X.shape[1])]
poisoning_generation_input_post.columns = [f'{i}' for i in range(X.shape[1])]
dg = get_dg(poisoning_generation_input=poisoning_generation_input_pre)
with tempfile.TemporaryDirectory() as tmp_dir:
dg.export(base_directory=tmp_dir, exists_ok=True)
# now, we re-import it.
dg_imported = dataset_generator.DatasetGenerator.import_from_directory(
base_directory=tmp_dir, poisoning_generation_input=poisoning_generation_input_post)
# now, the length should be smaller.
points, wrappers = poisoning_generation_input_post.generate_from_sequence()
assert len(dg_imported) == len(points)