diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc index ba5c93546a6..cd4b2bb30e6 100644 --- a/cpp/src/arrow/compute/compute-test.cc +++ b/cpp/src/arrow/compute/compute-test.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -769,6 +770,123 @@ TEST_F(TestCast, OffsetOutputBuffer) { int16(), e3); } +TEST_F(TestCast, StringToBoolean) { + CastOptions options; + + vector is_valid = {true, false, true, true, true}; + + vector v1 = {"False", "true", "true", "True", "false"}; + vector v2 = {"0", "1", "1", "1", "0"}; + vector e = {false, true, true, true, false}; + CheckCase(utf8(), v1, is_valid, boolean(), + e, options); + CheckCase(utf8(), v2, is_valid, boolean(), + e, options); +} + +TEST_F(TestCast, StringToBooleanErrors) { + CastOptions options; + + vector is_valid = {true}; + + CheckFails(utf8(), {"false "}, is_valid, boolean(), options); + CheckFails(utf8(), {"T"}, is_valid, boolean(), options); +} + +TEST_F(TestCast, StringToNumber) { + CastOptions options; + + vector is_valid = {true, false, true, true, true}; + + // string to int + vector v_int = {"0", "1", "127", "-1", "0"}; + vector e_int8 = {0, 1, 127, -1, 0}; + vector e_int16 = {0, 1, 127, -1, 0}; + vector e_int32 = {0, 1, 127, -1, 0}; + vector e_int64 = {0, 1, 127, -1, 0}; + CheckCase(utf8(), v_int, is_valid, int8(), + e_int8, options); + CheckCase(utf8(), v_int, is_valid, int16(), + e_int16, options); + CheckCase(utf8(), v_int, is_valid, int32(), + e_int32, options); + CheckCase(utf8(), v_int, is_valid, int64(), + e_int64, options); + + v_int = {"2147483647", "0", "-2147483648", "0", "0"}; + e_int32 = {2147483647, 0, -2147483648LL, 0, 0}; + CheckCase(utf8(), v_int, is_valid, int32(), + e_int32, options); + v_int = {"9223372036854775807", "0", "-9223372036854775808", "0", "0"}; + e_int64 = {9223372036854775807LL, 0, (-9223372036854775807LL - 1), 0, 0}; + CheckCase(utf8(), v_int, is_valid, int64(), + e_int64, options); + + // string to uint + vector v_uint = {"0", "1", "127", "255", "0"}; + vector e_uint8 = {0, 1, 127, 255, 0}; + vector e_uint16 = {0, 1, 127, 255, 0}; + vector e_uint32 = {0, 1, 127, 255, 0}; + vector e_uint64 = {0, 1, 127, 255, 0}; + CheckCase(utf8(), v_uint, is_valid, + uint8(), e_uint8, options); + CheckCase(utf8(), v_uint, is_valid, + uint16(), e_uint16, options); + CheckCase(utf8(), v_uint, is_valid, + uint32(), e_uint32, options); + CheckCase(utf8(), v_uint, is_valid, + uint64(), e_uint64, options); + + v_uint = {"4294967295", "0", "0", "0", "0"}; + e_uint32 = {4294967295, 0, 0, 0, 0}; + CheckCase(utf8(), v_uint, is_valid, + uint32(), e_uint32, options); + v_uint = {"18446744073709551615", "0", "0", "0", "0"}; + e_uint64 = {18446744073709551615ULL, 0, 0, 0, 0}; + CheckCase(utf8(), v_uint, is_valid, + uint64(), e_uint64, options); + + // string to float + vector v_float = {"0.1", "1.2", "127.3", "200.4", "0.5"}; + vector e_float = {0.1f, 1.2f, 127.3f, 200.4f, 0.5f}; + vector e_double = {0.1, 1.2, 127.3, 200.4, 0.5}; + CheckCase(utf8(), v_float, is_valid, + float32(), e_float, options); + CheckCase(utf8(), v_float, is_valid, + float64(), e_double, options); + + // Test that casting is locale-independent + auto global_locale = std::locale(); + try { + // French locale uses the comma as decimal point + std::locale::global(std::locale("fr_FR.UTF-8")); + } catch (std::runtime_error) { + // Locale unavailable, ignore + } + CheckCase(utf8(), v_float, is_valid, + float32(), e_float, options); + CheckCase(utf8(), v_float, is_valid, + float64(), e_double, options); + std::locale::global(global_locale); +} + +TEST_F(TestCast, StringToNumberErrors) { + CastOptions options; + + vector is_valid = {true}; + + CheckFails(utf8(), {"z"}, is_valid, int8(), options); + CheckFails(utf8(), {"12 z"}, is_valid, int8(), options); + CheckFails(utf8(), {"128"}, is_valid, int8(), options); + CheckFails(utf8(), {"-129"}, is_valid, int8(), options); + CheckFails(utf8(), {"0.5"}, is_valid, int8(), options); + + CheckFails(utf8(), {"256"}, is_valid, uint8(), options); + CheckFails(utf8(), {"-1"}, is_valid, uint8(), options); + + CheckFails(utf8(), {"z"}, is_valid, float32(), options); +} + template class TestDictionaryCast : public TestCast {}; diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 39925d78358..8b14b7b5664 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -17,10 +17,12 @@ #include "arrow/compute/kernels/cast.h" +#include #include #include #include #include +#include #include #include #include @@ -727,6 +729,178 @@ struct CastFunctor +typename std::enable_if::value, + bool>::type static CastStringToNumber(std::istringstream& ibuf, + T* out) { + ibuf >> *out; + return !ibuf.fail() && ibuf.eof(); +} + +// For integers, not all integer widths are handled by the C++ stdlib, so +// we check for limits outselves. + +template +typename std::enable_if::value && std::is_signed::value, + bool>::type static CastStringToNumber(std::istringstream& ibuf, + T* out) { + static constexpr bool need_long_long = sizeof(T) > sizeof(long); // NOLINT + static constexpr T min_value = std::numeric_limits::min(); + static constexpr T max_value = std::numeric_limits::max(); + + if (need_long_long) { + long long res; // NOLINT + ibuf >> res; + *out = static_cast(res); // may downcast + if (res < min_value || res > max_value) { + return false; + } + } else { + long res; // NOLINT + ibuf >> res; + *out = static_cast(res); // may downcast + if (res < min_value || res > max_value) { + return false; + } + } + return !ibuf.fail() && ibuf.eof(); +} + +template +typename std::enable_if::value && std::is_unsigned::value, + bool>::type static CastStringToNumber(std::istringstream& ibuf, + T* out) { + static constexpr bool need_long_long = sizeof(T) > sizeof(unsigned long); // NOLINT + static constexpr T max_value = std::numeric_limits::max(); + + if (need_long_long) { + unsigned long long res; // NOLINT + ibuf >> res; + *out = static_cast(res); // may downcast + if (res > max_value) { + return false; + } + } else { + unsigned long res; // NOLINT + ibuf >> res; + *out = static_cast(res); // may downcast + if (res > max_value) { + return false; + } + } + return !ibuf.fail() && ibuf.eof(); +} + +template +struct CastFunctor> { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + using out_type = typename O::c_type; + + StringArray input_array(input.Copy()); + auto out_data = GetMutableValues(output, 1); + errno = 0; + // Instantiate the stringstream outside of the loop + std::istringstream ibuf; + ibuf.imbue(std::locale::classic()); + + for (int64_t i = 0; i < input.length; ++i, ++out_data) { + if (input_array.IsNull(i)) { + continue; + } + auto str = input_array.GetString(i); + ibuf.clear(); + ibuf.str(str); + if (!CastStringToNumber(ibuf, out_data)) { + std::stringstream ss; + ss << "Failed to cast String '" << str << "' into " << output->type->ToString(); + ctx->SetStatus(Status(StatusCode::Invalid, ss.str())); + return; + } + } + } +}; + +// ---------------------------------------------------------------------- +// String to Boolean + +// Helper function to cast a C string to a boolean. Returns true on success, +// false on error. + +static bool CastStringtoBoolean(const char* s, size_t length, bool* out) { + if (length == 1) { + // "0" or "1"? + if (s[0] == '0') { + *out = false; + return true; + } + if (s[0] == '1') { + *out = true; + return true; + } + return false; + } + if (length == 4) { + // "true"? + *out = true; + return ((s[0] == 't' || s[0] == 'T') && (s[1] == 'r' || s[1] == 'R') && + (s[2] == 'u' || s[2] == 'U') && (s[3] == 'e' || s[3] == 'E')); + } + if (length == 5) { + // "false"? + *out = false; + return ((s[0] == 'f' || s[0] == 'F') && (s[1] == 'a' || s[1] == 'A') && + (s[2] == 'l' || s[2] == 'L') && (s[3] == 's' || s[3] == 'S') && + (s[4] == 'e' || s[4] == 'E')); + } + return false; +} + +template +struct CastFunctor::value>::type> { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + StringArray input_array(input.Copy()); + internal::FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(), + output->offset, input.length); + + for (int64_t i = 0; i < input.length; ++i) { + if (input_array.IsNull(i)) { + writer.Next(); + continue; + } + + int32_t length = -1; + auto str = input_array.GetValue(i, &length); + bool value; + if (!CastStringtoBoolean(reinterpret_cast(str), + static_cast(length), &value)) { + std::stringstream ss; + ss << "Failed to cast String '" << input_array.GetString(i) << "' into " + << output->type->ToString(); + ctx->SetStatus(Status(StatusCode::Invalid, ss.str())); + return; + } + + if (value) { + writer.Set(); + } else { + writer.Clear(); + } + writer.Next(); + } + writer.Finish(); + } +}; + // ---------------------------------------------------------------------- typedef std::function& CAST_FUNCTION_CASE(Time32Type); CAST_FUNCTION_CASE(Time64Type); CAST_FUNCTION_CASE(TimestampType); + CAST_FUNCTION_CASE(StringType); CAST_FUNCTION_CASE(DictionaryType); case Type::LIST: RETURN_NOT_OK(GetListCastFunc(in_type, out_type, options, kernel));