Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix group by array low cardinality arguments #4055

Merged
merged 2 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions dbms/src/DataTypes/DataTypeLowCardinalityHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,24 @@ ColumnPtr recursiveRemoveLowCardinality(const ColumnPtr & column)
return column;

if (const auto * column_array = typeid_cast<const ColumnArray *>(column.get()))
return ColumnArray::create(recursiveRemoveLowCardinality(column_array->getDataPtr()), column_array->getOffsetsPtr());
{
auto & data = column_array->getDataPtr();
auto data_no_lc = recursiveRemoveLowCardinality(data);
if (data.get() == data_no_lc.get())
return column;

return ColumnArray::create(data_no_lc, column_array->getOffsetsPtr());
}

if (const auto * column_const = typeid_cast<const ColumnConst *>(column.get()))
return ColumnConst::create(recursiveRemoveLowCardinality(column_const->getDataColumnPtr()), column_const->size());
{
auto & nested = column_const->getDataColumnPtr();
auto nested_no_lc = recursiveRemoveLowCardinality(nested);
if (nested.get() == nested_no_lc.get())
return column;

return ColumnConst::create(nested_no_lc, column_const->size());
}

if (const auto * column_tuple = typeid_cast<const ColumnTuple *>(column.get()))
{
Expand All @@ -76,8 +90,14 @@ ColumnPtr recursiveLowCardinalityConversion(const ColumnPtr & column, const Data
return column;

if (const auto * column_const = typeid_cast<const ColumnConst *>(column.get()))
return ColumnConst::create(recursiveLowCardinalityConversion(column_const->getDataColumnPtr(), from_type, to_type),
column_const->size());
{
auto & nested = column_const->getDataColumnPtr();
auto nested_no_lc = recursiveLowCardinalityConversion(nested, from_type, to_type);
if (nested.get() == nested_no_lc.get())
return column;

return ColumnConst::create(nested_no_lc, column_const->size());
}

if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(from_type.get()))
{
Expand Down Expand Up @@ -125,11 +145,23 @@ ColumnPtr recursiveLowCardinalityConversion(const ColumnPtr & column, const Data
Columns columns = column_tuple->getColumns();
auto & from_elements = from_tuple_type->getElements();
auto & to_elements = to_tuple_type->getElements();

bool has_converted = false;

for (size_t i = 0; i < columns.size(); ++i)
{
auto & element = columns[i];
element = recursiveLowCardinalityConversion(element, from_elements.at(i), to_elements.at(i));
auto element_no_lc = recursiveLowCardinalityConversion(element, from_elements.at(i), to_elements.at(i));
if (element.get() != element_no_lc.get())
{
element = element_no_lc;
has_converted = true;
}
}

if (!has_converted)
return column;

return ColumnTuple::create(columns);
}
}
Expand Down
12 changes: 7 additions & 5 deletions dbms/src/Interpreters/Aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,11 +768,12 @@ bool Aggregator::executeOnBlock(const Block & block, AggregatedDataVariants & re
materialized_columns.push_back(block.safeGetByPosition(params.keys[i]).column->convertToFullColumnIfConst());
key_columns[i] = materialized_columns.back().get();

if (const auto * low_cardinality_column = typeid_cast<const ColumnLowCardinality *>(key_columns[i]))
if (!result.isLowCardinality())
{
if (!result.isLowCardinality())
auto column_no_lc = recursiveRemoveLowCardinality(key_columns[i]->getPtr());
if (column_no_lc.get() != key_columns[i])
{
materialized_columns.push_back(low_cardinality_column->convertToFullColumn());
materialized_columns.emplace_back(std::move(column_no_lc));
key_columns[i] = materialized_columns.back().get();
}
}
Expand All @@ -788,9 +789,10 @@ bool Aggregator::executeOnBlock(const Block & block, AggregatedDataVariants & re
materialized_columns.push_back(block.safeGetByPosition(params.aggregates[i].arguments[j]).column->convertToFullColumnIfConst());
aggregate_columns[i][j] = materialized_columns.back().get();

if (auto * col_low_cardinality = typeid_cast<const ColumnLowCardinality *>(aggregate_columns[i][j]))
auto column_no_lc = recursiveRemoveLowCardinality(aggregate_columns[i][j]->getPtr());
if (column_no_lc.get() != aggregate_columns[i][j])
{
materialized_columns.push_back(col_low_cardinality->convertToFullColumn());
materialized_columns.emplace_back(std::move(column_no_lc));
aggregate_columns[i][j] = materialized_columns.back().get();
}
}
Expand Down
14 changes: 10 additions & 4 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,18 @@ void Join::setSampleBlock(const Block & block)

size_t keys_size = key_names_right.size();
ColumnRawPtrs key_columns(keys_size);
Columns materialized_columns(keys_size);
Columns materialized_columns;

for (size_t i = 0; i < keys_size; ++i)
{
materialized_columns[i] = recursiveRemoveLowCardinality(block.getByName(key_names_right[i]).column);
key_columns[i] = materialized_columns[i].get();
auto & column = block.getByName(key_names_right[i]).column;
key_columns[i] = column.get();
auto column_no_lc = recursiveRemoveLowCardinality(column);
if (column.get() != column_no_lc.get())
{
materialized_columns.emplace_back(std::move(column_no_lc));
key_columns[i] = materialized_columns[i].get();
}

/// We will join only keys, where all components are not NULL.
if (key_columns[i]->isColumnNullable())
Expand Down Expand Up @@ -914,7 +920,7 @@ void Join::joinGetImpl(Block & block, const String & column_name, const Maps & m


// TODO: support composite key
// TODO: return multible columns as named tuple
// TODO: return multiple columns as named tuple
// TODO: return array of values when strictness == ASTTableJoin::Strictness::All
void Join::joinGet(Block & block, const String & column_name) const
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
2019-01-14 1 ['aaa','aaa','bbb','ccc']
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
SET allow_experimental_low_cardinality_type = 1;

DROP TABLE IF EXISTS test.table1;
DROP TABLE IF EXISTS test.table2;

CREATE TABLE test.table1
(
dt Date,
id Int32,
arr Array(LowCardinality(String))
) ENGINE = MergeTree PARTITION BY toMonday(dt)
ORDER BY (dt, id) SETTINGS index_granularity = 8192;

CREATE TABLE test.table2
(
dt Date,
id Int32,
arr Array(LowCardinality(String))
) ENGINE = MergeTree PARTITION BY toMonday(dt)
ORDER BY (dt, id) SETTINGS index_granularity = 8192;

insert into test.table1 (dt, id, arr) values ('2019-01-14', 1, ['aaa']);
insert into test.table2 (dt, id, arr) values ('2019-01-14', 1, ['aaa','bbb','ccc']);

select dt, id, groupArrayArray(arr)
from (
select dt, id, arr from test.table1
where dt = '2019-01-14' and id = 1
UNION ALL
select dt, id, arr from test.table2
where dt = '2019-01-14' and id = 1
)
group by dt, id;