Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize table generation #203

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 48 additions & 60 deletions semantic_model_generator/generate_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import concurrent.futures
import os
import time
from datetime import datetime
from typing import List, Optional

Expand Down Expand Up @@ -160,83 +162,69 @@ def _raw_table_to_semantic_context_table(
)


def process_table(
table: str, conn: SnowflakeConnection, n_sample_values: int
) -> semantic_model_pb2.Table:
fqn_table = create_fqn_table(table)
valid_schemas_tables_columns_df = get_valid_schemas_tables_columns_df(
conn=conn,
db_name=fqn_table.database,
table_schema=fqn_table.schema_name,
table_names=[fqn_table.table],
)
assert not valid_schemas_tables_columns_df.empty

valid_columns_df_this_table = valid_schemas_tables_columns_df[
valid_schemas_tables_columns_df["TABLE_NAME"] == fqn_table.table
]

raw_table = get_table_representation(
conn=conn,
schema_name=fqn_table.database + "." + fqn_table.schema_name,
table_name=fqn_table.table,
table_index=0,
ndv_per_column=n_sample_values,
columns_df=valid_columns_df_this_table,
)
return _raw_table_to_semantic_context_table(
database=fqn_table.database,
schema=fqn_table.schema_name,
raw_table=raw_table,
)


def raw_schema_to_semantic_context(
base_tables: List[str],
semantic_model_name: str,
conn: SnowflakeConnection,
n_sample_values: int = _DEFAULT_N_SAMPLE_VALUES_PER_COL,
allow_joins: Optional[bool] = False,
) -> semantic_model_pb2.SemanticModel:
"""
Converts a list of fully qualified Snowflake table names into a semantic model.

Parameters:
- base_tables (list[str]): Fully qualified table names to include in the semantic model.
- snowflake_account (str): Snowflake account identifier.
- semantic_model_name (str): A meaningful semantic model name.
- conn (SnowflakeConnection): SnowflakeConnection to reuse.
- n_sample_values (int): The number of sample values per col.

Returns:
- The semantic model (semantic_model_pb2.SemanticModel).

This function fetches metadata for the specified tables, performs schema validation, extracts key information,
enriches metadata from the Snowflake database, and constructs a semantic model in protobuf format.
It handles different databases and schemas within the same account by creating unique Snowflake connections as needed.

Raises:
- AssertionError: If no valid tables are found in the specified schema.
"""

# For FQN tables, create a new snowflake connection per table in case the db/schema is different.
start_time = time.time()
table_objects = []
unique_database_schema: List[str] = []
for table in base_tables:
# Verify this is a valid FQN table. For now, we check that the table follows the following format.
# {database}.{schema}.{table}
fqn_table = create_fqn_table(table)
fqn_databse_schema = f"{fqn_table.database}.{fqn_table.schema_name}"

if fqn_databse_schema not in unique_database_schema:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we were using unique_database_schema anywhere, so I removed it

unique_database_schema.append(fqn_databse_schema)

logger.info(f"Pulling column information from {fqn_table}")
valid_schemas_tables_columns_df = get_valid_schemas_tables_columns_df(
conn=conn,
db_name=fqn_table.database,
table_schema=fqn_table.schema_name,
table_names=[fqn_table.table],
)
assert not valid_schemas_tables_columns_df.empty

# get the valid columns for this table.
valid_columns_df_this_table = valid_schemas_tables_columns_df[
valid_schemas_tables_columns_df["TABLE_NAME"] == fqn_table.table
# Create a Table object representation for each provided table name.
# This is done concurrently because `process_table` is I/O bound, executing potentially long-running
# queries to fetch column metadata and sample values.
with concurrent.futures.ThreadPoolExecutor() as executor:
table_futures = [
executor.submit(process_table, table, conn, n_sample_values)
for table in base_tables
]

raw_table = get_table_representation(
conn=conn,
schema_name=fqn_databse_schema, # Fully-qualified schema
table_name=fqn_table.table, # Non-qualified table name
table_index=0,
ndv_per_column=n_sample_values, # number of sample values to pull per column.
columns_df=valid_columns_df_this_table,
max_workers=1,
)
table_object = _raw_table_to_semantic_context_table(
database=fqn_table.database,
schema=fqn_table.schema_name,
raw_table=raw_table,
)
table_objects.append(table_object)
# TODO(jhilgart): Call cortex model to generate a semantically friendly name here.
concurrent.futures.wait(table_futures)
for future in table_futures:
table_object = future.result()
table_objects.append(table_object)

placeholder_relationships = _get_placeholder_joins() if allow_joins else None
context = semantic_model_pb2.SemanticModel(
name=semantic_model_name,
tables=table_objects,
relationships=placeholder_relationships,
)
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Time taken to generate semantic model: {elapsed_time} seconds.")
return context


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def get_table_representation(
table_index: int,
ndv_per_column: int,
columns_df: pd.DataFrame,
max_workers: int,
) -> Table:
table_comment = _get_table_comment(conn, schema_name, table_name, columns_df)

Expand All @@ -160,7 +159,7 @@ def _get_col(col_index: int, column_row: pd.Series) -> Column:
ndv=ndv_per_column,
)

with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_col_index = {
executor.submit(_get_col, col_index, column_row): col_index
for col_index, (_, column_row) in enumerate(columns_df.iterrows())
Expand Down
Loading