diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 5ca2a257ec..137b35e8f3 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -304,12 +304,19 @@ def test_write_iterator( assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data +@pytest.parametrize("large_dtypes", [True, False]) +@pytest.parametrize("constructor", [ + lambda table: table.to_pyarrow_dataset(), + lambda table: table.to_pyarrow_table(), + lambda table: table.to_pyarrow_table().to_batches()[0] +]) def test_write_dataset( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table + tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table, + large_dtypes: bool, constructor ): - dataset = existing_table.to_pyarrow_dataset() + dataset = constructor(existing_table) - write_deltalake(tmp_path, dataset, mode="overwrite") + write_deltalake(tmp_path, dataset, mode="overwrite", large_dtypes=large_dtypes) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data