Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-36026: [C++][ORC] Catch all ORC exceptions to avoid crash #40697

Merged
merged 6 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 39 additions & 19 deletions cpp/src/arrow/adapters/orc/adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@
#include "arrow/adapters/orc/adapter.h"

#include <algorithm>
#include <cstdint>
#include <functional>
#include <filesystem>
#include <list>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "arrow/adapters/orc/util.h"
#include "arrow/buffer.h"
#include "arrow/builder.h"
#include "arrow/io/interfaces.h"
#include "arrow/memory_pool.h"
Expand All @@ -37,14 +34,11 @@
#include "arrow/table.h"
#include "arrow/table_builder.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/decimal.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/macros.h"
#include "arrow/util/range.h"
#include "arrow/util/visibility.h"
#include "orc/Exceptions.hh"

// alias to not interfere with nested orc namespace
Expand Down Expand Up @@ -80,6 +74,12 @@ namespace liborc = orc;
} \
catch (const liborc::NotImplementedYet& e) { \
return Status::NotImplemented(e.what()); \
} \
catch (const std::exception& e) { \
wgtmac marked this conversation as resolved.
Show resolved Hide resolved
return Status::UnknownError(e.what()); \
} \
catch (...) { \
return Status::UnknownError("ORC error"); \
Comment on lines +81 to +82
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? (Is catch (const std::exception& e) enough?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually it is enough. But just in case?

}

#define ORC_CATCH_NOT_OK(_s) \
Expand Down Expand Up @@ -173,7 +173,7 @@ class OrcStripeReader : public RecordBatchReader {
int64_t batch_size_;
};

liborc::RowReaderOptions default_row_reader_options() {
liborc::RowReaderOptions DefaultRowReaderOptions() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function was added by me in the past and now I have changed its name to follow the same style.

liborc::RowReaderOptions options;
// Orc timestamp type is error-prone since it serializes values in the writer timezone
// and reads them back in the reader timezone. To avoid this, both the Apache Orc C++
Expand All @@ -183,6 +183,24 @@ liborc::RowReaderOptions default_row_reader_options() {
return options;
}

// Proactively check timezone database availability for ORC versions older than 2.0.0
Status CheckTimeZoneDatabaseAvailability() {
if (GetOrcMajorVersion() >= 2) {
return Status::OK();
}
auto tz_dir = std::getenv("TZDIR");
bool is_tzdb_avaiable = tz_dir != nullptr
? std::filesystem::exists(tz_dir)
: std::filesystem::exists("/usr/share/zoneinfo");
if (!is_tzdb_avaiable) {
return Status::Invalid(
"IANA time zone database is unavailable but required by ORC."
" Please install it to /usr/share/zoneinfo or set TZDIR env to the installed"
" directory");
}
return Status::OK();
}

} // namespace

class ORCFileReader::Impl {
Expand Down Expand Up @@ -332,47 +350,47 @@ class ORCFileReader::Impl {
}

Result<std::shared_ptr<Table>> Read() {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
ARROW_ASSIGN_OR_RAISE(auto schema, ReadSchema());
return ReadTable(opts, schema);
}

Result<std::shared_ptr<Table>> Read(const std::shared_ptr<Schema>& schema) {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
return ReadTable(opts, schema);
}

Result<std::shared_ptr<Table>> Read(const std::vector<int>& include_indices) {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
RETURN_NOT_OK(SelectIndices(&opts, include_indices));
ARROW_ASSIGN_OR_RAISE(auto schema, ReadSchema(opts));
return ReadTable(opts, schema);
}

Result<std::shared_ptr<Table>> Read(const std::vector<std::string>& include_names) {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
RETURN_NOT_OK(SelectNames(&opts, include_names));
ARROW_ASSIGN_OR_RAISE(auto schema, ReadSchema(opts));
return ReadTable(opts, schema);
}

Result<std::shared_ptr<Table>> Read(const std::shared_ptr<Schema>& schema,
const std::vector<int>& include_indices) {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
RETURN_NOT_OK(SelectIndices(&opts, include_indices));
return ReadTable(opts, schema);
}

Result<std::shared_ptr<RecordBatch>> ReadStripe(int64_t stripe) {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
RETURN_NOT_OK(SelectStripe(&opts, stripe));
ARROW_ASSIGN_OR_RAISE(auto schema, ReadSchema(opts));
return ReadBatch(opts, schema, stripes_[static_cast<size_t>(stripe)].num_rows);
}

Result<std::shared_ptr<RecordBatch>> ReadStripe(
int64_t stripe, const std::vector<int>& include_indices) {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
RETURN_NOT_OK(SelectIndices(&opts, include_indices));
RETURN_NOT_OK(SelectStripe(&opts, stripe));
ARROW_ASSIGN_OR_RAISE(auto schema, ReadSchema(opts));
Expand All @@ -381,7 +399,7 @@ class ORCFileReader::Impl {

Result<std::shared_ptr<RecordBatch>> ReadStripe(
int64_t stripe, const std::vector<std::string>& include_names) {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
RETURN_NOT_OK(SelectNames(&opts, include_names));
RETURN_NOT_OK(SelectStripe(&opts, stripe));
ARROW_ASSIGN_OR_RAISE(auto schema, ReadSchema(opts));
Expand Down Expand Up @@ -487,7 +505,7 @@ class ORCFileReader::Impl {
return nullptr;
}

liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
if (!include_indices.empty()) {
RETURN_NOT_OK(SelectIndices(&opts, include_indices));
}
Expand All @@ -508,7 +526,7 @@ class ORCFileReader::Impl {

Result<std::shared_ptr<RecordBatchReader>> GetRecordBatchReader(
int64_t batch_size, const std::vector<std::string>& include_names) {
liborc::RowReaderOptions opts = default_row_reader_options();
liborc::RowReaderOptions opts = DefaultRowReaderOptions();
if (!include_names.empty()) {
RETURN_NOT_OK(SelectNames(&opts, include_names));
}
Expand Down Expand Up @@ -541,6 +559,7 @@ ORCFileReader::~ORCFileReader() {}

Result<std::unique_ptr<ORCFileReader>> ORCFileReader::Open(
const std::shared_ptr<io::RandomAccessFile>& file, MemoryPool* pool) {
RETURN_NOT_OK(CheckTimeZoneDatabaseAvailability());
auto result = std::unique_ptr<ORCFileReader>(new ORCFileReader());
RETURN_NOT_OK(result->impl_->Open(file, pool));
return std::move(result);
Expand Down Expand Up @@ -779,7 +798,7 @@ class ORCFileWriter::Impl {
&(arrow_index_offset[i]), (root->fields)[i]));
}
root->numElements = (root->fields)[0]->numElements;
writer_->add(*batch);
ORC_CATCH_NOT_OK(writer_->add(*batch));
batch->clear();
num_rows -= batch_size;
}
Expand Down Expand Up @@ -807,6 +826,7 @@ ORCFileWriter::ORCFileWriter() { impl_.reset(new ORCFileWriter::Impl()); }

Result<std::unique_ptr<ORCFileWriter>> ORCFileWriter::Open(
io::OutputStream* output_stream, const WriteOptions& writer_options) {
RETURN_NOT_OK(CheckTimeZoneDatabaseAvailability());
std::unique_ptr<ORCFileWriter> result =
std::unique_ptr<ORCFileWriter>(new ORCFileWriter());
Status status = result->impl_->Open(output_stream, writer_options);
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/arrow/adapters/orc/adapter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
#include "arrow/testing/random.h"
#include "arrow/type.h"
#include "arrow/util/io_util.h"
#include "arrow/util/key_value_metadata.h"

namespace liborc = orc;
Expand Down Expand Up @@ -636,6 +638,23 @@ TEST(TestAdapterReadWrite, FieldAttributesRoundTrip) {
AssertSchemaEqual(schema, read_schema, /*check_metadata=*/true);
}

TEST(TestAdapterReadWrite, ThrowWhenTZDBUnavaiable) {
if (adapters::orc::GetOrcMajorVersion() >= 2) {
GTEST_SKIP() << "Only ORC pre-2.0.0 versions have the time zone database check";
}

EnvVarGuard tzdir_guard("TZDIR", "/wrong/path");
const char* expect_str = "IANA time zone database is unavailable but required by ORC";
EXPECT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create(1024));
EXPECT_THAT(
adapters::orc::ORCFileWriter::Open(out_stream.get(), adapters::orc::WriteOptions()),
Raises(StatusCode::Invalid, testing::HasSubstr(expect_str)));
EXPECT_OK_AND_ASSIGN(auto buffer, out_stream->Finish());
EXPECT_THAT(adapters::orc::ORCFileReader::Open(
std::make_shared<io::BufferReader>(buffer), default_memory_pool()),
Raises(StatusCode::Invalid, testing::HasSubstr(expect_str)));
}

// Trivial

class TestORCWriterTrivialNoWrite : public ::testing::Test {};
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/adapters/orc/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include "orc/MemoryPool.hh"
#include "orc/OrcFile.hh"
#include "orc/orc-config.hh"

// alias to not interfere with nested orc namespace
namespace liborc = orc;
Expand Down Expand Up @@ -1220,6 +1221,13 @@ Result<std::shared_ptr<Field>> GetArrowField(const std::string& name,
return field(name, std::move(arrow_type), nullable, std::move(metadata));
}

int GetOrcMajorVersion() {
std::stringstream orc_version(ORC_VERSION);
std::string major_version;
std::getline(orc_version, major_version, '.');
return std::stoi(major_version);
}

} // namespace orc
} // namespace adapters
} // namespace arrow
3 changes: 3 additions & 0 deletions cpp/src/arrow/adapters/orc/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ ARROW_EXPORT Status WriteBatch(const ChunkedArray& chunked_array, int64_t length
int* arrow_chunk_offset, int64_t* arrow_index_offset,
liborc::ColumnVectorBatch* column_vector_batch);

/// \brief Get the major version provided by the official ORC C++ library.
ARROW_EXPORT int GetOrcMajorVersion();

} // namespace orc
} // namespace adapters
} // namespace arrow
Loading