-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
row.rs
236 lines (207 loc) · 8.49 KB
/
row.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::{Array, ArrayRef};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use hashbrown::raw::RawTable;
/// A [`GroupValues`] making use of [`Rows`]
pub struct GroupValuesRows {
/// The output schema
schema: SchemaRef,
/// Converter for the group values
row_converter: RowConverter,
/// Logically maps group values to a group_index in
/// [`Self::group_values`] and in each accumulator
///
/// Uses the raw API of hashbrown to avoid actually storing the
/// keys (group values) in the table
///
/// keys: u64 hashes of the GroupValue
/// values: (hash, group_index)
map: RawTable<(u64, usize)>,
/// The size of `map` in bytes
map_size: usize,
/// The actual group by values, stored in arrow [`Row`] format.
/// `group_values[i]` holds the group value for group_index `i`.
///
/// The row format is used to compare group keys quickly and store
/// them efficiently in memory. Quick comparison is especially
/// important for multi-column group keys.
///
/// [`Row`]: arrow::row::Row
group_values: Option<Rows>,
// buffer to be reused to store hashes
hashes_buffer: Vec<u64>,
/// Random state for creating hashes
random_state: RandomState,
}
impl GroupValuesRows {
pub fn try_new(schema: SchemaRef) -> Result<Self> {
let row_converter = RowConverter::new(
schema
.fields()
.iter()
.map(|f| SortField::new(f.data_type().clone()))
.collect(),
)?;
let map = RawTable::with_capacity(0);
Ok(Self {
schema,
row_converter,
map,
map_size: 0,
group_values: None,
hashes_buffer: Default::default(),
random_state: Default::default(),
})
}
}
impl GroupValues for GroupValuesRows {
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
// Convert the group keys into the row format
// Avoid reallocation when https://github.com/apache/arrow-rs/issues/4479 is available
let group_rows = self.row_converter.convert_columns(cols)?;
let n_rows = group_rows.num_rows();
let mut group_values = match self.group_values.take() {
Some(group_values) => group_values,
None => self.row_converter.empty_rows(0, 0),
};
// tracks to which group each of the input rows belongs
groups.clear();
// 1.1 Calculate the group keys for the group values
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(n_rows, 0);
create_hashes(cols, &self.random_state, batch_hashes)?;
for (row, &hash) in batch_hashes.iter().enumerate() {
let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
// verify that a group that we are inserting with hash is
// actually the same key value as the group in
// existing_idx (aka group_values @ row)
group_rows.row(row) == group_values.row(*group_idx)
});
let group_idx = match entry {
// Existing group_index for this group value
Some((_hash, group_idx)) => *group_idx,
// 1.2 Need to create new entry for the group
None => {
// Add new entry to aggr_state and save newly created index
let group_idx = group_values.num_rows();
group_values.push(group_rows.row(row));
// for hasher function, use precomputed hash value
self.map.insert_accounted(
(hash, group_idx),
|(hash, _group_index)| *hash,
&mut self.map_size,
);
group_idx
}
};
groups.push(group_idx);
}
self.group_values = Some(group_values);
Ok(())
}
fn size(&self) -> usize {
let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0);
self.row_converter.size()
+ group_values_size
+ self.map_size
+ self.hashes_buffer.allocated_size()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn len(&self) -> usize {
self.group_values
.as_ref()
.map(|group_values| group_values.num_rows())
.unwrap_or(0)
}
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let mut group_values = self
.group_values
.take()
.expect("Can not emit from empty rows");
let mut output = match emit_to {
EmitTo::All => {
let output = self.row_converter.convert_rows(&group_values)?;
group_values.clear();
output
}
EmitTo::First(n) => {
let groups_rows = group_values.iter().take(n);
let output = self.row_converter.convert_rows(groups_rows)?;
// Clear out first n group keys by copying them to a new Rows.
// TODO file some ticket in arrow-rs to make this more efficent?
let mut new_group_values = self.row_converter.empty_rows(0, 0);
for row in group_values.iter().skip(n) {
new_group_values.push(row);
}
std::mem::swap(&mut new_group_values, &mut group_values);
// SAFETY: self.map outlives iterator and is not modified concurrently
unsafe {
for bucket in self.map.iter() {
// Decrement group index by n
match bucket.as_ref().1.checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => bucket.as_mut().1 = sub,
// Group index was < n, so remove from table
None => self.map.erase(bucket),
}
}
}
output
}
};
// TODO: Materialize dictionaries in group keys (#7647)
for (field, array) in self.schema.fields.iter().zip(&mut output) {
let expected = field.data_type();
if let DataType::Dictionary(_, v) = expected {
let actual = array.data_type();
if v.as_ref() != actual {
return Err(DataFusionError::Internal(format!(
"Converted group rows expected dictionary of {v} got {actual}"
)));
}
*array = cast(array.as_ref(), expected)?;
}
}
self.group_values = Some(group_values);
Ok(output)
}
fn clear_shrink(&mut self, batch: &RecordBatch) {
let count = batch.num_rows();
self.group_values = self.group_values.take().map(|mut rows| {
rows.clear();
rows
});
self.map.clear();
self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared
self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>();
self.hashes_buffer.clear();
self.hashes_buffer.shrink_to(count);
}
}