Skip to content

Commit

Permalink
Reverting some testing changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Apr 17, 2024
1 parent cfac91a commit b691a6c
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 87 deletions.
2 changes: 0 additions & 2 deletions cpp/src/arrow/acero/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "arrow/api.h"
#include "arrow/compute/kernels/row_encoder_internal.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/extension/uuid.h"
#include "arrow/testing/extension_type.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
Expand All @@ -48,7 +47,6 @@ using compute::SortIndices;
using compute::SortKey;
using compute::Take;
using compute::internal::RowEncoder;
using extension::uuid;

namespace acero {

Expand Down
4 changes: 1 addition & 3 deletions cpp/src/arrow/acero/util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@

#include "arrow/acero/hash_join_node.h"
#include "arrow/acero/schema_util.h"
#include "arrow/testing/extension_type.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"

#include "arrow/extension/uuid.h"

using testing::Eq;

namespace arrow {
using extension::uuid;
namespace acero {

const char* kLeftSuffix = ".left";
Expand Down
20 changes: 9 additions & 11 deletions cpp/src/arrow/c/bridge_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "arrow/c/bridge.h"
#include "arrow/c/helpers.h"
#include "arrow/c/util_internal.h"
#include "arrow/extension/uuid.h"
#include "arrow/ipc/json_simple.h"
#include "arrow/memory_pool.h"
#include "arrow/testing/builder.h"
Expand All @@ -54,7 +53,6 @@

namespace arrow {

using extension::uuid;
using internal::ArrayExportGuard;
using internal::ArrayExportTraits;
using internal::ArrayStreamExportGuard;
Expand Down Expand Up @@ -2192,12 +2190,13 @@ TEST_F(TestSchemaImport, Dictionary) {
TEST_F(TestSchemaImport, UnregisteredExtension) {
FillPrimitive("w:16");
c_struct_.metadata = kEncodedUuidMetadata.c_str();
auto expected = uuid();
auto expected = fixed_size_binary(16);
CheckImport(expected);
}

TEST_F(TestSchemaImport, RegisteredExtension) {
{
ExtensionTypeGuard guard(uuid());
FillPrimitive("w:16");
c_struct_.metadata = kEncodedUuidMetadata.c_str();
auto expected = uuid();
Expand Down Expand Up @@ -2323,14 +2322,16 @@ TEST_F(TestSchemaImport, DictionaryError) {
}

TEST_F(TestSchemaImport, ExtensionError) {
ExtensionTypeGuard guard(uuid());

// Storage type doesn't match
FillPrimitive("w:15");
c_struct_.metadata = kEncodedUuidMetadata.c_str();
CheckImportError();

// Invalid serialization
std::string bogus_metadata = kEncodedUuidMetadata;
bogus_metadata[bogus_metadata.size() - 4] += 1;
bogus_metadata[bogus_metadata.size() - 5] += 1;
FillPrimitive("w:16");
c_struct_.metadata = bogus_metadata.c_str();
CheckImportError();
Expand Down Expand Up @@ -3706,11 +3707,7 @@ std::shared_ptr<Field> GetStorageWithMetadata(const std::string& field_name,
}

TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
TestWithTypeFactory(uuid, []() { return uuid(); });
TestWithTypeFactory(complex128, []() {
return struct_({::arrow::field("real", float64(), /*nullable=*/false),
::arrow::field("imag", float64(), /*nullable=*/false)});
});
TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); });
TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); });

// Inside nested type.
Expand All @@ -3722,7 +3719,7 @@ TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
}

TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
ExtensionTypeGuard guard({dict_extension_type(), complex128()});
ExtensionTypeGuard guard({uuid(), dict_extension_type(), complex128()});
TestWithTypeFactory(uuid);
TestWithTypeFactory(dict_extension_type);
TestWithTypeFactory(complex128);
Expand Down Expand Up @@ -4081,7 +4078,7 @@ TEST_F(TestArrayRoundtrip, Dictionary) {
}

TEST_F(TestArrayRoundtrip, RegisteredExtension) {
ExtensionTypeGuard guard({smallint(), complex128(), dict_extension_type()});
ExtensionTypeGuard guard({smallint(), complex128(), dict_extension_type(), uuid()});

TestWithArrayFactory(ExampleSmallint);
TestWithArrayFactory(ExampleUuid);
Expand Down Expand Up @@ -4110,6 +4107,7 @@ TEST_F(TestArrayRoundtrip, UnregisteredExtension) {
};

TestWithArrayFactory(ExampleSmallint, StorageExtractor(ExampleSmallint));
TestWithArrayFactory(ExampleUuid, StorageExtractor(ExampleUuid));
TestWithArrayFactory(ExampleComplex128, StorageExtractor(ExampleComplex128));
TestWithArrayFactory(ExampleDictExtension, StorageExtractor(ExampleDictExtension));
}
Expand Down
7 changes: 5 additions & 2 deletions cpp/src/arrow/extension_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,13 @@ namespace internal {
static void CreateGlobalRegistry() {
g_registry = std::make_shared<ExtensionTypeRegistryImpl>();

std::vector<std::shared_ptr<DataType>> ext_types{::arrow::extension::uuid()};
#ifdef ARROW_JSON
ext_types.push_back(extension::fixed_shape_tensor(int64(), {}));
std::vector<std::shared_ptr<DataType>> ext_types{
extension::fixed_shape_tensor(int64(), {}), ::arrow::extension::uuid()};
#else
std::vector<std::shared_ptr<DataType>> ext_types{::arrow::extension::uuid()};
#endif

// Register canonical extension types
for (const auto& ext_type : ext_types) {
ARROW_CHECK_OK(
Expand Down
23 changes: 15 additions & 8 deletions cpp/src/arrow/extension_type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

#include "arrow/array/array_nested.h"
#include "arrow/array/util.h"
#include "arrow/extension/uuid.h"
#include "arrow/extension_type.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/options.h"
Expand All @@ -42,8 +41,6 @@

namespace arrow {

using extension::uuid;

class Parametric1Array : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
Expand Down Expand Up @@ -179,13 +176,22 @@ class ExtStructType : public ExtensionType {
std::string Serialize() const override { return "ext-struct-type-unique-code"; }
};

class TestExtensionType : public ::testing::Test {};
class TestExtensionType : public ::testing::Test {
public:
void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared<ExampleUuidType>())); }

void TearDown() {
if (GetExtensionType("uuid")) {
ASSERT_OK(UnregisterExtensionType("uuid"));
}
}
};

TEST_F(TestExtensionType, ExtensionTypeTest) {
auto type_not_exist = GetExtensionType("uuid-unknown");
ASSERT_EQ(type_not_exist, nullptr);

auto registered_type = GetExtensionType("arrow.uuid");
auto registered_type = GetExtensionType("uuid");
ASSERT_NE(registered_type, nullptr);

auto type = uuid();
Expand Down Expand Up @@ -230,9 +236,10 @@ TEST_F(TestExtensionType, UnrecognizedExtension) {

ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());

ASSERT_OK(UnregisterExtensionType("arrow.uuid"));
auto ext_metadata = key_value_metadata(
{{"ARROW:extension:name", "arrow.uuid"}, {"ARROW:extension:metadata", ""}});
ASSERT_OK(UnregisterExtensionType("uuid"));
auto ext_metadata =
key_value_metadata({{"ARROW:extension:name", "uuid"},
{"ARROW:extension:metadata", "uuid-serialized"}});
auto ext_field = field("f0", fixed_size_binary(16), true, ext_metadata);
auto batch_no_ext = RecordBatch::Make(schema({ext_field}), 4, {storage_arr});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class IntegrationTestScenario : public Scenario {

Status RunClient(std::unique_ptr<FlightClient> client) override {
// Make sure the required extension types are registered.
ExtensionTypeGuard uuid_ext_guard(uuid());
ExtensionTypeGuard dict_ext_guard(dict_extension_type());

FlightDescriptor descr{FlightDescriptor::PATH, "", {FLAGS_path}};
Expand Down
39 changes: 19 additions & 20 deletions cpp/src/arrow/integration/json_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "arrow/array.h"
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/extension/uuid.h"
#include "arrow/integration/json_integration.h"
#include "arrow/integration/json_internal.h"
#include "arrow/io/file.h"
Expand Down Expand Up @@ -68,14 +67,11 @@ DEFINE_bool(validate_times, true,

namespace arrow::internal::integration {

using extension::uuid;
using internal::TemporaryDir;
using ipc::DictionaryFieldMapper;
using ipc::DictionaryMemo;
using ipc::IpcWriteOptions;
using ipc::MetadataVersion;

namespace testing {
using ::arrow::internal::TemporaryDir;
using ::arrow::ipc::DictionaryFieldMapper;
using ::arrow::ipc::DictionaryMemo;
using ::arrow::ipc::IpcWriteOptions;
using ::arrow::ipc::MetadataVersion;

using namespace ::arrow::ipc::test; // NOLINT

Expand Down Expand Up @@ -228,7 +224,7 @@ Status RunCommand(const std::string& json_path, const std::string& arrow_path,
const std::string& command) {
// Make sure the required extension types are registered, as they will be
// referenced in test data.
ExtensionTypeGuard ext_guard({dict_extension_type()});
ExtensionTypeGuard ext_guard({uuid(), dict_extension_type()});

if (json_path == "") {
return Status::Invalid("Must specify json file name");
Expand Down Expand Up @@ -468,8 +464,8 @@ static const char* json_example2 = R"example(
"nullable": true,
"children" : [],
"metadata" : [
{"key": "ARROW:extension:name", "value": "arrow.uuid"},
{"key": "ARROW:extension:metadata", "value": ""}
{"key": "ARROW:extension:name", "value": "uuid"},
{"key": "ARROW:extension:metadata", "value": "uuid-serialized"}
]
},
{
Expand Down Expand Up @@ -1031,10 +1027,12 @@ TEST(TestJsonFileReadWrite, JsonExample1) {

TEST(TestJsonFileReadWrite, JsonExample2) {
// Example 2: two extension types (one registered, one unregistered)
auto uuid_type = arrow::extension::uuid();
auto uuid_type = uuid();
auto buffer = Buffer::Wrap(json_example2, strlen(json_example2));

{
ExtensionTypeGuard ext_guard(uuid_type);

ASSERT_OK_AND_ASSIGN(auto reader, IntegrationJsonReader::Open(buffer));
// The second field is an unregistered extension and will be read as
// its underlying storage.
Expand All @@ -1048,11 +1046,13 @@ TEST(TestJsonFileReadWrite, JsonExample2) {

auto storage_array =
ArrayFromJSON(fixed_size_binary(16), R"(["0123456789abcdef", null])");
AssertArraysEqual(*batch->column(0),
arrow::extension::UuidArray(uuid_type, storage_array));
AssertArraysEqual(*batch->column(0), ExampleUuidArray(uuid_type, storage_array));

AssertArraysEqual(*batch->column(1), NullArray(2));
}

// Should fail now that the Uuid extension is unregistered
ASSERT_RAISES(KeyError, IntegrationJsonReader::Open(buffer));
}

TEST(TestJsonFileReadWrite, JsonExample3) {
Expand Down Expand Up @@ -1119,7 +1119,7 @@ class TestJsonRoundTrip : public ::testing::TestWithParam<MakeRecordBatch*> {
};

void CheckRoundtrip(const RecordBatch& batch) {
ExtensionTypeGuard guard({dict_extension_type(), complex128()});
ExtensionTypeGuard guard({uuid(), dict_extension_type(), complex128()});

TestSchemaRoundTrip(batch.schema());

Expand Down Expand Up @@ -1176,16 +1176,16 @@ const std::vector<ipc::test::MakeRecordBatch*> kBatchCases = {
INSTANTIATE_TEST_SUITE_P(TestJsonRoundTrip, TestJsonRoundTrip,
::testing::ValuesIn(kBatchCases));

} // namespace testing
} // namespace arrow::internal::integration

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

int ret = 0;

if (FLAGS_integration) {
arrow::Status result = arrow::internal::integration::testing::RunCommand(
FLAGS_json, FLAGS_arrow, FLAGS_mode);
arrow::Status result =
arrow::internal::integration::RunCommand(FLAGS_json, FLAGS_arrow, FLAGS_mode);
if (!result.ok()) {
std::cout << "Error message: " << result.ToString() << std::endl;
ret = 1;
Expand All @@ -1197,4 +1197,3 @@ int main(int argc, char** argv) {
gflags::ShutDownCommandLineFlags();
return ret;
}
} // namespace arrow::internal::integration
4 changes: 1 addition & 3 deletions cpp/src/arrow/ipc/read_write_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "arrow/array.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/buffer_builder.h"
#include "arrow/extension/uuid.h"
#include "arrow/io/file.h"
#include "arrow/io/memory.h"
#include "arrow/io/test_common.h"
Expand Down Expand Up @@ -60,7 +59,6 @@

namespace arrow {

using extension::uuid;
using internal::checked_cast;
using internal::checked_pointer_cast;
using internal::TemporaryDir;
Expand Down Expand Up @@ -409,7 +407,7 @@ static int g_file_number = 0;
class ExtensionTypesMixin {
public:
// Register the extension types required to ensure roundtripping
ExtensionTypesMixin() : ext_guard_({dict_extension_type(), complex128()}) {}
ExtensionTypesMixin() : ext_guard_({uuid(), dict_extension_type(), complex128()}) {}

protected:
ExtensionTypeGuard ext_guard_;
Expand Down
7 changes: 2 additions & 5 deletions cpp/src/arrow/ipc/test_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/array/builder_time.h"
#include "arrow/extension/uuid.h"
#include "arrow/ipc/test_common.h"
#include "arrow/pretty_print.h"
#include "arrow/record_batch.h"
Expand All @@ -51,8 +50,6 @@

namespace arrow {

using extension::uuid;
using extension::UuidArray;
using internal::checked_cast;

namespace ipc {
Expand Down Expand Up @@ -1091,9 +1088,9 @@ Status MakeUuid(std::shared_ptr<RecordBatch>* out) {
auto f1 = field("f1", uuid_type, /*nullable=*/false);
auto schema = ::arrow::schema({f0, f1});

auto a0 = std::make_shared<UuidArray>(
auto a0 = std::make_shared<ExampleUuidArray>(
uuid_type, ArrayFromJSON(storage_type, R"(["0123456789abcdef", null])"));
auto a1 = std::make_shared<UuidArray>(
auto a1 = std::make_shared<ExampleUuidArray>(
uuid_type,
ArrayFromJSON(storage_type, R"(["ZYXWVUTSRQPONMLK", "JIHGFEDBA9876543"])"));

Expand Down
Loading

0 comments on commit b691a6c

Please sign in to comment.