diff --git a/src/openlayer/lib/core/base_model.py b/src/openlayer/lib/core/base_model.py index e847e2bd..c69fabcb 100644 --- a/src/openlayer/lib/core/base_model.py +++ b/src/openlayer/lib/core/base_model.py @@ -39,6 +39,8 @@ class OpenlayerModel(abc.ABC): Refer to Openlayer's templates for examples of how to implement this class. """ + custom_args: dict = {} + def run_from_cli(self) -> None: """Run the model from the command line.""" parser = argparse.ArgumentParser(description="Run data through a model.") @@ -51,10 +53,26 @@ def run_from_cli(self) -> None: required=False, help="Directory to dump the results in", ) + parser.add_argument( + "--custom-args", + type=str, + required=False, + help="Custom arguments in format 'key1=value1,key2=value2'", + ) # Parse the arguments args = parser.parse_args() + # Parse custom arguments string + custom_args = {} + if args.custom_args: + pairs = args.custom_args.split(",") + for pair in pairs: + if "=" in pair: + key, value = pair.split("=", 1) + custom_args[key] = value + self.custom_args = custom_args + return self.batch( dataset_path=args.dataset_path, output_dir=args.output_dir, @@ -69,12 +87,16 @@ def batch(self, dataset_path: str, output_dir: str) -> None: elif dataset_path.endswith(".json"): df = pd.read_json(dataset_path, orient="records") fmt = "json" + else: + raise ValueError(f"Unsupported dataset format: {dataset_path}") # Call the model's run_batch method, passing in the DataFrame - output_df, config = self.run_batch_from_df(df) + output_df, config = self.run_batch_from_df(df, custom_args=self.custom_args) self.write_output_to_directory(output_df, config, output_dir, fmt) - def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: + def run_batch_from_df( + self, df: pd.DataFrame, custom_args: dict = None + ) -> Tuple[pd.DataFrame, dict]: """Function that runs the model and returns the result.""" # Ensure the 'output' column exists if "output" not in df.columns: @@ -83,6 +105,10 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: # Get the signature of the 'run' method run_signature = inspect.signature(self.run) + # If the model has a custom_args attribute, update it + if hasattr(self, "custom_args") and custom_args is not None: + self.custom_args.update(custom_args) + for index, row in df.iterrows(): # Filter row_dict to only include keys that are valid parameters # for the 'run' method @@ -112,8 +138,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: if "tokens" in processed_trace: df.at[index, "tokens"] = processed_trace["tokens"] if "context" in processed_trace: - # Convert the context list to a string to avoid pandas issues - df.at[index, "context"] = json.dumps(processed_trace["context"]) + df.at[index, "context"] = processed_trace["context"] config = { "outputColumnName": "output", @@ -132,6 +157,9 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: if "context" in df.columns: config["contextColumnName"] = "context" + for k, v in self.custom_args.items(): + config["metadata"][k] = v + return df, config def write_output_to_directory(