From caa72a6ef2d9855b36650139d40ce2aeef2c88d9 Mon Sep 17 00:00:00 2001 From: yuzelin Date: Mon, 11 Nov 2024 19:35:00 +0800 Subject: [PATCH 1/4] Fix that field nullability affects write --- .../org/apache/paimon/python/BytesWriter.java | 42 ++++++++++- paimon_python_java/pypaimon.py | 23 +++--- .../tests/test_write_and_read.py | 74 +++++++++++++++++++ 3 files changed, 126 insertions(+), 13 deletions(-) diff --git a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java index 7cf6267..e107bd3 100644 --- a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java +++ b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java @@ -18,6 +18,7 @@ package org.apache.paimon.python; +import org.apache.paimon.arrow.ArrowUtils; import org.apache.paimon.arrow.reader.ArrowBatchReader; import org.apache.paimon.data.InternalRow; import org.apache.paimon.table.sink.TableWrite; @@ -27,8 +28,12 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.types.pojo.Field; import java.io.ByteArrayInputStream; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; /** Write Arrow bytes to Paimon. */ public class BytesWriter { @@ -36,17 +41,30 @@ public class BytesWriter { private final TableWrite tableWrite; private final ArrowBatchReader arrowBatchReader; private final BufferAllocator allocator; + private final List arrowFields; public BytesWriter(TableWrite tableWrite, RowType rowType) { this.tableWrite = tableWrite; this.arrowBatchReader = new ArrowBatchReader(rowType); this.allocator = new RootAllocator(); + arrowFields = + rowType.getFields().stream() + .map(f -> ArrowUtils.toArrowField(f.name(), f.type())) + .collect(Collectors.toList()); } - public void write(byte[] bytes) throws Exception { + public void write(byte[] bytes, boolean needCheckSchema) throws Exception { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator); VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot(); + if (needCheckSchema && !checkSchema(arrowFields, vsr.getSchema().getFields())) { + throw new RuntimeException( + String.format( + "Input schema isn't consistent with table schema.\n" + + "\tTable schema is: %s\n" + + "\tInput schema is: %s", + arrowFields, vsr.getSchema().getFields())); + } while (arrowStreamReader.loadNextBatch()) { Iterable rows = arrowBatchReader.readBatch(vsr); @@ -60,4 +78,26 @@ public void write(byte[] bytes) throws Exception { public void close() { allocator.close(); } + + private boolean checkSchema(List expectedFields, List actualFields) { + if (expectedFields.size() != actualFields.size()) { + return false; + } + + for (int i = 0; i < expectedFields.size(); i++) { + Field expectedField = expectedFields.get(i); + Field actualField = actualFields.get(i); + if (!checkField(expectedField, actualField) + || !checkSchema(expectedField.getChildren(), actualField.getChildren())) { + return false; + } + } + + return true; + } + + private boolean checkField(Field expected, Field actual) { + return Objects.equals(expected.getName(), actual.getName()) + && Objects.equals(expected.getType(), actual.getType()); + } } diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py index 0d3101b..50a889d 100644 --- a/paimon_python_java/pypaimon.py +++ b/paimon_python_java/pypaimon.py @@ -218,23 +218,22 @@ def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema): def write_arrow(self, table): for record_batch in table.to_reader(): - # TODO: can we use a reusable stream? - stream = pa.BufferOutputStream() - with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer: - writer.write(record_batch) - arrow_bytes = stream.getvalue().to_pybytes() - self._j_bytes_writer.write(arrow_bytes) + # TODO: can we use a reusable stream in #_write_arrow_batch ? + self._write_arrow_batch(record_batch, True) def write_arrow_batch(self, record_batch): - stream = pa.BufferOutputStream() - with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer: - writer.write(record_batch) - arrow_bytes = stream.getvalue().to_pybytes() - self._j_bytes_writer.write(arrow_bytes) + self._write_arrow_batch(record_batch, True) def write_pandas(self, dataframe: pd.DataFrame): record_batch = pa.RecordBatch.from_pandas(dataframe, schema=self._arrow_schema) - self.write_arrow_batch(record_batch) + self._write_arrow_batch(record_batch, False) + + def _write_arrow_batch(self, record_batch, check_schema): + stream = pa.BufferOutputStream() + with pa.RecordBatchStreamWriter(stream, record_batch.schema) as writer: + writer.write(record_batch) + arrow_bytes = stream.getvalue().to_pybytes() + self._j_bytes_writer.write(arrow_bytes, check_schema) def prepare_commit(self) -> List['CommitMessage']: j_commit_messages = self._j_batch_table_write.prepareCommit() diff --git a/paimon_python_java/tests/test_write_and_read.py b/paimon_python_java/tests/test_write_and_read.py index e1ea72b..b468e9f 100644 --- a/paimon_python_java/tests/test_write_and_read.py +++ b/paimon_python_java/tests/test_write_and_read.py @@ -22,6 +22,7 @@ import unittest import pandas as pd import pyarrow as pa +from py4j.protocol import Py4JJavaError from paimon_python_api import Schema from paimon_python_java import Catalog @@ -371,3 +372,76 @@ def test_overwrite(self): df2['f0'] = df2['f0'].astype('int32') pd.testing.assert_frame_equal( actual_df2.reset_index(drop=True), df2.reset_index(drop=True)) + + def testWriteWrongSchema(self): + schema = Schema(self.simple_pa_schema) + self.catalog.create_table('default.test_wrong_schema', schema, False) + table = self.catalog.get_table('default.test_wrong_schema') + + data = { + 'f0': [1, 2, 3], + 'f1': ['a', 'b', 'c'], + } + df = pd.DataFrame(data) + schema = pa.schema([ + ('f0', pa.int64()), + ('f1', pa.string()) + ]) + record_batch = pa.RecordBatch.from_pandas(df, schema) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + + with self.assertRaises(Py4JJavaError) as e: + table_write.write_arrow_batch(record_batch) + self.assertEqual( + str(e.exception.java_exception), + '''java.lang.RuntimeException: Input schema isn't consistent with table schema. +\tTable schema is: [f0: Int(32, true), f1: Utf8] +\tInput schema is: [f0: Int(64, true), f1: Utf8]''') + + def testIgnoreNullable(self): + pa_schema1 = pa.schema([ + ('f0', pa.int32(), False), + ('f1', pa.string()) + ]) + + pa_schema2 = pa.schema([ + ('f0', pa.int32()), + ('f1', pa.string()) + ]) + + # write nullable to non-null + self._testIgnoreNullableImpl('test_ignore_nullable1', pa_schema1, pa_schema2) + + # write non-null to nullable + self._testIgnoreNullableImpl('test_ignore_nullable2', pa_schema2, pa_schema1) + + def _testIgnoreNullableImpl(self, table_name, table_schema, data_schema): + schema = Schema(table_schema) + self.catalog.create_table(f'default.{table_name}', schema, False) + table = self.catalog.get_table(f'default.{table_name}') + + data = { + 'f0': [1, 2, 3], + 'f1': ['a', 'b', 'c'], + } + df = pd.DataFrame(data) + record_batch = pa.RecordBatch.from_pandas(pd.DataFrame(data), data_schema) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow_batch(record_batch) + table_commit.commit(table_write.prepare_commit()) + + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + actual_df = table_read.to_pandas(table_scan.plan().splits()) + df['f0'] = df['f0'].astype('int32') + pd.testing.assert_frame_equal( + actual_df.reset_index(drop=True), df.reset_index(drop=True)) From 47931ad56969fc90b0a8ac87666cf90339099c5c Mon Sep 17 00:00:00 2001 From: yuzelin Date: Tue, 12 Nov 2024 09:44:30 +0800 Subject: [PATCH 2/4] fix --- .../src/main/java/org/apache/paimon/python/BytesWriter.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java index e107bd3..64aea4e 100644 --- a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java +++ b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java @@ -87,7 +87,7 @@ private boolean checkSchema(List expectedFields, List actualFields for (int i = 0; i < expectedFields.size(); i++) { Field expectedField = expectedFields.get(i); Field actualField = actualFields.get(i); - if (!checkField(expectedField, actualField) + if (!checkFieldIgnoreNullability(expectedField, actualField) || !checkSchema(expectedField.getChildren(), actualField.getChildren())) { return false; } @@ -96,7 +96,7 @@ private boolean checkSchema(List expectedFields, List actualFields return true; } - private boolean checkField(Field expected, Field actual) { + private boolean checkFieldIgnoreNullability(Field expected, Field actual) { return Objects.equals(expected.getName(), actual.getName()) && Objects.equals(expected.getType(), actual.getType()); } From b786c73b4c42587d6c4835b525e50cc1c01859fa Mon Sep 17 00:00:00 2001 From: yuzelin Date: Tue, 12 Nov 2024 10:20:17 +0800 Subject: [PATCH 3/4] fix --- .../org/apache/paimon/python/BytesWriter.java | 16 ++++++---------- paimon_python_java/pypaimon.py | 10 +++++----- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java index 64aea4e..2c3b8c7 100644 --- a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java +++ b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java @@ -32,7 +32,6 @@ import java.io.ByteArrayInputStream; import java.util.List; -import java.util.Objects; import java.util.stream.Collectors; /** Write Arrow bytes to Paimon. */ @@ -53,11 +52,11 @@ public BytesWriter(TableWrite tableWrite, RowType rowType) { .collect(Collectors.toList()); } - public void write(byte[] bytes, boolean needCheckSchema) throws Exception { + public void write(byte[] bytes) throws Exception { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator); VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot(); - if (needCheckSchema && !checkSchema(arrowFields, vsr.getSchema().getFields())) { + if (!checkTypesIgnoreNullability(arrowFields, vsr.getSchema().getFields())) { throw new RuntimeException( String.format( "Input schema isn't consistent with table schema.\n" @@ -79,7 +78,8 @@ public void close() { allocator.close(); } - private boolean checkSchema(List expectedFields, List actualFields) { + private boolean checkTypesIgnoreNullability( + List expectedFields, List actualFields) { if (expectedFields.size() != actualFields.size()) { return false; } @@ -87,7 +87,8 @@ private boolean checkSchema(List expectedFields, List actualFields for (int i = 0; i < expectedFields.size(); i++) { Field expectedField = expectedFields.get(i); Field actualField = actualFields.get(i); - if (!checkFieldIgnoreNullability(expectedField, actualField) + // ArrowType doesn't have nullability (similar to DataTypeRoot) + if (!actualField.getType().equals(expectedField.getType()) || !checkSchema(expectedField.getChildren(), actualField.getChildren())) { return false; } @@ -95,9 +96,4 @@ private boolean checkSchema(List expectedFields, List actualFields return true; } - - private boolean checkFieldIgnoreNullability(Field expected, Field actual) { - return Objects.equals(expected.getName(), actual.getName()) - && Objects.equals(expected.getType(), actual.getType()); - } } diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py index 50a889d..16c7a69 100644 --- a/paimon_python_java/pypaimon.py +++ b/paimon_python_java/pypaimon.py @@ -219,21 +219,21 @@ def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema): def write_arrow(self, table): for record_batch in table.to_reader(): # TODO: can we use a reusable stream in #_write_arrow_batch ? - self._write_arrow_batch(record_batch, True) + self._write_arrow_batch(record_batch) def write_arrow_batch(self, record_batch): - self._write_arrow_batch(record_batch, True) + self._write_arrow_batch(record_batch) def write_pandas(self, dataframe: pd.DataFrame): record_batch = pa.RecordBatch.from_pandas(dataframe, schema=self._arrow_schema) - self._write_arrow_batch(record_batch, False) + self._write_arrow_batch(record_batch) - def _write_arrow_batch(self, record_batch, check_schema): + def _write_arrow_batch(self, record_batch): stream = pa.BufferOutputStream() with pa.RecordBatchStreamWriter(stream, record_batch.schema) as writer: writer.write(record_batch) arrow_bytes = stream.getvalue().to_pybytes() - self._j_bytes_writer.write(arrow_bytes, check_schema) + self._j_bytes_writer.write(arrow_bytes) def prepare_commit(self) -> List['CommitMessage']: j_commit_messages = self._j_batch_table_write.prepareCommit() From 34f72092de9e2e308fa40edb75226d68b6d149df Mon Sep 17 00:00:00 2001 From: yuzelin Date: Tue, 12 Nov 2024 10:45:30 +0800 Subject: [PATCH 4/4] fix --- .../src/main/java/org/apache/paimon/python/BytesWriter.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java index 2c3b8c7..f2ca4e1 100644 --- a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java +++ b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java @@ -89,7 +89,8 @@ private boolean checkTypesIgnoreNullability( Field actualField = actualFields.get(i); // ArrowType doesn't have nullability (similar to DataTypeRoot) if (!actualField.getType().equals(expectedField.getType()) - || !checkSchema(expectedField.getChildren(), actualField.getChildren())) { + || !checkTypesIgnoreNullability( + expectedField.getChildren(), actualField.getChildren())) { return false; } }