diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 3577d76218b..5f73e6d63ff 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -3,6 +3,7 @@ # import copy +import logging import time from typing import Any, Dict, List @@ -32,6 +33,8 @@ from snowflake.snowpark._internal.utils import random_name_for_temp_object from snowflake.snowpark.mock._connection import MockServerConnection +_logger = logging.getLogger(__name__) + class PlanCompiler: """ @@ -77,98 +80,112 @@ def should_start_query_compilation(self) -> bool: ) def compile(self) -> Dict[PlanQueryType, List[Query]]: + # initialize the queries with the original queries without optimization + final_plan = self._plan + queries = { + PlanQueryType.QUERIES: final_plan.queries, + PlanQueryType.POST_ACTIONS: final_plan.post_actions, + } + if self.should_start_query_compilation(): session = self._plan.session - # preparation for compilation - # 1. make a copy of the original plan - start_time = time.time() - complexity_score_before_compilation = get_complexity_score(self._plan) - logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)] - plot_plan_if_enabled(self._plan, "original_plan") - plot_plan_if_enabled(logical_plans[0], "deep_copied_plan") - deep_copy_end_time = time.time() - - # 2. create a code generator with the original plan - query_generator = create_query_generator(self._plan) - - extra_optimization_status: Dict[str, Any] = {} - # 3. apply each optimizations if needed - # CTE optimization - cte_start_time = time.time() - if session.cte_optimization_enabled: - repeated_subquery_eliminator = RepeatedSubqueryElimination( - logical_plans, query_generator + try: + # preparation for compilation + # 1. make a copy of the original plan + start_time = time.time() + complexity_score_before_compilation = get_complexity_score(self._plan) + logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)] + plot_plan_if_enabled(self._plan, "original_plan") + plot_plan_if_enabled(logical_plans[0], "deep_copied_plan") + deep_copy_end_time = time.time() + + # 2. create a code generator with the original plan + query_generator = create_query_generator(self._plan) + + extra_optimization_status: Dict[str, Any] = {} + # 3. apply each optimizations if needed + # CTE optimization + cte_start_time = time.time() + if session.cte_optimization_enabled: + repeated_subquery_eliminator = RepeatedSubqueryElimination( + logical_plans, query_generator + ) + elimination_result = repeated_subquery_eliminator.apply() + logical_plans = elimination_result.logical_plans + # add the extra repeated subquery elimination status + extra_optimization_status[ + CompilationStageTelemetryField.CTE_NODE_CREATED.value + ] = elimination_result.total_num_of_ctes + + cte_end_time = time.time() + complexity_scores_after_cte = [ + get_complexity_score(logical_plan) for logical_plan in logical_plans + ] + for i, plan in enumerate(logical_plans): + plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}") + + # Large query breakdown + breakdown_failure_summary, skipped_summary = {}, {} + if session.large_query_breakdown_enabled: + large_query_breakdown = LargeQueryBreakdown( + session, + query_generator, + logical_plans, + session.large_query_breakdown_complexity_bounds, + ) + breakdown_result = large_query_breakdown.apply() + logical_plans = breakdown_result.logical_plans + breakdown_failure_summary = breakdown_result.breakdown_summary + skipped_summary = breakdown_result.skipped_summary + + large_query_breakdown_end_time = time.time() + complexity_scores_after_large_query_breakdown = [ + get_complexity_score(logical_plan) for logical_plan in logical_plans + ] + for i, plan in enumerate(logical_plans): + plot_plan_if_enabled(plan, f"large_query_breakdown_plan_{i}") + + # 4. do a final pass of code generation + queries = query_generator.generate_queries(logical_plans) + + # log telemetry data + deep_copy_time = deep_copy_end_time - start_time + cte_time = cte_end_time - cte_start_time + large_query_breakdown_time = ( + large_query_breakdown_end_time - cte_end_time ) - elimination_result = repeated_subquery_eliminator.apply() - logical_plans = elimination_result.logical_plans - # add the extra repeated subquery elimination status - extra_optimization_status[ - CompilationStageTelemetryField.CTE_NODE_CREATED.value - ] = elimination_result.total_num_of_ctes - - cte_end_time = time.time() - complexity_scores_after_cte = [ - get_complexity_score(logical_plan) for logical_plan in logical_plans - ] - for i, plan in enumerate(logical_plans): - plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}") - - # Large query breakdown - breakdown_failure_summary, skipped_summary = {}, {} - if session.large_query_breakdown_enabled: - large_query_breakdown = LargeQueryBreakdown( - session, - query_generator, - logical_plans, - session.large_query_breakdown_complexity_bounds, + total_time = time.time() - start_time + summary_value = { + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, + CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN.value: large_query_breakdown_time, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation, + CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte, + CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown, + CompilationStageTelemetryField.BREAKDOWN_FAILURE_SUMMARY.value: breakdown_failure_summary, + CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED.value: skipped_summary, + } + # add the extra optimization status + summary_value.update(extra_optimization_status) + session._conn._telemetry_client.send_query_compilation_summary_telemetry( + session_id=session.session_id, + plan_uuid=self._plan.uuid, + compilation_stage_summary=summary_value, ) - breakdown_result = large_query_breakdown.apply() - logical_plans = breakdown_result.logical_plans - breakdown_failure_summary = breakdown_result.breakdown_summary - skipped_summary = breakdown_result.skipped_summary - - large_query_breakdown_end_time = time.time() - complexity_scores_after_large_query_breakdown = [ - get_complexity_score(logical_plan) for logical_plan in logical_plans - ] - for i, plan in enumerate(logical_plans): - plot_plan_if_enabled(plan, f"large_query_breakdown_plan_{i}") - - # 4. do a final pass of code generation - queries = query_generator.generate_queries(logical_plans) - - # log telemetry data - deep_copy_time = deep_copy_end_time - start_time - cte_time = cte_end_time - cte_start_time - large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time - total_time = time.time() - start_time - summary_value = { - TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, - CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, - CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, - CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, - CompilationStageTelemetryField.TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN.value: large_query_breakdown_time, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation, - CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte, - CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown, - CompilationStageTelemetryField.BREAKDOWN_FAILURE_SUMMARY.value: breakdown_failure_summary, - CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED.value: skipped_summary, - } - # add the extra optimization status - summary_value.update(extra_optimization_status) - session._conn._telemetry_client.send_query_compilation_summary_telemetry( - session_id=session.session_id, - plan_uuid=self._plan.uuid, - compilation_stage_summary=summary_value, - ) - else: - final_plan = self._plan - queries = { - PlanQueryType.QUERIES: final_plan.queries, - PlanQueryType.POST_ACTIONS: final_plan.post_actions, - } + except Exception as e: + # if any error occurs during the compilation, we should fall back to the original plan + _logger.debug(f"Skipping optimization due to error: {e}") + session._conn._telemetry_client.send_query_compilation_stage_failed_telemetry( + session_id=session.session_id, + plan_uuid=self._plan.uuid, + error_type=type(e).__name__, + error_message=str(e), + ) + pass return self.replace_temp_obj_placeholders(queries) diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index b0e97a7b856..d395d2df114 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -23,6 +23,7 @@ class CompilationStageTelemetryField(Enum): "snowpark_large_query_breakdown_optimization_skipped" ) TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics" + TYPE_COMPILATION_STAGE_FAILED = "snowpark_compilation_stage_failed" TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS = ( "snowpark_large_query_breakdown_update_complexity_bounds" ) @@ -30,6 +31,8 @@ class CompilationStageTelemetryField(Enum): # keys KEY_REASON = "reason" PLAN_UUID = "plan_uuid" + ERROR_TYPE = "error_type" + ERROR_MESSAGE = "error_message" TIME_TAKEN_FOR_COMPILATION = "time_taken_for_compilation_sec" TIME_TAKEN_FOR_DEEP_COPY_PLAN = "time_taken_for_deep_copy_plan_sec" TIME_TAKEN_FOR_CTE_OPTIMIZATION = "time_taken_for_cte_optimization_sec" diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index f149c796f8b..54e16786d0a 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -482,6 +482,22 @@ def send_query_compilation_summary_telemetry( } self.send(message) + def send_query_compilation_stage_failed_telemetry( + self, session_id: int, plan_uuid: str, error_type: str, error_message: str + ) -> None: + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_FAILED.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid, + CompilationStageTelemetryField.ERROR_TYPE.value: error_type, + CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message, + }, + } + self.send(message) + def send_temp_table_cleanup_telemetry( self, session_id: str, diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index 81d6baa1666..8daa9818a28 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -705,6 +705,35 @@ def test_add_parent_plan_uuid_to_statement_params(session, large_query_df): assert call.kwargs["_statement_params"]["_PLAN_UUID"] == plan.uuid +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test" +) +@pytest.mark.parametrize("error_type", [AssertionError, ValueError, RuntimeError]) +@patch("snowflake.snowpark._internal.compiler.plan_compiler.LargeQueryBreakdown.apply") +def test_optimization_skipped_with_exceptions( + mock_lqb_apply, session, large_query_df, caplog, error_type +): + """Test large query breakdown is skipped when there are exceptions""" + caplog.clear() + mock_lqb_apply.side_effect = error_type("test exception") + with caplog.at_level(logging.DEBUG): + with patch.object( + session._conn._telemetry_client, + "send_query_compilation_stage_failed_telemetry", + ) as patch_send: + queries = large_query_df.queries + + assert "Skipping optimization due to error:" in caplog.text + assert len(queries["queries"]) == 1 + assert len(queries["post_actions"]) == 0 + + patch_send.assert_called_once() + _, kwargs = patch_send.call_args + print(kwargs) + assert kwargs["error_message"] == "test exception" + assert kwargs["error_type"] == error_type.__name__ + + def test_complexity_bounds_affect_num_partitions(session, large_query_df): """Test complexity bounds affect number of partitions. Also test that when partitions are added, drop table queries are added.