Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions cpp/src/arrow/compute/compute-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <cstdint>
#include <cstdlib>
#include <locale>
#include <memory>
#include <numeric>
#include <sstream>
Expand Down Expand Up @@ -769,6 +770,123 @@ TEST_F(TestCast, OffsetOutputBuffer) {
int16(), e3);
}

TEST_F(TestCast, StringToBoolean) {
CastOptions options;

vector<bool> is_valid = {true, false, true, true, true};

vector<std::string> v1 = {"False", "true", "true", "True", "false"};
vector<std::string> v2 = {"0", "1", "1", "1", "0"};
vector<bool> e = {false, true, true, true, false};
CheckCase<StringType, std::string, BooleanType, bool>(utf8(), v1, is_valid, boolean(),
e, options);
CheckCase<StringType, std::string, BooleanType, bool>(utf8(), v2, is_valid, boolean(),
e, options);
}

TEST_F(TestCast, StringToBooleanErrors) {
CastOptions options;

vector<bool> is_valid = {true};

CheckFails<StringType, std::string>(utf8(), {"false "}, is_valid, boolean(), options);
CheckFails<StringType, std::string>(utf8(), {"T"}, is_valid, boolean(), options);
}

TEST_F(TestCast, StringToNumber) {
CastOptions options;

vector<bool> is_valid = {true, false, true, true, true};

// string to int
vector<std::string> v_int = {"0", "1", "127", "-1", "0"};
vector<int8_t> e_int8 = {0, 1, 127, -1, 0};
vector<int16_t> e_int16 = {0, 1, 127, -1, 0};
vector<int32_t> e_int32 = {0, 1, 127, -1, 0};
vector<int64_t> e_int64 = {0, 1, 127, -1, 0};
CheckCase<StringType, std::string, Int8Type, int8_t>(utf8(), v_int, is_valid, int8(),
e_int8, options);
CheckCase<StringType, std::string, Int16Type, int16_t>(utf8(), v_int, is_valid, int16(),
e_int16, options);
CheckCase<StringType, std::string, Int32Type, int32_t>(utf8(), v_int, is_valid, int32(),
e_int32, options);
CheckCase<StringType, std::string, Int64Type, int64_t>(utf8(), v_int, is_valid, int64(),
e_int64, options);

v_int = {"2147483647", "0", "-2147483648", "0", "0"};
e_int32 = {2147483647, 0, -2147483648LL, 0, 0};
CheckCase<StringType, std::string, Int32Type, int32_t>(utf8(), v_int, is_valid, int32(),
e_int32, options);
v_int = {"9223372036854775807", "0", "-9223372036854775808", "0", "0"};
e_int64 = {9223372036854775807LL, 0, (-9223372036854775807LL - 1), 0, 0};
CheckCase<StringType, std::string, Int64Type, int64_t>(utf8(), v_int, is_valid, int64(),
e_int64, options);

// string to uint
vector<std::string> v_uint = {"0", "1", "127", "255", "0"};
vector<uint8_t> e_uint8 = {0, 1, 127, 255, 0};
vector<uint16_t> e_uint16 = {0, 1, 127, 255, 0};
vector<uint32_t> e_uint32 = {0, 1, 127, 255, 0};
vector<uint64_t> e_uint64 = {0, 1, 127, 255, 0};
CheckCase<StringType, std::string, UInt8Type, uint8_t>(utf8(), v_uint, is_valid,
uint8(), e_uint8, options);
CheckCase<StringType, std::string, UInt16Type, uint16_t>(utf8(), v_uint, is_valid,
uint16(), e_uint16, options);
CheckCase<StringType, std::string, UInt32Type, uint32_t>(utf8(), v_uint, is_valid,
uint32(), e_uint32, options);
CheckCase<StringType, std::string, UInt64Type, uint64_t>(utf8(), v_uint, is_valid,
uint64(), e_uint64, options);

v_uint = {"4294967295", "0", "0", "0", "0"};
e_uint32 = {4294967295, 0, 0, 0, 0};
CheckCase<StringType, std::string, UInt32Type, uint32_t>(utf8(), v_uint, is_valid,
uint32(), e_uint32, options);
v_uint = {"18446744073709551615", "0", "0", "0", "0"};
e_uint64 = {18446744073709551615ULL, 0, 0, 0, 0};
CheckCase<StringType, std::string, UInt64Type, uint64_t>(utf8(), v_uint, is_valid,
uint64(), e_uint64, options);

// string to float
vector<std::string> v_float = {"0.1", "1.2", "127.3", "200.4", "0.5"};
vector<float> e_float = {0.1f, 1.2f, 127.3f, 200.4f, 0.5f};
vector<double> e_double = {0.1, 1.2, 127.3, 200.4, 0.5};
CheckCase<StringType, std::string, FloatType, float>(utf8(), v_float, is_valid,
float32(), e_float, options);
CheckCase<StringType, std::string, DoubleType, double>(utf8(), v_float, is_valid,
float64(), e_double, options);

// Test that casting is locale-independent
auto global_locale = std::locale();
try {
// French locale uses the comma as decimal point
std::locale::global(std::locale("fr_FR.UTF-8"));
} catch (std::runtime_error) {
// Locale unavailable, ignore
}
CheckCase<StringType, std::string, FloatType, float>(utf8(), v_float, is_valid,
float32(), e_float, options);
CheckCase<StringType, std::string, DoubleType, double>(utf8(), v_float, is_valid,
float64(), e_double, options);
std::locale::global(global_locale);
}

TEST_F(TestCast, StringToNumberErrors) {
CastOptions options;

vector<bool> is_valid = {true};

CheckFails<StringType, std::string>(utf8(), {"z"}, is_valid, int8(), options);
CheckFails<StringType, std::string>(utf8(), {"12 z"}, is_valid, int8(), options);
CheckFails<StringType, std::string>(utf8(), {"128"}, is_valid, int8(), options);
CheckFails<StringType, std::string>(utf8(), {"-129"}, is_valid, int8(), options);
CheckFails<StringType, std::string>(utf8(), {"0.5"}, is_valid, int8(), options);

CheckFails<StringType, std::string>(utf8(), {"256"}, is_valid, uint8(), options);
CheckFails<StringType, std::string>(utf8(), {"-1"}, is_valid, uint8(), options);

CheckFails<StringType, std::string>(utf8(), {"z"}, is_valid, float32(), options);
}

template <typename TestType>
class TestDictionaryCast : public TestCast {};

Expand Down
190 changes: 190 additions & 0 deletions cpp/src/arrow/compute/kernels/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

#include "arrow/compute/kernels/cast.h"

#include <cerrno>
#include <cstdint>
#include <cstring>
#include <functional>
#include <limits>
#include <locale>
#include <memory>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -727,6 +729,178 @@ struct CastFunctor<T, DictionaryType,
}
};

// ----------------------------------------------------------------------
// String to Number

// Cast a string to a number. Returns true on success, false on error.
// We rely on C++ istringstream for locale-independent parsing, which might
// not be the fastest option.

template <typename T>
typename std::enable_if<std::is_floating_point<T>::value,
bool>::type static CastStringToNumber(std::istringstream& ibuf,
T* out) {
ibuf >> *out;
return !ibuf.fail() && ibuf.eof();
}

// For integers, not all integer widths are handled by the C++ stdlib, so
// we check for limits outselves.

template <typename T>
typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value,
bool>::type static CastStringToNumber(std::istringstream& ibuf,
T* out) {
static constexpr bool need_long_long = sizeof(T) > sizeof(long); // NOLINT
static constexpr T min_value = std::numeric_limits<T>::min();
static constexpr T max_value = std::numeric_limits<T>::max();

if (need_long_long) {
long long res; // NOLINT
ibuf >> res;
*out = static_cast<T>(res); // may downcast
if (res < min_value || res > max_value) {
return false;
}
} else {
long res; // NOLINT
ibuf >> res;
*out = static_cast<T>(res); // may downcast
if (res < min_value || res > max_value) {
return false;
}
}
return !ibuf.fail() && ibuf.eof();
}

template <typename T>
typename std::enable_if<std::is_integral<T>::value && std::is_unsigned<T>::value,
bool>::type static CastStringToNumber(std::istringstream& ibuf,
T* out) {
static constexpr bool need_long_long = sizeof(T) > sizeof(unsigned long); // NOLINT
static constexpr T max_value = std::numeric_limits<T>::max();

if (need_long_long) {
unsigned long long res; // NOLINT
ibuf >> res;
*out = static_cast<T>(res); // may downcast
if (res > max_value) {
return false;
}
} else {
unsigned long res; // NOLINT
ibuf >> res;
*out = static_cast<T>(res); // may downcast
if (res > max_value) {
return false;
}
}
return !ibuf.fail() && ibuf.eof();
}

template <typename O>
struct CastFunctor<O, StringType, enable_if_number<O>> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
using out_type = typename O::c_type;

StringArray input_array(input.Copy());
auto out_data = GetMutableValues<out_type>(output, 1);
errno = 0;
// Instantiate the stringstream outside of the loop
std::istringstream ibuf;
ibuf.imbue(std::locale::classic());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I sometimes wonder what the meeting was like that decided names like this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same thought :-)


for (int64_t i = 0; i < input.length; ++i, ++out_data) {
if (input_array.IsNull(i)) {
continue;
}
auto str = input_array.GetString(i);
ibuf.clear();
ibuf.str(str);
if (!CastStringToNumber(ibuf, out_data)) {
std::stringstream ss;
ss << "Failed to cast String '" << str << "' into " << output->type->ToString();
ctx->SetStatus(Status(StatusCode::Invalid, ss.str()));
return;
}
}
}
};

// ----------------------------------------------------------------------
// String to Boolean

// Helper function to cast a C string to a boolean. Returns true on success,
// false on error.

static bool CastStringtoBoolean(const char* s, size_t length, bool* out) {
if (length == 1) {
// "0" or "1"?
if (s[0] == '0') {
*out = false;
return true;
}
if (s[0] == '1') {
*out = true;
return true;
}
return false;
}
if (length == 4) {
// "true"?
*out = true;
return ((s[0] == 't' || s[0] == 'T') && (s[1] == 'r' || s[1] == 'R') &&
(s[2] == 'u' || s[2] == 'U') && (s[3] == 'e' || s[3] == 'E'));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a little more readable by use tolower?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tolower is locale-dependent.

}
if (length == 5) {
// "false"?
*out = false;
return ((s[0] == 'f' || s[0] == 'F') && (s[1] == 'a' || s[1] == 'A') &&
(s[2] == 'l' || s[2] == 'L') && (s[3] == 's' || s[3] == 'S') &&
(s[4] == 'e' || s[4] == 'E'));
}
return false;
}

template <typename O>
struct CastFunctor<O, StringType,
typename std::enable_if<std::is_same<BooleanType, O>::value>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
StringArray input_array(input.Copy());
internal::FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(),
output->offset, input.length);

for (int64_t i = 0; i < input.length; ++i) {
if (input_array.IsNull(i)) {
writer.Next();
continue;
}

int32_t length = -1;
auto str = input_array.GetValue(i, &length);
bool value;
if (!CastStringtoBoolean(reinterpret_cast<const char*>(str),
static_cast<size_t>(length), &value)) {
std::stringstream ss;
ss << "Failed to cast String '" << input_array.GetString(i) << "' into "
<< output->type->ToString();
ctx->SetStatus(Status(StatusCode::Invalid, ss.str()));
return;
}

if (value) {
writer.Set();
} else {
writer.Clear();
}
writer.Next();
}
writer.Finish();
}
};

// ----------------------------------------------------------------------

typedef std::function<void(FunctionContext*, const CastOptions& options, const ArrayData&,
Expand Down Expand Up @@ -905,6 +1079,20 @@ class CastKernel : public UnaryKernel {
FN(TimestampType, Date64Type); \
FN(TimestampType, Int64Type);

#define STRING_CASES(FN, IN_TYPE) \
FN(StringType, StringType); \
FN(StringType, BooleanType); \
FN(StringType, UInt8Type); \
FN(StringType, Int8Type); \
FN(StringType, UInt16Type); \
FN(StringType, Int16Type); \
FN(StringType, UInt32Type); \
FN(StringType, Int32Type); \
FN(StringType, UInt64Type); \
FN(StringType, Int64Type); \
FN(StringType, FloatType); \
FN(StringType, DoubleType);

#define DICTIONARY_CASES(FN, IN_TYPE) \
FN(IN_TYPE, NullType); \
FN(IN_TYPE, Time32Type); \
Expand Down Expand Up @@ -962,6 +1150,7 @@ GET_CAST_FUNCTION(DATE64_CASES, Date64Type);
GET_CAST_FUNCTION(TIME32_CASES, Time32Type);
GET_CAST_FUNCTION(TIME64_CASES, Time64Type);
GET_CAST_FUNCTION(TIMESTAMP_CASES, TimestampType);
GET_CAST_FUNCTION(STRING_CASES, StringType);
GET_CAST_FUNCTION(DICTIONARY_CASES, DictionaryType);

#define CAST_FUNCTION_CASE(InType) \
Expand Down Expand Up @@ -1009,6 +1198,7 @@ Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>&
CAST_FUNCTION_CASE(Time32Type);
CAST_FUNCTION_CASE(Time64Type);
CAST_FUNCTION_CASE(TimestampType);
CAST_FUNCTION_CASE(StringType);
CAST_FUNCTION_CASE(DictionaryType);
case Type::LIST:
RETURN_NOT_OK(GetListCastFunc(in_type, out_type, options, kernel));
Expand Down