From e73df1e6726cef8cdc30fb62d509fa7aba3189f3 Mon Sep 17 00:00:00 2001 From: Mariella CC Date: Thu, 17 Oct 2024 15:08:41 +0200 Subject: [PATCH] fix: publication name column --- .../modules/corpus_metadata.py | 53 ++++++++----------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/src/kiara_plugin/topic_modelling/modules/corpus_metadata.py b/src/kiara_plugin/topic_modelling/modules/corpus_metadata.py index 4dde873..2784436 100644 --- a/src/kiara_plugin/topic_modelling/modules/corpus_metadata.py +++ b/src/kiara_plugin/topic_modelling/modules/corpus_metadata.py @@ -44,39 +44,35 @@ def create_outputs_schema(self): } def process(self, inputs, outputs): - import re - import polars as pl # type: ignore - import pyarrow as pa # type: ignore + import polars as pl # type: ignore + import pyarrow as pa # type: ignore table_obj = inputs.get_value_obj("corpus_table") column_name = inputs.get_value_obj("column_name").data sources = table_obj.data - + sources_col_names = sources.column_names if column_name not in sources_col_names: - raise KiaraProcessingException( f"Could not find file names column '{column_name}' in the table. Please specify a valid column name manually, using one of: {', '.join(sources_col_names)}" ) sources_data: pa.Table = table_obj.data.arrow_table - sources_tb: pl.DataFrame = pl.from_arrow(sources_data) # type: ignore - + sources_tb: pl.DataFrame = pl.from_arrow(sources_data) # type: ignore def get_ref(file): try: - ref_match = re.findall(r"(\w+\d+)_\d{4}-\d{2}-\d{2}_", file) + ref_match = re.findall(r"(sn\d+)_", file) if not ref_match: return None return ref_match[0] - except Exception as e: msg = f"There was a problem in the publication reference pattern: {e}" - raise KiaraProcessingException(e) + raise KiaraProcessingException(msg) def get_date(file): try: @@ -88,34 +84,31 @@ def get_date(file): msg = f"There was a problem in the date pattern: {e}" raise KiaraProcessingException(msg) - try: - - pub_refs: list[str] = inputs.get_value_obj("map").data[0] - pub_names: list[str] = inputs.get_value_obj("map").data[1] - - pub_ref_to_name = dict(zip(pub_refs, pub_names)) - - augm_sources = sources_tb.with_columns( - sources_tb["publication_ref"] - .map_elements(lambda x: pub_ref_to_name.get(x)) - .alias("publication_name") - ) - - except: - try: - augm_sources = sources_tb.with_columns([ + augm_sources = sources_tb.with_columns([ sources_tb[column_name].map_elements(get_date, return_dtype=pl.Utf8).alias("date"), sources_tb[column_name].map_elements(get_ref, return_dtype=pl.Utf8).alias("publication_ref"), ]) - except Exception as e: - msg = f"An error occurred while augmenting the dataframe: {e}" - raise KiaraProcessingException(msg) + # If a map is provided, add the publication_name column + map_input = inputs.get_value_obj("map") + if map_input is not None and map_input.data is not None: + pub_refs: list[str] = inputs.get_value_obj("map").data[0] + pub_names: list[str] = inputs.get_value_obj("map").data[1] + pub_ref_to_name = dict(zip(pub_refs, pub_names)) + + augm_sources = augm_sources.with_columns( + augm_sources["publication_ref"] + .map_elements(lambda x: pub_ref_to_name.get(x, None), return_dtype=pl.Utf8) + .alias("publication_name") + ) + + except Exception as e: + msg = f"An error occurred while augmenting the dataframe: {e}" + raise KiaraProcessingException(msg) try: output_table = augm_sources.to_arrow() - except Exception as e: raise KiaraProcessingException(e)