diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index e8bbfd34706..68a2b12379e 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include "arrow/array.h" #include "arrow/buffer.h" @@ -68,6 +69,24 @@ namespace arrow { namespace compute { +template +inline const T* GetValuesAs(const ArrayData& data, int i) { + return reinterpret_cast(data.buffers[i]->data()) + data.offset; +} + +namespace { + +void CopyData(const Array& input, ArrayData* output) { + auto in_data = input.data(); + output->length = in_data->length; + output->null_count = input.null_count(); + output->buffers = in_data->buffers; + output->offset = in_data->offset; + output->child_data = in_data->child_data; +} + +} // namespace + // ---------------------------------------------------------------------- // Zero copy casts @@ -77,7 +96,9 @@ struct is_zero_copy_cast { }; template -struct is_zero_copy_cast::value>::type> { +struct is_zero_copy_cast< + O, I, typename std::enable_if::value && + !std::is_base_of::value>::type> { static constexpr bool value = true; }; @@ -102,10 +123,7 @@ template struct CastFunctor::value>::type> { void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input, ArrayData* output) { - auto in_data = input.data(); - output->null_count = input.null_count(); - output->buffers = in_data->buffers; - output->child_data = in_data->child_data; + CopyData(input, output); } }; @@ -119,6 +137,7 @@ struct CastFunctorbuffers[1]; + DCHECK_EQ(output->offset, 0); memset(buf->mutable_data(), 0, buf->size()); } }; @@ -139,12 +158,16 @@ struct CastFunctorbuffers[1]->data(); - auto out = reinterpret_cast(output->buffers[1]->mutable_data()); constexpr auto kOne = static_cast(1); constexpr auto kZero = static_cast(0); + + auto in_data = input.data(); + internal::BitmapReader bit_reader(in_data->buffers[1]->data(), in_data->offset, + in_data->length); + auto out = reinterpret_cast(output->buffers[1]->mutable_data()); for (int64_t i = 0; i < input.length(); ++i) { - *out++ = BitUtil::GetBit(data, i) ? kOne : kZero; + *out++ = bit_reader.IsSet() ? kOne : kZero; + bit_reader.Next(); } } }; @@ -189,7 +212,9 @@ struct CastFunctor::v void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input, ArrayData* output) { using in_type = typename I::c_type; - auto in_data = reinterpret_cast(input.data()->buffers[1]->data()); + DCHECK_EQ(output->offset, 0); + + const in_type* in_data = GetValuesAs(*input.data(), 1); uint8_t* out_data = reinterpret_cast(output->buffers[1]->mutable_data()); for (int64_t i = 0; i < input.length(); ++i) { BitUtil::SetBitTo(out_data, i, (*in_data++) != 0); @@ -204,12 +229,11 @@ struct CastFunctoroffset, 0); auto in_offset = input.offset(); - const auto& input_buffers = input.data()->buffers; - - auto in_data = reinterpret_cast(input_buffers[1]->data()) + in_offset; + const in_type* in_data = GetValuesAs(*input.data(), 1); auto out_data = reinterpret_cast(output->buffers[1]->mutable_data()); if (!options.allow_int_overflow) { @@ -217,14 +241,15 @@ struct CastFunctor(std::numeric_limits::min()); if (input.null_count() > 0) { - const uint8_t* is_valid = input_buffers[0]->data(); - int64_t is_valid_offset = in_offset; + internal::BitmapReader is_valid_reader(input.data()->buffers[0]->data(), + in_offset, input.length()); for (int64_t i = 0; i < input.length(); ++i) { - if (ARROW_PREDICT_FALSE(BitUtil::GetBit(is_valid, is_valid_offset++) && + if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet() && (*in_data > kMax || *in_data < kMin))) { ctx->SetStatus(Status::Invalid("Integer value out of bounds")); } *out_data++ = static_cast(*in_data++); + is_valid_reader.Next(); } } else { for (int64_t i = 0; i < input.length(); ++i) { @@ -251,7 +276,7 @@ struct CastFunctor(input.data()->buffers[1]->data()); + const in_type* in_data = GetValuesAs(*input.data(), 1); auto out_data = reinterpret_cast(output->buffers[1]->mutable_data()); for (int64_t i = 0; i < input.length(); ++i) { *out_data++ = static_cast(*in_data++); @@ -259,6 +284,125 @@ struct CastFunctor +inline void ShiftTime(FunctionContext* ctx, const CastOptions& options, + const bool is_multiply, const int64_t factor, const Array& input, + ArrayData* output) { + const in_type* in_data = GetValuesAs(*input.data(), 1); + auto out_data = reinterpret_cast(output->buffers[1]->mutable_data()); + + if (is_multiply) { + for (int64_t i = 0; i < input.length(); i++) { + out_data[i] = static_cast(in_data[i] * factor); + } + } else { + if (options.allow_time_truncate) { + for (int64_t i = 0; i < input.length(); i++) { + out_data[i] = static_cast(in_data[i] / factor); + } + } else { + for (int64_t i = 0; i < input.length(); i++) { + out_data[i] = static_cast(in_data[i] / factor); + if (input.IsValid(i) && (out_data[i] * factor != in_data[i])) { + std::stringstream ss; + ss << "Casting from " << input.type()->ToString() << " to " + << output->type->ToString() << " would lose data: " << in_data[i]; + ctx->SetStatus(Status::Invalid(ss.str())); + break; + } + } + } + } +} + +namespace { + +// {is_multiply, factor} +const std::pair kTimeConversionTable[4][4] = { + {{true, 1}, {true, 1000}, {true, 1000000}, {true, 1000000000L}}, // SECOND + {{false, 1000}, {true, 1}, {true, 1000}, {true, 1000000}}, // MILLI + {{false, 1000000}, {false, 1000}, {true, 1}, {true, 1000}}, // MICRO + {{false, 1000000000L}, {false, 1000000}, {false, 1000}, {true, 1}}, // NANO +}; + +} // namespace + +template <> +struct CastFunctor { + void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input, + ArrayData* output) { + // If units are the same, zero copy, otherwise convert + const auto& in_type = static_cast(*input.type()); + const auto& out_type = static_cast(*output->type); + + if (in_type.unit() == out_type.unit()) { + CopyData(input, output); + return; + } + + std::pair conversion = + kTimeConversionTable[static_cast(in_type.unit())] + [static_cast(out_type.unit())]; + + ShiftTime(ctx, options, conversion.first, conversion.second, input, + output); + } +}; + +// ---------------------------------------------------------------------- +// From one time32 or time64 to another + +template +struct CastFunctor::value && + std::is_base_of::value>::type> { + void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input, + ArrayData* output) { + using in_t = typename I::c_type; + using out_t = typename O::c_type; + + // If units are the same, zero copy, otherwise convert + const auto& in_type = static_cast(*input.type()); + const auto& out_type = static_cast(*output->type); + + if (in_type.unit() == out_type.unit()) { + CopyData(input, output); + return; + } + + std::pair conversion = + kTimeConversionTable[static_cast(in_type.unit())] + [static_cast(out_type.unit())]; + + ShiftTime(ctx, options, conversion.first, conversion.second, input, + output); + } +}; + +// ---------------------------------------------------------------------- +// Between date32 and date64 + +constexpr int64_t kMillisecondsInDay = 86400000; + +template <> +struct CastFunctor { + void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input, + ArrayData* output) { + ShiftTime(ctx, options, true, kMillisecondsInDay, input, output); + } +}; + +template <> +struct CastFunctor { + void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input, + ArrayData* output) { + ShiftTime(ctx, options, false, kMillisecondsInDay, input, output); + } +}; + // ---------------------------------------------------------------------- // Dictionary to other things @@ -271,9 +415,8 @@ void UnpackFixedSizeBinaryDictionary(FunctionContext* ctx, const Array& indices, internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(), indices.offset(), indices.length()); - const index_c_type* in = - reinterpret_cast(indices.data()->buffers[1]->data()) + - indices.offset(); + const index_c_type* in = GetValuesAs(*indices.data(), 1); + uint8_t* out = output->buffers[1]->mutable_data(); int32_t byte_width = static_cast(*output->type).byte_width(); @@ -336,9 +479,7 @@ Status UnpackBinaryDictionary(FunctionContext* ctx, const Array& indices, internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(), indices.offset(), indices.length()); - const index_c_type* in = - reinterpret_cast(indices.data()->buffers[1]->data()) + - indices.offset(); + const index_c_type* in = GetValuesAs(*indices.data(), 1); for (int64_t i = 0; i < indices.length(); ++i) { if (valid_bits_reader.IsSet()) { int32_t length; @@ -409,9 +550,7 @@ void UnpackPrimitiveDictionary(const Array& indices, const c_type* dictionary, internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(), indices.offset(), indices.length()); - const index_c_type* in = - reinterpret_cast(indices.data()->buffers[1]->data()) + - indices.offset(); + const index_c_type* in = GetValuesAs(*indices.data(), 1); for (int64_t i = 0; i < indices.length(); ++i) { if (valid_bits_reader.IsSet()) { out[i] = dictionary[in[i]]; @@ -436,9 +575,8 @@ struct CastFunctortype)) << "Dictionary type: " << values_type << " target type: " << (*output->type); - auto dictionary = - reinterpret_cast(type.dictionary()->data()->buffers[1]->data()) + - type.dictionary()->offset(); + const c_type* dictionary = GetValuesAs(*type.dictionary()->data(), 1); + auto out = reinterpret_cast(output->buffers[1]->mutable_data()); const Array& indices = *dict_array.indices(); switch (indices.type()->id()) { @@ -481,6 +619,9 @@ static Status AllocateIfNotPreallocated(FunctionContext* ctx, const Array& input int64_t bitmap_size = BitUtil::BytesForBits(length); RETURN_NOT_OK(ctx->Allocate(bitmap_size, &validity_bitmap)); memset(validity_bitmap->mutable_data(), 0, bitmap_size); + } else if (input.offset() != 0) { + RETURN_NOT_OK(CopyBitmap(ctx->memory_pool(), validity_bitmap->data(), input.offset(), + length, &validity_bitmap)); } if (out->buffers.size() == 2) { @@ -598,13 +739,21 @@ class CastKernel : public UnaryKernel { FN(Int64Type, Time64Type); \ FN(Int64Type, Date64Type); -#define DATE32_CASES(FN, IN_TYPE) FN(Date32Type, Date32Type); +#define DATE32_CASES(FN, IN_TYPE) \ + FN(Date32Type, Date32Type); \ + FN(Date32Type, Date64Type); -#define DATE64_CASES(FN, IN_TYPE) FN(Date64Type, Date64Type); +#define DATE64_CASES(FN, IN_TYPE) \ + FN(Date64Type, Date64Type); \ + FN(Date64Type, Date32Type); -#define TIME32_CASES(FN, IN_TYPE) FN(Time32Type, Time32Type); +#define TIME32_CASES(FN, IN_TYPE) \ + FN(Time32Type, Time32Type); \ + FN(Time32Type, Time64Type); -#define TIME64_CASES(FN, IN_TYPE) FN(Time64Type, Time64Type); +#define TIME64_CASES(FN, IN_TYPE) \ + FN(Time64Type, Time32Type); \ + FN(Time64Type, Time64Type); #define TIMESTAMP_CASES(FN, IN_TYPE) FN(TimestampType, TimestampType); diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 7a07512b2ad..d7bde20d607 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -34,9 +34,10 @@ class FunctionContext; class UnaryKernel; struct CastOptions { - CastOptions() : allow_int_overflow(false) {} + CastOptions() : allow_int_overflow(false), allow_time_truncate(false) {} bool allow_int_overflow; + bool allow_time_truncate; }; /// \since 0.7.0 diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc index 8a595178d05..8a7ef923b47 100644 --- a/cpp/src/arrow/compute/compute-test.cc +++ b/cpp/src/arrow/compute/compute-test.cc @@ -68,7 +68,7 @@ class TestCast : public ComputeFixture, public TestBase { const std::shared_ptr& out_type, const CastOptions& options) { std::shared_ptr result; ASSERT_OK(Cast(&ctx_, input, out_type, options, &result)); - AssertArraysEqual(expected, *result); + ASSERT_ARRAYS_EQUAL(expected, *result); } template @@ -105,6 +105,11 @@ class TestCast : public ComputeFixture, public TestBase { ArrayFromVector(out_type, out_values, &expected); } CheckPass(*input, *expected, out_type, options); + + // Check a sliced variant + if (input->length() > 1) { + CheckPass(*input->Slice(1), *expected->Slice(1), out_type, options); + } } }; @@ -270,6 +275,205 @@ TEST_F(TestCast, ToIntDowncastUnsafe) { options); } +TEST_F(TestCast, TimestampToTimestamp) { + CastOptions options; + + auto CheckTimestampCast = [this]( + const CastOptions& options, TimeUnit::type from_unit, TimeUnit::type to_unit, + const std::vector& from_values, const std::vector& to_values, + const std::vector& is_valid) { + CheckCase( + timestamp(from_unit), from_values, is_valid, timestamp(to_unit), to_values, + options); + }; + + vector is_valid = {true, false, true, true, true}; + + // Multiply promotions + vector v1 = {0, 100, 200, 1, 2}; + vector e1 = {0, 100000, 200000, 1000, 2000}; + CheckTimestampCast(options, TimeUnit::SECOND, TimeUnit::MILLI, v1, e1, is_valid); + + vector v2 = {0, 100, 200, 1, 2}; + vector e2 = {0, 100000000L, 200000000L, 1000000, 2000000}; + CheckTimestampCast(options, TimeUnit::SECOND, TimeUnit::MICRO, v2, e2, is_valid); + + vector v3 = {0, 100, 200, 1, 2}; + vector e3 = {0, 100000000000L, 200000000000L, 1000000000L, 2000000000L}; + CheckTimestampCast(options, TimeUnit::SECOND, TimeUnit::NANO, v3, e3, is_valid); + + vector v4 = {0, 100, 200, 1, 2}; + vector e4 = {0, 100000, 200000, 1000, 2000}; + CheckTimestampCast(options, TimeUnit::MILLI, TimeUnit::MICRO, v4, e4, is_valid); + + vector v5 = {0, 100, 200, 1, 2}; + vector e5 = {0, 100000000L, 200000000L, 1000000, 2000000}; + CheckTimestampCast(options, TimeUnit::MILLI, TimeUnit::NANO, v5, e5, is_valid); + + vector v6 = {0, 100, 200, 1, 2}; + vector e6 = {0, 100000, 200000, 1000, 2000}; + CheckTimestampCast(options, TimeUnit::MICRO, TimeUnit::NANO, v6, e6, is_valid); + + // Zero copy + std::shared_ptr arr; + vector v7 = {0, 70000, 2000, 1000, 0}; + ArrayFromVector(timestamp(TimeUnit::SECOND), is_valid, v7, + &arr); + CheckZeroCopy(*arr, timestamp(TimeUnit::SECOND)); + + // Divide, truncate + vector v8 = {0, 100123, 200456, 1123, 2456}; + vector e8 = {0, 100, 200, 1, 2}; + + options.allow_time_truncate = true; + CheckTimestampCast(options, TimeUnit::MILLI, TimeUnit::SECOND, v8, e8, is_valid); + CheckTimestampCast(options, TimeUnit::MICRO, TimeUnit::MILLI, v8, e8, is_valid); + CheckTimestampCast(options, TimeUnit::NANO, TimeUnit::MICRO, v8, e8, is_valid); + + vector v9 = {0, 100123000, 200456000, 1123000, 2456000}; + vector e9 = {0, 100, 200, 1, 2}; + CheckTimestampCast(options, TimeUnit::MICRO, TimeUnit::SECOND, v9, e9, is_valid); + CheckTimestampCast(options, TimeUnit::NANO, TimeUnit::MILLI, v9, e9, is_valid); + + vector v10 = {0, 100123000000L, 200456000000L, 1123000000L, 2456000000}; + vector e10 = {0, 100, 200, 1, 2}; + CheckTimestampCast(options, TimeUnit::NANO, TimeUnit::SECOND, v10, e10, is_valid); + + // Disallow truncate, failures + options.allow_time_truncate = false; + CheckFails(timestamp(TimeUnit::MILLI), v8, is_valid, + timestamp(TimeUnit::SECOND), options); + CheckFails(timestamp(TimeUnit::MICRO), v8, is_valid, + timestamp(TimeUnit::MILLI), options); + CheckFails(timestamp(TimeUnit::NANO), v8, is_valid, + timestamp(TimeUnit::MICRO), options); + CheckFails(timestamp(TimeUnit::MICRO), v9, is_valid, + timestamp(TimeUnit::SECOND), options); + CheckFails(timestamp(TimeUnit::NANO), v9, is_valid, + timestamp(TimeUnit::MILLI), options); + CheckFails(timestamp(TimeUnit::NANO), v10, is_valid, + timestamp(TimeUnit::SECOND), options); +} + +TEST_F(TestCast, TimeToTime) { + CastOptions options; + + vector is_valid = {true, false, true, true, true}; + + // Multiply promotions + vector v1 = {0, 100, 200, 1, 2}; + vector e1 = {0, 100000, 200000, 1000, 2000}; + CheckCase( + time32(TimeUnit::SECOND), v1, is_valid, time32(TimeUnit::MILLI), e1, options); + + vector v2 = {0, 100, 200, 1, 2}; + vector e2 = {0, 100000000L, 200000000L, 1000000, 2000000}; + CheckCase( + time32(TimeUnit::SECOND), v2, is_valid, time64(TimeUnit::MICRO), e2, options); + + vector v3 = {0, 100, 200, 1, 2}; + vector e3 = {0, 100000000000L, 200000000000L, 1000000000L, 2000000000L}; + CheckCase( + time32(TimeUnit::SECOND), v3, is_valid, time64(TimeUnit::NANO), e3, options); + + vector v4 = {0, 100, 200, 1, 2}; + vector e4 = {0, 100000, 200000, 1000, 2000}; + CheckCase( + time32(TimeUnit::MILLI), v4, is_valid, time64(TimeUnit::MICRO), e4, options); + + vector v5 = {0, 100, 200, 1, 2}; + vector e5 = {0, 100000000L, 200000000L, 1000000, 2000000}; + CheckCase( + time32(TimeUnit::MILLI), v5, is_valid, time64(TimeUnit::NANO), e5, options); + + vector v6 = {0, 100, 200, 1, 2}; + vector e6 = {0, 100000, 200000, 1000, 2000}; + CheckCase( + time64(TimeUnit::MICRO), v6, is_valid, time64(TimeUnit::NANO), e6, options); + + // Zero copy + std::shared_ptr arr; + vector v7 = {0, 70000, 2000, 1000, 0}; + ArrayFromVector(time64(TimeUnit::MICRO), is_valid, v7, &arr); + CheckZeroCopy(*arr, time64(TimeUnit::MICRO)); + + // Divide, truncate + vector v8 = {0, 100123, 200456, 1123, 2456}; + vector e8 = {0, 100, 200, 1, 2}; + + options.allow_time_truncate = true; + CheckCase( + time32(TimeUnit::MILLI), v8, is_valid, time32(TimeUnit::SECOND), e8, options); + CheckCase( + time64(TimeUnit::MICRO), v8, is_valid, time32(TimeUnit::MILLI), e8, options); + CheckCase( + time64(TimeUnit::NANO), v8, is_valid, time64(TimeUnit::MICRO), e8, options); + + vector v9 = {0, 100123000, 200456000, 1123000, 2456000}; + vector e9 = {0, 100, 200, 1, 2}; + CheckCase( + time64(TimeUnit::MICRO), v9, is_valid, time32(TimeUnit::SECOND), e9, options); + CheckCase( + time64(TimeUnit::NANO), v9, is_valid, time32(TimeUnit::MILLI), e9, options); + + vector v10 = {0, 100123000000L, 200456000000L, 1123000000L, 2456000000}; + vector e10 = {0, 100, 200, 1, 2}; + CheckCase( + time64(TimeUnit::NANO), v10, is_valid, time32(TimeUnit::SECOND), e10, options); + + // Disallow truncate, failures + + options.allow_time_truncate = false; + CheckFails(time32(TimeUnit::MILLI), v8, is_valid, time32(TimeUnit::SECOND), + options); + CheckFails(time64(TimeUnit::MICRO), v8, is_valid, time32(TimeUnit::MILLI), + options); + CheckFails(time64(TimeUnit::NANO), v8, is_valid, time64(TimeUnit::MICRO), + options); + CheckFails(time64(TimeUnit::MICRO), v9, is_valid, time32(TimeUnit::SECOND), + options); + CheckFails(time64(TimeUnit::NANO), v9, is_valid, time32(TimeUnit::MILLI), + options); + CheckFails(time64(TimeUnit::NANO), v10, is_valid, time32(TimeUnit::SECOND), + options); +} + +TEST_F(TestCast, DateToDate) { + CastOptions options; + + vector is_valid = {true, false, true, true, true}; + + constexpr int64_t F = 86400000; + + // Multiply promotion + vector v1 = {0, 100, 200, 1, 2}; + vector e1 = {0, 100 * F, 200 * F, F, 2 * F}; + CheckCase(date32(), v1, is_valid, date64(), + e1, options); + + // Zero copy + std::shared_ptr arr; + vector v2 = {0, 70000, 2000, 1000, 0}; + vector v3 = {0, 70000, 2000, 1000, 0}; + ArrayFromVector(date32(), is_valid, v2, &arr); + CheckZeroCopy(*arr, date32()); + + ArrayFromVector(date64(), is_valid, v3, &arr); + CheckZeroCopy(*arr, date64()); + + // Divide, truncate + vector v8 = {0, 100 * F + 123, 200 * F + 456, F + 123, 2 * F + 456}; + vector e8 = {0, 100, 200, 1, 2}; + + options.allow_time_truncate = true; + CheckCase(date64(), v8, is_valid, date32(), + e8, options); + + // Disallow truncate, failures + options.allow_time_truncate = false; + CheckFails(date64(), v8, is_valid, date32(), options); +} + TEST_F(TestCast, ToDouble) { CastOptions options; vector is_valid = {true, false, true, true, true}; @@ -335,7 +539,7 @@ TEST_F(TestCast, FromNull) { ASSERT_EQ(length, result->null_count()); // OK to look at bitmaps - AssertArraysEqual(*result, *result); + ASSERT_ARRAYS_EQUAL(*result, *result); } TEST_F(TestCast, PreallocatedMemory) { @@ -373,7 +577,7 @@ TEST_F(TestCast, PreallocatedMemory) { std::shared_ptr expected; ArrayFromVector(int64(), is_valid, e1, &expected); - AssertArraysEqual(*expected, *result); + ASSERT_ARRAYS_EQUAL(*expected, *result); } template diff --git a/cpp/src/arrow/test-util.h b/cpp/src/arrow/test-util.h index 83ebdea4a85..044fb9476ca 100644 --- a/cpp/src/arrow/test-util.h +++ b/cpp/src/arrow/test-util.h @@ -281,15 +281,20 @@ Status MakeArray(const std::vector& valid_bytes, const std::vector& return builder->Finish(out); } -void AssertArraysEqual(const Array& expected, const Array& actual) { - if (!actual.Equals(expected)) { - std::stringstream pp_result; - std::stringstream pp_expected; +#define ASSERT_ARRAYS_EQUAL(LEFT, RIGHT) \ + do { \ + if (!(LEFT).Equals((RIGHT))) { \ + std::stringstream pp_result; \ + std::stringstream pp_expected; \ + \ + EXPECT_OK(PrettyPrint(RIGHT, 0, &pp_result)); \ + EXPECT_OK(PrettyPrint(LEFT, 0, &pp_expected)); \ + FAIL() << "Got: \n" << pp_result.str() << "\nExpected: \n" << pp_expected.str(); \ + } \ + } while (false) - EXPECT_OK(PrettyPrint(actual, 0, &pp_result)); - EXPECT_OK(PrettyPrint(expected, 0, &pp_expected)); - FAIL() << "Got: \n" << pp_result.str() << "\nExpected: \n" << pp_expected.str(); - } +void AssertArraysEqual(const Array& expected, const Array& actual) { + ASSERT_ARRAYS_EQUAL(expected, actual); } #define ASSERT_BATCHES_EQUAL(LEFT, RIGHT) \ diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 443828423e7..878fdf29efe 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -228,7 +228,11 @@ class ARROW_EXPORT FloatingPoint : public Number { virtual Precision precision() const = 0; }; -class ARROW_EXPORT NestedType : public DataType { +/// \class ParametricType +/// \brief A superclass for types having additional metadata +class ParametricType {}; + +class ARROW_EXPORT NestedType : public DataType, public ParametricType { public: using DataType::DataType; }; @@ -444,7 +448,7 @@ class ARROW_EXPORT BinaryType : public DataType, public NoExtraMeta { }; // BinaryType type is represents lists of 1-byte values. -class ARROW_EXPORT FixedSizeBinaryType : public FixedWidthType { +class ARROW_EXPORT FixedSizeBinaryType : public FixedWidthType, public ParametricType { public: static constexpr Type::type type_id = Type::FIXED_SIZE_BINARY; @@ -611,7 +615,7 @@ static inline std::ostream& operator<<(std::ostream& os, TimeUnit::type unit) { return os; } -class ARROW_EXPORT TimeType : public FixedWidthType { +class ARROW_EXPORT TimeType : public FixedWidthType, public ParametricType { public: TimeUnit::type unit() const { return unit_; } @@ -650,7 +654,7 @@ class ARROW_EXPORT Time64Type : public TimeType { std::string name() const override { return "time64"; } }; -class ARROW_EXPORT TimestampType : public FixedWidthType { +class ARROW_EXPORT TimestampType : public FixedWidthType, public ParametricType { public: using Unit = TimeUnit; diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index c596d2ad8e7..72262f0c981 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -260,8 +260,8 @@ cdef class Array: type = _ensure_type(target_type) - if not safe: - options.allow_int_overflow = 1 + options.allow_int_overflow = not safe + options.allow_time_truncate = not safe with nogil: check_status(Cast(_context(), self.ap[0], type.sp_type, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 0e5d4a8eddc..809bb96b7a4 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -747,6 +747,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CCastOptions" arrow::compute::CastOptions": c_bool allow_int_overflow + c_bool allow_time_truncate CStatus Cast(CFunctionContext* context, const CArray& array, const shared_ptr[CDataType]& to_type, diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 418076f8196..e3a4c97567e 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -290,6 +290,32 @@ def test_cast_integers_unsafe(): _check_cast_case(case, safe=False) +def test_cast_timestamp_unit(): + # ARROW-1680 + val = datetime.datetime.now() + s = pd.Series([val]) + s_nyc = s.dt.tz_localize('tzlocal()').dt.tz_convert('America/New_York') + + us_with_tz = pa.timestamp('us', tz='America/New_York') + arr = pa.Array.from_pandas(s_nyc, type=us_with_tz) + + arr2 = pa.Array.from_pandas(s, type=pa.timestamp('us')) + + assert arr[0].as_py() == s_nyc[0] + assert arr2[0].as_py() == s[0] + + # Disallow truncation + arr = pa.array([123123], type='int64').cast(pa.timestamp('ms')) + expected = pa.array([123], type='int64').cast(pa.timestamp('s')) + + target = pa.timestamp('s') + with pytest.raises(ValueError): + arr.cast(target) + + result = arr.cast(target, safe=False) + assert result.equals(expected) + + def test_cast_signed_to_unsigned(): safe_cases = [ (np.array([0, 1, 2, 3], dtype='i1'), pa.uint8(),