|
9 | 9 | logger = setup_logger(__name__)
|
10 | 10 |
|
11 | 11 |
|
| 12 | +# Note - this is not a method on the class below in order to avoid |
| 13 | +# serializing the object itself when multi-processing is used. |
| 14 | +# In particular, SSLContext - embedded in the OpenAI client object - |
| 15 | +# cannot be pickled. |
| 16 | +def _filter_by_values(samples, column, op, values, num_proc=1): |
| 17 | + return samples.filter( |
| 18 | + lambda x: any(op(x[column], value) for value in values), |
| 19 | + num_proc=num_proc, |
| 20 | + ) |
| 21 | + |
| 22 | + |
| 23 | +def _map_dtype(samples, column, dtype, num_proc=1): |
| 24 | + def convert_column(sample): |
| 25 | + try: |
| 26 | + sample[column] = dtype(sample[column]) |
| 27 | + except ValueError as e: |
| 28 | + logger.error( |
| 29 | + "Error converting dtype: %s, filling with None to be filtered later", e |
| 30 | + ) |
| 31 | + sample[column] = None |
| 32 | + return sample |
| 33 | + |
| 34 | + # FIXME: it appears multiprocessing map has issues with |
| 35 | + # None columns. If we pass num_proc>1 here and the error |
| 36 | + # case is triggered above, we get: |
| 37 | + # ValueError: The features can't be aligned ... |
| 38 | + # because the column is still considered a string not |
| 39 | + # the new dtype. |
| 40 | + num_proc = 1 |
| 41 | + |
| 42 | + return samples.map(convert_column, num_proc=num_proc) |
| 43 | + |
| 44 | + |
12 | 45 | class FilterByValueBlock(Block):
|
13 | 46 | def __init__(
|
14 | 47 | self,
|
@@ -40,26 +73,12 @@ def __init__(
|
40 | 73 | self.convert_dtype = convert_dtype
|
41 | 74 | self.num_procs = batch_kwargs.get("num_procs", 1)
|
42 | 75 |
|
43 |
| - def _convert_dtype(self, sample): |
44 |
| - try: |
45 |
| - sample[self.column_name] = self.convert_dtype(sample[self.column_name]) |
46 |
| - except ValueError as e: |
47 |
| - logger.error( |
48 |
| - "Error converting dtype: %s, filling with None to be filtered later", e |
49 |
| - ) |
50 |
| - sample[self.column_name] = None |
51 |
| - return sample |
52 |
| - |
53 | 76 | def generate(self, samples) -> Dataset:
|
54 | 77 | if self.convert_dtype:
|
55 |
| - samples = samples.map( |
56 |
| - self._convert_dtype, |
57 |
| - num_proc=self.num_procs, |
| 78 | + samples = _map_dtype( |
| 79 | + samples, self.column_name, self.convert_dtype, self.num_procs |
58 | 80 | )
|
59 | 81 |
|
60 |
| - return samples.filter( |
61 |
| - lambda x: any( |
62 |
| - self.operation(x[self.column_name], value) for value in self.value |
63 |
| - ), |
64 |
| - num_proc=self.num_procs, |
| 82 | + return _filter_by_values( |
| 83 | + samples, self.column_name, self.operation, self.value, self.num_procs |
65 | 84 | )
|
0 commit comments