diff --git a/.gitignore b/.gitignore index f4c53236..954f6dfb 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,5 @@ dmypy.json # Cython debug symbols cython_debug/ + +.idea/ diff --git a/ffcv/loader/loader.py b/ffcv/loader/loader.py index aa8fd9dc..0531fc2c 100644 --- a/ffcv/loader/loader.py +++ b/ffcv/loader/loader.py @@ -101,6 +101,7 @@ def __init__(self, drop_last: bool = True, batches_ahead: int = 3, recompile: bool = False, # Recompile at every epoch + order_kwargs: dict = dict(), ): if distributed and order == OrderOption.RANDOM and (seed is None): @@ -156,8 +157,8 @@ def __init__(self, if order in ORDER_MAP: self.traversal_order: TraversalOrder = ORDER_MAP[order](self) - elif isinstance(order, TraversalOrder): - self.traversal_order: TraversalOrder = order(self) + elif issubclass(order, TraversalOrder): + self.traversal_order: TraversalOrder = order(self, **order_kwargs) else: raise ValueError(f"Order {order} is not a supported order type or a subclass of TraversalOrder") @@ -180,7 +181,7 @@ def __init__(self, elif spec is None: continue # This is a disabled field else: - msg = f"The pipeline for {output_name} has to be " + msg = f"The pipeline for {output_name} has to be " msg += f"either a PipelineSpec or a sequence of operations" raise ValueError(msg) custom_pipeline_specs[output_name] = spec