From df04f58a0518154b8e85ed472ce3d7b36f154f17 Mon Sep 17 00:00:00 2001 From: sfc-gh-mvashishtha Date: Wed, 11 Dec 2024 01:46:53 -0800 Subject: [PATCH 1/5] Make streams ingest via a UDTF. The UDTF runs forever and can't connect to the kafka cluster. Signed-off-by: sfc-gh-mvashishtha --- snowpark_streaming_demo.py | 46 +++++++++++++ .../_internal/analyzer/snowflake_plan_node.py | 6 ++ src/snowflake/snowpark/dataframe.py | 6 +- src/snowflake/snowpark/dataframe_reader.py | 40 ++++++++++- src/snowflake/snowpark/dataframe_writer.py | 6 +- src/snowflake/snowpark/kafka_ingest_udtf.py | 66 +++++++++++++++++++ src/snowflake/snowpark/session.py | 14 +++- 7 files changed, 179 insertions(+), 5 deletions(-) create mode 100644 snowpark_streaming_demo.py create mode 100644 src/snowflake/snowpark/kafka_ingest_udtf.py diff --git a/snowpark_streaming_demo.py b/snowpark_streaming_demo.py new file mode 100644 index 0000000000..a304ddcf6f --- /dev/null +++ b/snowpark_streaming_demo.py @@ -0,0 +1,46 @@ +from snowflake.snowpark.session import Session +from snowflake.snowpark.functions import parse_json, col +from snowflake.snowpark.types import StructType, MapType, StructField +import logging; logging.getLogger("snowflake.snowpark").setLevel(logging.DEBUG) + +logging.basicConfig() + + +session = Session.builder.create() + +# Subscribe to 1 topic +source_df = ( + session + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .option("partition_id", 1) + .schema( + StructType( + [ + StructField(column_identifier="records", datatype=MapType()) + ] + ) + ) + .load() +) + + +print(source_df.collect()) + +# transformation (simple) +transformed_df = ( + source_df + .select(col("*")) +) + +# Write to output data sink +sink_query = ( + source_df + .writeStream + .format("snowflake") + .option("table", "") +) + +sink_query.start() \ No newline at end of file diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 017ec43316..511c44f1f6 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -350,3 +350,9 @@ def __init__( self.file_format_name = file_format_name self.file_format_type = file_format_type self.copy_options = copy_options + +class StreamSource(SnowflakeValues): + def __init__(self, output: List[Attribute], data: List[Row], schema) -> None: + super().__init__(output, data) + self.schema = schema + diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index c4dd09095f..497765f1be 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -162,7 +162,7 @@ from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions -from snowflake.snowpark.dataframe_writer import DataFrameWriter +from snowflake.snowpark.dataframe_writer import DataFrameWriter, DataStreamWriter from snowflake.snowpark.exceptions import SnowparkDataframeException from snowflake.snowpark.functions import ( abs as abs_, @@ -3933,6 +3933,10 @@ def write(self, _emit_ast: bool = True) -> DataFrameWriter: return self._writer + @property + def writeStream(self): + return DataStreamWriter(self) + @df_collect_api_telemetry @publicapi def copy_into_table( diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index ec2160c32c..eeb802f448 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -8,6 +8,9 @@ import snowflake.snowpark import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto +import snowflake.snowpark.dataframe_reader +from snowflake.snowpark.functions import lit +from snowflake.snowpark.kafka_ingest_udtf import KafkaFetch from snowflake.snowpark._internal.analyzer.analyzer_utils import ( create_file_format_statement, drop_file_format_if_exists_statement, @@ -38,7 +41,7 @@ from snowflake.snowpark.column import METADATA_COLUMN_TYPES, Column, _to_col_if_str from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.exceptions import SnowparkSessionException -from snowflake.snowpark.functions import sql_expr +from snowflake.snowpark.functions import sql_expr, udtf from snowflake.snowpark.mock._connection import MockServerConnection from snowflake.snowpark.table import Table from snowflake.snowpark.types import StructType, VariantType @@ -478,7 +481,7 @@ def _format(self) -> Optional[str]: @_format.setter def _format(self, value: str) -> None: canon_format = value.strip().lower() - allowed_formats = ["csv", "json", "avro", "parquet", "orc", "xml"] + allowed_formats = ["csv", "json", "avro", "parquet", "orc", "xml", "kafka"] if canon_format not in allowed_formats: raise ValueError( f"Invalid format '{value}'. Supported formats are {allowed_formats}." @@ -980,3 +983,36 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: df._reader = self set_api_call_source(df, f"DataFrameReader.{format.lower()}") return df + + + + +import json +import logging +import time + +from confluent_kafka import Consumer, KafkaException, TopicPartition, TIMESTAMP_CREATE_TIME, TIMESTAMP_LOG_APPEND_TIME,KafkaError + +import confluent_kafka + +class DataStreamReader(DataFrameReader): + def load(self) -> DataFrame: + bootstrap_servers = self._cur_options["kafka.bootstrap.servers".upper()] + topic = self._cur_options["topic".upper()] + partition_id = self._cur_options["partition_id".upper()] + self._session.custom_package_usage_config['force_push'] = True + self._session.custom_package_usage_config['enabled'] = True + self._session.add_import(snowflake.snowpark.kafka_ingest_udtf.__file__, import_path="snowflake.snowpark.kafka_ingest_udtf") + self._session.add_packages(["python-confluent-kafka"]) + + kafka_udtf = udtf( + KafkaFetch, + output_schema=self._user_schema, + ) + return self._session.table_function( + kafka_udtf( + lit(bootstrap_servers), + lit(topic), + lit(partition_id) + ) + ) \ No newline at end of file diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index 5e7adf2a7f..635871884b 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -630,7 +630,7 @@ def _format(self) -> str: @_format.setter def _format(self, value: str) -> None: - allowed_formats = ["csv", "json", "parquet"] + allowed_formats = ["csv", "json", "parquet", "snowflake"] canon_file_format_name = value.strip().lower() if canon_file_format_name not in allowed_formats: raise ValueError( @@ -916,3 +916,7 @@ def parquet( ) saveAsTable = save_as_table + +class DataStreamWriter(DataFrameWriter): + def start(self): + raise NotImplementedError("cannot write a data stream yet.") \ No newline at end of file diff --git a/src/snowflake/snowpark/kafka_ingest_udtf.py b/src/snowflake/snowpark/kafka_ingest_udtf.py new file mode 100644 index 0000000000..ccbc0ba730 --- /dev/null +++ b/src/snowflake/snowpark/kafka_ingest_udtf.py @@ -0,0 +1,66 @@ + + +import json +import logging +import time + +from confluent_kafka import Consumer, KafkaException, TopicPartition, TIMESTAMP_CREATE_TIME, TIMESTAMP_LOG_APPEND_TIME,KafkaError + + +class KafkaFetch: + def __init__(self): + self.__consumer = {} + + def createConsumerIfNotExist(self, bootstrap_servers, topic, partition_id): + if (topic, partition_id) in self.__consumer: + return self.__consumer[(topic, partition_id)] + # Kafka consumer configuration + consumer_config = { + 'bootstrap.servers': bootstrap_servers, # Address of the Kafka server + 'group.id': 'my-consumer-group', # Consumer group ID + 'auto.offset.reset': 'earliest', # Start reading at the beginning if no offset is committed + 'allow.auto.create.topics': False, + 'enable.auto.commit': False, + } + + # Create a Consumer instance + consumer = Consumer(consumer_config) + + # Assign partition + topic_partition = TopicPartition(topic, partition_id) + consumer.assign([topic_partition]) + + self.__consumer[(topic, partition_id)] = consumer + return consumer + + def process(self, bootstrap_servers: str, topic: str, partition_id: int): + consumer = self.createConsumerIfNotExist(bootstrap_servers=bootstrap_servers, topic=topic, partition_id=partition_id) + + try: + while True: + # Poll for a message + msg = consumer.poll(1.0) # Timeout in seconds + + if msg is None: + # No message received within the timeout + continue + + if msg.error(): + # Handle Kafka errors + if msg.error().code() == KafkaError._PARTITION_EOF: + # End of partition event + logging.error(f"End of partition reached: {msg.error()}") + elif msg.error(): + raise KafkaException(msg.error()) + else: + # Successfully received a message + logging.info(f"Received message: {msg.value().decode('utf-8')}") + yield (msg.value().decode('utf-8'),) + + except: + logging.error("Consumer Error") + + finally: + consumer.close() + + diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index d02d9101f4..da8fd20268 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -53,6 +53,7 @@ from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( Range, SnowflakeValues, + StreamSource ) from snowflake.snowpark._internal.analyzer.table_function import ( FlattenFunction, @@ -140,7 +141,7 @@ _use_scoped_temp_objects, ) from snowflake.snowpark.dataframe import DataFrame -from snowflake.snowpark.dataframe_reader import DataFrameReader +from snowflake.snowpark.dataframe_reader import DataFrameReader, DataStreamReader from snowflake.snowpark.exceptions import ( SnowparkClientException, SnowparkSessionException, @@ -2621,6 +2622,11 @@ def read(self) -> "DataFrameReader": supported sources (e.g. a file in a stage) as a DataFrame.""" return DataFrameReader(self) + + @property + def readStream(self): + return DataStreamReader(self) + @property def session_id(self) -> int: """Returns an integer that represents the session ID of this session.""" @@ -3021,6 +3027,7 @@ def create_dataframe( data: Union[List, Tuple, "pandas.DataFrame"], schema: Optional[Union[StructType, Iterable[str]]] = None, _emit_ast: bool = True, + stream: bool = False ) -> DataFrame: """Creates a new DataFrame containing the specified values from the local data. @@ -3354,6 +3361,11 @@ def convert_row_to_list( _emit_ast=False, ).select(project_columns, _emit_ast=False) if self.sql_simplifier_enabled + else + DataFrame( + self, StreamSource(attrs, converted, project_columns) + ) + if stream else DataFrame( self, SnowflakeValues(attrs, converted, schema_query=schema_query), From 6cb6e165a9c7280e6b22da7b909b0ceb0eba5106 Mon Sep 17 00:00:00 2001 From: sfc-gh-mvashishtha Date: Thu, 12 Dec 2024 00:23:34 -0800 Subject: [PATCH 2/5] Join UDTF result to a static table and write the result to a dynamic table Signed-off-by: sfc-gh-mvashishtha --- snowpark_streaming_demo.py | 67 ++++++++++++++----- .../_internal/analyzer/snowflake_plan.py | 5 +- src/snowflake/snowpark/dataframe_reader.py | 12 ++++ src/snowflake/snowpark/kafka_ingest_udtf.py | 47 ++++++++----- src/snowflake/snowpark/table_function.py | 5 ++ src/snowflake/snowpark/udtf.py | 2 +- 6 files changed, 101 insertions(+), 37 deletions(-) diff --git a/snowpark_streaming_demo.py b/snowpark_streaming_demo.py index a304ddcf6f..92de4d204d 100644 --- a/snowpark_streaming_demo.py +++ b/snowpark_streaming_demo.py @@ -1,15 +1,41 @@ from snowflake.snowpark.session import Session from snowflake.snowpark.functions import parse_json, col -from snowflake.snowpark.types import StructType, MapType, StructField +from snowflake.snowpark.types import StructType, MapType, StructField, StringType import logging; logging.getLogger("snowflake.snowpark").setLevel(logging.DEBUG) +import pandas as pd + + +# Function to generate random JSON data +def generate_json_data(): + import random + import time + import datetime + return { + "id": random.randint(1, 1000), + "name": f"Item-{random.randint(1, 100)}", + "price": round(random.uniform(10.0, 500.0), 2), + "timestamp": datetime.datetime.now() + } logging.basicConfig() session = Session.builder.create() +# Create static dataframe +session.create_dataframe( + pd.DataFrame( + { + 'KEY': [str(i) for i in range(10)], + 'STATIC_VALUE': [generate_json_data() for _ in range(10)] + } + ) +).write.save_as_table(table_name="static_df", mode="overwrite") +static_df = session.table("static_df") + + # Subscribe to 1 topic -source_df = ( +kafka_ingest_df = ( session .readStream .format("kafka") @@ -19,28 +45,33 @@ .schema( StructType( [ - StructField(column_identifier="records", datatype=MapType()) + StructField(column_identifier="KEY", datatype=StringType()), + StructField(column_identifier="STREAM_VALUE", datatype=StringType()) ] ) ) .load() ) +# Join kafka ingest to static table, and write result to dynamic table. +joined = kafka_ingest_df.join(static_df, on='KEY') +joined.create_or_replace_dynamic_table( + 'dynamic_join_result', + warehouse=session.connection.warehouse, + lag='1 hour', + + ) -print(source_df.collect()) - -# transformation (simple) -transformed_df = ( - source_df - .select(col("*")) -) +# Clean up dynamic table. +drop_result = session.connection.cursor().execute('DROP DYNAMIC TABLE dynamic_join_result;') +assert drop_result is not None -# Write to output data sink -sink_query = ( - source_df - .writeStream - .format("snowflake") - .option("table", "
") -) +# # Write streaming dataframe to output data sink +# sink_query = ( +# source_df +# .writeStream +# .format("snowflake") +# .option("table", "
") +# ) -sink_query.start() \ No newline at end of file +# sink_query.start() \ No newline at end of file diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index e00a78dd69..339bd25c7b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -1577,7 +1577,10 @@ def from_table_function( table_function_statement(func, source_plan.table_function.operators), source_plan, ) - return self.query(table_function_statement(func), None) + return self.query( + table_function_statement(func, operators=[a.name for a in source_plan.table_function.output_schema]), + source_plan=None + ) def join_table_function( self, diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index eeb802f448..2e7f08e527 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -1005,10 +1005,22 @@ def load(self) -> DataFrame: self._session.add_import(snowflake.snowpark.kafka_ingest_udtf.__file__, import_path="snowflake.snowpark.kafka_ingest_udtf") self._session.add_packages(["python-confluent-kafka"]) + self._session.sql("create or replace stage mystage").collect() kafka_udtf = udtf( KafkaFetch, output_schema=self._user_schema, + # Dynamic tables can't depend on the temporary UDTF, so we must make + # a permanent UDTF. + # Note: https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_01/bcr-1489 + is_permanent=True, + replace=True, + name='my_streaming_udtf', + stage_location="@mystage" ) + # + # "In a dynamic table definition, SELECT blocks that read from user-defined + # table functions (UDTF) must explicitly specify columns and can’t use *." + # but snowpark table_function() uses a star... return self._session.table_function( kafka_udtf( lit(bootstrap_servers), diff --git a/src/snowflake/snowpark/kafka_ingest_udtf.py b/src/snowflake/snowpark/kafka_ingest_udtf.py index ccbc0ba730..f3ed856870 100644 --- a/src/snowflake/snowpark/kafka_ingest_udtf.py +++ b/src/snowflake/snowpark/kafka_ingest_udtf.py @@ -6,6 +6,17 @@ from confluent_kafka import Consumer, KafkaException, TopicPartition, TIMESTAMP_CREATE_TIME, TIMESTAMP_LOG_APPEND_TIME,KafkaError +# Function to generate random JSON data +def generate_json_data(): + import random + import time + import datetime + return { + "id": random.randint(1, 1000), + "name": f"Item-{random.randint(1, 100)}", + "price": round(random.uniform(10.0, 500.0), 2), + "timestamp": datetime.datetime.now() + } class KafkaFetch: def __init__(self): @@ -37,25 +48,27 @@ def process(self, bootstrap_servers: str, topic: str, partition_id: int): consumer = self.createConsumerIfNotExist(bootstrap_servers=bootstrap_servers, topic=topic, partition_id=partition_id) try: - while True: - # Poll for a message - msg = consumer.poll(1.0) # Timeout in seconds + for i in range(10): + # # Poll for a message + # msg = consumer.poll(1.0) # Timeout in seconds + + # if msg is None: + # # No message received within the timeout + # continue - if msg is None: - # No message received within the timeout - continue + # if msg.error(): + # # Handle Kafka errors + # if msg.error().code() == KafkaError._PARTITION_EOF: + # # End of partition event + # logging.error(f"End of partition reached: {msg.error()}") + # elif msg.error(): + # raise KafkaException(msg.error()) + # else: + # # Successfully received a message + # logging.info(f"Received message: {msg.value().decode('utf-8')}") + # yield (msg.value().decode('utf-8'),) - if msg.error(): - # Handle Kafka errors - if msg.error().code() == KafkaError._PARTITION_EOF: - # End of partition event - logging.error(f"End of partition reached: {msg.error()}") - elif msg.error(): - raise KafkaException(msg.error()) - else: - # Successfully received a message - logging.info(f"Received message: {msg.value().decode('utf-8')}") - yield (msg.value().decode('utf-8'),) + yield (str(i), str(generate_json_data())) except: logging.error("Consumer Error") diff --git a/src/snowflake/snowpark/table_function.py b/src/snowflake/snowpark/table_function.py index 35de21c600..ce991e6e0e 100644 --- a/src/snowflake/snowpark/table_function.py +++ b/src/snowflake/snowpark/table_function.py @@ -47,6 +47,7 @@ def __init__( *func_arguments: ColumnOrName, _ast: Optional[proto.Expr] = None, _emit_ast: bool = True, + output_schema = None, **func_named_arguments: ColumnOrName, ) -> None: if func_arguments and func_named_arguments: @@ -65,6 +66,7 @@ def __init__( self._aliases: Optional[Iterable[str]] = None self._api_call_source = None self._ast = _ast + self._output_schema = output_schema def _set_api_call_source(self, api_call_source): self._api_call_source = api_call_source @@ -235,6 +237,7 @@ def _create_table_function_expression( order_by = None aliases = None api_call_source = None + output_schema = None if args and named_args: raise ValueError("A table function shouldn't have both args and named args.") if isinstance(func, str): @@ -256,6 +259,7 @@ def _create_table_function_expression( order_by = func._order_by aliases = func._aliases api_call_source = func._api_call_source + output_schema = func._output_schema else: raise TypeError( "'func' should be a function name in str, a list of strs that have all or a part of the fully qualified name, or a TableFunctionCall instance." @@ -280,6 +284,7 @@ def _create_table_function_expression( ) table_function_expression.aliases = aliases table_function_expression.api_call_source = api_call_source + table_function_expression.output_schema = output_schema return table_function_expression diff --git a/src/snowflake/snowpark/udtf.py b/src/snowflake/snowpark/udtf.py index 69908890c3..f921b5ae97 100644 --- a/src/snowflake/snowpark/udtf.py +++ b/src/snowflake/snowpark/udtf.py @@ -123,7 +123,7 @@ def __call__( build_udtf_apply(udtf_expr, self._ast_id, *arguments, **named_arguments) table_function_call = TableFunctionCall( - self.name, *arguments, **named_arguments, _ast=udtf_expr + self.name, *arguments, **named_arguments, _ast=udtf_expr, output_schema=self._output_schema ) table_function_call._set_api_call_source("UserDefinedTableFunction.__call__") From 6418c582b39e9aee402c621eaaeaea16b8a71057 Mon Sep 17 00:00:00 2001 From: sfc-gh-mvashishtha Date: Thu, 12 Dec 2024 13:50:41 -0800 Subject: [PATCH 3/5] Instead of writing result to dynamic table, call a UDF to write stream Signed-off-by: sfc-gh-mvashishtha --- snowpark_streaming_demo.py | 59 ++++++++++++------- src/snowflake/snowpark/dataframe_reader.py | 3 +- src/snowflake/snowpark/dataframe_writer.py | 35 ++++++++++- src/snowflake/snowpark/kafka_ingest_udtf.py | 2 +- .../snowpark/write_stream_to_table.py | 2 + 5 files changed, 76 insertions(+), 25 deletions(-) create mode 100644 src/snowflake/snowpark/write_stream_to_table.py diff --git a/snowpark_streaming_demo.py b/snowpark_streaming_demo.py index 92de4d204d..ff5da21622 100644 --- a/snowpark_streaming_demo.py +++ b/snowpark_streaming_demo.py @@ -1,8 +1,9 @@ from snowflake.snowpark.session import Session from snowflake.snowpark.functions import parse_json, col -from snowflake.snowpark.types import StructType, MapType, StructField, StringType +from snowflake.snowpark.types import StructType, MapType, StructField, StringType, IntegerType, FloatType, TimestampType import logging; logging.getLogger("snowflake.snowpark").setLevel(logging.DEBUG) import pandas as pd +from snowflake.snowpark.async_job import AsyncJob # Function to generate random JSON data @@ -34,6 +35,16 @@ def generate_json_data(): static_df = session.table("static_df") +kafka_event_schema = StructType( + [ + StructField(column_identifier="ID", datatype=IntegerType()), + StructField(column_identifier="NAME", datatype=StringType()), + StructField(column_identifier="PRICE", datatype=FloatType()), + StructField(column_identifier="TIMESTAMP", datatype=TimestampType()), + ] + ) + + # Subscribe to 1 topic kafka_ingest_df = ( session @@ -42,29 +53,37 @@ def generate_json_data(): .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("topic", "topic1") .option("partition_id", 1) - .schema( - StructType( - [ - StructField(column_identifier="KEY", datatype=StringType()), - StructField(column_identifier="STREAM_VALUE", datatype=StringType()) - ] - ) - ) + .schema(kafka_event_schema) .load() ) -# Join kafka ingest to static table, and write result to dynamic table. -joined = kafka_ingest_df.join(static_df, on='KEY') -joined.create_or_replace_dynamic_table( - 'dynamic_join_result', - warehouse=session.connection.warehouse, - lag='1 hour', - - ) +RESULT_TABLE_NAME = "dynamic_join_result"; + +transformed_df = kafka_ingest_df \ + .select(col("id"), col("timestamp"), col("name")) \ + .filter(col("price") > 100.0) + + +""" +This query looks like + +SELECT write_stream_udf('dynamic_join_result', "id", "timestamp", "name") +FROM (SELECT id, + name, + price, + timestamp + FROM ( TABLE (my_streaming_udtf('host1:port1,host2:port2', 'topic1', 1 + :: INT + ) ))) +WHERE ( "price" > 100.0 ) +""" + +streaming_query: AsyncJob = transformed_df \ + .writeStream \ + .toTable(RESULT_TABLE_NAME) + +streaming_query.cancel() -# Clean up dynamic table. -drop_result = session.connection.cursor().execute('DROP DYNAMIC TABLE dynamic_join_result;') -assert drop_result is not None # # Write streaming dataframe to output data sink # sink_query = ( diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 2e7f08e527..e7970265fe 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -1000,12 +1000,13 @@ def load(self) -> DataFrame: bootstrap_servers = self._cur_options["kafka.bootstrap.servers".upper()] topic = self._cur_options["topic".upper()] partition_id = self._cur_options["partition_id".upper()] + self._session.custom_package_usage_config['force_push'] = True self._session.custom_package_usage_config['enabled'] = True self._session.add_import(snowflake.snowpark.kafka_ingest_udtf.__file__, import_path="snowflake.snowpark.kafka_ingest_udtf") self._session.add_packages(["python-confluent-kafka"]) - self._session.sql("create or replace stage mystage").collect() + kafka_udtf = udtf( KafkaFetch, output_schema=self._user_schema, diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index 635871884b..8729c8c0ca 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -8,12 +8,14 @@ import snowflake.snowpark # for forward references of type hints import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto +from snowflake.snowpark.write_stream_to_table import write_stream_to_table from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( CopyIntoLocationNode, SaveMode, SnowflakeCreateTable, TableCreationSource, ) +from snowflake.snowpark.async_job import AsyncJob from snowflake.snowpark._internal.ast.utils import ( build_expr_from_snowpark_column_or_col_name, debug_check_missing_ast, @@ -40,8 +42,9 @@ warning, ) from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType +from snowflake.snowpark.types import StringType from snowflake.snowpark.column import Column, _to_col_if_str -from snowflake.snowpark.functions import sql_expr +from snowflake.snowpark.functions import sql_expr, udf, lit, col from snowflake.snowpark.mock._connection import MockServerConnection from snowflake.snowpark.row import Row @@ -917,6 +920,32 @@ def parquet( saveAsTable = save_as_table + class DataStreamWriter(DataFrameWriter): - def start(self): - raise NotImplementedError("cannot write a data stream yet.") \ No newline at end of file + def toTable(self, table_name: str) -> AsyncJob: + self._dataframe.session.custom_package_usage_config['force_push'] = True + self._dataframe.session.custom_package_usage_config['enabled'] = True + self._dataframe.session.add_import(snowflake.snowpark.write_stream_to_table.__file__, import_path="snowflake.snowpark.write_stream_to_table") + self._dataframe.session.sql("create or replace stage mystage").collect() + + + write_stream_udf = udf( + write_stream_to_table, + input_types= + [ + StringType(), + *(f.datatype for f in self._dataframe.schema.fields) + ], + is_permanent=True, + replace=True, + name='write_stream_udf', + stage_location="@mystage" + ) + + return self._dataframe.select(write_stream_udf( + lit(table_name), + *( + col(f.name) + for f in self._dataframe.schema.fields + ) + )).collect_nowait() \ No newline at end of file diff --git a/src/snowflake/snowpark/kafka_ingest_udtf.py b/src/snowflake/snowpark/kafka_ingest_udtf.py index f3ed856870..0a16b0579e 100644 --- a/src/snowflake/snowpark/kafka_ingest_udtf.py +++ b/src/snowflake/snowpark/kafka_ingest_udtf.py @@ -68,7 +68,7 @@ def process(self, bootstrap_servers: str, topic: str, partition_id: int): # logging.info(f"Received message: {msg.value().decode('utf-8')}") # yield (msg.value().decode('utf-8'),) - yield (str(i), str(generate_json_data())) + yield tuple(generate_json_data().values()) except: logging.error("Consumer Error") diff --git a/src/snowflake/snowpark/write_stream_to_table.py b/src/snowflake/snowpark/write_stream_to_table.py new file mode 100644 index 0000000000..9d765db27b --- /dev/null +++ b/src/snowflake/snowpark/write_stream_to_table.py @@ -0,0 +1,2 @@ +def write_stream_to_table(table_name: str, *args) -> str: + return f"streaming query result to table {table_name}" From e60adfe2ddccc57e30fe3b31c4b0dfa0f26557b6 Mon Sep 17 00:00:00 2001 From: sfc-gh-mvashishtha Date: Thu, 12 Dec 2024 14:14:21 -0800 Subject: [PATCH 4/5] Read stream from table Signed-off-by: sfc-gh-mvashishtha --- snowpark_streaming_demo.py | 15 ++++- src/snowflake/snowpark/dataframe_reader.py | 69 ++++++++++++---------- src/snowflake/snowpark/dataframe_writer.py | 5 +- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/snowpark_streaming_demo.py b/snowpark_streaming_demo.py index ff5da21622..2f1c619d26 100644 --- a/snowpark_streaming_demo.py +++ b/snowpark_streaming_demo.py @@ -57,7 +57,7 @@ def generate_json_data(): .load() ) -RESULT_TABLE_NAME = "dynamic_join_result"; +LANDING_TABLE_NAME = "dynamic_join_result"; transformed_df = kafka_ingest_df \ .select(col("id"), col("timestamp"), col("name")) \ @@ -80,11 +80,22 @@ def generate_json_data(): streaming_query: AsyncJob = transformed_df \ .writeStream \ - .toTable(RESULT_TABLE_NAME) + .toTable(LANDING_TABLE_NAME) + + streaming_query.cancel() +# Read stream from a table +df_streamed_from_table = ( + session + .readStream + .format("table") + .option("table_name", LANDING_TABLE_NAME) +) + + # # Write streaming dataframe to output data sink # sink_query = ( # source_df diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index e7970265fe..823ebebbf8 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -481,7 +481,7 @@ def _format(self) -> Optional[str]: @_format.setter def _format(self, value: str) -> None: canon_format = value.strip().lower() - allowed_formats = ["csv", "json", "avro", "parquet", "orc", "xml", "kafka"] + allowed_formats = ["csv", "json", "avro", "parquet", "orc", "xml", "kafka", "table"] if canon_format not in allowed_formats: raise ValueError( f"Invalid format '{value}'. Supported formats are {allowed_formats}." @@ -997,35 +997,40 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: class DataStreamReader(DataFrameReader): def load(self) -> DataFrame: - bootstrap_servers = self._cur_options["kafka.bootstrap.servers".upper()] - topic = self._cur_options["topic".upper()] - partition_id = self._cur_options["partition_id".upper()] - - self._session.custom_package_usage_config['force_push'] = True - self._session.custom_package_usage_config['enabled'] = True - self._session.add_import(snowflake.snowpark.kafka_ingest_udtf.__file__, import_path="snowflake.snowpark.kafka_ingest_udtf") - self._session.add_packages(["python-confluent-kafka"]) - self._session.sql("create or replace stage mystage").collect() - - kafka_udtf = udtf( - KafkaFetch, - output_schema=self._user_schema, - # Dynamic tables can't depend on the temporary UDTF, so we must make - # a permanent UDTF. - # Note: https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_01/bcr-1489 - is_permanent=True, - replace=True, - name='my_streaming_udtf', - stage_location="@mystage" - ) - # - # "In a dynamic table definition, SELECT blocks that read from user-defined - # table functions (UDTF) must explicitly specify columns and can’t use *." - # but snowpark table_function() uses a star... - return self._session.table_function( - kafka_udtf( - lit(bootstrap_servers), - lit(topic), - lit(partition_id) + if self._format == "kafka": + bootstrap_servers = self._cur_options["kafka.bootstrap.servers".upper()] + topic = self._cur_options["topic".upper()] + partition_id = self._cur_options["partition_id".upper()] + + self._session.custom_package_usage_config['force_push'] = True + self._session.custom_package_usage_config['enabled'] = True + self._session.add_import(snowflake.snowpark.kafka_ingest_udtf.__file__, import_path="snowflake.snowpark.kafka_ingest_udtf") + self._session.add_packages(["python-confluent-kafka"]) + self._session.sql("create or replace stage mystage").collect() + + kafka_udtf = udtf( + KafkaFetch, + output_schema=self._user_schema, + # Dynamic tables can't depend on the temporary UDTF, so we must make + # a permanent UDTF. + # Note: https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_01/bcr-1489 + is_permanent=True, + replace=True, + name='my_streaming_udtf', + stage_location="@mystage" + ) + # + # "In a dynamic table definition, SELECT blocks that read from user-defined + # table functions (UDTF) must explicitly specify columns and can’t use *." + # but snowpark table_function() uses a star... + return self._session.table_function( + kafka_udtf( + lit(bootstrap_servers), + lit(topic), + lit(partition_id) + ) ) - ) \ No newline at end of file + elif self._format == "table": + return self._session.table(self._cur_options["TABLE_NAME"]) + else: + raise NotImplementedError(f"cannot stream read in format {self._format}") \ No newline at end of file diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index 8729c8c0ca..8d884bc7df 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -948,4 +948,7 @@ def toTable(self, table_name: str) -> AsyncJob: col(f.name) for f in self._dataframe.schema.fields ) - )).collect_nowait() \ No newline at end of file + )).collect_nowait() + + def outputMode(self, output_mode: str) -> "DataStreamWriter": + raise NotImplementedError \ No newline at end of file From f9260246d8a4d21a4bef9773e22640bcbeb97b32 Mon Sep 17 00:00:00 2001 From: sfc-gh-mvashishtha Date: Thu, 12 Dec 2024 17:02:14 -0800 Subject: [PATCH 5/5] Support writing to dynamic table if stream source is table. Otherwise, use UDTF to stream to table. Signed-off-by: sfc-gh-mvashishtha --- snowpark_streaming_demo.py | 142 +++++++++++---- src/snowflake/snowpark/dataframe.py | 162 ++++++++++++------ src/snowflake/snowpark/dataframe_reader.py | 4 +- src/snowflake/snowpark/dataframe_writer.py | 81 ++++++--- .../snowpark/relational_grouped_dataframe.py | 4 +- 5 files changed, 284 insertions(+), 109 deletions(-) diff --git a/snowpark_streaming_demo.py b/snowpark_streaming_demo.py index 2f1c619d26..dcd8e49217 100644 --- a/snowpark_streaming_demo.py +++ b/snowpark_streaming_demo.py @@ -23,16 +23,32 @@ def generate_json_data(): session = Session.builder.create() +STATIC_TABLE_NAME = "static_df" + +SIMULATED_STREAM_DATA_NAME = "static_df2" + # Create static dataframe session.create_dataframe( pd.DataFrame( { - 'KEY': [str(i) for i in range(10)], - 'STATIC_VALUE': [generate_json_data() for _ in range(10)] + 'ID': [str(i) for i in range(1000)], + 'STATIC_VALUE': [generate_json_data() for _ in range(1000)] + } + ) +).write.save_as_table(table_name=STATIC_TABLE_NAME, mode="overwrite") +static_df = session.table(STATIC_TABLE_NAME) + +# Create static dataframe 2 +data = [generate_json_data() for _ in range(10)] +session.create_dataframe( + pd.DataFrame( + { + "ID": [row["id"] for row in data ], + "TIMESTAMP": [row["timestamp"] for row in data], + "NAME": [row["name"] for row in data], } ) -).write.save_as_table(table_name="static_df", mode="overwrite") -static_df = session.table("static_df") +).write.save_as_table(table_name=SIMULATED_STREAM_DATA_NAME, mode="overwrite") kafka_event_schema = StructType( @@ -57,32 +73,48 @@ def generate_json_data(): .load() ) -LANDING_TABLE_NAME = "dynamic_join_result"; +LANDING_TABLE_NAME = "dynamic_join_result" transformed_df = kafka_ingest_df \ .select(col("id"), col("timestamp"), col("name")) \ .filter(col("price") > 100.0) - -""" -This query looks like - -SELECT write_stream_udf('dynamic_join_result', "id", "timestamp", "name") -FROM (SELECT id, - name, - price, - timestamp - FROM ( TABLE (my_streaming_udtf('host1:port1,host2:port2', 'topic1', 1 - :: INT - ) ))) -WHERE ( "price" > 100.0 ) -""" +assert transformed_df._stream_source == "kafka" streaming_query: AsyncJob = transformed_df \ .writeStream \ .toTable(LANDING_TABLE_NAME) +# The source table is from a kafka stream, so we write the result via UDTF. + +""" +This query looks like + +SELECT + write_stream_udf('dynamic_join_result', "id", "timestamp", "name") +FROM + ( + SELECT + id, + name, + price, + timestamp + FROM + ( + TABLE ( + my_streaming_udtf( + 'host1:port1,host2:port2', + 'topic1', + 1::INT + ) + ) + ) + ) +WHERE + ("price" > 100.0) +""" + streaming_query.cancel() @@ -92,16 +124,68 @@ def generate_json_data(): session .readStream .format("table") - .option("table_name", LANDING_TABLE_NAME) -) + # TODO: Temporarily reading from another static table here because the UDTF + # currently doesn't produce an output table. + .option("table_name", SIMULATED_STREAM_DATA_NAME) + # .option("table_name", LANDING_TABLE_NAME) +).load() + +complex_df = df_streamed_from_table.join(static_df, on="ID").groupBy("NAME").count() -# # Write streaming dataframe to output data sink -# sink_query = ( -# source_df -# .writeStream -# .format("snowflake") -# .option("table", "
") -# ) +FINAL_TABLE_NAME = "final_table" + + +assert complex_df._stream_source == "table" + +# One source is a Snowflake table, and the other source is a static table, so +# we write the result as a dynamic table instead of using a UDTF. + +""" +The query here is: + +CREATE +OR REPLACE DYNAMIC TABLE final_table LAG = '60 seconds' WAREHOUSE = NEW_WAREHOUSE REFRESH_MODE = 'incremental' AS +SELECT + * +FROM + ( + SELECT + "NAME", + count(1) AS "COUNT" + FROM + ( + SELECT + * + FROM + ( + ( + SELECT + "ID" AS "ID", + "TIMESTAMP" AS "TIMESTAMP", + "NAME" AS "NAME" + FROM + static_df2 + ) AS SNOWPARK_LEFT + INNER JOIN ( + SELECT + "ID" AS "ID", + "STATIC_VALUE" AS "STATIC_VALUE" + FROM + static_df + ) AS SNOWPARK_RIGHT USING (ID) + ) + ) + GROUP BY + "NAME" + ) +""" -# sink_query.start() \ No newline at end of file +( + complex_df + .writeStream + .outputMode("append") + # Dynamic Tables do not support lag values under 60 second(s). + .trigger(processingTime='60 seconds') + .toTable(FINAL_TABLE_NAME) +) \ No newline at end of file diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 497765f1be..43d09ccc9c 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -164,6 +164,7 @@ from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions from snowflake.snowpark.dataframe_writer import DataFrameWriter, DataStreamWriter from snowflake.snowpark.exceptions import SnowparkDataframeException + from snowflake.snowpark.functions import ( abs as abs_, col, @@ -333,6 +334,18 @@ def _disambiguate( return lhs_remapped, rhs_remapped +from functools import wraps + +def propagate_stream_source(f): + @wraps(f) + def g(self, *args, **kwargs): + from snowflake.snowpark.dataframe import DataFrame + result = f(self, *args, **kwargs) + if isinstance(result, DataFrame): + return result.set_stream_source(self._stream_source) + return result + return g + class DataFrame: """Represents a lazily-evaluated relational dataset that contains a collection of :class:`Row` objects with columns defined by a schema (column name and type). @@ -573,13 +586,14 @@ def __init__( plan: Optional[LogicalPlan] = None, is_cached: bool = False, _ast_stmt: Optional[proto.Assign] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> None: """ :param int _ast_stmt: The AST Assign atom corresponding to this dataframe value. We track its assigned ID in the slot self._ast_id. This allows this value to be referred to symbolically when it's referenced in subsequent dataframe expressions. """ + self._stream_source = None self._session = session self._ast_id = None if _emit_ast: @@ -618,6 +632,9 @@ def __init__( self.replace = self._na.replace self._alias: Optional[str] = None + + + def _set_ast_ref(self, sp_dataframe_expr_builder: Any) -> None: """ @@ -644,7 +661,7 @@ def collect( block: bool = True, log_on_exception: bool = False, case_sensitive: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> List[Row]: ... # pragma: no cover @@ -657,7 +674,7 @@ def collect( block: bool = False, log_on_exception: bool = False, case_sensitive: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> AsyncJob: ... # pragma: no cover @@ -670,7 +687,7 @@ def collect( block: bool = True, log_on_exception: bool = False, case_sensitive: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> Union[List[Row], AsyncJob]: """Executes the query representing this DataFrame and returns the result as a list of :class:`Row` objects. @@ -723,7 +740,7 @@ def collect_nowait( statement_params: Optional[Dict[str, str]] = None, log_on_exception: bool = False, case_sensitive: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> AsyncJob: """Executes the query representing this DataFrame asynchronously and returns: class:`AsyncJob`. It is equivalent to ``collect(block=False)``. @@ -819,7 +836,7 @@ def to_local_iterator( statement_params: Optional[Dict[str, str]] = None, block: bool = True, case_sensitive: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> Iterator[Row]: ... # pragma: no cover @@ -831,7 +848,7 @@ def to_local_iterator( statement_params: Optional[Dict[str, str]] = None, block: bool = False, case_sensitive: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> AsyncJob: ... # pragma: no cover @@ -843,7 +860,7 @@ def to_local_iterator( statement_params: Optional[Dict[str, str]] = None, block: bool = True, case_sensitive: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> Union[Iterator[Row], AsyncJob]: """Executes the query representing this DataFrame and returns an iterator of :class:`Row` objects that you can use to retrieve the results. @@ -916,6 +933,7 @@ def __copy__(self) -> "DataFrame": # a separate AST entity to model deep-copying. A deep-copy would generate here a new ID different from self._ast_id. df = DataFrame(self._session, new_plan) df._ast_id = self._ast_id + df._stream_source = self._stream_source return df if installed_pandas: @@ -928,7 +946,7 @@ def to_pandas( *, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, **kwargs: Dict[str, Any], ) -> pandas.DataFrame: ... # pragma: no cover @@ -940,7 +958,7 @@ def to_pandas( *, statement_params: Optional[Dict[str, str]] = None, block: bool = False, - _emit_ast: bool = True, + _emit_ast: bool = False, **kwargs: Dict[str, Any], ) -> AsyncJob: ... # pragma: no cover @@ -952,7 +970,7 @@ def to_pandas( *, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, **kwargs: Dict[str, Any], ) -> Union["pandas.DataFrame", AsyncJob]: """ @@ -1024,7 +1042,7 @@ def to_pandas_batches( *, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, **kwargs: Dict[str, Any], ) -> Iterator[pandas.DataFrame]: ... # pragma: no cover @@ -1036,7 +1054,7 @@ def to_pandas_batches( *, statement_params: Optional[Dict[str, str]] = None, block: bool = False, - _emit_ast: bool = True, + _emit_ast: bool = False, **kwargs: Dict[str, Any], ) -> AsyncJob: ... # pragma: no cover @@ -1048,7 +1066,7 @@ def to_pandas_batches( *, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, **kwargs: Dict[str, Any], ) -> Union[Iterator["pandas.DataFrame"], AsyncJob]: """ @@ -1165,7 +1183,7 @@ def to_snowpark_pandas( self, index_col: Optional[Union[str, List[str]]] = None, columns: Optional[List[str]] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "modin.pandas.DataFrame": """ Convert the Snowpark DataFrame to Snowpark pandas DataFrame. @@ -1259,6 +1277,7 @@ def to_snowpark_pandas( return snowpandas_df + @propagate_stream_source def __getitem__(self, item: Union[str, Column, List, Tuple, int]): _emit_ast = self._ast_id is not None @@ -1274,6 +1293,7 @@ def __getitem__(self, item: Union[str, Column, List, Tuple, int]): else: raise TypeError(f"Unexpected item type: {type(item)}") + @propagate_stream_source def __getattr__(self, name: str): # Snowflake DB ignores cases when there is no quotes. if name.lower() not in [c.lower() for c in self.columns]: @@ -1301,6 +1321,7 @@ def columns(self) -> List[str]: return self.schema.names @publicapi + @propagate_stream_source def col(self, col_name: str, _emit_ast: bool = True) -> Column: """Returns a reference to a column in the DataFrame.""" expr = None @@ -1316,6 +1337,7 @@ def col(self, col_name: str, _emit_ast: bool = True) -> Column: @df_api_usage @publicapi + @propagate_stream_source def select( self, *cols: Union[ @@ -1323,7 +1345,7 @@ def select( Iterable[Union[ColumnOrName, TableFunctionCall]], ], _ast_stmt: Optional[proto.Assign] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in SQL). Only the Columns specified as arguments will be @@ -1492,11 +1514,12 @@ def select( @df_api_usage @publicapi + @propagate_stream_source def select_expr( self, *exprs: Union[str, Iterable[str]], _ast_stmt: proto.Assign = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """ Projects a set of SQL expressions and returns a new :class:`DataFrame`. @@ -1550,6 +1573,7 @@ def select_expr( @df_api_usage @publicapi + @propagate_stream_source def drop( self, *cols: Union[ColumnOrName, Iterable[ColumnOrName]], _emit_ast: bool = True ) -> "DataFrame": @@ -1639,11 +1663,12 @@ def drop( @df_api_usage @publicapi + @propagate_stream_source def filter( self, expr: ColumnOrSqlExpr, _ast_stmt: proto.Assign = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Filters rows based on the specified conditional expression (similar to WHERE in SQL). @@ -1695,11 +1720,12 @@ def filter( @df_api_usage @publicapi + @propagate_stream_source def sort( self, *cols: Union[ColumnOrName, Iterable[ColumnOrName]], ascending: Optional[Union[bool, int, List[Union[bool, int]]]] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Sorts a DataFrame by the specified expressions (similar to ORDER BY in SQL). @@ -1834,6 +1860,7 @@ def sort( @experimental(version="1.5.0") @publicapi + @propagate_stream_source def alias(self, name: str, _emit_ast: bool = True): """Returns an aliased dataframe in which the columns can now be referenced to using `col(, )`. @@ -1902,10 +1929,11 @@ def alias(self, name: str, _emit_ast: bool = True): @df_api_usage @publicapi + @propagate_stream_source def agg( self, *exprs: Union[Column, Tuple[ColumnOrName, str], Dict[str, str]], - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Aggregate the data in the DataFrame. Use this method if you don't need to group the data (:func:`group_by`). @@ -2024,7 +2052,7 @@ def group_by( self, *cols: Union[ColumnOrName, Iterable[ColumnOrName]], _ast_stmt: Optional[proto.Assign] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "snowflake.snowpark.RelationalGroupedDataFrame": """Groups rows by the columns specified by expressions (similar to GROUP BY in SQL). @@ -2084,6 +2112,7 @@ def group_by( grouping_exprs, snowflake.snowpark.relational_grouped_dataframe._GroupByType(), _ast_stmt=stmt, + stream_source=self._stream_source ) if _emit_ast: @@ -2099,7 +2128,7 @@ def group_by_grouping_sets( "snowflake.snowpark.GroupingSets", Iterable["snowflake.snowpark.GroupingSets"], ], - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "snowflake.snowpark.RelationalGroupedDataFrame": """Performs a SQL `GROUP BY GROUPING SETS `_. @@ -2187,6 +2216,7 @@ def cube( @df_api_usage @publicapi + @propagate_stream_source def distinct( self, _ast_stmt: proto.Assign = None, _emit_ast: bool = True ) -> "DataFrame": @@ -2217,11 +2247,12 @@ def distinct( return df @publicapi + @propagate_stream_source def drop_duplicates( self, *subset: Union[str, Iterable[str]], _ast_stmt: proto.Assign = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Creates a new DataFrame by removing duplicated rows on given subset of columns. @@ -2295,7 +2326,7 @@ def pivot( Union[Iterable[LiteralType], "snowflake.snowpark.DataFrame"] ] = None, default_on_null: Optional[LiteralType] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "snowflake.snowpark.RelationalGroupedDataFrame": """Rotates this DataFrame by turning the unique values from one column in the input expression into multiple columns and aggregating results where required on any @@ -2372,17 +2403,19 @@ def pivot( pc[0], pivot_values, default_on_null ), _ast_stmt=stmt, + stream_source=self._stream_source ) @df_api_usage @publicapi + @propagate_stream_source def unpivot( self, value_column: str, name_column: str, column_list: List[ColumnOrName], include_nulls: bool = False, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Rotates a table by transforming columns into rows. UNPIVOT is a relational operator that accepts two columns (from a table or subquery), along with a list of columns, and generates a row for each column specified in the list. In a query, it is specified in the FROM clause after the table name or subquery. @@ -2455,12 +2488,13 @@ def unpivot( @df_api_usage @publicapi + @propagate_stream_source def limit( self, n: int, offset: int = 0, _ast_stmt: proto.Assign = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Returns a new DataFrame that contains at most ``n`` rows from the current DataFrame, skipping ``offset`` rows from the beginning (similar to LIMIT and OFFSET in SQL). @@ -2842,7 +2876,7 @@ def natural_join( self, right: "DataFrame", how: Optional[str] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, **kwargs, ) -> "DataFrame": """Performs a natural join of the specified type (``how``) with the @@ -2932,7 +2966,7 @@ def join( lsuffix: str = "", rsuffix: str = "", match_condition: Optional[Column] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, **kwargs, ) -> "DataFrame": """Performs a join of the specified type (``how``) with the current @@ -3299,6 +3333,16 @@ def join( if rsuffix: ast.rsuffix.value = rsuffix + if self._stream_source is None and right._stream_source is not None: + result_stream_source = right._stream_source + elif self._stream_source is not None and right._stream_source is None: + result_stream_source = self._stream_source + elif self._stream_source is None and right._stream_source is None: + result_stream_source = None + elif self._stream_source != right._stream_source: + raise NotImplementedError("cannot join two dataframes from different stream sources") + + return self._join_dataframes( right, using_columns, @@ -3307,7 +3351,7 @@ def join( rsuffix=rsuffix, match_condition=match_condition, _ast_stmt=stmt, - ) + ).set_stream_source(result_stream_source) raise TypeError("Invalid type for join. Must be Dataframe") @@ -3317,7 +3361,7 @@ def join_table_function( self, func: Union[str, List[str], TableFunctionCall], *func_arguments: ColumnOrName, - _emit_ast: bool = True, + _emit_ast: bool = False, **func_named_arguments: ColumnOrName, ) -> "DataFrame": """Lateral joins the current DataFrame with the output of the specified table function. @@ -3481,7 +3525,7 @@ def cross_join( *, lsuffix: str = "", rsuffix: str = "", - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Performs a cross join, which returns the Cartesian product of the current :class:`DataFrame` and another :class:`DataFrame` (``right``). @@ -3656,12 +3700,13 @@ def _join_dataframes_internal( @df_api_usage @publicapi + @propagate_stream_source def with_column( self, col_name: str, col: Union[Column, TableFunctionCall], ast_stmt: proto.Expr = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """ Returns a DataFrame with an additional column with the specified name @@ -3719,12 +3764,13 @@ def with_column( @df_api_usage @publicapi + @propagate_stream_source def with_columns( self, col_names: List[str], values: List[Union[Column, TableFunctionCall]], _ast_stmt: proto.Expr = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Returns a DataFrame with additional columns with the specified names ``col_names``. The columns are computed by using the specified expressions @@ -3843,7 +3889,7 @@ def count( *, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> int: ... # pragma: no cover @@ -3854,7 +3900,7 @@ def count( *, statement_params: Optional[Dict[str, str]] = None, block: bool = False, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> AsyncJob: ... # pragma: no cover @@ -3864,7 +3910,7 @@ def count( *, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> Union[int, AsyncJob]: """Executes the query representing this DataFrame and returns the number of rows in the result (similar to the COUNT function in SQL). @@ -3951,7 +3997,7 @@ def copy_into_table( format_type_options: Optional[Dict[str, Any]] = None, statement_params: Optional[Dict[str, str]] = None, iceberg_config: Optional[dict] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, **copy_options: Any, ) -> List[Row]: """Executes a `COPY INTO
`__ command to load data from files in a stage location into a specified table. @@ -4187,7 +4233,7 @@ def show( max_width: int = 50, *, statement_params: Optional[Dict[str, str]] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> None: """Evaluates this DataFrame and prints out the first ``n`` rows with the specified maximum number of characters per column. @@ -4220,6 +4266,7 @@ def show( ) @df_api_usage @publicapi + @propagate_stream_source def flatten( self, input: ColumnOrName, @@ -4227,7 +4274,7 @@ def flatten( outer: bool = False, recursive: bool = False, mode: str = "BOTH", - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Flattens (explodes) compound values into multiple rows. @@ -4352,7 +4399,7 @@ def _lateral( ) def _show_string( - self, n: int = 10, max_width: int = 50, _emit_ast: bool = True, **kwargs + self, n: int = 10, max_width: int = 50, _emit_ast: bool = False, **kwargs ) -> str: query = self._plan.queries[-1].sql.strip().lower() @@ -4461,7 +4508,7 @@ def create_or_replace_view( *, comment: Optional[str] = None, statement_params: Optional[Dict[str, str]] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> List[Row]: """Creates a view that captures the computation expressed by this DataFrame. @@ -4528,7 +4575,7 @@ def create_or_replace_dynamic_table( max_data_extension_time: Optional[int] = None, statement_params: Optional[Dict[str, str]] = None, iceberg_config: Optional[dict] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> List[Row]: """Creates a dynamic table that captures the computation expressed by this DataFrame. @@ -4668,7 +4715,7 @@ def create_or_replace_temp_view( *, comment: Optional[str] = None, statement_params: Optional[Dict[str, str]] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> List[Row]: """Creates a temporary view that returns the same results as this DataFrame. @@ -4797,7 +4844,7 @@ def first( *, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> Union[Optional[Row], List[Row]]: ... # pragma: no cover @@ -4809,7 +4856,7 @@ def first( *, statement_params: Optional[Dict[str, str]] = None, block: bool = False, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> AsyncJob: ... # pragma: no cover @@ -4820,7 +4867,7 @@ def first( *, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> Union[Optional[Row], List[Row], AsyncJob]: """Executes the query representing this DataFrame and returns the first ``n`` rows of the results. @@ -4882,11 +4929,12 @@ def first( @df_api_usage @publicapi + @propagate_stream_source def sample( self, frac: Optional[float] = None, n: Optional[int] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Samples rows based on either the number of rows to be returned or a percentage of rows to be returned. @@ -4938,7 +4986,7 @@ def _validate_sample_input(frac: Optional[float] = None, n: Optional[int] = None if n is not None and n < 0: raise ValueError(f"'n' value {n} must be greater than 0") - @property + @property def na(self) -> DataFrameNaFunctions: """ Returns a :class:`DataFrameNaFunctions` object that provides functions for @@ -4954,6 +5002,7 @@ def session(self) -> "snowflake.snowpark.Session": return self._session @publicapi + @propagate_stream_source def describe( self, *cols: Union[str, List[str]], _emit_ast: bool = True ) -> "DataFrame": @@ -5068,11 +5117,12 @@ def describe( @df_api_usage @publicapi + @propagate_stream_source def rename( self, col_or_mapper: Union[ColumnOrName, dict], new_column: str = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ): """ Returns a DataFrame with the specified column ``col_or_mapper`` renamed as ``new_column``. If ``col_or_mapper`` @@ -5155,12 +5205,13 @@ def rename( @df_api_usage @publicapi + @propagate_stream_source def with_column_renamed( self, existing: ColumnOrName, new: str, _ast_stmt: Optional[proto.Assign] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "DataFrame": """Returns a DataFrame with the specified column ``existing`` renamed as ``new``. @@ -5236,11 +5287,12 @@ def with_column_renamed( @df_collect_api_telemetry @publicapi + @propagate_stream_source def cache_result( self, *, statement_params: Optional[Dict[str, str]] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> "Table": """Caches the content of this DataFrame to create a new cached Table DataFrame. @@ -5377,7 +5429,7 @@ def random_split( seed: Optional[int] = None, *, statement_params: Optional[Dict[str, str]] = None, - _emit_ast: bool = True, + _emit_ast: bool = False, ) -> List["DataFrame"]: """ Randomly splits the current DataFrame into separate DataFrames, @@ -5749,6 +5801,12 @@ def print_schema(self, level: Optional[int] = None) -> None: # withColumns = with_columns + def set_stream_source(self, stream_source: str) -> "DataFrame": + result = self.__copy__() + result._stream_source = stream_source + return result + +@propagate_stream_source def map( dataframe: DataFrame, func: Callable, diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 823ebebbf8..8ae93c6f9e 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -1029,8 +1029,8 @@ def load(self) -> DataFrame: lit(topic), lit(partition_id) ) - ) + ).set_stream_source("kafka") elif self._format == "table": - return self._session.table(self._cur_options["TABLE_NAME"]) + return self._session.table(self._cur_options["TABLE_NAME"]).set_stream_source("table") else: raise NotImplementedError(f"cannot stream read in format {self._format}") \ No newline at end of file diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index 8d884bc7df..35f86fb6e4 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -922,33 +922,64 @@ def parquet( class DataStreamWriter(DataFrameWriter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._output_mode = None + self._processing_time = None + def toTable(self, table_name: str) -> AsyncJob: - self._dataframe.session.custom_package_usage_config['force_push'] = True - self._dataframe.session.custom_package_usage_config['enabled'] = True - self._dataframe.session.add_import(snowflake.snowpark.write_stream_to_table.__file__, import_path="snowflake.snowpark.write_stream_to_table") - self._dataframe.session.sql("create or replace stage mystage").collect() - - - write_stream_udf = udf( - write_stream_to_table, - input_types= - [ - StringType(), - *(f.datatype for f in self._dataframe.schema.fields) - ], - is_permanent=True, - replace=True, - name='write_stream_udf', - stage_location="@mystage" - ) + if self._dataframe._stream_source is None: + raise NotImplementedError("Could not track streaming source of this dataframe") + elif self._dataframe._stream_source == "kafka": + self._dataframe.session.custom_package_usage_config['force_push'] = True + self._dataframe.session.custom_package_usage_config['enabled'] = True + self._dataframe.session.add_import(snowflake.snowpark.write_stream_to_table.__file__, import_path="snowflake.snowpark.write_stream_to_table") + self._dataframe.session.sql("create or replace stage mystage").collect() + + + write_stream_udf = udf( + write_stream_to_table, + input_types= + [ + StringType(), + *(f.datatype for f in self._dataframe.schema.fields) + ], + is_permanent=True, + replace=True, + name='write_stream_udf', + stage_location="@mystage" + ) - return self._dataframe.select(write_stream_udf( - lit(table_name), - *( - col(f.name) - for f in self._dataframe.schema.fields + return self._dataframe.select(write_stream_udf( + lit(table_name), + *( + col(f.name) + for f in self._dataframe.schema.fields + ) + )).collect_nowait() + elif self._dataframe._stream_source == "table": + if self._output_mode == "append": + refresh_mode = "incremental" + elif self._output_mode == "complete": + refresh_mode = "full" + else: + raise NotImplementedError(f"unsupported output mode {self._output_mode}") + self._dataframe.create_or_replace_dynamic_table( + table_name, + warehouse=self._dataframe.session.connection.warehouse, + lag=self._processing_time, + refresh_mode=refresh_mode ) - )).collect_nowait() + else: + raise NotImplementedError(f"Cannot write dataframe with source {self._dataframe._stream_source}") def outputMode(self, output_mode: str) -> "DataStreamWriter": - raise NotImplementedError \ No newline at end of file + self._output_mode = output_mode + return self + + def trigger(self, **kwargs) -> "DataStreamWriter": + if list(kwargs.keys()) != ["processingTime"]: + raise NotImplementedError("can only handle trigger with processingTime=") + self._processing_time = kwargs["processingTime"] + return self + \ No newline at end of file diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index b47d5c2de2..1eeb1bd779 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -160,12 +160,14 @@ def __init__( grouping_exprs: List[Expression], group_type: _GroupType, _ast_stmt: Optional[proto.Assign] = None, + stream_source = None ) -> None: self._dataframe = df self._grouping_exprs = grouping_exprs self._group_type = group_type self._df_api_call = None self._ast_id = _ast_stmt.var_id.bitfield1 if _ast_stmt is not None else None + self._stream_source = stream_source def _to_df( self, @@ -246,7 +248,7 @@ def _to_df( group_plan, _ast_stmt=_ast_stmt, _emit_ast=_emit_ast, - ) + ).set_stream_source(self._stream_source) @relational_group_df_api_usage @publicapi