Skip to content

Commit

Permalink
changes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander kozhikhov committed Jan 23, 2019
1 parent 19ca2f3 commit fd8f9c4
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 86 deletions.
55 changes: 33 additions & 22 deletions dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,45 @@
namespace DB
{

namespace
{
namespace
{

using FuncLinearRegression = AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression>;
using FuncLinearRegression = AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression>;

template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & arguments, const Array & parameters)
{
if (parameters.size() > 1)
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (parameters.size() > 1)
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

Float64 lr;
if (parameters.empty())
lr = Float64(0.01);
else
lr = static_cast<const Float64>(parameters[0].template get<Float64>());
for (size_t i = 0; i < argument_types.size(); ++i)
{
if (!WhichDataType(argument_types[i]).isFloat64())
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " +
std::to_string(i) + "for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

if (arguments.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
Float64 learning_rate;
if (parameters.empty())
{
learning_rate = Float64(0.01);
} else
{
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
}

return std::make_shared<Method>(arguments.size() - 1, lr);
}
if (argument_types.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

}
return std::make_shared<Method>(argument_types.size() - 1, learning_rate);
}

void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
}
}

void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
}

}
104 changes: 58 additions & 46 deletions dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,74 +22,82 @@

namespace DB {

struct LinearRegressionData {
Float64 bias{0.0};
std::vector<Float64> w1;
Float64 learning_rate{0.01};
UInt32 iter_num = 0;
UInt32 param_num = 0;
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}


void add(Float64 target, std::vector<Float64>& feature, Float64 learning_rate_, UInt32 param_num_) {
if (w1.empty()) {
learning_rate = learning_rate_;
param_num = param_num_;
w1.resize(param_num);
}
struct LinearRegressionData
{
LinearRegressionData()
{}
LinearRegressionData(Float64 learning_rate_, UInt32 param_num_)
: learning_rate(learning_rate_) {
weights.resize(param_num_);
}

Float64 bias{0.0};
std::vector<Float64> weights;
Float64 learning_rate;
UInt32 iter_num = 0;

void add(Float64 target, const IColumn ** columns, size_t row_num)
{
Float64 derivative = (target - bias);
for (size_t i = 0; i < param_num; ++i)
for (size_t i = 0; i < weights.size(); ++i)
{
derivative -= w1[i] * feature[i];
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];

}
derivative *= (2 * learning_rate);

bias += derivative;
for (size_t i = 0; i < param_num; ++i)
for (size_t i = 0; i < weights.size(); ++i)
{
w1[i] += derivative * feature[i];
weights[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
}

++iter_num;
}

void merge(const LinearRegressionData & rhs) {
void merge(const LinearRegressionData & rhs)
{
if (iter_num == 0 && rhs.iter_num == 0)
throw std::runtime_error("Strange...");

if (param_num == 0) {
param_num = rhs.param_num;
w1.resize(param_num);
}
return;

Float64 frac = static_cast<Float64>(iter_num) / (iter_num + rhs.iter_num);
Float64 rhs_frac = static_cast<Float64>(rhs.iter_num) / (iter_num + rhs.iter_num);

for (size_t i = 0; i < param_num; ++i)
for (size_t i = 0; i < weights.size(); ++i)
{
w1[i] = w1[i] * frac + rhs.w1[i] * rhs_frac;
weights[i] = weights[i] * frac + rhs.weights[i] * rhs_frac;
}

bias = bias * frac + rhs.bias * rhs_frac;
iter_num += rhs.iter_num;
}

void write(WriteBuffer & buf) const {
void write(WriteBuffer & buf) const
{
writeBinary(bias, buf);
writeBinary(w1, buf);
writeBinary(weights, buf);
writeBinary(iter_num, buf);
}

void read(ReadBuffer & buf) {
void read(ReadBuffer & buf)
{
readBinary(bias, buf);
readBinary(w1, buf);
readBinary(weights, buf);
readBinary(iter_num, buf);
}
Float64 predict(std::vector<Float64>& predict_feature) const {
Float64 predict(const std::vector<Float64>& predict_feature) const
{
Float64 res{0.0};
for (size_t i = 0; i < static_cast<size_t>(param_num); ++i)
for (size_t i = 0; i < predict_feature.size(); ++i)
{
res += predict_feature[i] * w1[i];
res += predict_feature[i] * weights[i];
}
res += bias;

Expand Down Expand Up @@ -118,18 +126,16 @@ class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data
return std::make_shared<DataTypeNumber<Float64>>();
}

void create(AggregateDataPtr place) const override
{
new (place) Data(learning_rate, param_num);
}

void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]);

std::vector<Float64> x(param_num);
for (size_t i = 0; i < param_num; ++i)
{
x[i] = static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
}

this->data(place).add(target.getData()[row_num], x, learning_rate, param_num);

this->data(place).add(target.getData()[row_num], columns, row_num);
}

void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
Expand All @@ -149,20 +155,26 @@ class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data

void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const
{
if (arguments.size() != param_num + 1)
throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

auto &column = dynamic_cast<ColumnVector<Float64> &>(to);

std::vector<Float64> predict_features(arguments.size() - 1);
// for (size_t row_num = 0, rows = block.rows(); row_num < rows; ++row_num) {
for (size_t i = 1; i < arguments.size(); ++i) {
// predict_features[i] = array_elements[i].get<Float64>();
predict_features[i - 1] = applyVisitor(FieldVisitorConvertToNumber<Float64>(), (*block.getByPosition(arguments[i]).column)[row_num]);
const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
if (element.getType() != Field::Types::Float64)
throw Exception("Prediction arguments must be values of type Float",
ErrorCodes::BAD_ARGUMENTS);

predict_features[i - 1] = element.get<Float64>();
}
// column.getData().push_back(this->data(place).predict(predict_features));
column.getData().push_back(this->data(place).predict(predict_features));
// }
}

void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override {
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
std::ignore = column;
std::ignore = place;
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/AggregateFunctions/IAggregateFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class IAggregateFunctionDataHelper : public IAggregateFunctionHelper<Derived>
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }

public:
void create(AggregateDataPtr place) const override
virtual void create(AggregateDataPtr place) const override
{
new (place) Data;
}
Expand Down
61 changes: 44 additions & 17 deletions dbms/src/Columns/ColumnAggregateFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace ErrorCodes
{
extern const int PARAMETER_OUT_OF_BOUND;
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}


Expand All @@ -32,6 +33,23 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_);
}

bool ColumnAggregateFunction::convertion(MutableColumnPtr* res_) const
{
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
*res_ = std::move(res);
return true;
}

MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
*res_ = std::move(res);
return false;
}

MutableColumnPtr ColumnAggregateFunction::convertToValues() const
{
/** If the aggregate function returns an unfinalized/unfinished state,
Expand Down Expand Up @@ -64,38 +82,46 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
* AggregateFunction(quantileTiming(0.5), UInt64)
* into UInt16 - already finished result of `quantileTiming`.
*/
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// {
// auto res = createView();
// res->set(function_state->getNestedFunction());
// res->data.assign(data.begin(), data.end());
// return res;
// }
//
// MutableColumnPtr res = func->getReturnType()->createColumn();
// res->reserve(data.size());
MutableColumnPtr res;
if (convertion(&res))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}

MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());

for (auto val : data)
func->insertResultInto(val, *res);

return res;
}

//MutableColumnPtr ColumnAggregateFunction::predictValues(std::vector<Float64> predict_feature) const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
{
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// {
// auto res = createView();
// res->set(function_state->getNestedFunction());
// res->data.assign(data.begin(), data.end());
// return res;
// }
//
// MutableColumnPtr res = func->getReturnType()->createColumn();
// res->reserve(data.size());
MutableColumnPtr res;
if (convertion(&res))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}

MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());

// const AggregateFunctionMLMethod * ML_function = typeid_cast<const AggregateFunctionMLMethod *>(func.get());
auto ML_function = typeid_cast<const AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
if (ML_function)
{
Expand All @@ -105,7 +131,8 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
++row_num;
}
} else {

throw Exception("Illegal aggregate function is passed",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

return res;
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Columns/ColumnAggregateFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class ColumnAggregateFunction final : public COWPtrHelper<IColumn, ColumnAggrega
std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; }
const char * getFamilyName() const override { return "AggregateFunction"; }

bool convertion(MutableColumnPtr* res_) const;
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments) const;

size_t size() const override
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Functions/evalMLMethod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ namespace DB

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
// завести МЛ_аггр_функции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
const ColumnAggregateFunction * column_with_states
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
if (!column_with_states)
Expand Down

0 comments on commit fd8f9c4

Please sign in to comment.