Skip to content

Fix that field nullability affects write #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.paimon.python;

import org.apache.paimon.arrow.ArrowUtils;
import org.apache.paimon.arrow.reader.ArrowBatchReader;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.table.sink.TableWrite;
Expand All @@ -27,26 +28,42 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.pojo.Field;

import java.io.ByteArrayInputStream;
import java.util.List;
import java.util.stream.Collectors;

/** Write Arrow bytes to Paimon. */
public class BytesWriter {

private final TableWrite tableWrite;
private final ArrowBatchReader arrowBatchReader;
private final BufferAllocator allocator;
private final List<Field> arrowFields;

public BytesWriter(TableWrite tableWrite, RowType rowType) {
this.tableWrite = tableWrite;
this.arrowBatchReader = new ArrowBatchReader(rowType);
this.allocator = new RootAllocator();
arrowFields =
rowType.getFields().stream()
.map(f -> ArrowUtils.toArrowField(f.name(), f.type()))
.collect(Collectors.toList());
}

public void write(byte[] bytes) throws Exception {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator);
VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
if (!checkTypesIgnoreNullability(arrowFields, vsr.getSchema().getFields())) {
throw new RuntimeException(
String.format(
"Input schema isn't consistent with table schema.\n"
+ "\tTable schema is: %s\n"
+ "\tInput schema is: %s",
arrowFields, vsr.getSchema().getFields()));
}

while (arrowStreamReader.loadNextBatch()) {
Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
Expand All @@ -60,4 +77,24 @@ public void write(byte[] bytes) throws Exception {
public void close() {
allocator.close();
}

private boolean checkTypesIgnoreNullability(
List<Field> expectedFields, List<Field> actualFields) {
if (expectedFields.size() != actualFields.size()) {
return false;
}

for (int i = 0; i < expectedFields.size(); i++) {
Field expectedField = expectedFields.get(i);
Field actualField = actualFields.get(i);
// ArrowType doesn't have nullability (similar to DataTypeRoot)
if (!actualField.getType().equals(expectedField.getType())
|| !checkTypesIgnoreNullability(
expectedField.getChildren(), actualField.getChildren())) {
return false;
}
}

return true;
}
}
21 changes: 10 additions & 11 deletions paimon_python_java/pypaimon.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,24 +218,23 @@ def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema):

def write_arrow(self, table):
for record_batch in table.to_reader():
# TODO: can we use a reusable stream?
stream = pa.BufferOutputStream()
with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
writer.write(record_batch)
arrow_bytes = stream.getvalue().to_pybytes()
self._j_bytes_writer.write(arrow_bytes)
# TODO: can we use a reusable stream in #_write_arrow_batch ?
self._write_arrow_batch(record_batch)

def write_arrow_batch(self, record_batch):
self._write_arrow_batch(record_batch)

def write_pandas(self, dataframe: pd.DataFrame):
record_batch = pa.RecordBatch.from_pandas(dataframe, schema=self._arrow_schema)
self._write_arrow_batch(record_batch)

def _write_arrow_batch(self, record_batch):
stream = pa.BufferOutputStream()
with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
with pa.RecordBatchStreamWriter(stream, record_batch.schema) as writer:
writer.write(record_batch)
arrow_bytes = stream.getvalue().to_pybytes()
self._j_bytes_writer.write(arrow_bytes)

def write_pandas(self, dataframe: pd.DataFrame):
record_batch = pa.RecordBatch.from_pandas(dataframe, schema=self._arrow_schema)
self.write_arrow_batch(record_batch)

def prepare_commit(self) -> List['CommitMessage']:
j_commit_messages = self._j_batch_table_write.prepareCommit()
return list(map(lambda cm: CommitMessage(cm), j_commit_messages))
Expand Down
74 changes: 74 additions & 0 deletions paimon_python_java/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import unittest
import pandas as pd
import pyarrow as pa
from py4j.protocol import Py4JJavaError

from paimon_python_api import Schema
from paimon_python_java import Catalog
Expand Down Expand Up @@ -371,3 +372,76 @@ def test_overwrite(self):
df2['f0'] = df2['f0'].astype('int32')
pd.testing.assert_frame_equal(
actual_df2.reset_index(drop=True), df2.reset_index(drop=True))

def testWriteWrongSchema(self):
schema = Schema(self.simple_pa_schema)
self.catalog.create_table('default.test_wrong_schema', schema, False)
table = self.catalog.get_table('default.test_wrong_schema')

data = {
'f0': [1, 2, 3],
'f1': ['a', 'b', 'c'],
}
df = pd.DataFrame(data)
schema = pa.schema([
('f0', pa.int64()),
('f1', pa.string())
])
record_batch = pa.RecordBatch.from_pandas(df, schema)

write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()

with self.assertRaises(Py4JJavaError) as e:
table_write.write_arrow_batch(record_batch)
self.assertEqual(
str(e.exception.java_exception),
'''java.lang.RuntimeException: Input schema isn't consistent with table schema.
\tTable schema is: [f0: Int(32, true), f1: Utf8]
\tInput schema is: [f0: Int(64, true), f1: Utf8]''')

def testIgnoreNullable(self):
pa_schema1 = pa.schema([
('f0', pa.int32(), False),
('f1', pa.string())
])

pa_schema2 = pa.schema([
('f0', pa.int32()),
('f1', pa.string())
])

# write nullable to non-null
self._testIgnoreNullableImpl('test_ignore_nullable1', pa_schema1, pa_schema2)

# write non-null to nullable
self._testIgnoreNullableImpl('test_ignore_nullable2', pa_schema2, pa_schema1)

def _testIgnoreNullableImpl(self, table_name, table_schema, data_schema):
schema = Schema(table_schema)
self.catalog.create_table(f'default.{table_name}', schema, False)
table = self.catalog.get_table(f'default.{table_name}')

data = {
'f0': [1, 2, 3],
'f1': ['a', 'b', 'c'],
}
df = pd.DataFrame(data)
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame(data), data_schema)

write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
table_write.write_arrow_batch(record_batch)
table_commit.commit(table_write.prepare_commit())

table_write.close()
table_commit.close()

read_builder = table.new_read_builder()
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
actual_df = table_read.to_pandas(table_scan.plan().splits())
df['f0'] = df['f0'].astype('int32')
pd.testing.assert_frame_equal(
actual_df.reset_index(drop=True), df.reset_index(drop=True))
Loading