From f31e6860d2c3af0ee8f1c8f1c894595830885ab2 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 8 Sep 2024 22:04:01 +0200 Subject: [PATCH] refactors writers and buffered code, improves docs --- dlt/common/data_writers/buffered.py | 43 ++++---- dlt/common/data_writers/writers.py | 99 ++++++++----------- dlt/common/libs/pyarrow.py | 24 +++++ .../dlt-ecosystem/file-formats/parquet.md | 9 +- 4 files changed, 95 insertions(+), 80 deletions(-) diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index f18b997ee2..e2b6c9a442 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -99,32 +99,17 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int # until the first chunk is written we can change the columns schema freely if columns is not None: self._current_columns = dict(columns) - - new_rows_count: int - if isinstance(item, List): - # update row count, if item supports "num_rows" it will be used to count items - if len(item) > 0 and hasattr(item[0], "num_rows"): - new_rows_count = sum(tbl.num_rows for tbl in item) - else: - new_rows_count = len(item) - # items coming in a single list will be written together, no matter how many there are - self._buffered_items.extend(item) - else: - self._buffered_items.append(item) - # update row count, if item supports "num_rows" it will be used to count items - if hasattr(item, "num_rows"): - new_rows_count = item.num_rows - else: - new_rows_count = 1 + # add item to buffer and count new rows + new_rows_count = self._buffer_items_with_row_count(item) self._buffered_items_count += new_rows_count + # set last modification date + self._last_modified = time.time() # flush if max buffer exceeded, the second path of the expression prevents empty data frames to pile up in the buffer if ( self._buffered_items_count >= self.buffer_max_items or len(self._buffered_items) >= self.buffer_max_items ): self._flush_items() - # set last modification date - self._last_modified = time.time() # rotate the file if max_bytes exceeded if self._file: # rotate on max file size @@ -221,6 +206,26 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb if not in_exception: raise + def _buffer_items_with_row_count(self, item: TDataItems) -> int: + """Adds `item` to in-memory buffer and counts new rows, depending in item type""" + new_rows_count: int + if isinstance(item, List): + # update row count, if item supports "num_rows" it will be used to count items + if len(item) > 0 and hasattr(item[0], "num_rows"): + new_rows_count = sum(tbl.num_rows for tbl in item) + else: + new_rows_count = len(item) + # items coming in a single list will be written together, no matter how many there are + self._buffered_items.extend(item) + else: + self._buffered_items.append(item) + # update row count, if item supports "num_rows" it will be used to count items + if hasattr(item, "num_rows"): + new_rows_count = item.num_rows + else: + new_rows_count = 1 + return new_rows_count + def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: metrics = self._flush_and_close_file(allow_empty_file) self._file_name = ( diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index d5d4ee278e..4311fb270e 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -36,7 +36,7 @@ ) from dlt.common.metrics import DataWriterMetrics from dlt.common.schema.typing import TTableSchemaColumns -from dlt.common.typing import StrAny +from dlt.common.typing import StrAny, TDataItem if TYPE_CHECKING: @@ -72,8 +72,8 @@ def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> N def write_header(self, columns_schema: TTableSchemaColumns) -> None: # noqa pass - def write_data(self, rows: Sequence[Any]) -> None: - self.items_count += len(rows) + def write_data(self, items: Sequence[TDataItem]) -> None: + self.items_count += len(items) def write_footer(self) -> None: # noqa pass @@ -81,9 +81,9 @@ def write_footer(self) -> None: # noqa def close(self) -> None: # noqa pass - def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> None: + def write_all(self, columns_schema: TTableSchemaColumns, items: Sequence[TDataItem]) -> None: self.write_header(columns_schema) - self.write_data(rows) + self.write_data(items) self.write_footer() @classmethod @@ -156,9 +156,9 @@ def writer_spec(cls) -> FileWriterSpec: class JsonlWriter(DataWriter): - def write_data(self, rows: Sequence[Any]) -> None: - super().write_data(rows) - for row in rows: + def write_data(self, items: Sequence[TDataItem]) -> None: + super().write_data(items) + for row in items: json.dump(row, self._f) self._f.write(b"\n") @@ -175,12 +175,12 @@ def writer_spec(cls) -> FileWriterSpec: class TypedJsonlListWriter(JsonlWriter): - def write_data(self, rows: Sequence[Any]) -> None: + def write_data(self, items: Sequence[TDataItem]) -> None: # skip JsonlWriter when calling super - super(JsonlWriter, self).write_data(rows) + super(JsonlWriter, self).write_data(items) # write all rows as one list which will require to write just one line # encode types with PUA characters - json.typed_dump(rows, self._f) + json.typed_dump(items, self._f) self._f.write(b"\n") @classmethod @@ -222,11 +222,11 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: if self.writer_type == "default": self._f.write("VALUES\n") - def write_data(self, rows: Sequence[Any]) -> None: - super().write_data(rows) + def write_data(self, items: Sequence[TDataItem]) -> None: + super().write_data(items) # do not write empty rows, such things may be produced by Arrow adapters - if len(rows) == 0: + if len(items) == 0: return def write_row(row: StrAny, last_row: bool = False) -> None: @@ -244,11 +244,11 @@ def write_row(row: StrAny, last_row: bool = False) -> None: self._f.write(self.sep) # write rows - for row in rows[:-1]: + for row in items[:-1]: write_row(row) # write last row without separator so we can write footer eventually - write_row(rows[-1], last_row=True) + write_row(items[-1], last_row=True) self._chunks_written += 1 def write_footer(self) -> None: @@ -342,19 +342,19 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: ] self.writer = self._create_writer(self.schema) - def write_data(self, rows: Sequence[Any]) -> None: - super().write_data(rows) + def write_data(self, items: Sequence[TDataItem]) -> None: + super().write_data(items) from dlt.common.libs.pyarrow import pyarrow # replace complex types with json for key in self.complex_indices: - for row in rows: + for row in items: if (value := row.get(key)) is not None: # TODO: make this configurable if value is not None and not isinstance(value, str): row[key] = json.dumps(value) - table = pyarrow.Table.from_pylist(rows, schema=self.schema) + table = pyarrow.Table.from_pylist(items, schema=self.schema) # Write self.writer.write_table(table, row_group_size=self.parquet_row_group_size) @@ -423,10 +423,10 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: i for i, field in columns_schema.items() if field["data_type"] == "binary" ] - def write_data(self, rows: Sequence[Any]) -> None: + def write_data(self, items: Sequence[TDataItem]) -> None: # convert bytes and json if self.complex_indices or self.bytes_indices: - for row in rows: + for row in items: for key in self.complex_indices: if (value := row.get(key)) is not None: row[key] = json.dumps(value) @@ -445,9 +445,9 @@ def write_data(self, rows: Sequence[Any]) -> None: " type as binary.", ) - self.writer.writerows(rows) + self.writer.writerows(items) # count rows that got written - self.items_count += sum(len(row) for row in rows) + self.items_count += sum(len(row) for row in items) def close(self) -> None: self.writer = None @@ -471,35 +471,20 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: # Schema will be written as-is from the arrow table self._column_schema = columns_schema - def write_data(self, rows: Sequence[Any]) -> None: - from dlt.common.libs.pyarrow import pyarrow + def write_data(self, items: Sequence[TDataItem]) -> None: + from dlt.common.libs.pyarrow import concat_batches_and_tables_in_order - if not rows: + if not items: return # concat batches and tables into a single one, preserving order # pyarrow writer starts a row group for each item it writes (even with 0 rows) # it also converts batches into tables internally. by creating a single table # we allow the user rudimentary control over row group size via max buffered items - batches = [] - tables = [] - for row in rows: - self.items_count += row.num_rows - if isinstance(row, pyarrow.RecordBatch): - batches.append(row) - elif isinstance(row, pyarrow.Table): - if batches: - tables.append(pyarrow.Table.from_batches(batches)) - batches = [] - tables.append(row) - else: - raise ValueError(f"Unsupported type {type(row)}") - if batches: - tables.append(pyarrow.Table.from_batches(batches)) - - table = pyarrow.concat_tables(tables, promote_options="none") + table = concat_batches_and_tables_in_order(items) + self.items_count += table.num_rows if not self.writer: self.writer = self._create_writer(table.schema) - # write concatenated tables, "none" options ensures 0 copy concat + # write concatenated tables self.writer.write_table(table, row_group_size=self.parquet_row_group_size) def write_footer(self) -> None: @@ -544,12 +529,12 @@ def __init__( def write_header(self, columns_schema: TTableSchemaColumns) -> None: self._columns_schema = columns_schema - def write_data(self, rows: Sequence[Any]) -> None: + def write_data(self, items: Sequence[TDataItem]) -> None: from dlt.common.libs.pyarrow import pyarrow import pyarrow.csv - for row in rows: - if isinstance(row, (pyarrow.Table, pyarrow.RecordBatch)): + for item in items: + if isinstance(item, (pyarrow.Table, pyarrow.RecordBatch)): if not self.writer: if self.quoting == "quote_needed": quoting = "needed" @@ -560,14 +545,14 @@ def write_data(self, rows: Sequence[Any]) -> None: try: self.writer = pyarrow.csv.CSVWriter( self._f, - row.schema, + item.schema, write_options=pyarrow.csv.WriteOptions( include_header=self.include_header, delimiter=self._delimiter_b, quoting_style=quoting, ), ) - self._first_schema = row.schema + self._first_schema = item.schema except pyarrow.ArrowInvalid as inv_ex: if "Unsupported Type" in str(inv_ex): raise InvalidDataItem( @@ -579,18 +564,18 @@ def write_data(self, rows: Sequence[Any]) -> None: ) raise # make sure that Schema stays the same - if not row.schema.equals(self._first_schema): + if not item.schema.equals(self._first_schema): raise InvalidDataItem( "csv", "arrow", "Arrow schema changed without rotating the file. This may be internal" " error or misuse of the writer.\nFirst" - f" schema:\n{self._first_schema}\n\nCurrent schema:\n{row.schema}", + f" schema:\n{self._first_schema}\n\nCurrent schema:\n{item.schema}", ) # write headers only on the first write try: - self.writer.write(row) + self.writer.write(item) except pyarrow.ArrowInvalid as inv_ex: if "Invalid UTF8 payload" in str(inv_ex): raise InvalidDataItem( @@ -611,9 +596,9 @@ def write_data(self, rows: Sequence[Any]) -> None: ) raise else: - raise ValueError(f"Unsupported type {type(row)}") + raise ValueError(f"Unsupported type {type(item)}") # count rows that got written - self.items_count += row.num_rows + self.items_count += item.num_rows def write_footer(self) -> None: if self.writer is None and self.include_header: @@ -649,8 +634,8 @@ def writer_spec(cls) -> FileWriterSpec: class ArrowToObjectAdapter: """A mixin that will convert object writer into arrow writer.""" - def write_data(self, rows: Sequence[Any]) -> None: - for batch in rows: + def write_data(self, items: Sequence[TDataItem]) -> None: + for batch in items: # convert to object data item format super().write_data(batch.to_pylist()) # type: ignore[misc] diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index e9dcfaf095..14ca1fb46f 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -500,6 +500,30 @@ def cast_arrow_schema_types( return schema +def concat_batches_and_tables_in_order( + tables_or_batches: Iterable[Union[pyarrow.Table, pyarrow.RecordBatch]] +) -> pyarrow.Table: + """Concatenate iterable of tables and batches into a single table, preserving row order. Zero copy is used during + concatenation so schemas must be identical. + """ + batches = [] + tables = [] + for item in tables_or_batches: + if isinstance(item, pyarrow.RecordBatch): + batches.append(item) + elif isinstance(item, pyarrow.Table): + if batches: + tables.append(pyarrow.Table.from_batches(batches)) + batches = [] + tables.append(item) + else: + raise ValueError(f"Unsupported type {type(item)}") + if batches: + tables.append(pyarrow.Table.from_batches(batches)) + # "none" option ensures 0 copy concat + return pyarrow.concat_tables(tables, promote_options="none") + + class NameNormalizationCollision(ValueError): def __init__(self, reason: str) -> None: msg = f"Arrow column name collision after input data normalization. {reason}" diff --git a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md index 39f7e70051..30f7051386 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md @@ -80,15 +80,16 @@ To our best knowledge, arrow will convert your timezone aware DateTime(s) to UTC ### Row group size -The `pyarrow` parquet writer writes each item, i.e. table or record batch, in a separate row group. +The `pyarrow` parquet writer writes each item, i.e. table or record batch, in a separate row group. This may lead to many small row groups which may not be optimal for certain query engines. For example, `duckdb` parallelizes on a row group. `dlt` allows controlling the size of the row group by -buffering and concatenating tables and batches before they are written. The concatenation is done as a zero-copy to save memory. -You can control the memory needed by setting the count of records to be buffered as follows: +[buffering and concatenating tables](../../reference/performance.md#controlling-in-memory-buffers) and batches before they are written. The concatenation is done as a zero-copy to save memory. +You can control the size of the row group by setting the maximum number of rows kept in the buffer. ```toml [extract.data_writer] buffer_max_items=10e6 ``` Mind that `dlt` holds the tables in memory. Thus, 1,000,000 rows in the example above may consume a significant amount of RAM. -`row_group_size` has limited utility with `pyarrow` writer. It will split large tables into many groups if set below item buffer size. \ No newline at end of file +`row_group_size` configuration setting has limited utility with `pyarrow` writer. It may be useful when you write single very large pyarrow tables +or when your in memory buffer is really large. \ No newline at end of file