Skip to content

Commit

Permalink
Merge pull request ClickHouse#4629 from bgranvea/simple_aggregate_fun…
Browse files Browse the repository at this point in the history
…ction

SimpleAggregateFunction data type
  • Loading branch information
alexey-milovidov authored May 9, 2019
2 parents 8be1bc0 + 42b07c5 commit 8f8d2c0
Show file tree
Hide file tree
Showing 14 changed files with 484 additions and 113 deletions.
46 changes: 44 additions & 2 deletions dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <DataStreams/AggregatingSortedBlockInputStream.h>
#include <Common/typeid_cast.h>
#include <Common/StringUtils/StringUtils.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>


namespace DB
Expand All @@ -22,7 +24,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
ColumnWithTypeAndName & column = header.safeGetByPosition(i);

/// We leave only states of aggregate functions.
if (!startsWith(column.type->getName(), "AggregateFunction"))
if (!dynamic_cast<const DataTypeAggregateFunction *>(column.type.get()) && !dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
{
column_numbers_not_to_aggregate.push_back(i);
continue;
Expand All @@ -40,7 +42,17 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
continue;
}

column_numbers_to_aggregate.push_back(i);
if (auto simple_aggr = dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
{
// simple aggregate function
SimpleAggregateDescription desc{simple_aggr->getFunction(), i};
columns_to_simple_aggregate.emplace_back(std::move(desc));
}
else
{
// standard aggregate function
column_numbers_to_aggregate.push_back(i);
}
}
}

Expand Down Expand Up @@ -91,7 +103,11 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s

/// if there are enough rows accumulated and the last one is calculated completely
if (key_differs && merged_rows >= max_block_size)
{
/// Write the simple aggregation result for the previous group.
insertSimpleAggregationResult(merged_columns);
return;
}

queue.pop();

Expand All @@ -110,6 +126,14 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
for (auto & column_to_aggregate : columns_to_aggregate)
column_to_aggregate->insertDefault();

/// Write the simple aggregation result for the previous group.
if (merged_rows > 0)
insertSimpleAggregationResult(merged_columns);

/// Reset simple aggregation states for next row
for (auto & desc : columns_to_simple_aggregate)
desc.createState();

++merged_rows;
}

Expand All @@ -127,6 +151,9 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
}
}

/// Write the simple aggregation result for the previous group.
insertSimpleAggregationResult(merged_columns);

finished = true;
}

Expand All @@ -138,6 +165,21 @@ void AggregatingSortedBlockInputStream::addRow(SortCursor & cursor)
size_t j = column_numbers_to_aggregate[i];
columns_to_aggregate[i]->insertMergeFrom(*cursor->all_columns[j], cursor->pos);
}

for (auto & desc : columns_to_simple_aggregate)
{
auto & col = cursor->all_columns[desc.column_number];
desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr);
}
}

void AggregatingSortedBlockInputStream::insertSimpleAggregationResult(MutableColumns & merged_columns)
{
for (auto & desc : columns_to_simple_aggregate)
{
desc.function->insertResultInto(desc.state.data(), *merged_columns[desc.column_number]);
desc.destroyState();
}
}

}
51 changes: 51 additions & 0 deletions dbms/src/DataStreams/AggregatingSortedBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <DataStreams/MergingSortedBlockInputStream.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/AlignedBuffer.h>


namespace DB
Expand Down Expand Up @@ -38,10 +39,13 @@ class AggregatingSortedBlockInputStream : public MergingSortedBlockInputStream
/// Read finished.
bool finished = false;

struct SimpleAggregateDescription;

/// Columns with which numbers should be aggregated.
ColumnNumbers column_numbers_to_aggregate;
ColumnNumbers column_numbers_not_to_aggregate;
std::vector<ColumnAggregateFunction *> columns_to_aggregate;
std::vector<SimpleAggregateDescription> columns_to_simple_aggregate;

RowRef current_key; /// The current primary key.
RowRef next_key; /// The primary key of the next row.
Expand All @@ -54,6 +58,53 @@ class AggregatingSortedBlockInputStream : public MergingSortedBlockInputStream
/** Extract all states of aggregate functions and merge them with the current group.
*/
void addRow(SortCursor & cursor);

/** Insert all values of current row for simple aggregate functions
*/
void insertSimpleAggregationResult(MutableColumns & merged_columns);

/// Stores information for aggregation of SimpleAggregateFunction columns
struct SimpleAggregateDescription
{
/// An aggregate function 'anyLast', 'sum'...
AggregateFunctionPtr function;
IAggregateFunction::AddFunc add_function;
size_t column_number;
AlignedBuffer state;
bool created = false;

SimpleAggregateDescription(const AggregateFunctionPtr & function_, const size_t column_number_) : function(function_), column_number(column_number_)
{
add_function = function->getAddressOfAddFunction();
state.reset(function->sizeOfData(), function->alignOfData());
}

void createState()
{
if (created)
return;
function->create(state.data());
created = true;
}

void destroyState()
{
if (!created)
return;
function->destroy(state.data());
created = false;
}

/// Explicitly destroy aggregation state if the stream is terminated
~SimpleAggregateDescription()
{
destroyState();
}

SimpleAggregateDescription() = default;
SimpleAggregateDescription(SimpleAggregateDescription &&) = default;
SimpleAggregateDescription(const SimpleAggregateDescription &) = delete;
};
};

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include <memory>
#include <cstddef>
#include <Core/Types.h>

namespace DB
{
Expand All @@ -10,21 +12,21 @@ class WriteBuffer;
struct FormatSettings;
class IColumn;

/** Further refinment of the properties of data type.
*
* Contains methods for serialization/deserialization.
* Implementations of this interface represent a data type domain (example: IPv4)
* which is a refinement of the exsitgin type with a name and specific text
* representation.
*
* IDataTypeDomain is totally immutable object. You can always share them.
/** Allow to customize an existing data type and set a different name and/or text serialization/deserialization methods.
* See use in IPv4 and IPv6 data types, and also in SimpleAggregateFunction.
*/
class IDataTypeDomain
class IDataTypeCustomName
{
public:
virtual ~IDataTypeDomain() {}
virtual ~IDataTypeCustomName() {}

virtual const char* getName() const = 0;
virtual String getName() const = 0;
};

class IDataTypeCustomTextSerialization
{
public:
virtual ~IDataTypeCustomTextSerialization() {}

/** Text serialization for displaying on a terminal or saving into a text file, and the like.
* Without escaping or quoting.
Expand Down Expand Up @@ -56,4 +58,31 @@ class IDataTypeDomain
virtual void serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const = 0;
};

using DataTypeCustomNamePtr = std::unique_ptr<const IDataTypeCustomName>;
using DataTypeCustomTextSerializationPtr = std::unique_ptr<const IDataTypeCustomTextSerialization>;

/** Describe a data type customization
*/
struct DataTypeCustomDesc
{
DataTypeCustomNamePtr name;
DataTypeCustomTextSerializationPtr text_serialization;

DataTypeCustomDesc(DataTypeCustomNamePtr name_, DataTypeCustomTextSerializationPtr text_serialization_)
: name(std::move(name_)), text_serialization(std::move(text_serialization_)) {}
};

using DataTypeCustomDescPtr = std::unique_ptr<DataTypeCustomDesc>;

/** A simple implementation of IDataTypeCustomName
*/
class DataTypeCustomFixedName : public IDataTypeCustomName
{
private:
String name;
public:
DataTypeCustomFixedName(String name_) : name(name_) {}
String getName() const override { return name; }
};

} // namespace DB
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <Columns/ColumnsNumber.h>
#include <Common/Exception.h>
#include <Common/formatIPv6.h>
#include <DataTypes/DataTypeDomainWithSimpleSerialization.h>
#include <DataTypes/DataTypeCustomSimpleTextSerialization.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/IDataTypeDomain.h>
#include <DataTypes/DataTypeCustom.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionsCoding.h>

Expand All @@ -20,20 +20,15 @@ namespace ErrorCodes
namespace
{

class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization
class DataTypeCustomIPv4Serialization : public DataTypeCustomSimpleTextSerialization
{
public:
const char * getName() const override
{
return "IPv4";
}

void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override
{
const auto col = checkAndGetColumn<ColumnUInt32>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("IPv4 type can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}

char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
Expand All @@ -48,7 +43,7 @@ class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization
ColumnUInt32 * col = typeid_cast<ColumnUInt32 *>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("IPv4 type can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}

char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
Expand All @@ -63,20 +58,16 @@ class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization
}
};

class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization
class DataTypeCustomIPv6Serialization : public DataTypeCustomSimpleTextSerialization
{
public:
const char * getName() const override
{
return "IPv6";
}

void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override
{
const auto col = checkAndGetColumn<ColumnFixedString>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("IPv6 type domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}

char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
Expand All @@ -91,7 +82,7 @@ class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization
ColumnFixedString * col = typeid_cast<ColumnFixedString *>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("IPv6 type domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}

char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
Expand All @@ -100,7 +91,7 @@ class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization
std::string ipv6_value(IPV6_BINARY_LENGTH, '\0');
if (!parseIPv6(buffer, reinterpret_cast<unsigned char *>(ipv6_value.data())))
{
throw Exception(String("Invalid ") + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
throw Exception("Invalid IPv6 value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
}

col->insertString(ipv6_value);
Expand All @@ -111,8 +102,17 @@ class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization

void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory)
{
factory.registerDataTypeDomain("UInt32", std::make_unique<DataTypeDomainIPv4>());
factory.registerDataTypeDomain("FixedString(16)", std::make_unique<DataTypeDomainIPv6>());
factory.registerSimpleDataTypeCustom("IPv4", []
{
return std::make_pair(DataTypeFactory::instance().get("UInt32"),
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv4"), std::make_unique<DataTypeCustomIPv4Serialization>()));
});

factory.registerSimpleDataTypeCustom("IPv6", []
{
return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"),
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv6"), std::make_unique<DataTypeCustomIPv6Serialization>()));
});
}

} // namespace DB
Loading

0 comments on commit 8f8d2c0

Please sign in to comment.