diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 52a843a7dfd..80bdf158280 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -308,6 +308,10 @@ include(ThirdpartyToolchain) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_COMMON_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARROW_CXXFLAGS}") +if (MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj") +endif() + if ("${COMPILER_FAMILY}" STREQUAL "clang") # Using Clang with ccache causes a bunch of spurious warnings that are # purportedly fixed in the next version of ccache. See the following for details: diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc index e8dc2bca83d..ea927992942 100644 --- a/cpp/src/arrow/compute/compute-test.cc +++ b/cpp/src/arrow/compute/compute-test.cc @@ -769,6 +769,65 @@ TEST_F(TestCast, OffsetOutputBuffer) { int16(), e3); } +TEST_F(TestCast, StringToBoolean) { + CastOptions options; + + vector is_valid = {true, false, true, true, true}; + + vector v1 = {"False", "true", "true", "True", "false"}; + vector v2 = {"0", "1", "1", "1", "0"}; + vector e = {false, true, true, true, false}; + CheckCase(utf8(), v1, is_valid, boolean(), + e, options); + CheckCase(utf8(), v2, is_valid, boolean(), + e, options); +} + +TEST_F(TestCast, StringToNumber) { + CastOptions options; + + vector is_valid = {true, false, true, true, true}; + + // string to int + vector v_int = {"0", "1", "127", "-1", "0"}; + vector e_int8 = {0, 1, 127, -1, 0}; + vector e_int16 = {0, 1, 127, -1, 0}; + vector e_int32 = {0, 1, 127, -1, 0}; + vector e_int64 = {0, 1, 127, -1, 0}; + CheckCase(utf8(), v_int, is_valid, int8(), + e_int8, options); + CheckCase(utf8(), v_int, is_valid, int16(), + e_int16, options); + CheckCase(utf8(), v_int, is_valid, int32(), + e_int32, options); + CheckCase(utf8(), v_int, is_valid, int64(), + e_int64, options); + + // string to uint + vector v_uint = {"0", "1", "127", "255", "0"}; + vector e_uint8 = {0, 1, 127, 255, 0}; + vector e_uint16 = {0, 1, 127, 255, 0}; + vector e_uint32 = {0, 1, 127, 255, 0}; + vector e_uint64 = {0, 1, 127, 255, 0}; + CheckCase(utf8(), v_uint, is_valid, + uint8(), e_uint8, options); + CheckCase(utf8(), v_uint, is_valid, + uint16(), e_uint16, options); + CheckCase(utf8(), v_uint, is_valid, + uint32(), e_uint32, options); + CheckCase(utf8(), v_uint, is_valid, + uint64(), e_uint64, options); + + // string to float + vector v_float = {"0.1", "1.2", "127.3", "200.4", "0.5"}; + vector e_float = {0.1f, 1.2f, 127.3f, 200.4f, 0.5f}; + vector e_double = {0.1, 1.2, 127.3, 200.4, 0.5}; + CheckCase(utf8(), v_float, is_valid, + float32(), e_float, options); + CheckCase(utf8(), v_float, is_valid, + float64(), e_double, options); +} + template class TestDictionaryCast : public TestCast {}; diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 185a966cd90..138d7e01388 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -17,6 +17,9 @@ #include "arrow/compute/kernels/cast.h" +#include +#include +#include #include #include #include @@ -735,6 +738,104 @@ struct CastFunctor +typename std::enable_if::value && !std::is_same::value && + !std::is_same::value, + T>::type +CastStringToNumeric(const std::string& s) { + return boost::lexical_cast(s); +} + +template +typename std::enable_if::value || std::is_same::value, + T>::type +CastStringToNumeric(const std::string& s) { + // Convert to int before casting to T + // because boost::lexical_cast does not support 8bit int/uint. + return boost::numeric_cast(boost::lexical_cast(s)); +} + +template +struct CastFunctor::value>::type> { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + using out_type = typename O::c_type; + StringArray input_array(input.Copy()); + + auto out_data = GetMutableValues(output, 1); + + for (int64_t i = 0; i < input.length; ++i) { + if (input_array.IsNull(i)) { + out_data++; + continue; + } + + std::string s = input_array.GetString(i); + + try { + *out_data++ = CastStringToNumeric(s); + } catch (...) { + std::stringstream ss; + ss << "Failed to cast String '" << s << "' into " << output->type->ToString(); + ctx->SetStatus(Status(StatusCode::SerializationError, ss.str())); + return; + } + } + } +}; + +// ---------------------------------------------------------------------- +// String to Boolean + +template +struct CastFunctor::value>::type> { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + StringArray input_array(input.Copy()); + internal::BitmapWriter writer(output->buffers[1]->mutable_data(), output->offset, + input.length); + + for (int64_t i = 0; i < input.length; ++i) { + if (input_array.IsNull(i)) { + writer.Next(); + continue; + } + + auto s = input_array.GetString(i); + auto s_lower = boost::algorithm::to_lower_copy(s); + bool flag; + + if (s_lower == "true") { + flag = true; + } else if (s_lower == "false") { + flag = false; + } else { + try { + flag = boost::lexical_cast(s); + } catch (...) { + std::stringstream ss; + ss << "Failed to cast String '" << s << "' into " << output->type->ToString(); + ctx->SetStatus(Status(StatusCode::SerializationError, ss.str())); + return; + } + } + + if (flag) { + writer.Set(); + } else { + writer.Clear(); + } + writer.Next(); + } + writer.Finish(); + } +}; + // ---------------------------------------------------------------------- typedef std::function& CAST_FUNCTION_CASE(Time32Type); CAST_FUNCTION_CASE(Time64Type); CAST_FUNCTION_CASE(TimestampType); + CAST_FUNCTION_CASE(StringType); CAST_FUNCTION_CASE(DictionaryType); case Type::LIST: RETURN_NOT_OK(GetListCastFunc(in_type, out_type, options, kernel));