From 92f68bb57ca43daecb917b93a2b5c1a5cafc28a5 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 7 Aug 2024 15:06:01 +0200 Subject: [PATCH] Pass dataset to dry_run method --- src/distilabel/pipeline/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 398af9a727..6d9cc15bac 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -400,6 +400,7 @@ def dry_run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, batch_size: int = 1, + dataset: Optional["InputDataset"] = None, ) -> "Distiset": """Do a dry run to test the pipeline runs as expected. @@ -412,6 +413,9 @@ def dry_run( the runtime parameters for the step as the value. Defaults to `None`. batch_size: The batch size of the unique batch generated by the generators steps of the pipeline. Defaults to `1`. + dataset: If given, it will be used to create a `GeneratorStep` and put it as the + root step. Convenient method when you have already processed the dataset in + your script and just want to pass it already processed. Defaults to `None`. Returns: Will return the `Distiset` as the main run method would do. @@ -426,7 +430,7 @@ def dry_run( parameters = {} parameters[step_name] = {"batch_size": batch_size} - distiset = self.run(parameters=parameters, use_cache=False) + distiset = self.run(parameters=parameters, use_cache=False, dataset=dataset) self._dry_run = False return distiset