Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 89 additions & 14 deletions src/processing/encryption_sequencer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <cassert>
#include <cstring>
#include <memory>
#include <chrono>
#include <cstdlib>

using namespace dbps::external;
using namespace dbps::enum_utils;
Expand All @@ -42,6 +44,11 @@ namespace {
constexpr const char* ENCRYPTION_MODE_KEY_DATA_PAGE = "encrypt_mode_data_page";
constexpr const char* ENCRYPTION_MODE_PER_BLOCK = "per_block";
constexpr const char* ENCRYPTION_MODE_PER_VALUE = "per_value";

bool ShouldLogStepTimings() {
const char* env = std::getenv("DBPS_LOG_ENCRYPT_TIMING");
return env == nullptr || std::string(env) == "1";
}
}

// Helper function to create encryptor instance
Expand Down Expand Up @@ -126,6 +133,24 @@ bool DataBatchEncryptionSequencer::DecodeAndEncrypt(const std::vector<uint8_t>&

auto encryption_mode_key = GetEncryptionModeKey();

// ++++++ FORCED PER-BLOCK ENCRYPTION ++++++ vvvvvv
if (false) {
std::cout << "+++++ FORCED PER-BLOCK ENCRYPTION +++++" << " datatype_: " << to_string(datatype_) << std::endl;
std::cout << "+++++ FORCED PER-BLOCK ENCRYPTION +++++" << " compression_: " << to_string(compression_) << std::endl;
std::cout << "+++++ FORCED PER-BLOCK ENCRYPTION +++++" << " encoding_: " << to_string(encoding_) << std::endl;
std::cout << "+++++ FORCED PER-BLOCK ENCRYPTION +++++" << " page_type_: " << std::get<std::string>(encoding_attributes_converted_.at("page_type")) << std::endl;
encrypted_result_ = encryptor_->EncryptBlock(plaintext);
if (encrypted_result_.empty()) {
error_stage_ = "encryption";
error_message_ = "Failed to encrypt data";
return false;
}
encryption_metadata_[encryption_mode_key] = ENCRYPTION_MODE_PER_BLOCK;
encryption_metadata_[DBPS_VERSION_KEY] = DBPS_VERSION;
return true;
}
// ++++++ FORCED PER-BLOCK ENCRYPTION ++++++ ^^^^^^

/*
* Note on try-catch block:
* - When fully done, DecodeAndEncrypt will support per-value encryption for all cases, except for
Expand All @@ -137,24 +162,74 @@ bool DataBatchEncryptionSequencer::DecodeAndEncrypt(const std::vector<uint8_t>&
* - Once per-value encryption for all cases is complete, the try-catch block and the call to EncryptBlock must be removed.
*/
try {
// Decompress and split plaintext into level and value bytes
auto [level_bytes, value_bytes] = DecompressAndSplit(
plaintext, compression_, encoding_attributes_converted_);

// Parse value bytes into typed list
auto typed_list = ParseValueBytesIntoTypedList(value_bytes, datatype_, datatype_length_, encoding_);

// Encrypt the typed list and level bytes, then join them into a single encrypted byte vector.
auto encrypted_value_bytes = encryptor_->EncryptValueList(typed_list);
auto encrypted_level_bytes = encryptor_->EncryptBlock(level_bytes);
auto joined_encrypted_bytes = JoinWithLengthPrefix(encrypted_level_bytes, encrypted_value_bytes);

// Compress the joined encrypted bytes
encrypted_result_ = Compress(joined_encrypted_bytes, encrypted_compression_);
const bool log_timings = ShouldLogStepTimings();
using Clock = std::chrono::steady_clock;
std::vector<std::pair<std::string, long long>> timings;

std::vector<uint8_t> level_bytes;
std::vector<uint8_t> value_bytes;
TypedListValues typed_list;
std::vector<uint8_t> encrypted_value_bytes;
std::vector<uint8_t> encrypted_level_bytes;
std::vector<uint8_t> joined_encrypted_bytes;

auto time_step = [&](const char* label, const std::function<void()>& fn) {
if (!log_timings) {
fn();
return;
}
auto start = Clock::now();
fn();
auto end = Clock::now();
auto micros = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
timings.emplace_back(label, micros);
};

time_step("DecompressAndSplit", [&]() {
auto split = DecompressAndSplit(plaintext, compression_, encoding_attributes_converted_);
level_bytes = std::move(split.level_bytes);
value_bytes = std::move(split.value_bytes);
});

time_step("ParseValueBytesIntoTypedList", [&]() {
typed_list = ParseValueBytesIntoTypedList(value_bytes, datatype_, datatype_length_, encoding_);
});

time_step("EncryptValueList", [&]() {
encrypted_value_bytes = encryptor_->EncryptValueList(typed_list);
});

time_step("EncryptBlock(level_bytes)", [&]() {
encrypted_level_bytes = encryptor_->EncryptBlock(level_bytes);
});

time_step("JoinWithLengthPrefix", [&]() {
joined_encrypted_bytes = JoinWithLengthPrefix(encrypted_level_bytes, encrypted_value_bytes);
});

time_step("Compress", [&]() {
encrypted_result_ = Compress(joined_encrypted_bytes, encrypted_compression_);
});

// Set the encryption type to per-value
encryption_metadata_[encryption_mode_key] = ENCRYPTION_MODE_PER_VALUE;
encryption_metadata_[DBPS_VERSION_KEY] = DBPS_VERSION;

std::cout << "+++++ PER-VALUE ENCRYPTION +++++" << " datatype_: " << to_string(datatype_) << std::endl;
std::cout << "+++++ PER-VALUE ENCRYPTION +++++" << " compression_: " << to_string(compression_) << std::endl;
std::cout << "+++++ PER-VALUE ENCRYPTION +++++" << " encoding_: " << to_string(encoding_) << std::endl;
std::cout << "+++++ PER-VALUE ENCRYPTION +++++" << " page_type_: " << std::get<std::string>(encoding_attributes_converted_.at("page_type")) << std::endl;
std::cout << "+++++ PER-VALUE ENCRYPTION +++++" << " encrypted_compression_: " << to_string(encrypted_compression_) << std::endl;
const auto typed_list_size = std::visit([](const auto& values) { return values.size(); }, typed_list);
std::cout << "+++++ PER-VALUE ENCRYPTION +++++" << " typed_list size: " << typed_list_size << std::endl;

if (log_timings) {
std::cout << "+++++ DecodeAndEncrypt timings (microseconds) +++++\n";
for (const auto& entry : timings) {
std::cout << " " << entry.first << ": " << entry.second << "\n";
}
}

return true;
}
// Allow fallback to per-block encryption, only for explicitly unsupported conditions. See note above.
Expand Down
83 changes: 63 additions & 20 deletions src/processing/encryptors/basic_encryptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "../../common/enum_utils.h"
#include <functional>
#include <iostream>
#include <cstdlib>
#include <chrono>
#include "../value_encryption_utils.h"

using namespace dbps::value_encryption_utils;
Expand Down Expand Up @@ -52,6 +54,17 @@ namespace {
std::vector<uint8_t> DecryptByteArray(const std::vector<uint8_t>& data, const std::string& key_id) {
return EncryptByteArray(data, key_id); // for XOR encryption, decryption is the same as encryption
}

bool ShouldLogValueEncryption() {
return false;
// const char* env = std::getenv("DBPS_LOG_VALUE_ENCRYPTION");
// return env != nullptr && std::string(env) == "1";
}

bool ShouldLogValueEncryptionTiming() {
const char* env = std::getenv("DBPS_LOG_VALUE_ENCRYPT_TIMING");
return env == nullptr || std::string(env) == "1";
}
}

std::vector<uint8_t> BasicEncryptor::EncryptBlock(const std::vector<uint8_t>& data) {
Expand All @@ -66,23 +79,40 @@ std::vector<uint8_t> BasicEncryptor::DecryptBlock(const std::vector<uint8_t>& da
std::vector<uint8_t> BasicEncryptor::EncryptValueList(
const TypedListValues& typed_list) {

// Printout the typed list.
auto print_result = TypedListToString(typed_list);
if (print_result.length() > 1000) {
std::cout << "Encrypt value - Decoded plaintext data (first 1000 chars):\n"
<< print_result.substr(0, 1000) << "...";
} else {
std::cout << "Encrypt value - Decoded plaintext data:\n" << print_result;
}
const bool log_timings = ShouldLogValueEncryptionTiming();
using Clock = std::chrono::steady_clock;
std::vector<std::pair<std::string, long long>> timings;

// Printout the additional context parameters.
std::cout << "Context parameters:\n"
<< " column_name: " << column_name_ << "\n"
<< " user_id: " << user_id_ << "\n"
<< " key_id: " << key_id_ << "\n"
<< " application_context: " << application_context_ << "\n"
<< " datatype: " << dbps::enum_utils::to_string(datatype_) << "\n"
<< std::endl;
auto time_step = [&](const char* label, const std::function<void()>& fn) {
if (!log_timings) {
fn();
return;
}
auto start = Clock::now();
fn();
auto end = Clock::now();
auto micros = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
timings.emplace_back(label, micros);
};

if (ShouldLogValueEncryption()) {
// Printout the typed list.
auto print_result = TypedListToString(typed_list);
if (print_result.length() > 1000) {
std::cout << "Encrypt value - Decoded plaintext data (first 1000 chars):\n"
<< print_result.substr(0, 1000) << "...\n";
} else {
std::cout << "Encrypt value - Decoded plaintext data:\n" << print_result << "\n";
}

// Printout the additional context parameters.
std::cout << "Context parameters:\n"
<< " column_name: " << column_name_ << "\n"
<< " user_id: " << user_id_ << "\n"
<< " key_id: " << key_id_ << "\n"
<< " application_context: " << application_context_ << "\n"
<< " datatype: " << dbps::enum_utils::to_string(datatype_) << "\n";
}

// create a closure for the encrypt function (to be used below)
// the closure captures the key_bytes and calls the EncryptByteArray function.
Expand All @@ -96,13 +126,26 @@ std::vector<uint8_t> BasicEncryptor::EncryptValueList(
// (1) encrypt the list of values. Each element in the list is encrypted separately
// using the key and the EncryptByteArray function.

std::vector<EncryptedValue> encrypted_values = EncryptTypedListValues(
typed_list,
encrypt_function);
std::vector<EncryptedValue> encrypted_values;
time_step("EncryptTypedListValues", [&]() {
encrypted_values = EncryptTypedListValues(
typed_list,
encrypt_function);
});

// (2) concatenate the encrypted values into a single byte blob.
// (the blob encodes #of elements and the size of each element)
std::vector<uint8_t> concatenated_encrypted_bytes = ConcatenateEncryptedValues(encrypted_values);
std::vector<uint8_t> concatenated_encrypted_bytes;
time_step("ConcatenateEncryptedValues", [&]() {
concatenated_encrypted_bytes = ConcatenateEncryptedValues(encrypted_values);
});

if (log_timings) {
std::cout << "EncryptValueList timings (microseconds):\n";
for (const auto& entry : timings) {
std::cout << " " << entry.first << ": " << entry.second << "\n";
}
}

return concatenated_encrypted_bytes;
} // EncryptValueList
Expand Down
45 changes: 42 additions & 3 deletions src/processing/parquet_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,20 @@
#include "compression_utils.h"
#include <cstring>
#include <iostream>
#include <chrono>
#include <cstdlib>
#include <functional>

using namespace dbps::external;
using namespace dbps::enum_utils;
using namespace dbps::compression;

namespace {
bool ShouldLogParseValueTiming() {
const char* env = std::getenv("DBPS_LOG_PARSE_VALUE_TIMING");
return env == nullptr || std::string(env) == "1";
}
}
int CalculateLevelBytesLength(const std::vector<uint8_t>& raw,
const AttributesMap& encoding_attribs) {

Expand Down Expand Up @@ -353,9 +362,39 @@ TypedListValues ParseValueBytesIntoTypedList(
Type::type datatype,
const std::optional<int>& datatype_length,
Encoding::type encoding) {
std::vector<RawValueBytes> raw_values =
SliceValueBytesIntoRawBytes(bytes, datatype, datatype_length, encoding);
return BuildTypedListFromRawBytes(datatype, raw_values);
const bool log_timings = ShouldLogParseValueTiming();
using Clock = std::chrono::steady_clock;
std::vector<std::pair<std::string, long long>> timings;

auto time_step = [&](const char* label, const std::function<void()>& fn) {
if (!log_timings) {
fn();
return;
}
auto start = Clock::now();
fn();
auto end = Clock::now();
auto micros = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
timings.emplace_back(label, micros);
};

std::vector<RawValueBytes> raw_values;
time_step("SliceValueBytesIntoRawBytes", [&]() {
raw_values = SliceValueBytesIntoRawBytes(bytes, datatype, datatype_length, encoding);
});

TypedListValues typed_list;
time_step("BuildTypedListFromRawBytes", [&]() {
typed_list = BuildTypedListFromRawBytes(datatype, raw_values);
});

if (log_timings) {
std::cout << "ParseValueBytesIntoTypedList timings (microseconds):\n";
for (const auto& entry : timings) {
std::cout << " " << entry.first << ": " << entry.second << "\n";
}
}
return typed_list;
}

std::vector<uint8_t> GetTypedListAsValueBytes(
Expand Down