diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc index 8bf7d1de246..a1dfdefef9b 100644 --- a/cpp/src/arrow/compute/compute-test.cc +++ b/cpp/src/arrow/compute/compute-test.cc @@ -184,12 +184,6 @@ TEST_F(TestCast, ToIntUpcast) { vector e3 = {0, 100, 200, 255, 0}; CheckCase(uint8(), v3, is_valid, int16(), e3, options); - - // floating point to integer - vector v4 = {1.5, 0, 0.5, -1.5, 5.5}; - vector e4 = {1, 0, 0, -1, 5}; - CheckCase(float64(), v4, is_valid, int32(), e4, - options); } TEST_F(TestCast, OverflowInNullSlot) { @@ -218,32 +212,32 @@ TEST_F(TestCast, ToIntDowncastSafe) { vector is_valid = {true, false, true, true, true}; // int16 to uint8, no overflow/underrun - vector v5 = {0, 100, 200, 1, 2}; - vector e5 = {0, 100, 200, 1, 2}; - CheckCase(int16(), v5, is_valid, uint8(), e5, + vector v1 = {0, 100, 200, 1, 2}; + vector e1 = {0, 100, 200, 1, 2}; + CheckCase(int16(), v1, is_valid, uint8(), e1, options); // int16 to uint8, with overflow - vector v6 = {0, 100, 256, 0, 0}; - CheckFails(int16(), v6, is_valid, uint8(), options); + vector v2 = {0, 100, 256, 0, 0}; + CheckFails(int16(), v2, is_valid, uint8(), options); // underflow - vector v7 = {0, 100, -1, 0, 0}; - CheckFails(int16(), v7, is_valid, uint8(), options); + vector v3 = {0, 100, -1, 0, 0}; + CheckFails(int16(), v3, is_valid, uint8(), options); // int32 to int16, no overflow - vector v8 = {0, 1000, 2000, 1, 2}; - vector e8 = {0, 1000, 2000, 1, 2}; - CheckCase(int32(), v8, is_valid, int16(), e8, + vector v4 = {0, 1000, 2000, 1, 2}; + vector e4 = {0, 1000, 2000, 1, 2}; + CheckCase(int32(), v4, is_valid, int16(), e4, options); // int32 to int16, overflow - vector v9 = {0, 1000, 2000, 70000, 0}; - CheckFails(int32(), v9, is_valid, int16(), options); + vector v5 = {0, 1000, 2000, 70000, 0}; + CheckFails(int32(), v5, is_valid, int16(), options); // underflow - vector v10 = {0, 1000, 2000, -70000, 0}; - CheckFails(int32(), v9, is_valid, int16(), options); + vector v6 = {0, 1000, 2000, -70000, 0}; + CheckFails(int32(), v6, is_valid, int16(), options); } TEST_F(TestCast, ToIntDowncastUnsafe) { @@ -253,41 +247,75 @@ TEST_F(TestCast, ToIntDowncastUnsafe) { vector is_valid = {true, false, true, true, true}; // int16 to uint8, no overflow/underrun - vector v5 = {0, 100, 200, 1, 2}; - vector e5 = {0, 100, 200, 1, 2}; - CheckCase(int16(), v5, is_valid, uint8(), e5, + vector v1 = {0, 100, 200, 1, 2}; + vector e1 = {0, 100, 200, 1, 2}; + CheckCase(int16(), v1, is_valid, uint8(), e1, options); // int16 to uint8, with overflow - vector v6 = {0, 100, 256, 0, 0}; - vector e6 = {0, 100, 0, 0, 0}; - CheckCase(int16(), v6, is_valid, uint8(), e6, + vector v2 = {0, 100, 256, 0, 0}; + vector e2 = {0, 100, 0, 0, 0}; + CheckCase(int16(), v2, is_valid, uint8(), e2, options); // underflow - vector v7 = {0, 100, -1, 0, 0}; - vector e7 = {0, 100, 255, 0, 0}; - CheckCase(int16(), v7, is_valid, uint8(), e7, + vector v3 = {0, 100, -1, 0, 0}; + vector e3 = {0, 100, 255, 0, 0}; + CheckCase(int16(), v3, is_valid, uint8(), e3, options); // int32 to int16, no overflow - vector v8 = {0, 1000, 2000, 1, 2}; - vector e8 = {0, 1000, 2000, 1, 2}; - CheckCase(int32(), v8, is_valid, int16(), e8, + vector v4 = {0, 1000, 2000, 1, 2}; + vector e4 = {0, 1000, 2000, 1, 2}; + CheckCase(int32(), v4, is_valid, int16(), e4, options); // int32 to int16, overflow // TODO(wesm): do we want to allow this? we could set to null - vector v9 = {0, 1000, 2000, 70000, 0}; - vector e9 = {0, 1000, 2000, 4464, 0}; - CheckCase(int32(), v9, is_valid, int16(), e9, + vector v5 = {0, 1000, 2000, 70000, 0}; + vector e5 = {0, 1000, 2000, 4464, 0}; + CheckCase(int32(), v5, is_valid, int16(), e5, options); // underflow // TODO(wesm): do we want to allow this? we could set overflow to null - vector v10 = {0, 1000, 2000, -70000, 0}; - vector e10 = {0, 1000, 2000, -4464, 0}; - CheckCase(int32(), v10, is_valid, int16(), e10, + vector v6 = {0, 1000, 2000, -70000, 0}; + vector e6 = {0, 1000, 2000, -4464, 0}; + CheckCase(int32(), v6, is_valid, int16(), e6, + options); +} + +TEST_F(TestCast, FloatingPointToInt) { + auto options = CastOptions::Safe(); + + vector is_valid = {true, false, true, true, true}; + vector all_valid = {true, true, true, true, true}; + + // float32 point to integer + vector v1 = {1.5, 0, 0.5, -1.5, 5.5}; + vector e1 = {1, 0, 0, -1, 5}; + CheckCase(float32(), v1, is_valid, int32(), e1, + options); + CheckCase(float32(), v1, all_valid, int32(), e1, + options); + + // float64 point to integer + vector v2 = {1.0, 0, 0.0, -1.0, 5.0}; + vector e2 = {1, 0, 0, -1, 5}; + CheckCase(float64(), v2, is_valid, int32(), e2, + options); + CheckCase(float64(), v2, all_valid, int32(), e2, + options); + + vector v3 = {1.5, 0, 0.5, -1.5, 5.5}; + vector e3 = {1, 0, 0, -1, 5}; + CheckFails(float64(), v3, is_valid, int32(), options); + CheckFails(float64(), v3, all_valid, int32(), options); + + options.allow_float_truncate = true; + CheckCase(float64(), v3, is_valid, int32(), e3, + options); + CheckCase(float64(), v3, all_valid, int32(), e3, options); } @@ -982,18 +1010,14 @@ TEST_F(TestCast, ListToList) { ASSERT_OK( ListArray::FromArrays(*offsets, *float64_plain_array, pool_, &float64_list_array)); - this->CheckPass(*int32_list_array, *int64_list_array, int64_list_array->type(), - options); - this->CheckPass(*int32_list_array, *float64_list_array, float64_list_array->type(), - options); - this->CheckPass(*int64_list_array, *int32_list_array, int32_list_array->type(), - options); - this->CheckPass(*int64_list_array, *float64_list_array, float64_list_array->type(), - options); - this->CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(), - options); - this->CheckPass(*float64_list_array, *int64_list_array, int64_list_array->type(), - options); + CheckPass(*int32_list_array, *int64_list_array, int64_list_array->type(), options); + CheckPass(*int32_list_array, *float64_list_array, float64_list_array->type(), options); + CheckPass(*int64_list_array, *int32_list_array, int32_list_array->type(), options); + CheckPass(*int64_list_array, *float64_list_array, float64_list_array->type(), options); + + options.allow_float_truncate = true; + CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(), options); + CheckPass(*float64_list_array, *int64_list_array, int64_list_array->type(), options); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 1101ce708ad..2a0479d68b4 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -193,6 +193,23 @@ struct is_integer_downcast< (sizeof(O_T) < sizeof(I_T)))); }; +template +struct is_float_downcast { + static constexpr bool value = false; +}; + +template +struct is_float_downcast< + O, I, + typename std::enable_if::value && + std::is_base_of::value>::type> { + using O_T = typename O::c_type; + using I_T = typename I::c_type; + + // Smaller output size + static constexpr bool value = !std::is_same::value && (sizeof(O_T) < sizeof(I_T)); +}; + template struct CastFunctor::value && @@ -252,9 +269,54 @@ struct CastFunctor +struct CastFunctor::value>::type> { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + using in_type = typename I::c_type; + using out_type = typename O::c_type; + + auto in_offset = input.offset; + const in_type* in_data = GetValues(input, 1); + auto out_data = GetMutableValues(output, 1); + + if (options.allow_float_truncate) { + // unsafe cast + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = static_cast(*in_data++); + } + } else { + // safe cast + if (input.null_count != 0) { + internal::BitmapReader is_valid_reader(input.buffers[0]->data(), in_offset, + input.length); + for (int64_t i = 0; i < input.length; ++i) { + auto out_value = static_cast(*in_data); + if (ARROW_PREDICT_FALSE(out_value != *in_data)) { + ctx->SetStatus(Status::Invalid("Floating point value truncated")); + } + *out_data++ = out_value; + in_data++; + is_valid_reader.Next(); + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + auto out_value = static_cast(*in_data); + if (ARROW_PREDICT_FALSE(out_value != *in_data)) { + ctx->SetStatus(Status::Invalid("Floating point value truncated")); + } + *out_data++ = out_value; + in_data++; + } + } + } + } +}; + template struct CastFunctor::value && + !is_float_downcast::value && !is_integer_downcast::value>::type> { void operator()(FunctionContext* ctx, const CastOptions& options, const ArrayData& input, ArrayData* output) { diff --git a/cpp/src/arrow/compute/kernels/cast.h b/cpp/src/arrow/compute/kernels/cast.h index b75bb7b6c15..8392c188dfd 100644 --- a/cpp/src/arrow/compute/kernels/cast.h +++ b/cpp/src/arrow/compute/kernels/cast.h @@ -35,10 +35,23 @@ class DataType; namespace compute { struct ARROW_EXPORT CastOptions { - CastOptions() : allow_int_overflow(false), allow_time_truncate(false) {} + CastOptions() + : allow_int_overflow(false), + allow_time_truncate(false), + allow_float_truncate(true) {} + + explicit CastOptions(bool safe) + : allow_int_overflow(!safe), + allow_time_truncate(!safe), + allow_float_truncate(!safe) {} + + static CastOptions Safe() { return CastOptions(true); } + + static CastOptions Unsafe() { return CastOptions(false); } bool allow_int_overflow; bool allow_time_truncate; + bool allow_float_truncate; }; /// \since 0.7.0