diff --git a/data_explorer/app/df_helpers/fields.py b/data_explorer/app/df_helpers/fields.py index 95de3bec8..85ab1cb4a 100644 --- a/data_explorer/app/df_helpers/fields.py +++ b/data_explorer/app/df_helpers/fields.py @@ -2,7 +2,8 @@ import typing as t -from fondant.core.schema import Field +import pyarrow as pa +from fondant.core.schema import Field, Type def get_fields_by_types( @@ -41,3 +42,7 @@ def get_numeric_fields(fields: t.Dict[str, Field]) -> t.List[str]: "float64", ] return get_fields_by_types(fields, numeric_types) + + +def is_nested_string_array(field: Field): + return field.type == Type(pa.list_(pa.string())) diff --git a/data_explorer/app/interfaces/dataset_interface.py b/data_explorer/app/interfaces/dataset_interface.py index 98185b522..0b8aeb047 100644 --- a/data_explorer/app/interfaces/dataset_interface.py +++ b/data_explorer/app/interfaces/dataset_interface.py @@ -10,6 +10,7 @@ import pandas as pd import streamlit as st from config import DEFAULT_INDEX_NAME, ROWS_TO_RETURN +from df_helpers.fields import is_nested_string_array from fondant.core.manifest import Manifest from fondant.core.schema import Field from interfaces.utils import get_index_from_state @@ -131,6 +132,7 @@ def get_pandas_from_dask( partition_index: int, partition_row_index: int, cache_key: str, + _selected_fields: t.Dict[str, Field], ): """ Converts a Dask DataFrame into a Pandas DataFrame with specified number of rows. @@ -194,7 +196,26 @@ def get_pandas_from_dask( # Concatenate the selected partitions into a single pandas DataFrame df = pd.concat(data_to_return) - return df, partition_index, partition_row_index + # Unnest columns that contain lists of string values + additional_text_fields = [] + nested_columns = [] + normal_columns = [] + + for column in df.columns: + if is_nested_string_array(_selected_fields[column]): + unnested = df[column].apply(pd.Series) + unnested.columns = [f"{column}_{i}" for i in unnested.columns] + additional_text_fields.extend(unnested.columns) + nested_columns.append(column) + nested_columns.extend(unnested.columns) + df = pd.concat([df, unnested], axis=1) + else: + normal_columns.append(column) + + desired_order = normal_columns + nested_columns + df = df[desired_order] + + return df, partition_index, partition_row_index, additional_text_fields @staticmethod def _initialize_page_view_dict(component, cache_key): @@ -236,16 +257,18 @@ def _update_page_view_dict( def load_pandas_dataframe( self, + *, dask_df: dd.DataFrame, field_mapping: t.Dict[str, str], cache_key: t.Optional[str] = "", + selected_fields: t.Dict[str, Field], ): """ Provides common widgets for loading a dataframe and selecting a partition to load. Uses Cached dataframes to avoid reloading the dataframe when changing the partition. Returns: - Dataframe and fields + Dataframe and name of additional text fields from unpacking """ previous_button_disabled = True next_button_disabled = False @@ -265,13 +288,19 @@ def load_pandas_dataframe( ] start_index = page_view_dict[component][cache_key][page_index]["start_index"] - pandas_df, partition_index, partition_row_index = self.get_pandas_from_dask( + ( + pandas_df, + partition_index, + partition_row_index, + additional_text_fields, + ) = self.get_pandas_from_dask( field_mapping=field_mapping, _dask_df=dask_df, rows_to_return=ROWS_TO_RETURN, partition_index=start_partition, partition_row_index=start_index, cache_key=cache_key, + _selected_fields=selected_fields, ) self._update_page_view_dict( @@ -322,4 +351,4 @@ def load_pandas_dataframe( st.markdown(f"Page {page_index}") - return pandas_df + return pandas_df, additional_text_fields diff --git a/data_explorer/app/pages/dataset.py b/data_explorer/app/pages/dataset.py index f07f77cfe..8757188ca 100644 --- a/data_explorer/app/pages/dataset.py +++ b/data_explorer/app/pages/dataset.py @@ -131,11 +131,20 @@ def render_text(text: str): else: st.markdown(text) - def setup_viewer_widget(self, grid_dict: AgGridReturn, fields: t.Dict[str, t.Any]): + def setup_viewer_widget( + self, + grid_dict: AgGridReturn, + fields: t.Dict[str, t.Any], + extra_text_fields: t.Optional[t.List[str]], + ): """Setup the viewer widget. This widget allows the user to view the selected row in the dataframe. """ text_fields = get_string_fields(fields) + + if extra_text_fields: + text_fields.extend(extra_text_fields) + if text_fields: st.markdown("### Document Viewer") selected_column = st.selectbox("View column", text_fields) @@ -283,6 +292,11 @@ def result_found(): dask_df = app.load_dask_dataframe(field_mapping) dask_df, cache_key = app.setup_search_widget(dask_df, selected_fields, field_mapping) if st.session_state.result_found is True: - loaded_df = app.load_pandas_dataframe(dask_df, field_mapping, cache_key) + loaded_df, additional_text_fields = app.load_pandas_dataframe( + dask_df=dask_df, + field_mapping=field_mapping, + cache_key=cache_key, + selected_fields=selected_fields, + ) grid_data_dict = app.setup_app_page(loaded_df, selected_fields) - app.setup_viewer_widget(grid_data_dict, selected_fields) + app.setup_viewer_widget(grid_data_dict, selected_fields, additional_text_fields) diff --git a/data_explorer/app/pages/images.py b/data_explorer/app/pages/images.py index 3ee8b5842..16bc1858d 100644 --- a/data_explorer/app/pages/images.py +++ b/data_explorer/app/pages/images.py @@ -34,5 +34,5 @@ def setup_app_page(dataframe, fields): app = ImageGalleryApp() field_mapping, selected_fields = app.get_fields_mapping() dask_df = app.load_dask_dataframe(field_mapping) -df = app.load_pandas_dataframe(dask_df, field_mapping) +df, _ = app.load_pandas_dataframe(dask_df=dask_df, field_mapping=field_mapping) app.setup_app_page(df, selected_fields)