diff --git a/armory/scenarios/poison.py b/armory/scenarios/poison.py index 142aa4662..ee8c1d3e9 100644 --- a/armory/scenarios/poison.py +++ b/armory/scenarios/poison.py @@ -62,8 +62,9 @@ def poison_dataset(self, x, y, return_index=False, fraction=None): poison_x, poison_y = list(x), list(y) poison_index = self.get_poison_index(y, fraction=fraction) for i in poison_index: - poison_x_i, poison_y[i] = self.attack.poison(x[i], [self.target_class]) + poison_x_i, poison_y_i = self.attack.poison(x[i], [self.target_class]) poison_x[i] = np.asarray(poison_x_i, dtype=x[i].dtype) + poison_y[i] = poison_y_i[0] poison_x, poison_y = np.array(poison_x), np.array(poison_y, dtype=int) if return_index: