Skip to content

Commit

Permalink
Adjust test assertions to pyarrow values
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Sep 3, 2020
1 parent b4deb53 commit 8df5a86
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 4 deletions.
9 changes: 8 additions & 1 deletion tests/system/v1/test_reader_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,14 @@ def test_decoding_data_types(

stream = session.streams[0].name

rows = list(client.read_rows(stream).rows(session))
if data_format == bigquery_storage_v1.enums.DataFormat.AVRO:
rows = list(client.read_rows(stream).rows(session))
else:
assert data_format == bigquery_storage_v1.enums.DataFormat.ARROW
rows = list(
dict((key, value.as_py()) for key, value in row_dict.items())
for row_dict in client.read_rows(stream).rows(session)
)

expected_result = {
u"string_field": u"Price: € 9.95.",
Expand Down
9 changes: 8 additions & 1 deletion tests/system/v1beta1/test_reader_v1beta1.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,14 @@ def test_decoding_data_types(
stream=session.streams[0]
)

rows = list(client.read_rows(stream_pos).rows(session))
if data_format == bigquery_storage_v1beta1.enums.DataFormat.AVRO:
rows = list(client.read_rows(stream_pos).rows(session))
else:
assert data_format == bigquery_storage_v1beta1.enums.DataFormat.ARROW
rows = list(
dict((key, value.as_py()) for key, value in row_dict.items())
for row_dict in client.read_rows(stream_pos).rows(session)
)

expected_result = {
u"string_field": u"Price: € 9.95.",
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_reader_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def test_rows_w_scalars_arrow(class_under_test, mock_client):
arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema)

reader = class_under_test(arrow_batches, mock_client, "", 0, {})
got = tuple(reader.rows(read_session))
got = tuple(
dict((key, value.as_py()) for key, value in row_dict.items())
for row_dict in reader.rows(read_session)
)

expected = tuple(itertools.chain.from_iterable(SCALAR_BLOCKS))
assert got == expected
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_reader_v1beta1.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,10 @@ def test_rows_w_scalars_arrow(class_under_test, mock_client):
reader = class_under_test(
arrow_batches, mock_client, bigquery_storage_v1beta1.types.StreamPosition(), {}
)
got = tuple(reader.rows(read_session))
got = tuple(
dict((key, value.as_py()) for key, value in row_dict.items())
for row_dict in reader.rows(read_session)
)

expected = tuple(itertools.chain.from_iterable(SCALAR_BLOCKS))
assert got == expected
Expand Down

0 comments on commit 8df5a86

Please sign in to comment.