diff --git a/netam/common.py b/netam/common.py index 93bae1b5..c50200a3 100644 --- a/netam/common.py +++ b/netam/common.py @@ -387,18 +387,18 @@ def encode_sequences(sequences, encoder): # from https://docs.python.org/3.11/library/itertools.html#itertools-recipes # avoiding walrus: -def batched(iterable, n): - "Batch data into lists of length n. The last batch may be shorter." +def chunked(iterable, n): + "Chunk data into lists of length n. The last chunk may be shorter." it = iter(iterable) while True: - batch = list(islice(it, n)) - if not batch: + chunk = list(islice(it, n)) + if not chunk: return - yield batch + yield chunk -def batch_method(default_batch_size=2048, progress_bar_name=None): - """Decorator to batch the input to a method. +def chunk_method(default_chunk_size=2048, progress_bar_name=None): + """Decorator to chunk the input to a method. Expects that all positional arguments are iterables of the same length, and that outputs are tuples of tensors whose first dimension @@ -406,30 +406,30 @@ def batch_method(default_batch_size=2048, progress_bar_name=None): If method returns just one item, it must not be a tuple. - Batching is done along the first dimension of all inputs. + Chunking is done along the first dimension of all inputs. Args: - default_batch_size: The default batch size. The decorated method can - also automatically accept a `default_batch_size` keyword argument. + default_chunk_size: The default chunk size. The decorated method can + also automatically accept a `default_chunk_size` keyword argument. progress_bar_name: The name of the progress bar. If None, no progress bar is shown. """ def decorator(method): @wraps(method) def wrapper(self, *args, **kwargs): - if "batch_size" in kwargs: - batch_size = kwargs.pop("batch_size") + if "chunk_size" in kwargs: + chunk_size = kwargs.pop("chunk_size") else: - batch_size = default_batch_size + chunk_size = default_chunk_size results = [] if progress_bar_name is None: progargs = {"disable": True} else: progargs = {"desc": progress_bar_name} bar = tqdm(total=len(args[0]), delay=2.0, **progargs) - for batched_args in zip(*(batched(arg, batch_size) for arg in args)): - bar.update(len(batched_args[0])) - results.append(method(self, *batched_args, **kwargs)) + for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)): + bar.update(len(chunked_args[0])) + results.append(method(self, *chunked_args, **kwargs)) if isinstance(results[0], tuple): return tuple(torch.cat(tensors) for tensors in zip(*results)) else: diff --git a/netam/models.py b/netam/models.py index 08360200..93293b8e 100644 --- a/netam/models.py +++ b/netam/models.py @@ -17,7 +17,7 @@ generate_kmers, aa_mask_tensor_of, encode_sequences, - batch_method, + chunk_method, ) warnings.filterwarnings( @@ -65,8 +65,8 @@ def unfreeze(self): for param in self.parameters(): param.requires_grad = True - @batch_method(progress_bar_name="Evaluating model") - def evaluate_sequences(self, sequences, encoder=None, batch_size=2048): + @chunk_method(progress_bar_name="Evaluating model") + def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048): if encoder is None: raise ValueError("An encoder must be provided.") encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)