Skip to content

Commit c84aac8

Browse files
pitrouwesm
authored andcommitted
ARROW-1491: [C++] Add casting from strings to numbers and booleans
The implementation for numbers uses the C standard strto* functions. This makes casting a bit lenient (it will accept whitespace).
1 parent edfbf84 commit c84aac8

File tree

2 files changed

+306
-0
lines changed

2 files changed

+306
-0
lines changed

cpp/src/arrow/compute/compute-test.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,109 @@ TEST_F(TestCast, OffsetOutputBuffer) {
769769
int16(), e3);
770770
}
771771

772+
TEST_F(TestCast, StringToBoolean) {
773+
CastOptions options;
774+
775+
vector<bool> is_valid = {true, false, true, true, true};
776+
777+
vector<std::string> v1 = {"False", "true", "true", "True", "false"};
778+
vector<std::string> v2 = {"0", "1", "1", "1", "0"};
779+
vector<bool> e = {false, true, true, true, false};
780+
CheckCase<StringType, std::string, BooleanType, bool>(utf8(), v1, is_valid, boolean(),
781+
e, options);
782+
CheckCase<StringType, std::string, BooleanType, bool>(utf8(), v2, is_valid, boolean(),
783+
e, options);
784+
}
785+
786+
TEST_F(TestCast, StringToBooleanErrors) {
787+
CastOptions options;
788+
789+
vector<bool> is_valid = {true};
790+
791+
CheckFails<StringType, std::string>(utf8(), {"false "}, is_valid, boolean(), options);
792+
CheckFails<StringType, std::string>(utf8(), {"T"}, is_valid, boolean(), options);
793+
}
794+
795+
TEST_F(TestCast, StringToNumber) {
796+
CastOptions options;
797+
798+
vector<bool> is_valid = {true, false, true, true, true};
799+
800+
// string to int
801+
vector<std::string> v_int = {"0", "1", "127", "-1", "0"};
802+
vector<int8_t> e_int8 = {0, 1, 127, -1, 0};
803+
vector<int16_t> e_int16 = {0, 1, 127, -1, 0};
804+
vector<int32_t> e_int32 = {0, 1, 127, -1, 0};
805+
vector<int64_t> e_int64 = {0, 1, 127, -1, 0};
806+
CheckCase<StringType, std::string, Int8Type, int8_t>(utf8(), v_int, is_valid, int8(),
807+
e_int8, options);
808+
CheckCase<StringType, std::string, Int16Type, int16_t>(utf8(), v_int, is_valid, int16(),
809+
e_int16, options);
810+
CheckCase<StringType, std::string, Int32Type, int32_t>(utf8(), v_int, is_valid, int32(),
811+
e_int32, options);
812+
CheckCase<StringType, std::string, Int64Type, int64_t>(utf8(), v_int, is_valid, int64(),
813+
e_int64, options);
814+
815+
v_int = {"2147483647", "0", "-2147483648", "0", "0"};
816+
e_int32 = {2147483647, 0, -2147483648LL, 0, 0};
817+
CheckCase<StringType, std::string, Int32Type, int32_t>(utf8(), v_int, is_valid, int32(),
818+
e_int32, options);
819+
v_int = {"9223372036854775807", "0", "-9223372036854775808", "0", "0"};
820+
e_int64 = {9223372036854775807LL, 0, (-9223372036854775807LL - 1), 0, 0};
821+
CheckCase<StringType, std::string, Int64Type, int64_t>(utf8(), v_int, is_valid, int64(),
822+
e_int64, options);
823+
824+
// string to uint
825+
vector<std::string> v_uint = {"0", "1", "127", "255", "0"};
826+
vector<uint8_t> e_uint8 = {0, 1, 127, 255, 0};
827+
vector<uint16_t> e_uint16 = {0, 1, 127, 255, 0};
828+
vector<uint32_t> e_uint32 = {0, 1, 127, 255, 0};
829+
vector<uint64_t> e_uint64 = {0, 1, 127, 255, 0};
830+
CheckCase<StringType, std::string, UInt8Type, uint8_t>(utf8(), v_uint, is_valid,
831+
uint8(), e_uint8, options);
832+
CheckCase<StringType, std::string, UInt16Type, uint16_t>(utf8(), v_uint, is_valid,
833+
uint16(), e_uint16, options);
834+
CheckCase<StringType, std::string, UInt32Type, uint32_t>(utf8(), v_uint, is_valid,
835+
uint32(), e_uint32, options);
836+
CheckCase<StringType, std::string, UInt64Type, uint64_t>(utf8(), v_uint, is_valid,
837+
uint64(), e_uint64, options);
838+
839+
v_uint = {"4294967295", "0", "0", "0", "0"};
840+
e_uint32 = {4294967295, 0, 0, 0, 0};
841+
CheckCase<StringType, std::string, UInt32Type, uint32_t>(utf8(), v_uint, is_valid,
842+
uint32(), e_uint32, options);
843+
v_uint = {"18446744073709551615", "0", "0", "0", "0"};
844+
e_uint64 = {18446744073709551615ULL, 0, 0, 0, 0};
845+
CheckCase<StringType, std::string, UInt64Type, uint64_t>(utf8(), v_uint, is_valid,
846+
uint64(), e_uint64, options);
847+
848+
// string to float
849+
vector<std::string> v_float = {"0.1", "1.2", "127.3", "200.4", "0.5"};
850+
vector<float> e_float = {0.1f, 1.2f, 127.3f, 200.4f, 0.5f};
851+
vector<double> e_double = {0.1, 1.2, 127.3, 200.4, 0.5};
852+
CheckCase<StringType, std::string, FloatType, float>(utf8(), v_float, is_valid,
853+
float32(), e_float, options);
854+
CheckCase<StringType, std::string, DoubleType, double>(utf8(), v_float, is_valid,
855+
float64(), e_double, options);
856+
}
857+
858+
TEST_F(TestCast, StringToNumberErrors) {
859+
CastOptions options;
860+
861+
vector<bool> is_valid = {true};
862+
863+
CheckFails<StringType, std::string>(utf8(), {"z"}, is_valid, int8(), options);
864+
CheckFails<StringType, std::string>(utf8(), {"12 z"}, is_valid, int8(), options);
865+
CheckFails<StringType, std::string>(utf8(), {"128"}, is_valid, int8(), options);
866+
CheckFails<StringType, std::string>(utf8(), {"-129"}, is_valid, int8(), options);
867+
CheckFails<StringType, std::string>(utf8(), {"0.5"}, is_valid, int8(), options);
868+
869+
CheckFails<StringType, std::string>(utf8(), {"256"}, is_valid, uint8(), options);
870+
CheckFails<StringType, std::string>(utf8(), {"-1"}, is_valid, uint8(), options);
871+
872+
CheckFails<StringType, std::string>(utf8(), {"z"}, is_valid, float32(), options);
873+
}
874+
772875
template <typename TestType>
773876
class TestDictionaryCast : public TestCast {};
774877

cpp/src/arrow/compute/kernels/cast.cc

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
#include "arrow/compute/kernels/cast.h"
1919

20+
#include <cerrno>
2021
#include <cstdint>
22+
#include <cstdlib>
2123
#include <cstring>
2224
#include <functional>
2325
#include <limits>
@@ -727,6 +729,191 @@ struct CastFunctor<T, DictionaryType,
727729
}
728730
};
729731

732+
// ----------------------------------------------------------------------
733+
// String to Number
734+
735+
// Polymorphic wrapper around strtof() and friends
736+
static void StringToFloat(const char* str, char** str_end, float* out) {
737+
*out = strtof(str, str_end);
738+
}
739+
740+
static void StringToFloat(const char* str, char** str_end, double* out) {
741+
*out = strtod(str, str_end);
742+
}
743+
744+
// Function to cast a C string to a number. Returns true on success,
745+
// false on error.
746+
747+
template <typename T>
748+
typename std::enable_if<std::is_floating_point<T>::value,
749+
bool>::type static CastStringToNumber(const char* str,
750+
size_t length, T* out) {
751+
// Need a null-terminated copy to pass to the C library converters
752+
std::string null_terminated(str, length);
753+
str = null_terminated.data();
754+
char* str_end;
755+
StringToFloat(str, &str_end, out);
756+
return (errno == 0 && static_cast<size_t>(str_end - str) == length);
757+
}
758+
759+
template <typename T>
760+
typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value,
761+
bool>::type static CastStringToNumber(const char* str,
762+
size_t length, T* out) {
763+
static constexpr bool need_long_long = sizeof(T) > sizeof(long); // NOLINT
764+
static constexpr T min_value = std::numeric_limits<T>::min();
765+
static constexpr T max_value = std::numeric_limits<T>::max();
766+
767+
// Need a null-terminated copy to pass to the C library converters
768+
std::string null_terminated(str, length);
769+
str = null_terminated.data();
770+
char* str_end;
771+
if (need_long_long) {
772+
auto res = std::strtoll(str, &str_end, 10);
773+
*out = static_cast<T>(res); // may downcast
774+
if (res < min_value || res > max_value) {
775+
return false;
776+
}
777+
} else {
778+
auto res = std::strtol(str, &str_end, 10);
779+
*out = static_cast<T>(res); // may downcast
780+
if (res < min_value || res > max_value) {
781+
return false;
782+
}
783+
}
784+
return (errno == 0 && static_cast<size_t>(str_end - str) == length);
785+
}
786+
787+
template <typename T>
788+
typename std::enable_if<std::is_integral<T>::value && std::is_unsigned<T>::value,
789+
bool>::type static CastStringToNumber(const char* str,
790+
size_t length, T* out) {
791+
static constexpr bool need_long_long = sizeof(T) > sizeof(unsigned long); // NOLINT
792+
static constexpr T max_value = std::numeric_limits<T>::max();
793+
794+
// Need a null-terminated copy to pass to the C library converters
795+
std::string null_terminated(str, length);
796+
str = null_terminated.data();
797+
char* str_end;
798+
if (need_long_long) {
799+
auto res = std::strtoull(str, &str_end, 10);
800+
*out = static_cast<T>(res); // may downcast
801+
if (res > max_value) {
802+
return false;
803+
}
804+
} else {
805+
auto res = std::strtoul(str, &str_end, 10);
806+
*out = static_cast<T>(res); // may downcast
807+
if (res > max_value) {
808+
return false;
809+
}
810+
}
811+
return (errno == 0 && static_cast<size_t>(str_end - str) == length);
812+
}
813+
814+
template <typename O>
815+
struct CastFunctor<O, StringType, typename std::enable_if<is_number<O>::value>::type> {
816+
void operator()(FunctionContext* ctx, const CastOptions& options,
817+
const ArrayData& input, ArrayData* output) {
818+
using out_type = typename O::c_type;
819+
820+
StringArray input_array(input.Copy());
821+
auto out_data = GetMutableValues<out_type>(output, 1);
822+
errno = 0;
823+
824+
for (int64_t i = 0; i < input.length; ++i) {
825+
if (input_array.IsNull(i)) {
826+
out_data++;
827+
continue;
828+
}
829+
int32_t length = -1;
830+
auto str = input_array.GetValue(i, &length);
831+
if (!CastStringToNumber(reinterpret_cast<const char*>(str),
832+
static_cast<size_t>(length), out_data)) {
833+
std::stringstream ss;
834+
ss << "Failed to cast String '" << input_array.GetString(i) << "' into "
835+
<< output->type->ToString();
836+
ctx->SetStatus(Status(StatusCode::Invalid, ss.str()));
837+
return;
838+
}
839+
++out_data;
840+
}
841+
}
842+
};
843+
844+
// ----------------------------------------------------------------------
845+
// String to Boolean
846+
847+
// Helper function to cast a C string to a boolean. Returns true on success,
848+
// false on error.
849+
850+
static bool CastStringtoBoolean(const char* s, size_t length, bool* out) {
851+
if (length == 1) {
852+
// "0" or "1"?
853+
if (s[0] == '0') {
854+
*out = false;
855+
return true;
856+
}
857+
if (s[0] == '1') {
858+
*out = true;
859+
return true;
860+
}
861+
return false;
862+
}
863+
if (length == 4) {
864+
// "true"?
865+
*out = true;
866+
return ((s[0] == 't' || s[0] == 'T') && (s[1] == 'r' || s[1] == 'R') &&
867+
(s[2] == 'u' || s[2] == 'U') && (s[3] == 'e' || s[3] == 'E'));
868+
}
869+
if (length == 5) {
870+
// "false"?
871+
*out = false;
872+
return ((s[0] == 'f' || s[0] == 'F') && (s[1] == 'a' || s[1] == 'A') &&
873+
(s[2] == 'l' || s[2] == 'L') && (s[3] == 's' || s[3] == 'S') &&
874+
(s[4] == 'e' || s[4] == 'E'));
875+
}
876+
return false;
877+
}
878+
879+
template <typename O>
880+
struct CastFunctor<O, StringType,
881+
typename std::enable_if<std::is_same<BooleanType, O>::value>::type> {
882+
void operator()(FunctionContext* ctx, const CastOptions& options,
883+
const ArrayData& input, ArrayData* output) {
884+
StringArray input_array(input.Copy());
885+
internal::FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(),
886+
output->offset, input.length);
887+
888+
for (int64_t i = 0; i < input.length; ++i) {
889+
if (input_array.IsNull(i)) {
890+
writer.Next();
891+
continue;
892+
}
893+
894+
int32_t length = -1;
895+
auto str = input_array.GetValue(i, &length);
896+
bool value;
897+
if (!CastStringtoBoolean(reinterpret_cast<const char*>(str),
898+
static_cast<size_t>(length), &value)) {
899+
std::stringstream ss;
900+
ss << "Failed to cast String '" << input_array.GetString(i) << "' into "
901+
<< output->type->ToString();
902+
ctx->SetStatus(Status(StatusCode::Invalid, ss.str()));
903+
return;
904+
}
905+
906+
if (value) {
907+
writer.Set();
908+
} else {
909+
writer.Clear();
910+
}
911+
writer.Next();
912+
}
913+
writer.Finish();
914+
}
915+
};
916+
730917
// ----------------------------------------------------------------------
731918

732919
typedef std::function<void(FunctionContext*, const CastOptions& options, const ArrayData&,
@@ -905,6 +1092,20 @@ class CastKernel : public UnaryKernel {
9051092
FN(TimestampType, Date64Type); \
9061093
FN(TimestampType, Int64Type);
9071094

1095+
#define STRING_CASES(FN, IN_TYPE) \
1096+
FN(StringType, StringType); \
1097+
FN(StringType, BooleanType); \
1098+
FN(StringType, UInt8Type); \
1099+
FN(StringType, Int8Type); \
1100+
FN(StringType, UInt16Type); \
1101+
FN(StringType, Int16Type); \
1102+
FN(StringType, UInt32Type); \
1103+
FN(StringType, Int32Type); \
1104+
FN(StringType, UInt64Type); \
1105+
FN(StringType, Int64Type); \
1106+
FN(StringType, FloatType); \
1107+
FN(StringType, DoubleType);
1108+
9081109
#define DICTIONARY_CASES(FN, IN_TYPE) \
9091110
FN(IN_TYPE, NullType); \
9101111
FN(IN_TYPE, Time32Type); \
@@ -962,6 +1163,7 @@ GET_CAST_FUNCTION(DATE64_CASES, Date64Type);
9621163
GET_CAST_FUNCTION(TIME32_CASES, Time32Type);
9631164
GET_CAST_FUNCTION(TIME64_CASES, Time64Type);
9641165
GET_CAST_FUNCTION(TIMESTAMP_CASES, TimestampType);
1166+
GET_CAST_FUNCTION(STRING_CASES, StringType);
9651167
GET_CAST_FUNCTION(DICTIONARY_CASES, DictionaryType);
9661168

9671169
#define CAST_FUNCTION_CASE(InType) \
@@ -1009,6 +1211,7 @@ Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>&
10091211
CAST_FUNCTION_CASE(Time32Type);
10101212
CAST_FUNCTION_CASE(Time64Type);
10111213
CAST_FUNCTION_CASE(TimestampType);
1214+
CAST_FUNCTION_CASE(StringType);
10121215
CAST_FUNCTION_CASE(DictionaryType);
10131216
case Type::LIST:
10141217
RETURN_NOT_OK(GetListCastFunc(in_type, out_type, options, kernel));

0 commit comments

Comments
 (0)