Skip to content

Commit

Permalink
change arrow start stream length to be written in ArrowStreamPandasSe…
Browse files Browse the repository at this point in the history
…rializer
  • Loading branch information
BryanCutler committed Mar 18, 2019
1 parent 93bb831 commit bc08d1b
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,25 +222,13 @@ class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
"""
def __init__(self, send_start_stream=True):
self._send_start_stream = send_start_stream

def _init_dump_stream(self, stream):
"""
Called just before writing an Arrow stream
"""
# NOTE: this is required by Pandas UDFs to be called after creating first record batch so
# that any errors can be sent back to the JVM, but not interfere with the Arrow stream
if self._send_start_stream:
write_int(SpecialLengths.START_ARROW_STREAM, stream)

def dump_stream(self, iterator, stream):
import pyarrow as pa
writer = None
try:
for batch in iterator:
if writer is None:
self._init_dump_stream(stream)
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
finally:
Expand Down Expand Up @@ -346,10 +334,11 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, send_start_stream=True):
super(ArrowStreamPandasSerializer, self).__init__(send_start_stream)
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
self._send_start_stream = send_start_stream

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
Expand All @@ -366,15 +355,28 @@ def dump_stream(self, iterator, stream):
"""
batches = (_create_batch(series, self._timezone, self._safecheck, self._assign_cols_by_name)
for series in iterator)
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)

def init_stream_yield_batches():
# NOTE: START_ARROW_STREAM is required by Pandas UDFs, called after creating the first
# record batch so any errors can be sent back to the JVM before the Arrow stream starts
should_write_start_length = True
for batch in batches:
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch

iterator = init_stream_yield_batches() if self._send_start_stream else batches

super(ArrowStreamPandasSerializer, self).dump_stream(iterator, stream)

def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
batch_iter = super(ArrowStreamPandasSerializer, self).load_stream(stream)
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
import pyarrow as pa
for batch in batch_iter:
for batch in batches:
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]

def __repr__(self):
Expand Down

0 comments on commit bc08d1b

Please sign in to comment.