From e178c00476c6d4543796fae85f0008632a7674b0 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Thu, 15 Jul 2021 15:54:34 +0800 Subject: [PATCH] Create utility dir arrow/jniutil (#27) * Create utility dir arrow/jniutil * fix --- cpp/cmake_modules/DefineOptions.cmake | 2 + cpp/src/arrow/CMakeLists.txt | 18 +- cpp/src/arrow/jniutil/CMakeLists.txt | 31 ++ cpp/src/arrow/jniutil/jni_util.cc | 432 ++++++++++++++++++ .../{jni/dataset => arrow/jniutil}/jni_util.h | 18 +- .../jniutil}/jni_util_test.cc | 8 +- cpp/src/jni/dataset/CMakeLists.txt | 11 +- cpp/src/jni/dataset/jni_util.cc | 242 ---------- cpp/src/jni/dataset/jni_wrapper.cc | 160 +++---- java/dataset/CMakeLists.txt | 1 - java/dataset/pom.xml | 10 + .../apache/arrow/dataset/file/JniWrapper.java | 16 + .../apache/arrow/dataset/jni/JniWrapper.java | 7 +- .../dataset/jni/NativeRecordBatchHandle.java | 106 ----- .../arrow/dataset/jni/NativeScanner.java | 36 +- .../jni/UnsafeRecordBatchSerializer.java | 266 +++++++++++ 16 files changed, 864 insertions(+), 500 deletions(-) create mode 100644 cpp/src/arrow/jniutil/CMakeLists.txt create mode 100644 cpp/src/arrow/jniutil/jni_util.cc rename cpp/src/{jni/dataset => arrow/jniutil}/jni_util.h (87%) rename cpp/src/{jni/dataset => arrow/jniutil}/jni_util_test.cc (96%) delete mode 100644 cpp/src/jni/dataset/jni_util.cc delete mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index ec1e0b6352a36..1da851514c07a 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -260,6 +260,8 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_JNI "Build the Arrow JNI lib" OFF) + define_option(ARROW_JNIUTIL "Build Arrow JNI utilities" ON) + define_option(ARROW_JSON "Build Arrow with JSON support (requires RapidJSON)" OFF) define_option(ARROW_MIMALLOC "Build the Arrow mimalloc-based allocator" OFF) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 2933457287407..7e852ab67d17d 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -488,6 +488,14 @@ if(ARROW_FILESYSTEM) list(APPEND ARROW_TESTING_SRCS filesystem/test_util.cc) endif() +if(ARROW_JNIUTIL) + find_package(JNI REQUIRED) + list(APPEND ARROW_SRCS + jniutil/jni_util.cc) + + set(ARROW_PRIVATE_INCLUDES ${ARROW_PRIVATE_INCLUDES} ${JNI_INCLUDE_DIRS}) +endif() + if(ARROW_IPC) list(APPEND ARROW_SRCS @@ -571,7 +579,11 @@ add_arrow_lib(arrow ${ARROW_STATIC_LINK_LIBS} ${ARROW_STATIC_INSTALL_INTERFACE_LIBS} SHARED_INSTALL_INTERFACE_LIBS - ${ARROW_SHARED_INSTALL_INTERFACE_LIBS}) + ${ARROW_SHARED_INSTALL_INTERFACE_LIBS} + STATIC_INSTALL_INTERFACE_LIBS + ${ARROW_STATIC_INSTALL_INTERFACE_LIBS} + PRIVATE_INCLUDES + ${JNI_INCLUDE_DIRS}) add_dependencies(arrow ${ARROW_LIBRARIES}) @@ -772,6 +784,10 @@ if(ARROW_IPC) add_subdirectory(ipc) endif() +if(ARROW_JNIUTIL) + add_subdirectory(jniutil) +endif() + if(ARROW_JSON) add_subdirectory(json) endif() diff --git a/cpp/src/arrow/jniutil/CMakeLists.txt b/cpp/src/arrow/jniutil/CMakeLists.txt new file mode 100644 index 0000000000000..d1a11fdf18777 --- /dev/null +++ b/cpp/src/arrow/jniutil/CMakeLists.txt @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitationsn +# under the License. + +# +# arrow_dataset_jni +# + +arrow_install_all_headers("arrow/jniutil") + +find_package(JNI REQUIRED) + +add_arrow_test(arrow_jniutil_test + SOURCES + jni_util_test.cc + jni_util.cc + EXTRA_INCLUDES + ${JNI_INCLUDE_DIRS}) diff --git a/cpp/src/arrow/jniutil/jni_util.cc b/cpp/src/arrow/jniutil/jni_util.cc new file mode 100644 index 0000000000000..d1e6fb98c4097 --- /dev/null +++ b/cpp/src/arrow/jniutil/jni_util.cc @@ -0,0 +1,432 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/jniutil/jni_util.h" +#include "arrow/ipc/metadata_internal.h" +#include "arrow/util/base64.h" +#include "arrow/util/key_value_metadata.h" +#include "arrow/util/logging.h" +#include "arrow/util/string_view.h" + +#include + +namespace arrow { + +namespace flatbuf = org::apache::arrow::flatbuf; + +namespace jniutil { + +class ReservationListenableMemoryPool::Impl { + public: + explicit Impl(arrow::MemoryPool *pool, std::shared_ptr listener, + int64_t block_size) + : pool_(pool), + listener_(listener), + block_size_(block_size), + blocks_reserved_(0), + bytes_reserved_(0) {} + + arrow::Status Allocate(int64_t size, uint8_t **out) { + RETURN_NOT_OK(UpdateReservation(size)); + arrow::Status error = pool_->Allocate(size, out); + if (!error.ok()) { + RETURN_NOT_OK(UpdateReservation(-size)); + return error; + } + return arrow::Status::OK(); + } + + arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t **ptr) { + bool reserved = false; + int64_t diff = new_size - old_size; + if (new_size >= old_size) { + // new_size >= old_size, pre-reserve bytes from listener before allocating + // from underlying pool + RETURN_NOT_OK(UpdateReservation(diff)); + reserved = true; + } + arrow::Status error = pool_->Reallocate(old_size, new_size, ptr); + if (!error.ok()) { + if (reserved) { + // roll back reservations on error + RETURN_NOT_OK(UpdateReservation(-diff)); + } + return error; + } + if (!reserved) { + // otherwise (e.g. new_size < old_size), make updates after calling underlying pool + RETURN_NOT_OK(UpdateReservation(diff)); + } + return arrow::Status::OK(); + } + + void Free(uint8_t *buffer, int64_t size) { + pool_->Free(buffer, size); + // FIXME: See ARROW-11143, currently method ::Free doesn't allow Status return + arrow::Status s = UpdateReservation(-size); + if (!s.ok()) { + ARROW_LOG(FATAL) << "Failed to update reservation while freeing bytes: " + << s.message(); + return; + } + } + + arrow::Status UpdateReservation(int64_t diff) { + int64_t granted = Reserve(diff); + if (granted == 0) { + return arrow::Status::OK(); + } + if (granted < 0) { + RETURN_NOT_OK(listener_->OnRelease(-granted)); + return arrow::Status::OK(); + } + RETURN_NOT_OK(listener_->OnReservation(granted)); + return arrow::Status::OK(); + } + + int64_t Reserve(int64_t diff) { + std::lock_guard lock(mutex_); + bytes_reserved_ += diff; + int64_t new_block_count; + if (bytes_reserved_ == 0) { + new_block_count = 0; + } else { + // ceil to get the required block number + new_block_count = (bytes_reserved_ - 1) / block_size_ + 1; + } + int64_t bytes_granted = (new_block_count - blocks_reserved_) * block_size_; + blocks_reserved_ = new_block_count; + return bytes_granted; + } + + int64_t bytes_allocated() { return pool_->bytes_allocated(); } + + int64_t max_memory() { return pool_->max_memory(); } + + std::string backend_name() { return pool_->backend_name(); } + + std::shared_ptr get_listener() { return listener_; } + + private: + arrow::MemoryPool *pool_; + std::shared_ptr listener_; + int64_t block_size_; + int64_t blocks_reserved_; + int64_t bytes_reserved_; + std::mutex mutex_; +}; + +/// \brief Buffer implementation that binds to a +/// Java buffer reference. Java buffer's release +/// method will be called once when being destructed. +class JavaAllocatedBuffer : public Buffer { + public: + JavaAllocatedBuffer(JNIEnv* env, jobject cleaner_ref, jmethodID cleaner_method_ref, + uint8_t* buffer, int32_t len) + : Buffer(buffer, len), + env_(env), + cleaner_ref_(cleaner_ref), + cleaner_method_ref_(cleaner_method_ref) {} + + ~JavaAllocatedBuffer() override { + env_->CallVoidMethod(cleaner_ref_, cleaner_method_ref_); + env_->DeleteGlobalRef(cleaner_ref_); + } + + private: + JNIEnv* env_; + jobject cleaner_ref_; + jmethodID cleaner_method_ref_; +}; + +ReservationListenableMemoryPool::ReservationListenableMemoryPool( + MemoryPool *pool, std::shared_ptr listener, int64_t block_size) { + impl_.reset(new Impl(pool, listener, block_size)); +} + +arrow::Status ReservationListenableMemoryPool::Allocate(int64_t size, uint8_t **out) { + return impl_->Allocate(size, out); +} + +arrow::Status ReservationListenableMemoryPool::Reallocate(int64_t old_size, + int64_t new_size, + uint8_t **ptr) { + return impl_->Reallocate(old_size, new_size, ptr); +} + +void ReservationListenableMemoryPool::Free(uint8_t *buffer, int64_t size) { + return impl_->Free(buffer, size); +} + +int64_t ReservationListenableMemoryPool::bytes_allocated() const { + return impl_->bytes_allocated(); +} + +int64_t ReservationListenableMemoryPool::max_memory() const { + return impl_->max_memory(); +} + +std::string ReservationListenableMemoryPool::backend_name() const { + return impl_->backend_name(); +} + +std::shared_ptr ReservationListenableMemoryPool::get_listener() { + return impl_->get_listener(); +} + +ReservationListenableMemoryPool::~ReservationListenableMemoryPool() {} + +Status CheckException(JNIEnv *env) { + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + return Status::Invalid("Error during calling Java code from native code"); + } + return Status::OK(); +} + +jclass CreateGlobalClassReference(JNIEnv *env, const char *class_name) { + jclass local_class = env->FindClass(class_name); + jclass global_class = (jclass) env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + return global_class; +} + +arrow::Result GetMethodID(JNIEnv *env, jclass this_class, const char *name, + const char *sig) { + jmethodID ret = env->GetMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find method " + std::string(name) + + " within signature" + std::string(sig); + return arrow::Status::Invalid(error_message); + } + return ret; +} + +arrow::Result GetStaticMethodID(JNIEnv *env, jclass this_class, + const char *name, const char *sig) { + jmethodID ret = env->GetStaticMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find static method " + std::string(name) + + " within signature" + std::string(sig); + return arrow::Status::Invalid(error_message); + } + return ret; +} + +std::string JStringToCString(JNIEnv *env, jstring string) { + if (string == nullptr) { + return std::string(); + } + const char *chars = env->GetStringUTFChars(string, nullptr); + std::string ret(chars); + env->ReleaseStringUTFChars(string, chars); + return ret; +} + +std::vector ToStringVector(JNIEnv *env, jobjectArray &str_array) { + int length = env->GetArrayLength(str_array); + std::vector vector; + for (int i = 0; i < length; i++) { + auto string = reinterpret_cast(env->GetObjectArrayElement(str_array, i)); + vector.push_back(JStringToCString(env, string)); + } + return vector; +} + +arrow::Result ToSchemaByteArray(JNIEnv *env, + std::shared_ptr schema) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr buffer, + arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool())) + + jbyteArray out = env->NewByteArray(buffer->size()); + auto src = reinterpret_cast(buffer->data()); + env->SetByteArrayRegion(out, 0, buffer->size(), src); + return out; +} + +arrow::Result> FromSchemaByteArray( + JNIEnv *env, jbyteArray schemaBytes) { + arrow::ipc::DictionaryMemo in_memo; + int schemaBytes_len = env->GetArrayLength(schemaBytes); + jbyte *schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr); + auto serialized_schema = std::make_shared( + reinterpret_cast(schemaBytes_data), schemaBytes_len); + arrow::io::BufferReader buf_reader(serialized_schema); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema, + arrow::ipc::ReadSchema(&buf_reader, &in_memo)) + env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT); + return schema; +} + +Status SetMetadataForSingleField(std::shared_ptr array_data, + std::vector &nodes_meta, + std::vector &buffers_meta, + std::shared_ptr &custom_metadata) { + nodes_meta.push_back({array_data->length, array_data->null_count, 0L}); + + for (size_t i = 0; i < array_data->buffers.size(); i++) { + auto buffer = array_data->buffers.at(i); + uint8_t *data = nullptr; + int64_t size = 0; + if (buffer != nullptr) { + data = (uint8_t *) buffer->data(); + size = buffer->size(); + } + ipc::internal::BufferMetadata buffer_metadata{}; + buffer_metadata.offset = reinterpret_cast(data); + buffer_metadata.length = size; + // store buffer refs into custom metadata + jlong ref = CreateNativeRef(buffer); + custom_metadata->Append( + "NATIVE_BUFFER_REF_" + std::to_string(i), + util::base64_encode(arrow::util::string_view(reinterpret_cast(&ref), sizeof(ref)))); + buffers_meta.push_back(buffer_metadata); + } + + auto children_data = array_data->child_data; + for (const auto &child_data : children_data) { + RETURN_NOT_OK( + SetMetadataForSingleField(child_data, nodes_meta, buffers_meta, custom_metadata)); + } + return Status::OK(); +} + +Result> SerializeMetadata(const RecordBatch &batch, + const ipc::IpcWriteOptions &options) { + std::vector nodes; + std::vector buffers; + std::shared_ptr custom_metadata = + std::make_shared(); + for (const auto &column : batch.columns()) { + auto array_data = column->data(); + RETURN_NOT_OK(SetMetadataForSingleField(array_data, nodes, buffers, custom_metadata)); + } + std::shared_ptr meta_buffer; + RETURN_NOT_OK(ipc::internal::WriteRecordBatchMessage( + batch.num_rows(), 0L, custom_metadata, nodes, buffers, options, &meta_buffer)); + // no message body is needed for JNI serialization/deserialization + int32_t meta_length = -1; + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create(1024L)); + RETURN_NOT_OK(ipc::WriteMessage(*meta_buffer, options, stream.get(), &meta_length)); + return stream->Finish(); +} + +Result SerializeUnsafeFromNative(JNIEnv *env, + const std::shared_ptr &batch) { + ARROW_ASSIGN_OR_RAISE(auto meta_buffer, + SerializeMetadata(*batch, ipc::IpcWriteOptions::Defaults())); + + jbyteArray ret = env->NewByteArray(meta_buffer->size()); + auto src = reinterpret_cast(meta_buffer->data()); + env->SetByteArrayRegion(ret, 0, meta_buffer->size(), src); + return ret; +} + +Result> MakeArrayData( + JNIEnv *env, const flatbuf::RecordBatch &batch_meta, + const std::shared_ptr &custom_metadata, + const std::shared_ptr &type, int32_t *field_offset, + int32_t *buffer_offset) { + const org::apache::arrow::flatbuf::FieldNode *field = + batch_meta.nodes()->Get((*field_offset)++); + int32_t own_buffer_size = static_cast(type->layout().buffers.size()); + std::vector> buffers; + for (int32_t i = *buffer_offset; i < *buffer_offset + own_buffer_size; i++) { + const org::apache::arrow::flatbuf::Buffer *java_managed_buffer = + batch_meta.buffers()->Get(i); + const std::string &cleaner_object_ref_base64 = + util::base64_decode(custom_metadata->value(i * 2)); + const std::string &cleaner_method_ref_base64 = + util::base64_decode(custom_metadata->value(i * 2 + 1)); + const auto *cleaner_object_ref = + reinterpret_cast(cleaner_object_ref_base64.data()); + const auto *cleaner_method_ref = + reinterpret_cast(cleaner_method_ref_base64.data()); + auto buffer = std::make_shared( + env, reinterpret_cast(*cleaner_object_ref), + reinterpret_cast(*cleaner_method_ref), + reinterpret_cast(java_managed_buffer->offset()), + java_managed_buffer->length()); + buffers.push_back(buffer); + } + (*buffer_offset) += own_buffer_size; + if (type->num_fields() == 0) { + return ArrayData::Make(type, field->length(), buffers, field->null_count()); + } + std::vector> children_array_data; + for (const auto &child_field : type->fields()) { + ARROW_ASSIGN_OR_RAISE(auto child_array_data, + MakeArrayData(env, batch_meta, custom_metadata, + child_field->type(), field_offset, buffer_offset)) + children_array_data.push_back(child_array_data); + } + return ArrayData::Make(type, field->length(), buffers, children_array_data, + field->null_count()); +} + +Result> DeserializeUnsafeFromJava( + JNIEnv *env, std::shared_ptr schema, jbyteArray byte_array) { + int bytes_len = env->GetArrayLength(byte_array); + jbyte *byte_data = env->GetByteArrayElements(byte_array, nullptr); + io::BufferReader meta_reader(reinterpret_cast(byte_data), + static_cast(bytes_len)); + ARROW_ASSIGN_OR_RAISE(auto meta_message, ipc::ReadMessage(&meta_reader)) + auto meta_buffer = meta_message->metadata(); + auto custom_metadata = meta_message->custom_metadata(); + const flatbuf::Message *flat_meta = nullptr; + RETURN_NOT_OK( + ipc::internal::VerifyMessage(meta_buffer->data(), meta_buffer->size(), &flat_meta)); + auto batch_meta = flat_meta->header_as_RecordBatch(); + + // Record batch serialized from java should have two ref IDs per buffer: cleaner object + // ref and cleaner method ref. The refs are originally of 64bit integer type and encoded + // within base64. + if (custom_metadata->size() != + static_cast(batch_meta->buffers()->size() * 2)) { + return Status::SerializationError( + "Buffer count mismatch between metadata and Java managed refs"); + } + + std::vector> columns_array_data; + int32_t field_offset = 0; + int32_t buffer_offset = 0; + for (int32_t i = 0; i < schema->num_fields(); i++) { + auto field = schema->field(i); + ARROW_ASSIGN_OR_RAISE(auto column_array_data, + MakeArrayData(env, *batch_meta, custom_metadata, field->type(), + &field_offset, &buffer_offset)) + columns_array_data.push_back(column_array_data); + } + if (field_offset != static_cast(batch_meta->nodes()->size())) { + return Status::SerializationError( + "Deserialization failed: Field count is not " + "as expected based on type layout"); + } + if (buffer_offset != static_cast(batch_meta->buffers()->size())) { + return Status::SerializationError( + "Deserialization failed: Buffer count is not " + "as expected based on type layout"); + } + int64_t length = batch_meta->length(); + env->ReleaseByteArrayElements(byte_array, byte_data, JNI_ABORT); + return RecordBatch::Make(schema, length, columns_array_data); +} + +} // namespace jniutil +} // namespace arrow diff --git a/cpp/src/jni/dataset/jni_util.h b/cpp/src/arrow/jniutil/jni_util.h similarity index 87% rename from cpp/src/jni/dataset/jni_util.h rename to cpp/src/arrow/jniutil/jni_util.h index c76033ae633b6..8ab094b1d7ada 100644 --- a/cpp/src/jni/dataset/jni_util.h +++ b/cpp/src/arrow/jniutil/jni_util.h @@ -27,8 +27,9 @@ #include namespace arrow { -namespace dataset { -namespace jni { +namespace jniutil { + +Status CheckException(JNIEnv* env); jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name); @@ -48,6 +49,16 @@ arrow::Result ToSchemaByteArray(JNIEnv* env, arrow::Result> FromSchemaByteArray(JNIEnv* env, jbyteArray schemaBytes); +/// \brief Serialize arrow::RecordBatch to jbyteArray (Java byte array byte[]). For +/// letting Java code manage lifecycles of buffers in the input batch, shared pointer IDs +/// pointing to the buffers are serialized into buffer metadata. +Result SerializeUnsafeFromNative(JNIEnv* env, + const std::shared_ptr& batch); + +/// \brief Deserialize jbyteArray (Java byte array byte[]) to arrow::RecordBatch. +Result> DeserializeUnsafeFromJava( + JNIEnv* env, std::shared_ptr schema, jbyteArray byte_array); + /// \brief Create a new shared_ptr on heap from shared_ptr t to prevent /// the managed object from being garbage-collected. /// @@ -130,6 +141,5 @@ class ReservationListenableMemoryPool : public arrow::MemoryPool { std::unique_ptr impl_; }; -} // namespace jni -} // namespace dataset +} // namespace jniutil } // namespace arrow diff --git a/cpp/src/jni/dataset/jni_util_test.cc b/cpp/src/arrow/jniutil/jni_util_test.cc similarity index 96% rename from cpp/src/jni/dataset/jni_util_test.cc rename to cpp/src/arrow/jniutil/jni_util_test.cc index 589f00b1cc750..893c28125c0c5 100644 --- a/cpp/src/jni/dataset/jni_util_test.cc +++ b/cpp/src/arrow/jniutil/jni_util_test.cc @@ -19,11 +19,10 @@ #include "arrow/memory_pool.h" #include "arrow/testing/gtest_util.h" -#include "jni/dataset/jni_util.h" +#include "arrow/jniutil/jni_util.h" namespace arrow { -namespace dataset { -namespace jni { +namespace jniutil { class MyListener : public ReservationListener { public: @@ -129,6 +128,5 @@ TEST(ReservationListenableMemoryPool, BlockSize2) { ASSERT_EQ(1, listener->release_count()); } -} // namespace jni -} // namespace dataset +} // namespace jniutil } // namespace arrow diff --git a/cpp/src/jni/dataset/CMakeLists.txt b/cpp/src/jni/dataset/CMakeLists.txt index f3e309b614aed..fc5fa26a5907e 100644 --- a/cpp/src/jni/dataset/CMakeLists.txt +++ b/cpp/src/jni/dataset/CMakeLists.txt @@ -35,7 +35,7 @@ set(ARROW_BUILD_STATIC OFF) set(ARROW_DATASET_JNI_LIBS arrow_dataset_static) -set(ARROW_DATASET_JNI_SOURCES jni_wrapper.cc jni_util.cc) +set(ARROW_DATASET_JNI_SOURCES jni_wrapper.cc) add_arrow_lib(arrow_dataset_jni BUILD_SHARED @@ -55,11 +55,4 @@ add_arrow_lib(arrow_dataset_jni arrow_static arrow_dataset_java) -add_dependencies(arrow_dataset_jni ${ARROW_DATASET_JNI_LIBRARIES}) - -add_arrow_test(dataset_jni_test - SOURCES - jni_util_test.cc - jni_util.cc - EXTRA_INCLUDES - ${JNI_INCLUDE_DIRS}) +add_dependencies(arrow_dataset_jni ${ARROW_DATASET_JNI_LIBRARIES}) \ No newline at end of file diff --git a/cpp/src/jni/dataset/jni_util.cc b/cpp/src/jni/dataset/jni_util.cc deleted file mode 100644 index 113669a4cf6dd..0000000000000 --- a/cpp/src/jni/dataset/jni_util.cc +++ /dev/null @@ -1,242 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "jni/dataset/jni_util.h" - -#include "arrow/util/logging.h" - -#include - -namespace arrow { -namespace dataset { -namespace jni { - -class ReservationListenableMemoryPool::Impl { - public: - explicit Impl(arrow::MemoryPool* pool, std::shared_ptr listener, - int64_t block_size) - : pool_(pool), - listener_(listener), - block_size_(block_size), - blocks_reserved_(0), - bytes_reserved_(0) {} - - arrow::Status Allocate(int64_t size, uint8_t** out) { - RETURN_NOT_OK(UpdateReservation(size)); - arrow::Status error = pool_->Allocate(size, out); - if (!error.ok()) { - RETURN_NOT_OK(UpdateReservation(-size)); - return error; - } - return arrow::Status::OK(); - } - - arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) { - bool reserved = false; - int64_t diff = new_size - old_size; - if (new_size >= old_size) { - // new_size >= old_size, pre-reserve bytes from listener before allocating - // from underlying pool - RETURN_NOT_OK(UpdateReservation(diff)); - reserved = true; - } - arrow::Status error = pool_->Reallocate(old_size, new_size, ptr); - if (!error.ok()) { - if (reserved) { - // roll back reservations on error - RETURN_NOT_OK(UpdateReservation(-diff)); - } - return error; - } - if (!reserved) { - // otherwise (e.g. new_size < old_size), make updates after calling underlying pool - RETURN_NOT_OK(UpdateReservation(diff)); - } - return arrow::Status::OK(); - } - - void Free(uint8_t* buffer, int64_t size) { - pool_->Free(buffer, size); - // FIXME: See ARROW-11143, currently method ::Free doesn't allow Status return - arrow::Status s = UpdateReservation(-size); - if (!s.ok()) { - ARROW_LOG(FATAL) << "Failed to update reservation while freeing bytes: " - << s.message(); - return; - } - } - - arrow::Status UpdateReservation(int64_t diff) { - int64_t granted = Reserve(diff); - if (granted == 0) { - return arrow::Status::OK(); - } - if (granted < 0) { - RETURN_NOT_OK(listener_->OnRelease(-granted)); - return arrow::Status::OK(); - } - RETURN_NOT_OK(listener_->OnReservation(granted)); - return arrow::Status::OK(); - } - - int64_t Reserve(int64_t diff) { - std::lock_guard lock(mutex_); - bytes_reserved_ += diff; - int64_t new_block_count; - if (bytes_reserved_ == 0) { - new_block_count = 0; - } else { - // ceil to get the required block number - new_block_count = (bytes_reserved_ - 1) / block_size_ + 1; - } - int64_t bytes_granted = (new_block_count - blocks_reserved_) * block_size_; - blocks_reserved_ = new_block_count; - return bytes_granted; - } - - int64_t bytes_allocated() { return pool_->bytes_allocated(); } - - int64_t max_memory() { return pool_->max_memory(); } - - std::string backend_name() { return pool_->backend_name(); } - - std::shared_ptr get_listener() { return listener_; } - - private: - arrow::MemoryPool* pool_; - std::shared_ptr listener_; - int64_t block_size_; - int64_t blocks_reserved_; - int64_t bytes_reserved_; - std::mutex mutex_; -}; - -ReservationListenableMemoryPool::ReservationListenableMemoryPool( - MemoryPool* pool, std::shared_ptr listener, int64_t block_size) { - impl_.reset(new Impl(pool, listener, block_size)); -} - -arrow::Status ReservationListenableMemoryPool::Allocate(int64_t size, uint8_t** out) { - return impl_->Allocate(size, out); -} - -arrow::Status ReservationListenableMemoryPool::Reallocate(int64_t old_size, - int64_t new_size, - uint8_t** ptr) { - return impl_->Reallocate(old_size, new_size, ptr); -} - -void ReservationListenableMemoryPool::Free(uint8_t* buffer, int64_t size) { - return impl_->Free(buffer, size); -} - -int64_t ReservationListenableMemoryPool::bytes_allocated() const { - return impl_->bytes_allocated(); -} - -int64_t ReservationListenableMemoryPool::max_memory() const { - return impl_->max_memory(); -} - -std::string ReservationListenableMemoryPool::backend_name() const { - return impl_->backend_name(); -} - -std::shared_ptr ReservationListenableMemoryPool::get_listener() { - return impl_->get_listener(); -} - -ReservationListenableMemoryPool::~ReservationListenableMemoryPool() {} - -jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { - jclass local_class = env->FindClass(class_name); - jclass global_class = (jclass)env->NewGlobalRef(local_class); - env->DeleteLocalRef(local_class); - return global_class; -} - -arrow::Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, - const char* sig) { - jmethodID ret = env->GetMethodID(this_class, name, sig); - if (ret == nullptr) { - std::string error_message = "Unable to find method " + std::string(name) + - " within signature" + std::string(sig); - return arrow::Status::Invalid(error_message); - } - return ret; -} - -arrow::Result GetStaticMethodID(JNIEnv* env, jclass this_class, - const char* name, const char* sig) { - jmethodID ret = env->GetStaticMethodID(this_class, name, sig); - if (ret == nullptr) { - std::string error_message = "Unable to find static method " + std::string(name) + - " within signature" + std::string(sig); - return arrow::Status::Invalid(error_message); - } - return ret; -} - -std::string JStringToCString(JNIEnv* env, jstring string) { - if (string == nullptr) { - return std::string(); - } - const char* chars = env->GetStringUTFChars(string, nullptr); - std::string ret(chars); - env->ReleaseStringUTFChars(string, chars); - return ret; -} - -std::vector ToStringVector(JNIEnv* env, jobjectArray& str_array) { - int length = env->GetArrayLength(str_array); - std::vector vector; - for (int i = 0; i < length; i++) { - auto string = reinterpret_cast(env->GetObjectArrayElement(str_array, i)); - vector.push_back(JStringToCString(env, string)); - } - return vector; -} - -arrow::Result ToSchemaByteArray(JNIEnv* env, - std::shared_ptr schema) { - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr buffer, - arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool())) - - jbyteArray out = env->NewByteArray(buffer->size()); - auto src = reinterpret_cast(buffer->data()); - env->SetByteArrayRegion(out, 0, buffer->size(), src); - return out; -} - -arrow::Result> FromSchemaByteArray( - JNIEnv* env, jbyteArray schemaBytes) { - arrow::ipc::DictionaryMemo in_memo; - int schemaBytes_len = env->GetArrayLength(schemaBytes); - jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr); - auto serialized_schema = std::make_shared( - reinterpret_cast(schemaBytes_data), schemaBytes_len); - arrow::io::BufferReader buf_reader(serialized_schema); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema, - arrow::ipc::ReadSchema(&buf_reader, &in_memo)) - env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT); - return schema; -} - -} // namespace jni -} // namespace dataset -} // namespace arrow diff --git a/cpp/src/jni/dataset/jni_wrapper.cc b/cpp/src/jni/dataset/jni_wrapper.cc index 041542804ce86..93f382059fc6b 100644 --- a/cpp/src/jni/dataset/jni_wrapper.cc +++ b/cpp/src/jni/dataset/jni_wrapper.cc @@ -23,10 +23,9 @@ #include "arrow/dataset/file_base.h" #include "arrow/filesystem/localfs.h" #include "arrow/ipc/api.h" +#include "arrow/jniutil/jni_util.h" #include "arrow/util/iterator.h" -#include "jni/dataset/jni_util.h" - #include "org_apache_arrow_dataset_file_JniWrapper.h" #include "org_apache_arrow_dataset_jni_JniWrapper.h" #include "org_apache_arrow_dataset_jni_NativeMemoryPool.h" @@ -37,14 +36,8 @@ jclass illegal_access_exception_class; jclass illegal_argument_exception_class; jclass runtime_exception_class; -jclass record_batch_handle_class; -jclass record_batch_handle_field_class; -jclass record_batch_handle_buffer_class; jclass java_reservation_listener_class; -jmethodID record_batch_handle_constructor; -jmethodID record_batch_handle_field_constructor; -jmethodID record_batch_handle_buffer_constructor; jmethodID reserve_memory_method; jmethodID unreserve_memory_method; @@ -89,7 +82,7 @@ arrow::Result> GetFileFormat( } } -class ReserveFromJava : public arrow::dataset::jni::ReservationListener { +class ReserveFromJava : public arrow::jniutil::ReservationListener { public: ReserveFromJava(JavaVM* vm, jobject java_reservation_listener) : vm_(vm), java_reservation_listener_(java_reservation_listener) {} @@ -100,11 +93,7 @@ class ReserveFromJava : public arrow::dataset::jni::ReservationListener { return arrow::Status::Invalid("JNIEnv was not attached to current thread"); } env->CallObjectMethod(java_reservation_listener_, reserve_memory_method, size); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return arrow::Status::Invalid("Error calling Java side reservation listener"); - } + RETURN_NOT_OK(arrow::jniutil::CheckException(env)); return arrow::Status::OK(); } @@ -114,11 +103,7 @@ class ReserveFromJava : public arrow::dataset::jni::ReservationListener { return arrow::Status::Invalid("JNIEnv was not attached to current thread"); } env->CallObjectMethod(java_reservation_listener_, unreserve_memory_method, size); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return arrow::Status::Invalid("Error calling Java side reservation listener"); - } + RETURN_NOT_OK(arrow::jniutil::CheckException(env)); return arrow::Status::OK(); } @@ -169,18 +154,18 @@ class DisposableScannerAdaptor { } // namespace -using arrow::dataset::jni::CreateGlobalClassReference; -using arrow::dataset::jni::CreateNativeRef; -using arrow::dataset::jni::FromSchemaByteArray; -using arrow::dataset::jni::GetMethodID; -using arrow::dataset::jni::JStringToCString; -using arrow::dataset::jni::ReleaseNativeRef; -using arrow::dataset::jni::RetrieveNativeInstance; -using arrow::dataset::jni::ToSchemaByteArray; -using arrow::dataset::jni::ToStringVector; +using arrow::jniutil::CreateGlobalClassReference; +using arrow::jniutil::CreateNativeRef; +using arrow::jniutil::FromSchemaByteArray; +using arrow::jniutil::GetMethodID; +using arrow::jniutil::JStringToCString; +using arrow::jniutil::ReleaseNativeRef; +using arrow::jniutil::RetrieveNativeInstance; +using arrow::jniutil::ToSchemaByteArray; +using arrow::jniutil::ToStringVector; -using arrow::dataset::jni::ReservationListenableMemoryPool; -using arrow::dataset::jni::ReservationListener; +using arrow::jniutil::ReservationListenableMemoryPool; +using arrow::jniutil::ReservationListener; #define JNI_METHOD_START try { // macro ended @@ -205,34 +190,10 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;"); runtime_exception_class = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); - - record_batch_handle_class = - CreateGlobalClassReference(env, - "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle;"); - record_batch_handle_field_class = - CreateGlobalClassReference(env, - "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle$Field;"); - record_batch_handle_buffer_class = - CreateGlobalClassReference(env, - "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle$Buffer;"); java_reservation_listener_class = CreateGlobalClassReference(env, "Lorg/apache/arrow/" "dataset/jni/ReservationListener;"); - - record_batch_handle_constructor = - JniGetOrThrow(GetMethodID(env, record_batch_handle_class, "", - "(J[Lorg/apache/arrow/dataset/" - "jni/NativeRecordBatchHandle$Field;" - "[Lorg/apache/arrow/dataset/" - "jni/NativeRecordBatchHandle$Buffer;)V")); - record_batch_handle_field_constructor = - JniGetOrThrow(GetMethodID(env, record_batch_handle_field_class, "", "(JJ)V")); - record_batch_handle_buffer_constructor = JniGetOrThrow( - GetMethodID(env, record_batch_handle_buffer_class, "", "(JJJJ)V")); reserve_memory_method = JniGetOrThrow(GetMethodID(env, java_reservation_listener_class, "reserve", "(J)V")); unreserve_memory_method = JniGetOrThrow( @@ -250,9 +211,6 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(illegal_access_exception_class); env->DeleteGlobalRef(illegal_argument_exception_class); env->DeleteGlobalRef(runtime_exception_class); - env->DeleteGlobalRef(record_batch_handle_class); - env->DeleteGlobalRef(record_batch_handle_field_class); - env->DeleteGlobalRef(record_batch_handle_buffer_class); env->DeleteGlobalRef(java_reservation_listener_class); default_memory_pool_id = -1L; @@ -458,9 +416,9 @@ Java_org_apache_arrow_dataset_jni_JniWrapper_getSchemaFromScanner(JNIEnv* env, j /* * Class: org_apache_arrow_dataset_jni_JniWrapper * Method: nextRecordBatch - * Signature: (J)Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle; + * Signature: (J)[B */ -JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch( +JNIEXPORT jbyteArray JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch( JNIEnv* env, jobject, jlong scanner_id) { JNI_METHOD_START std::shared_ptr scanner_adaptor = @@ -471,12 +429,8 @@ JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecor if (record_batch == nullptr) { return nullptr; // stream ended } - std::shared_ptr schema = record_batch->schema(); - jobjectArray field_array = - env->NewObjectArray(schema->num_fields(), record_batch_handle_field_class, nullptr); - - std::vector> buffers; - for (int i = 0; i < schema->num_fields(); ++i) { + std::vector> offset_zeroed_arrays; + for (int i = 0; i < record_batch->num_columns(); ++i) { // TODO: If the array has an offset then we need to de-offset the array // in order for it to be properly consumed on the Java end. // This forces a copy, it would be nice to avoid this if Java @@ -485,43 +439,19 @@ JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecor // // Generally a non-zero offset will occur whenever the scanner batch // size is smaller than the batch size of the underlying files. - auto column = record_batch->column(i); - if (column->offset() != 0) { - column = JniGetOrThrow(arrow::Concatenate({column})); - } - auto dataArray = column->data(); - jobject field = env->NewObject(record_batch_handle_field_class, - record_batch_handle_field_constructor, - column->length(), column->null_count()); - env->SetObjectArrayElement(field_array, i, field); - - for (auto& buffer : dataArray->buffers) { - buffers.push_back(buffer); - } - } - - jobjectArray buffer_array = - env->NewObjectArray(buffers.size(), record_batch_handle_buffer_class, nullptr); - - for (size_t j = 0; j < buffers.size(); ++j) { - auto buffer = buffers[j]; - uint8_t* data = nullptr; - int64_t size = 0; - int64_t capacity = 0; - if (buffer != nullptr) { - data = (uint8_t*)buffer->data(); - size = buffer->size(); - capacity = buffer->capacity(); + std::shared_ptr array = record_batch->column(i); + if (array->offset() == 0) { + offset_zeroed_arrays.push_back(array); + continue; } - jobject buffer_handle = env->NewObject(record_batch_handle_buffer_class, - record_batch_handle_buffer_constructor, - CreateNativeRef(buffer), data, size, capacity); - env->SetObjectArrayElement(buffer_array, j, buffer_handle); + std::shared_ptr offset_zeroed = + JniGetOrThrow(arrow::Concatenate({array})); + offset_zeroed_arrays.push_back(offset_zeroed); } - jobject ret = env->NewObject(record_batch_handle_class, record_batch_handle_constructor, - record_batch->num_rows(), field_array, buffer_array); - return ret; + std::shared_ptr offset_zeroed_batch = arrow::RecordBatch::Make( + record_batch->schema(), record_batch->num_rows(), offset_zeroed_arrays); + return JniGetOrThrow(arrow::jniutil::SerializeUnsafeFromNative(env, offset_zeroed_batch)); JNI_METHOD_END(nullptr) } @@ -555,3 +485,35 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory( return CreateNativeRef(d); JNI_METHOD_END(-1L) } + +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: newJniGlobalReference + * Signature: (Ljava/lang/Object;)J + */ +JNIEXPORT jlong JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_newJniGlobalReference(JNIEnv* env, jobject, + jobject referent) { + JNI_METHOD_START + return reinterpret_cast(env->NewGlobalRef(referent)); + JNI_METHOD_END(-1L) +} + +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: newJniMethodReference + * Signature: (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_newJniMethodReference(JNIEnv* env, jobject, + jstring class_sig, + jstring method_name, + jstring method_sig) { + JNI_METHOD_START + jclass clazz = env->FindClass(JStringToCString(env, class_sig).data()); + jmethodID jmethod_id = + env->GetMethodID(clazz, JStringToCString(env, method_name).data(), + JStringToCString(env, method_sig).data()); + return reinterpret_cast(jmethod_id); + JNI_METHOD_END(-1L) +} diff --git a/java/dataset/CMakeLists.txt b/java/dataset/CMakeLists.txt index 07e2d0ae8fc3d..5b6e4a9ce241a 100644 --- a/java/dataset/CMakeLists.txt +++ b/java/dataset/CMakeLists.txt @@ -33,7 +33,6 @@ message("generating headers to ${JNI_HEADERS_DIR}") add_jar(arrow_dataset_java src/main/java/org/apache/arrow/dataset/jni/JniLoader.java src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java - src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java src/main/java/org/apache/arrow/dataset/file/JniWrapper.java src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index 9a80a547c1c17..b19fb980edc28 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -38,6 +38,11 @@ compile ${arrow.vector.classifier} + + org.apache.arrow + arrow-format + ${project.version} + org.apache.arrow arrow-memory-core @@ -50,6 +55,11 @@ ${project.version} test + + com.google.flatbuffers + flatbuffers-java + 1.12.0 + org.apache.parquet parquet-avro diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java index f69d8205192c0..3c90fadb4de6e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java @@ -45,4 +45,20 @@ private JniWrapper() { */ public native long makeFileSystemDatasetFactory(String uri, int fileFormat); + /** + * Create a Jni global reference for the object. + * @param object the input object + * @return the native pointer of global reference object. + */ + public native long newJniGlobalReference(Object object); + + /** + * Create a Jni method reference. + * @param classSignature signature of the class defining the target method + * @param methodName method name + * @param methodSignature signature of the target method + * @return the native pointer of method reference object. + */ + public native long newJniMethodReference(String classSignature, String methodName, + String methodSignature); } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java index 7dd54e7648f28..a8e1c73be828a 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java @@ -93,10 +93,13 @@ private JniWrapper() { /** * Read next record batch from the specified scanner. + * * @param scannerId the native pointer of the arrow::dataset::Scanner instance. - * @return an instance of {@link NativeRecordBatchHandle} describing the overall layout of the native record batch. + * @return a flatbuffers-serialized + * {@link org.apache.arrow.flatbuf.Message} describing + * the overall layout of the native record batch. */ - public native NativeRecordBatchHandle nextRecordBatch(long scannerId); + public native byte[] nextRecordBatch(long scannerId); /** * Release the Buffer by destroying its reference held by JNI wrapper. diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java deleted file mode 100644 index dd90fd1c1ddb7..0000000000000 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.arrow.dataset.jni; - -import java.util.Arrays; -import java.util.List; - -/** - * Hold pointers to a Arrow C++ RecordBatch. - */ -public class NativeRecordBatchHandle { - - private final long numRows; - private final List fields; - private final List buffers; - - /** - * Constructor. - * - * @param numRows Total row number of the associated RecordBatch - * @param fields Metadata of fields - * @param buffers Retained Arrow buffers - */ - public NativeRecordBatchHandle(long numRows, Field[] fields, Buffer[] buffers) { - this.numRows = numRows; - this.fields = Arrays.asList(fields); - this.buffers = Arrays.asList(buffers); - } - - /** - * Returns the total row number of the associated RecordBatch. - * @return Total row number of the associated RecordBatch. - */ - public long getNumRows() { - return numRows; - } - - /** - * Returns Metadata of fields. - * @return Metadata of fields. - */ - public List getFields() { - return fields; - } - - /** - * Returns the buffers. - * @return Retained Arrow buffers. - */ - public List getBuffers() { - return buffers; - } - - /** - * Field metadata. - */ - public static class Field { - public final long length; - public final long nullCount; - - public Field(long length, long nullCount) { - this.length = length; - this.nullCount = nullCount; - } - } - - /** - * Pointers and metadata of the targeted Arrow buffer. - */ - public static class Buffer { - public final long nativeInstanceId; - public final long memoryAddress; - public final long size; - public final long capacity; - - /** - * Constructor. - * - * @param nativeInstanceId Native instance's id - * @param memoryAddress Memory address of the first byte - * @param size Size (in bytes) - * @param capacity Capacity (in bytes) - */ - public Buffer(long nativeInstanceId, long memoryAddress, long size, long capacity) { - this.nativeInstanceId = nativeInstanceId; - this.memoryAddress = memoryAddress; - this.size = size; - this.capacity = capacity; - } - } -} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java index 24c298067afde..ea2c9edf4ec30 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java @@ -18,23 +18,15 @@ package org.apache.arrow.dataset.jni; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.NoSuchElementException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.stream.Collectors; import org.apache.arrow.dataset.scanner.ScanTask; import org.apache.arrow.dataset.scanner.Scanner; -import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.BufferLedger; -import org.apache.arrow.memory.NativeUnderlyingMemory; -import org.apache.arrow.memory.util.LargeMemoryUtil; -import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.SchemaUtility; @@ -81,39 +73,21 @@ public boolean hasNext() { if (peek != null) { return true; } - final NativeRecordBatchHandle handle; + final byte[] bytes; readLock.lock(); try { if (closed) { throw new NativeInstanceReleasedException(); } - handle = JniWrapper.get().nextRecordBatch(scannerId); + bytes = JniWrapper.get().nextRecordBatch(scannerId); } finally { readLock.unlock(); } - if (handle == null) { + if (bytes == null) { return false; } - final ArrayList buffers = new ArrayList<>(); - for (NativeRecordBatchHandle.Buffer buffer : handle.getBuffers()) { - final BufferAllocator allocator = context.getAllocator(); - final int size = LargeMemoryUtil.checkedCastToInt(buffer.size); - final NativeUnderlyingMemory am = NativeUnderlyingMemory.create(allocator, - size, buffer.nativeInstanceId, buffer.memoryAddress); - BufferLedger ledger = am.associate(allocator); - ArrowBuf buf = new ArrowBuf(ledger, null, size, buffer.memoryAddress); - buffers.add(buf); - } - - try { - final int numRows = LargeMemoryUtil.checkedCastToInt(handle.getNumRows()); - peek = new ArrowRecordBatch(numRows, handle.getFields().stream() - .map(field -> new ArrowFieldNode(field.length, field.nullCount)) - .collect(Collectors.toList()), buffers); - return true; - } finally { - buffers.forEach(buffer -> buffer.getReferenceManager().release()); - } + peek = UnsafeRecordBatchSerializer.deserializeUnsafe(context.getAllocator(), bytes); + return true; } @Override diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java new file mode 100644 index 0000000000000..11402738fbc1a --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.jni; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.arrow.dataset.file.JniWrapper; +import org.apache.arrow.flatbuf.Buffer; +import org.apache.arrow.flatbuf.FieldNode; +import org.apache.arrow.flatbuf.KeyValue; +import org.apache.arrow.flatbuf.Message; +import org.apache.arrow.flatbuf.MessageHeader; +import org.apache.arrow.flatbuf.RecordBatch; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.BufferLedger; +import org.apache.arrow.memory.NativeUnderlyingMemory; +import org.apache.arrow.memory.util.LargeMemoryUtil; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.ArrowBodyCompression; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowMessage; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.FBSerializable; +import org.apache.arrow.vector.ipc.message.FBSerializables; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageMetadataResult; +import org.apache.arrow.vector.ipc.message.MessageSerializer; + +import com.google.flatbuffers.FlatBufferBuilder; + +/** + * A set of serialization utility methods against {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch}. + * + *

The utility should be used only in JNI case since the record batch + * to serialize should keep alive during the life cycle of its deserialized + * native record batch. We made this design for achieving zero-copy of + * the buffer bodies. + */ +public class UnsafeRecordBatchSerializer { + + /** + * This is in response to native arrow::Buffer instance's destructor after + * Java side {@link ArrowBuf} is transferred to native via JNI. At here we + * think of a C++ transferred buffer holding one Java side buffer's reference + * count so memory management system from Java side can be able to + * know when native code finishes using the buffer. This way the corresponding + * allocated memory space can be correctly collected. + * + * @see UnsafeRecordBatchSerializer#serializeUnsafe(ArrowRecordBatch) + */ + private static class TransferredReferenceCleaner implements Runnable { + private static final long NATIVE_METHOD_REF = + JniWrapper.get().newJniMethodReference("Ljava/lang/Runnable;", "run", "()V"); + private final ArrowBuf buf; + + private TransferredReferenceCleaner(ArrowBuf buf) { + this.buf = buf; + } + + @Override + public void run() { + buf.getReferenceManager().release(); + } + } + + /** + * Deserialize from native serialized bytes to {@link ArrowRecordBatch} using flatbuffers. + * The input byte array should be written from native code and of type + * {@link Message} + * in which a native buffer ID is required in custom metadata. + * + * @param allocator Allocator that the deserialized buffer should be associated with + * @param bytes flatbuffers byte array + * @return the deserialized record batch + * @see NativeUnderlyingMemory + */ + public static ArrowRecordBatch deserializeUnsafe( + BufferAllocator allocator, + byte[] bytes) { + final ReadChannel metaIn = new ReadChannel( + Channels.newChannel(new ByteArrayInputStream(bytes))); + + final Message metaMessage; + try { + final MessageMetadataResult result = MessageSerializer.readMessage(metaIn); + Preconditions.checkNotNull(result); + metaMessage = result.getMessage(); + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize record batch metadata", e); + } + final RecordBatch batchMeta = (RecordBatch) metaMessage.header(new RecordBatch()); + Preconditions.checkNotNull(batchMeta); + if (batchMeta.buffersLength() != metaMessage.customMetadataLength()) { + throw new IllegalArgumentException("Buffer count mismatch between metadata and native managed refs"); + } + + final ArrayList buffers = new ArrayList<>(); + for (int i = 0; i < batchMeta.buffersLength(); i++) { + final Buffer bufferMeta = batchMeta.buffers(i); + final KeyValue keyValue = metaMessage.customMetadata(i); // custom metadata containing native buffer refs + final byte[] refDecoded = Base64.getDecoder().decode(keyValue.value()); + final long nativeBufferRef = ByteBuffer.wrap(refDecoded).order(ByteOrder.LITTLE_ENDIAN).getLong(); + final int size = LargeMemoryUtil.checkedCastToInt(bufferMeta.length()); + final NativeUnderlyingMemory am = NativeUnderlyingMemory.create(allocator, + size, nativeBufferRef, bufferMeta.offset()); + BufferLedger ledger = am.associate(allocator); + ArrowBuf buf = new ArrowBuf(ledger, null, size, bufferMeta.offset()); + buffers.add(buf); + } + + try { + final int numRows = LargeMemoryUtil.checkedCastToInt(batchMeta.length()); + final List nodes = new ArrayList<>(batchMeta.nodesLength()); + for (int i = 0; i < batchMeta.nodesLength(); i++) { + final FieldNode node = batchMeta.nodes(i); + nodes.add(new ArrowFieldNode(node.length(), node.nullCount())); + } + return new ArrowRecordBatch(numRows, nodes, buffers); + } finally { + buffers.forEach(buffer -> buffer.getReferenceManager().release()); + } + } + + /** + * Serialize from {@link ArrowRecordBatch} to flatbuffers bytes for native use. A cleaner callback + * {@link TransferredReferenceCleaner} will be created for each individual serialized + * buffer. The callback should be invoked once the buffer is collected from native code. + * We use the callback to decrease reference count of Java side {@link ArrowBuf} here. + * + * @param batch input record batch + * @return serialized bytes + * @see TransferredReferenceCleaner + */ + public static byte[] serializeUnsafe(ArrowRecordBatch batch) { + final ArrowBodyCompression bodyCompression = batch.getBodyCompression(); + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + throw new UnsupportedOperationException("Could not serialize compressed buffers"); + } + + final FlatBufferBuilder builder = new FlatBufferBuilder(); + List buffers = batch.getBuffers(); + int[] metadataOffsets = new int[buffers.size() * 2]; + for (int i = 0, buffersSize = buffers.size(); i < buffersSize; i++) { + ArrowBuf buffer = buffers.get(i); + final TransferredReferenceCleaner cleaner = new TransferredReferenceCleaner(buffer); + // cleaner object ref + long objectRefValue = JniWrapper.get().newJniGlobalReference(cleaner); + byte[] objectRefBytes = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN) + .putLong(objectRefValue).array(); + metadataOffsets[i * 2] = KeyValue.createKeyValue(builder, builder.createString("JAVA_BUFFER_CO_REF_" + i), + builder.createString(Base64.getEncoder().encodeToString(objectRefBytes))); + // cleaner method ref + long methodRefValue = TransferredReferenceCleaner.NATIVE_METHOD_REF; + byte[] methodRefBytes = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN) + .putLong(methodRefValue).array(); + metadataOffsets[i * 2 + 1] = + KeyValue.createKeyValue(builder, builder.createString("JAVA_BUFFER_CM_REF_" + i), + builder.createString(Base64.getEncoder().encodeToString(methodRefBytes))); + } + final ArrowMessage unsafeRecordMessage = new UnsafeRecordBatchMetadataMessage(batch); + final int batchOffset = unsafeRecordMessage.writeTo(builder); + final int customMetadataOffset = Message.createCustomMetadataVector(builder, metadataOffsets); + Message.startMessage(builder); + Message.addHeaderType(builder, unsafeRecordMessage.getMessageType()); + Message.addHeader(builder, batchOffset); + Message.addVersion(builder, IpcOption.DEFAULT.metadataVersion.toFlatbufID()); + Message.addBodyLength(builder, unsafeRecordMessage.computeBodyLength()); + Message.addCustomMetadata(builder, customMetadataOffset); + builder.finish(Message.endMessage(builder)); + final ByteBuffer metaBuffer = builder.dataBuffer(); + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + MessageSerializer.writeMessageBuffer(new WriteChannel(Channels.newChannel(out)), metaBuffer.remaining(), + metaBuffer, IpcOption.DEFAULT); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize Java record batch", e); + } + return out.toByteArray(); + } + + /** + * IPC message for record batches that are based on unsafe shared virtual memory. + */ + public static class UnsafeRecordBatchMetadataMessage implements ArrowMessage { + private ArrowRecordBatch delegated; + + public UnsafeRecordBatchMetadataMessage(ArrowRecordBatch delegated) { + this.delegated = delegated; + } + + @Override + public long computeBodyLength() { + return 0L; + } + + @Override + public T accepts(ArrowMessageVisitor visitor) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getMessageType() { + return MessageHeader.RecordBatch; + } + + @Override + public int writeTo(FlatBufferBuilder builder) { + final List nodes = delegated.getNodes(); + final List buffers = delegated.getBuffers(); + final ArrowBodyCompression bodyCompression = delegated.getBodyCompression(); + final int length = delegated.getLength(); + RecordBatch.startNodesVector(builder, nodes.size()); + int nodesOffset = FBSerializables.writeAllStructsToVector(builder, nodes); + RecordBatch.startBuffersVector(builder, buffers.size()); + int buffersOffset = FBSerializables.writeAllStructsToVector(builder, buffers.stream() + .map(buf -> (FBSerializable) b -> Buffer.createBuffer(b, buf.memoryAddress(), + buf.getReferenceManager().getSize())) + .collect(Collectors.toList())); + int compressOffset = 0; + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + compressOffset = bodyCompression.writeTo(builder); + } + RecordBatch.startRecordBatch(builder); + RecordBatch.addLength(builder, length); + RecordBatch.addNodes(builder, nodesOffset); + RecordBatch.addBuffers(builder, buffersOffset); + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + RecordBatch.addCompression(builder, compressOffset); + } + return RecordBatch.endRecordBatch(builder); + } + + @Override + public void close() throws Exception { + delegated.close(); + } + } +} \ No newline at end of file