Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix encode error for bit/date/datetime type #189

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/DAGBlockOutputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void DAGBlockOutputStream::write(const Block & block)
for (size_t j = 0; j < block.columns(); j++)
{
auto field = (*block.getByPosition(j).column.get())[i];
EncodeDatum(field, getCodecFlagByFieldType(result_field_types[j]), current_ss);
EncodeDatum(field, getCodecFlagByFieldType(result_field_types[j]), current_ss, block.getByPosition(j).type);
}
// Encode current row
records_per_chunk++;
Expand Down
8 changes: 8 additions & 0 deletions dbms/src/Flash/Coprocessor/DAGCodec.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <Flash/Coprocessor/DAGCodec.h>

#include <Storages/Transaction/Codec.h>
#include <Storages/Transaction/DateTimeInfo.h>
#include <Storages/Transaction/TiKVRecordFormat.h>

namespace DB
Expand Down Expand Up @@ -62,4 +63,11 @@ Decimal decodeDAGDecimal(const String & s)
return DecodeDecimal(cursor, s);
}

DateTimeInfo decodeDAGDateTime(const String & s)
{
UInt64 packed = decodeDAGUInt64(s);
DateTimeInfo info(packed);
return info;
}

} // namespace DB
2 changes: 2 additions & 0 deletions dbms/src/Flash/Coprocessor/DAGCodec.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <Common/Decimal.h>
#include <Core/Field.h>
#include <Storages/Transaction/DateTimeInfo.h>

namespace DB
{
Expand All @@ -21,5 +22,6 @@ Float64 decodeDAGFloat64(const String &);
String decodeDAGString(const String &);
String decodeDAGBytes(const String &);
Decimal decodeDAGDecimal(const String &);
DateTimeInfo decodeDAGDateTime(const String &);

} // namespace DB
6 changes: 3 additions & 3 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(const tipb::Expr & expr, Expres
type_expr.set_tp(tipb::ExprType::String);
std::stringstream ss;
type_expr.set_val(expected_type->getName());
auto type_field_type = type_expr.field_type();
type_field_type.set_tp(0xfe);
type_field_type.set_flag(1);
auto * type_field_type = type_expr.mutable_field_type();
type_field_type->set_tp(0xfe);
type_field_type->set_flag(1);
getActions(type_expr, actions);

Names cast_argument_names;
Expand Down
21 changes: 20 additions & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ String exprToString(const tipb::Expr & expr, const NamesAndTypesList & input_col
return decodeDAGBytes(expr.val());
case tipb::ExprType::MysqlDecimal:
return decodeDAGDecimal(expr.val()).toString();
case tipb::ExprType::MysqlTime:
// todo use the time zone info in dag request
if (expr.has_field_type() && expr.field_type().tp() == TiDB::TypeDate)
{
return std::to_string((UInt64)decodeDAGDateTime(expr.val()).makeDayNum(DateLUT::instance()));
}
else
{
return std::to_string(decodeDAGDateTime(expr.val()).makeDateTime(DateLUT::instance()));
}
case tipb::ExprType::ColumnRef:
column_id = decodeDAGInt64(expr.val());
if (column_id < 0 || column_id >= (ColumnID)input_col.size())
Expand Down Expand Up @@ -210,12 +220,21 @@ Field decodeLiteral(const tipb::Expr & expr)
return decodeDAGBytes(expr.val());
case tipb::ExprType::MysqlDecimal:
return decodeDAGDecimal(expr.val());
case tipb::ExprType::MysqlTime:
// todo use the time zone info in dag request
if (expr.has_field_type() && expr.field_type().tp() == TiDB::TypeDate)
{
return (UInt64)decodeDAGDateTime(expr.val()).makeDayNum(DateLUT::instance());
}
else
{
return (Int64)decodeDAGDateTime(expr.val()).makeDateTime(DateLUT::instance());
}
case tipb::ExprType::MysqlBit:
case tipb::ExprType::MysqlDuration:
case tipb::ExprType::MysqlEnum:
case tipb::ExprType::MysqlHex:
case tipb::ExprType::MysqlSet:
case tipb::ExprType::MysqlTime:
case tipb::ExprType::MysqlJson:
case tipb::ExprType::ValueList:
throw Exception(tipb::ExprType_Name(expr.tp()) + " is not supported yet", ErrorCodes::UNSUPPORTED_METHOD);
Expand Down
173 changes: 172 additions & 1 deletion dbms/src/Functions/FunctionsComparison.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,113 @@ inline int memcmp16(const void * a, const void * b)
return 0;
}

inline time_t dateToDateTime(UInt32 date_data)
{
DayNum_t day_num(date_data);
LocalDate local_date(day_num);
// todo use timezone info
return DateLUT::instance().makeDateTime(local_date.year(), local_date.month(), local_date.day(), 0, 0, 0);
}

template <typename A, typename B, template <typename, typename> class Op, bool is_left_date>
struct DateDateTimeComparisonImpl
{
static void NO_INLINE vector_vector(const PaddedPODArray<A> & a, const PaddedPODArray<B> & b, PaddedPODArray<UInt8> & c)
{
size_t size = a.size();
const A * a_pos = &a[0];
const B * b_pos = &b[0];
UInt8 * c_pos = &c[0];
const A * a_end = a_pos + size;
while (a_pos < a_end)
{
if (is_left_date)
{
using OpType = B;
time_t date_time = dateToDateTime(*a_pos);
*c_pos = Op<OpType, OpType>::apply((OpType)date_time, *b_pos);
}
else
{
using OpType = A;
time_t date_time = dateToDateTime(*b_pos);
*c_pos = Op<OpType, OpType>::apply(*a_pos, (OpType)date_time);
}
++a_pos;
++b_pos;
++c_pos;
}
}

static void NO_INLINE vector_constant(const PaddedPODArray<A> & a, B b, PaddedPODArray<UInt8> & c)
{
if (!is_left_date)
{
// datetime vector with date constant
using OpType = A;
time_t date_time = dateToDateTime(b);
NumComparisonImpl<OpType, OpType, Op<OpType, OpType>>::vector_constant(a, (OpType) date_time, c);
}
else
{
using OpType = B;
size_t size = a.size();
const A * a_pos = &a[0];
UInt8 * c_pos = &c[0];
const A * a_end = a_pos + size;

while (a_pos < a_end)
{
time_t date_time = dateToDateTime(*a_pos);
*c_pos = Op<OpType, OpType>::apply((OpType)date_time, b);
++a_pos;
++c_pos;
}
}
}

static void constant_vector(A a, const PaddedPODArray<B> & b, PaddedPODArray<UInt8> & c)
{
if (is_left_date)
{
// date constant with datetime vector
using OpType = B;
time_t date_time = dateToDateTime(a);
NumComparisonImpl<OpType, OpType, Op<OpType, OpType>>::constant_vector((OpType)date_time, b, c);
}
else
{
using OpType = A;
size_t size = b.size();
const B * b_pos = &b[0];
UInt8 * c_pos = &c[0];
const B * b_end = b_pos + size;

while (b_pos < b_end)
{
time_t date_time = dateToDateTime(*b_pos);
*c_pos = Op<OpType, OpType>::apply(a, (OpType)date_time);
++b_pos;
++c_pos;
}
}
}

static void constant_constant(A a, B b, UInt8 & c) {
if (is_left_date)
{
using OpType = B;
time_t date_time = dateToDateTime(a);
NumComparisonImpl<OpType, OpType, Op<OpType, OpType>>::constant_constant((OpType) date_time, b, c);
}
else
{
using OpType = A;
time_t date_time = dateToDateTime(b);
NumComparisonImpl<OpType, OpType, Op<OpType, OpType>>::constant_constant(a, (OpType) date_time, c);
}
}
};

template <typename Op>
struct StringComparisonImpl
Expand Down Expand Up @@ -1009,6 +1116,67 @@ class FunctionComparison : public IFunction
}
}

bool executeDateWithDateTimeOrDateTimeWithDate(
Block &block, size_t result, const IColumn *col_left_untyped, const IColumn *col_right_untyped,
const DataTypePtr &left_type, const DataTypePtr &right_type)
{
if ((checkDataType<DataTypeDate>(left_type.get()) && checkDataType<DataTypeDateTime>(right_type.get()))
|| (checkDataType<DataTypeDateTime>(left_type.get()) && checkDataType<DataTypeDate>(right_type.get())))
{
bool is_left_date = checkDataType<DataTypeDate>(left_type.get());
if (is_left_date)
{
return executeDateAndDateTimeCompare<UInt32, Int64, true>(block, result, col_left_untyped, col_right_untyped);
}
else
{
return executeDateAndDateTimeCompare<Int64, UInt32, false>(block, result, col_left_untyped, col_right_untyped);
}
}
return false;
}

template <typename T0, typename T1, bool is_left_date>
bool executeDateAndDateTimeCompare(Block & block, size_t result, const IColumn * c0, const IColumn * c1) {
bool c0_const = c0->isColumnConst();
bool c1_const = c1->isColumnConst();

if (c0_const && c1_const)
{
UInt8 res = 0;
DateDateTimeComparisonImpl<T0, T1, Op, is_left_date>::constant_constant(
checkAndGetColumnConst<ColumnVector<T0>>(c0)->template getValue<T0>(),
checkAndGetColumnConst<ColumnVector<T1>>(c1)-> template getValue<T1>(), res);
block.getByPosition(result).column = DataTypeUInt8().createColumnConst(c0->size(), toField(res));
}
else
{
auto c_res = ColumnUInt8::create();
ColumnUInt8::Container & vec_res = c_res->getData();
vec_res.resize(c0->size());
if (c0_const)
{
DateDateTimeComparisonImpl<T0, T1, Op, is_left_date>::constant_vector(
checkAndGetColumnConst<ColumnVector<T0>>(c0)-> template getValue<T0>(),
checkAndGetColumn<ColumnVector<T1>>(c1)->getData(), vec_res);
}
else if (c1_const)
{
DateDateTimeComparisonImpl<T0, T1, Op, is_left_date>::vector_constant(
checkAndGetColumn<ColumnVector<T0>>(c0)->getData(),
checkAndGetColumnConst<ColumnVector<T1>>(c1)-> template getValue<T1>(), vec_res);
}
else
{
DateDateTimeComparisonImpl<T0, T1, Op, true>::vector_vector(
checkAndGetColumn<ColumnVector<T0>>(c0)->getData(),
checkAndGetColumn<ColumnVector<T1>>(c1)->getData(), vec_res);
}
block.getByPosition(result).column = std::move(c_res);
}
return true;
}

public:
String getName() const override
{
Expand Down Expand Up @@ -1107,7 +1275,10 @@ class FunctionComparison : public IFunction

if (left_is_num && right_is_num)
{
if (!( executeNumLeftType<UInt8>(block, result, col_left_untyped, col_right_untyped)
if (!(executeDateWithDateTimeOrDateTimeWithDate(block, result, col_left_untyped, col_right_untyped,
col_with_type_and_name_left.type,
col_with_type_and_name_right.type)
|| executeNumLeftType<UInt8>(block, result, col_left_untyped, col_right_untyped)
|| executeNumLeftType<UInt16>(block, result, col_left_untyped, col_right_untyped)
|| executeNumLeftType<UInt32>(block, result, col_left_untyped, col_right_untyped)
|| executeNumLeftType<UInt64>(block, result, col_left_untyped, col_right_untyped)
Expand Down
32 changes: 31 additions & 1 deletion dbms/src/Storages/Transaction/Codec.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
#include <Storages/Transaction/Codec.h>

#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/IDataType.h>
#include <Storages/Transaction/DateTimeInfo.h>
#include <Storages/Transaction/TiDB.h>
#include <Storages/Transaction/TiKVVarInt.h>
#include <Storages/Transaction/TypeMapping.h>

namespace DB
{
Expand Down Expand Up @@ -296,14 +303,33 @@ inline T getFieldValue(const Field & field)
}
}

void EncodeDatum(const Field & field, TiDB::CodecFlag flag, std::stringstream & ss)
void EncodeDateTime(const Field & field, const DataTypePtr & ch_type, std::stringstream & ss)
{
UInt64 packed_value;
auto * date_type = typeid_cast<const DataTypeDate *>(ch_type.get());
if (date_type != nullptr)
{
packed_value = DateTimeInfo(getFieldValue<UInt32>(field)).packedToUInt64();
}
else
{
auto * date_time_type = typeid_cast<const DataTypeDateTime *>(ch_type.get());
packed_value = DateTimeInfo(getFieldValue<Int64>(field), date_time_type->getTimeZone()).packedToUInt64();
}
EncodeUInt<UInt64>(packed_value, ss);
}

void EncodeDatum(const Field & field, TiDB::CodecFlag flag, std::stringstream & ss, const DataTypePtr & ch_type)
{
if (field.isNull())
{
ss << UInt8(TiDB::CodecFlagNil);
return;
}
ss << UInt8(flag);
auto non_nullable_type = ch_type == nullptr
? nullptr
: (ch_type->isNullable() ? std::dynamic_pointer_cast<const DataTypeNullable>(ch_type)->getNestedType() : ch_type);
switch (flag)
{
case TiDB::CodecFlagDecimal:
Expand All @@ -313,6 +339,10 @@ void EncodeDatum(const Field & field, TiDB::CodecFlag flag, std::stringstream &
case TiDB::CodecFlagFloat:
return EncodeFloat64(getFieldValue<Float64>(field), ss);
case TiDB::CodecFlagUInt:
if (non_nullable_type != nullptr && non_nullable_type->isDateOrDateTime())
{
return EncodeDateTime(field, non_nullable_type, ss);
}
return EncodeUInt<UInt64>(getFieldValue<UInt64>(field), ss);
case TiDB::CodecFlagInt:
return EncodeInt64(getFieldValue<Int64>(field), ss);
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/Storages/Transaction/Codec.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <Common/Decimal.h>
#include <Core/Field.h>
#include <DataTypes/IDataType.h>
#include <IO/Endian.h>
#include <Storages/Transaction/TiDB.h>

Expand Down Expand Up @@ -62,6 +63,6 @@ void EncodeVarUInt(UInt64 num, std::stringstream & ss);

void EncodeDecimal(const Decimal & dec, std::stringstream & ss);

void EncodeDatum(const Field & field, TiDB::CodecFlag flag, std::stringstream & ss);
void EncodeDatum(const Field & field, TiDB::CodecFlag flag, std::stringstream & ss, const DataTypePtr & ch_type = nullptr);

} // namespace DB
Loading