Skip to content

Commit

Permalink
fix decimal cast bug (#7026) (#7031)
Browse files Browse the repository at this point in the history
close #6994
  • Loading branch information
ti-chi-bot authored Mar 14, 2023
1 parent 531217a commit 3236b2b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
4 changes: 2 additions & 2 deletions dbms/src/Functions/FunctionsTiDBConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ struct TiDBConvertToDecimal
const Context & context)
{
using UType = typename U::NativeType;
CastInternalType value = static_cast<CastInternalType>(v.value);
auto value = static_cast<CastInternalType>(v.value);

if (v_scale < scale)
{
Expand All @@ -952,8 +952,8 @@ struct TiDBConvertToDecimal
else if (v_scale > scale)
{
context.getDAGContext()->handleTruncateError("cast decimal as decimal");
value /= scale_mul;
const bool need_to_round = ((value < 0 ? -value : value) % scale_mul) >= (scale_mul / 2);
value /= scale_mul;
if (need_to_round)
{
if (value < 0)
Expand Down
57 changes: 55 additions & 2 deletions dbms/src/Functions/tests/gtest_tidb_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ try
}
CATCH

TEST_F(TestTidbConversion, castDecimalAsReal)
TEST_F(TestTidbConversion, castDecimalAsTime)
try
{
testReturnNull<Decimal64, MyDateTime>(DecimalField64(11, 1), std::make_tuple(19, 1), 6);
Expand All @@ -1308,6 +1308,59 @@ try
}
CATCH

TEST_F(TestTidbConversion, castDecimalAsDecimalWithRound)
try
{
DAGContext * dag_context = context.getDAGContext();
UInt64 ori_flags = dag_context->getFlags();
dag_context->addFlag(TiDBSQLFlags::TRUNCATE_AS_WARNING);
dag_context->clearWarnings();

/// decimal32 to decimal32/64/128/256
ASSERT_COLUMN_EQ(createColumn<Decimal32>(std::make_tuple(5, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal32>(std::make_tuple(5, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(5,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal64>(std::make_tuple(15, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal32>(std::make_tuple(5, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(15,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal128>(std::make_tuple(25, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal32>(std::make_tuple(5, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(25,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal256>(std::make_tuple(45, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal32>(std::make_tuple(5, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(45,2)")}));

/// decimal64 to decimal32/64/128/256
ASSERT_COLUMN_EQ(createColumn<Decimal32>(std::make_tuple(5, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal64>(std::make_tuple(15, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(5,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal64>(std::make_tuple(15, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal64>(std::make_tuple(15, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(15,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal128>(std::make_tuple(25, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal64>(std::make_tuple(15, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(25,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal256>(std::make_tuple(45, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal64>(std::make_tuple(15, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(45,2)")}));

/// decimal128 to decimal32/64/128/256
ASSERT_COLUMN_EQ(createColumn<Decimal32>(std::make_tuple(5, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal128>(std::make_tuple(25, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(5,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal64>(std::make_tuple(15, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal128>(std::make_tuple(25, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(15,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal128>(std::make_tuple(25, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal128>(std::make_tuple(25, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(25,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal256>(std::make_tuple(45, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal128>(std::make_tuple(25, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(45,2)")}));

/// decimal256 to decimal32/64/128/256
ASSERT_COLUMN_EQ(createColumn<Decimal32>(std::make_tuple(5, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal256>(std::make_tuple(45, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(5,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal64>(std::make_tuple(15, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal256>(std::make_tuple(45, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(15,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal128>(std::make_tuple(25, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal256>(std::make_tuple(45, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(25,2)")}));
ASSERT_COLUMN_EQ(createColumn<Decimal256>(std::make_tuple(45, 2), {"1.23", "1.56", "1.01", "1.00", "-1.23", "-1.56", "-1.01", "-1.00"}),
executeFunction(func_name, {createColumn<Decimal256>(std::make_tuple(45, 4), {"1.2300", "1.5600", "1.0056", "1.0023", "-1.2300", "-1.5600", "-1.0056", "-1.0023"}), createCastTypeConstColumn("Decimal(45,2)")}));

dag_context->setFlags(ori_flags);
dag_context->clearWarnings();
}
CATCH

TEST_F(TestTidbConversion, castTimeAsReal)
try
{
Expand Down Expand Up @@ -1624,7 +1677,7 @@ TEST_F(TestTidbConversion, skipCheckOverflowIntToDeciaml)
ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal<DataTypeUInt64>(uint64_ptr, prec_decimal256, scale));
}

TEST_F(TestTidbConversion, skipCheckOverflowDecimalToDeciaml)
TEST_F(TestTidbConversion, skipCheckOverflowDecimalToDecimal)
{
DataTypePtr decimal32_ptr_8_3 = createDecimal(8, 3);
DataTypePtr decimal32_ptr_8_2 = createDecimal(8, 2);
Expand Down

0 comments on commit 3236b2b

Please sign in to comment.