Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor snowflake to use spmc abstractions #26900

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions posthog/settings/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
"BATCH_EXPORT_S3_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES", 0, type_cast=int
)
BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 100 # 100MB
BATCH_EXPORT_SNOWFLAKE_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES: int = get_from_env(
"BATCH_EXPORT_SNOWFLAKE_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES", 1024 * 1024 * 300, type_cast=int
)
BATCH_EXPORT_POSTGRES_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 50 # 50MB
BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 100 # 100MB
BATCH_EXPORT_BIGQUERY_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES: int = get_from_env(
Expand Down
268 changes: 136 additions & 132 deletions posthog/temporal/batch_exports/snowflake_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import snowflake.connector
from django.conf import settings
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import OperationalError, InterfaceError
from snowflake.connector.errors import InterfaceError, OperationalError
from temporalio import activity, workflow
from temporalio.common import RetryPolicy

Expand All @@ -31,32 +31,43 @@
default_fields,
execute_batch_export_insert_activity,
get_data_interval,
iter_model_records,
start_batch_export_run,
)
from posthog.temporal.batch_exports.metrics import (
get_bytes_exported_metric,
get_rows_exported_metric,
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
should_resume_from_activity_heartbeat,
)
from posthog.temporal.batch_exports.spmc import (
Consumer,
Producer,
RecordBatchQueue,
run_consumer_loop,
wait_for_schema_or_producer,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
JSONLBatchExportWriter,
WriterFormat,
)
from posthog.temporal.batch_exports.utils import (
JsonType,
apeek_first_and_rewind,
cast_record_batch_json_columns,
set_status_to_running_task,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
HeartbeatParseError,
should_resume_from_activity_heartbeat,
)

NON_RETRYABLE_ERROR_TYPES = [
# Raised when we cannot connect to Snowflake.
"DatabaseError",
# Raised by Snowflake when a query cannot be compiled.
# Usually this means we don't have table permissions or something doesn't exist (db, schema).
"ProgrammingError",
# Raised by Snowflake with an incorrect account name.
"ForbiddenError",
# Our own exception when we can't connect to Snowflake, usually due to invalid parameters.
"SnowflakeConnectionError",
]


class SnowflakeFileNotUploadedError(Exception):
Expand Down Expand Up @@ -91,37 +102,9 @@ class SnowflakeRetryableConnectionError(Exception):

@dataclasses.dataclass
class SnowflakeHeartbeatDetails(BatchExportRangeHeartbeatDetails):
"""The Snowflake batch export details included in every heartbeat.

Attributes:
file_no: The file number of the last file we managed to upload.
"""

file_no: int = 0

@classmethod
def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]:
"""Attempt to initialize HeartbeatDetails from an activity's details."""
file_no = 0
remaining = super().deserialize_details(details)

if len(remaining["_remaining"]) == 0:
return {"file_no": 0, **remaining}

first_detail = remaining["_remaining"][0]
remaining["_remaining"] = remaining["_remaining"][1:]

try:
file_no = int(first_detail)
except (TypeError, ValueError) as e:
raise HeartbeatParseError("file_no") from e

return {"file_no": file_no, **remaining}
"""The Snowflake batch export details included in every heartbeat."""

def serialize_details(self) -> tuple[typing.Any, ...]:
"""Attempt to initialize HeartbeatDetails from an activity's details."""
serialized_parent_details = super().serialize_details()
return (*serialized_parent_details[:-1], self.file_no, self._remaining)
pass


@dataclasses.dataclass
Expand Down Expand Up @@ -344,22 +327,16 @@ async def put_file_to_snowflake_table(
file: BatchExportTemporaryFile,
table_stage_prefix: str,
table_name: str,
file_no: int,
):
"""Executes a PUT query using the provided cursor to the provided table_name.

Sadly, Snowflake's execute_async does not work with PUT statements. So, we pass the execute
call to run_in_executor: Since execute ends up boiling down to blocking IO (HTTP request),
the event loop should not be locked up.

We add a file_no to the file_name when executing PUT as Snowflake will reject any files with the same
name. Since batch exports re-use the same file, our name does not change, but we don't want Snowflake
to reject or overwrite our new data.

Args:
file: The name of the local file to PUT.
table_name: The name of the Snowflake table where to PUT the file.
file_no: An int to identify which file number this is.

Raises:
TypeError: If we don't get a tuple back from Snowflake (should never happen).
Expand All @@ -371,7 +348,7 @@ async def put_file_to_snowflake_table(
# So we ask mypy to be nice with us.
reader = io.BufferedReader(file) # type: ignore
query = f"""
PUT file://{file.name}_{file_no}.jsonl '@%"{table_name}"/{table_stage_prefix}'
PUT file://{file.name} '@%"{table_name}"/{table_stage_prefix}'
"""

with self.connection.cursor() as cursor:
Expand Down Expand Up @@ -518,6 +495,53 @@ def snowflake_default_fields() -> list[BatchExportField]:
return batch_export_fields


class SnowflakeConsumer(Consumer):
def __init__(
self,
heartbeater: Heartbeater,
heartbeat_details: SnowflakeHeartbeatDetails,
data_interval_start: dt.datetime | str | None,
writer_format: WriterFormat,
snowflake_client: SnowflakeClient,
snowflake_table: str,
snowflake_table_stage_prefix: str,
):
super().__init__(heartbeater, heartbeat_details, data_interval_start, writer_format)
self.heartbeat_details: SnowflakeHeartbeatDetails = heartbeat_details
self.snowflake_table = snowflake_table
self.snowflake_client = snowflake_client
self.snowflake_table_stage_prefix = snowflake_table_stage_prefix

async def flush(
self,
batch_export_file: BatchExportTemporaryFile,
records_since_last_flush: int,
bytes_since_last_flush: int,
flush_counter: int,
last_date_range: DateRange,
is_last: bool,
error: Exception | None,
):
await self.logger.ainfo(
"Putting file %s containing %s records with size %s bytes",
flush_counter,
records_since_last_flush,
bytes_since_last_flush,
)

await self.snowflake_client.put_file_to_snowflake_table(
batch_export_file,
self.snowflake_table_stage_prefix,
self.snowflake_table,
)

await self.logger.adebug("Loaded %s to Snowflake table '%s'", records_since_last_flush, self.snowflake_table)
self.rows_exported_counter.add(records_since_last_flush)
self.bytes_exported_counter.add(bytes_since_last_flush)

self.heartbeat_details.track_done_range(last_date_range, self.data_interval_start)


def get_snowflake_fields_from_record_schema(
record_schema: pa.Schema, known_variant_columns: list[str]
) -> list[SnowflakeField]:
Expand Down Expand Up @@ -594,42 +618,63 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
details = SnowflakeHeartbeatDetails()

done_ranges: list[DateRange] = details.done_ranges
if done_ranges:
data_interval_start: str | None = done_ranges[-1][1].isoformat()
else:
data_interval_start = inputs.data_interval_start

current_flush_counter = details.file_no

rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()

model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
field.name for field in dataclasses.fields(inputs)
}:
model = inputs.batch_export_model
if model is not None:
model_name = model.name
extra_query_parameters = model.schema["values"] if model.schema is not None else None
fields = model.schema["fields"] if model.schema is not None else None
else:
model_name = "events"
extra_query_parameters = None
fields = None
else:
model = inputs.batch_export_schema
model_name = "custom"
extra_query_parameters = model["values"] if model is not None else {}
fields = model["fields"] if model is not None else None

records_iterator = iter_model_records(
client=client,
model=model,
data_interval_start = (
dt.datetime.fromisoformat(inputs.data_interval_start) if inputs.data_interval_start else None
)
data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end)
full_range = (data_interval_start, data_interval_end)

queue = RecordBatchQueue(max_size_bytes=settings.BATCH_EXPORT_SNOWFLAKE_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES)
producer = Producer(clickhouse_client=client)
producer_task = producer.start(
queue=queue,
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
full_range=full_range,
done_ranges=done_ranges,
fields=fields,
destination_default_fields=snowflake_default_fields(),
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
destination_default_fields=snowflake_default_fields(),
is_backfill=inputs.is_backfill,
extra_query_parameters=extra_query_parameters,
)
records_completed = 0

record_batch_schema = await wait_for_schema_or_producer(queue, producer_task)
if record_batch_schema is None:
return records_completed

record_batch_schema = pa.schema(
# NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other
# record batches have them as nullable.
# Until we figure it out, we set all fields to nullable. There are some fields we know
# are not nullable, but I'm opting for the more flexible option until we out why schemas differ
# between batches.
[field.with_nullable(True) for field in record_batch_schema if field.name != "_inserted_at"]
)
first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator)

if first_record_batch is None:
return 0

known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"]
first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=known_variant_columns)

if model is None or (isinstance(model, BatchExportModel) and model.name == "events"):
table_fields = [
Expand All @@ -647,10 +692,8 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
]

else:
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
table_fields = get_snowflake_fields_from_record_schema(
record_schema,
record_batch_schema,
known_variant_columns=known_variant_columns,
)

Expand All @@ -671,57 +714,28 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
stagle_table_name, data_interval_end_str, table_fields, create=requires_merge, delete=requires_merge
) as snow_stage_table,
):
record_columns = [field[0] for field in table_fields]
record_schema = pa.schema(
[field.with_nullable(True) for field in first_record_batch.select(record_columns).schema]
)

async def flush_to_snowflake(
local_results_file,
records_since_last_flush,
bytes_since_last_flush,
flush_counter: int,
last_date_range: DateRange,
last: bool,
error: Exception | None,
):
logger.info(
"Putting %sfile %s containing %s records with size %s bytes",
"last " if last else "",
flush_counter,
records_since_last_flush,
bytes_since_last_flush,
)

table = snow_stage_table if requires_merge else snow_table

await snow_client.put_file_to_snowflake_table(
local_results_file, data_interval_end_str, table, flush_counter
)
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)

details.track_done_range(last_date_range, data_interval_start)
details.file_no = flush_counter
heartbeater.set_from_heartbeat_details(details)

writer = JSONLBatchExportWriter(
records_completed = await run_consumer_loop(
queue=queue,
consumer_cls=SnowflakeConsumer,
producer_task=producer_task,
heartbeater=heartbeater,
heartbeat_details=details,
data_interval_end=data_interval_end,
data_interval_start=data_interval_start,
schema=record_batch_schema,
writer_format=WriterFormat.JSONL,
max_bytes=settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES,
flush_callable=flush_to_snowflake,
json_columns=known_variant_columns,
snowflake_client=snow_client,
snowflake_table=snow_stage_table if requires_merge else snow_table,
snowflake_table_stage_prefix=data_interval_end_str,
multiple_files=True,
)

async with writer.open_temporary_file(current_flush_counter):
async for record_batch in records_iterator:
record_batch = cast_record_batch_json_columns(record_batch, json_columns=known_variant_columns)

await writer.write_record_batch(record_batch)

details.complete_done_ranges(inputs.data_interval_end)
heartbeater.set_from_heartbeat_details(details)

await snow_client.copy_loaded_files_to_snowflake_table(
snow_stage_table if requires_merge else snow_table, data_interval_end_str
)

if requires_merge:
merge_key = (
("team_id", "INT64"),
Expand All @@ -734,7 +748,7 @@ async def flush_to_snowflake(
merge_key=merge_key,
)

return writer.records_total
return records_completed


@workflow.defn(name="snowflake-export", failure_exception_types=[workflow.NondeterminismError])
Expand Down Expand Up @@ -811,16 +825,6 @@ async def run(self, inputs: SnowflakeBatchExportInputs):
insert_into_snowflake_activity,
insert_inputs,
interval=inputs.interval,
non_retryable_error_types=[
# Raised when we cannot connect to Snowflake.
"DatabaseError",
# Raised by Snowflake when a query cannot be compiled.
# Usually this means we don't have table permissions or something doesn't exist (db, schema).
"ProgrammingError",
# Raised by Snowflake with an incorrect account name.
"ForbiddenError",
# Our own exception when we can't connect to Snowflake, usually due to invalid parameters.
"SnowflakeConnectionError",
],
non_retryable_error_types=NON_RETRYABLE_ERROR_TYPES,
finish_inputs=finish_inputs,
)
Loading
Loading