From b67e11b2e8e390565894ed8e5c6b16f17586e0cc Mon Sep 17 00:00:00 2001 From: Tom Zayats Date: Wed, 4 Dec 2024 11:21:38 -0800 Subject: [PATCH] added starter code --- app_utils/shared_utils.py | 79 +++++++++++++++++++ journeys/iteration.py | 72 ++++++++++++++++- .../snowflake_utils/snowflake_connector.py | 14 ++++ 3 files changed, 163 insertions(+), 2 deletions(-) diff --git a/app_utils/shared_utils.py b/app_utils/shared_utils.py index f44ed2ff..a13793aa 100644 --- a/app_utils/shared_utils.py +++ b/app_utils/shared_utils.py @@ -34,6 +34,7 @@ fetch_warehouses, fetch_stages_in_schema, fetch_yaml_names_in_stage, + fetch_columns_names_in_table, ) from semantic_model_generator.snowflake_utils.env_vars import ( # noqa: E402 @@ -199,6 +200,71 @@ def get_available_stages(schema: str) -> List[str]: """ return fetch_stages_in_schema(get_snowflake_connection(), schema) +@st.cache_resource(show_spinner=False) +def validate_table_columns(table: str, columns_must_exist) -> bool: + """ + Fetches the available stages from the Snowflake account. + + Returns: + List[str]: A list of available stages. + """ + columns_names = fetch_columns_names_in_table(get_snowflake_connection(), table) + for col in columns_must_exist: + if col not in columns_names: + return False + return True + + + +def table_selector_container() -> Optional[str]: + """ + Common component that encapsulates db/schema/table selection for the admin app. + When a db/schema/table is selected, it is saved to the session state for reading elsewhere. + Returns: None + """ + available_schemas = [] + available_tables = [] + + # First, retrieve all databases that the user has access to. + eval_database = st.selectbox( + "Eval database", + options=get_available_databases(), + index=None, + key="selected_eval_database", + ) + if eval_database: + # When a valid database is selected, fetch the available schemas in that database. + try: + available_schemas = get_available_schemas(eval_database) + except (ValueError, ProgrammingError): + st.error("Insufficient permissions to read from the selected database.") + st.stop() + + eval_schema = st.selectbox( + "Eval schema", + options=available_schemas, + index=None, + key="selected_eval_schema", + format_func=lambda x: format_snowflake_context(x, -1), + ) + if eval_schema: + # When a valid schema is selected, fetch the available tables in that schema. + try: + available_tables = get_available_tables(eval_schema) + except (ValueError, ProgrammingError): + st.error("Insufficient permissions to read from the selected schema.") + st.stop() + + tables = st.selectbox( + "Table name", + options=available_tables, + index=None, + key="selected_eval_table", + format_func=lambda x: format_snowflake_context(x, -1), + ) + + return tables + def stage_selector_container() -> Optional[List[str]]: """ @@ -1291,3 +1357,16 @@ def to_dict(self) -> dict[str, str]: "Schema": self.stage_schema, "Stage": self.stage_name, } + +@dataclass +class SnowflakeTable: + table_database: str + table_schema: str + table_name: str + + def to_dict(self) -> dict[str, str]: + return { + "Database": self.table_database, + "Schema": self.table_schema, + "Table": self.table_name, + } \ No newline at end of file diff --git a/journeys/iteration.py b/journeys/iteration.py index 8c959a1a..e8aa71a4 100644 --- a/journeys/iteration.py +++ b/journeys/iteration.py @@ -19,6 +19,7 @@ from app_utils.shared_utils import ( GeneratorAppScreen, SnowflakeStage, + SnowflakeTable, changed_from_last_validated_model, download_yaml, get_snowflake_connection, @@ -26,6 +27,8 @@ init_session_states, return_home_button, stage_selector_container, + table_selector_container, + validate_table_columns, upload_yaml, validate_and_upload_tmp_yaml, ) @@ -370,6 +373,35 @@ def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None: st.session_state.active_suggestion = None +@st.experimental_dialog("Evaluation Data", width="large") +def evaluation_data_dialog() -> None: + evaluation_table_columns = ["ID", "QUERY", "GOLD_SQL"] + st.markdown("Please enter evaluation select table") + table_selector_container() + if st.button("Use Table"): + if ( + not st.session_state["selected_eval_database"] + or not st.session_state["selected_eval_schema"] + or not st.session_state["selected_eval_table"] + ): + st.error("Please fill in all fields.") + return + + if not validate_table_columns(st.session_state["selected_eval_table"], evaluation_table_columns): + st.error("Table must have columns {evaluation_table_columns} to be used in Evaluation") + return + + st.session_state["eval_table"] = SnowflakeTable( + table_database=st.session_state["selected_eval_database"], + table_schema=st.session_state["selected_eval_schema"], + table_name=st.session_state["selected_eval_table"], + ) + st.rerun() + + + + + @st.experimental_dialog("Upload", width="small") def upload_dialog(content: str) -> None: def upload_handler(file_name: str) -> None: @@ -482,7 +514,18 @@ def yaml_editor(yaml_str: str) -> None: background-color: #fbfbfb; } """ - st.session_state.confirm = st.checkbox("Preview YAML") + checkbox_row = row(2) + + st.session_state.preview_yaml_mode = checkbox_row.checkbox( + "Preview YAML", + ) + + # Evaluation Mode checkbox + st.session_state.eval_mode = checkbox_row.checkbox( + "Evaluation Mode", + ) + + # Style text_area to mirror st.code with stylable_container(key="customized_text_area", css_styles=css_yaml_editor): content = st.text_area( @@ -611,6 +654,19 @@ def set_up_requirements() -> None: help="Checking this box will enable you to add/edit join paths in your semantic model. If enabling this setting, please ensure that you have the proper parameters set on your Snowflake account. Reach out to your account team for access.", ) + # # TODOTZ - uncomment this block to use defaults for testing + # print("USING DEFAULTS FOR TESTING") + # st.session_state["snowflake_stage"] = SnowflakeStage( + # stage_database="TZAYATS", + # stage_schema="TZAYATS.TESTING", + # stage_name="TZAYATS.TESTING.MY_SEMANTIC_MODELS", + # ) + # st.session_state["file_name"] = "revenue_timeseries_update.yaml" + # st.session_state["page"] = GeneratorAppScreen.ITERATION + # st.session_state["experimental_features"] = experimental_features + # st.rerun() + + # TODOTZ - comment this block to use defaults for testing if st.button( "Submit", disabled=not st.session_state["selected_iteration_database"] @@ -669,6 +725,15 @@ def chat_settings_dialog() -> None: Note that the Cortex Analyst semantic model must be validated before integrating partner semantics.""" +def evaluation_mode_show() -> None: + header_row = row([0.65, 0.15], vertical_align="center") + header_row.markdown("**Evaluation**") + if header_row.button("Select Eval Data"): + evaluation_data_dialog() + if "eval_table" in st.session_state: + st.write(f'Using this table as eval table {st.session_state["eval_table"].to_dict()}') + + def show() -> None: init_session_states() @@ -705,10 +770,13 @@ def show() -> None: yaml_editor(editor_contents) with chat_container: - if st.session_state.confirm: + if st.session_state.preview_yaml_mode: st.code( st.session_state.working_yml, language="yaml", line_numbers=True ) + elif st.session_state.eval_mode: + + evaluation_mode_show() else: header_row = row([0.85, 0.15], vertical_align="center") header_row.markdown("**Chat**") diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index bdd1db19..5f8c1951 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -345,6 +345,20 @@ def fetch_stages_in_schema(conn: SnowflakeConnection, schema_name: str) -> list[ return [f"{result[2]}.{result[3]}.{result[1]}" for result in stages] +def fetch_columns_names_in_table(conn: SnowflakeConnection, table_fqn: str) -> list[str]: + """ + Fetches all columns that the current user has access to in the current table + Args: + conn: SnowflakeConnection to run the query + table_fqn: The fully qualified name of the table to connect to. + + Returns: a list of column names + """ + query = f"DESCRIBE TABLE {table_fqn};" + cursor = conn.cursor() + cursor.execute(query) + columns = cursor.fetchall() + return [result[0] for result in columns] def fetch_yaml_names_in_stage( conn: SnowflakeConnection, stage_name: str, include_yml: bool = False