Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: recurse into Map datatype when hydrating dictionaries #6645

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,15 @@ fn prepare_field_for_flight(
.with_metadata(field.metadata().clone())
}
}
DataType::Map(inner, sorted) => Field::new(
field.name(),
DataType::Map(
prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
*sorted,
),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
_ => field.as_ref().clone(),
}
}
Expand Down Expand Up @@ -684,6 +693,7 @@ mod tests {
use arrow_cast::pretty::pretty_format_batches;
use arrow_ipc::MetadataVersion;
use arrow_schema::{UnionFields, UnionMode};
use builder::{GenericStringBuilder, MapBuilder};
use std::collections::HashMap;

use super::*;
Expand Down Expand Up @@ -1275,6 +1285,157 @@ mod tests {
verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
}

#[tokio::test]
async fn test_dictionary_map_hydration() {
let mut builder = MapBuilder::new(
None,
StringDictionaryBuilder::<UInt16Type>::new(),
StringDictionaryBuilder::<UInt16Type>::new(),
);

// {"k1":"a","k2":null,"k3":"b"}
builder.keys().append_value("k1");
builder.values().append_value("a");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("b");
builder.append(true).unwrap();

let arr1 = builder.finish();

// {"k1":"c","k2":null,"k3":"d"}
builder.keys().append_value("k1");
builder.values().append_value("c");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("d");
builder.append(true).unwrap();

let arr2 = builder.finish();

let schema = Arc::new(Schema::new(vec![Field::new_map(
"dict_map",
"entries",
Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
false,
false,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();

let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);

let encoder = FlightDataEncoderBuilder::default().build(stream);

let mut decoder = FlightDataDecoder::new(encoder);
let expected_schema = Schema::new(vec![Field::new_map(
"dict_map",
"entries",
Field::new("keys", DataType::Utf8, false),
Field::new("values", DataType::Utf8, true),
false,
false,
)]);

let expected_schema = Arc::new(expected_schema);

// Builder without dictionary fields
let mut builder = MapBuilder::new(
None,
GenericStringBuilder::<i32>::new(),
GenericStringBuilder::<i32>::new(),
);

// {"k1":"a","k2":null,"k3":"b"}
builder.keys().append_value("k1");
builder.values().append_value("a");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("b");
builder.append(true).unwrap();

let arr1 = builder.finish();

// {"k1":"c","k2":null,"k3":"d"}
builder.keys().append_value("k1");
builder.values().append_value("c");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("d");
builder.append(true).unwrap();

let arr2 = builder.finish();

let mut expected_arrays = vec![arr1, arr2].into_iter();

while let Some(decoded) = decoder.next().await {
let decoded = decoded.unwrap();
match decoded.payload {
DecodedPayload::None => {}
DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
DecodedPayload::RecordBatch(b) => {
assert_eq!(b.schema(), expected_schema);
let expected_array = expected_arrays.next().unwrap();
let map_array =
downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());

assert_eq!(map_array, expected_array);
}
}
}
}

#[tokio::test]
async fn test_dictionary_map_resend() {
let mut builder = MapBuilder::new(
None,
StringDictionaryBuilder::<UInt16Type>::new(),
StringDictionaryBuilder::<UInt16Type>::new(),
);

// {"k1":"a","k2":null,"k3":"b"}
builder.keys().append_value("k1");
builder.values().append_value("a");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("b");
builder.append(true).unwrap();

let arr1 = builder.finish();

// {"k1":"c","k2":null,"k3":"d"}
builder.keys().append_value("k1");
builder.values().append_value("c");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("d");
builder.append(true).unwrap();

let arr2 = builder.finish();

let schema = Arc::new(Schema::new(vec![Field::new_map(
"dict_map",
"entries",
Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
false,
false,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
let expected_schema = batches.first().unwrap().schema();

Expand Down
Loading