Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast Nested Loop Join - Left Anti and Left Semi #159

Merged
merged 2 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/enclave/App/App.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,50 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
return ret;
}

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin(
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray outer_rows, jbyteArray inner_rows) {
(void)obj;

jboolean if_copy;

uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr);
uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy);

uint32_t outer_rows_length = (uint32_t) env->GetArrayLength(outer_rows);
uint8_t *outer_rows_ptr = (uint8_t *) env->GetByteArrayElements(outer_rows, &if_copy);

uint32_t inner_rows_length = (uint32_t) env->GetArrayLength(inner_rows);
uint8_t *inner_rows_ptr = (uint8_t *) env->GetByteArrayElements(inner_rows, &if_copy);

uint8_t *output_rows = nullptr;
size_t output_rows_length = 0;

if (outer_rows_ptr == nullptr) {
ocall_throw("BroadcastNestedLoopJoin: JNI failed to get inner byte array.");
} else if (inner_rows_ptr == nullptr) {
ocall_throw("BroadcastNestedLoopJoin: JNI failed to get outer byte array.");
} else {
oe_check_and_time("Broadcast Nested Loop Join",
ecall_broadcast_nested_loop_join(
(oe_enclave_t*)eid,
join_expr_ptr, join_expr_length,
outer_rows_ptr, outer_rows_length,
inner_rows_ptr, inner_rows_length,
&output_rows, &output_rows_length));
}

jbyteArray ret = env->NewByteArray(output_rows_length);
env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows);
free(output_rows);

env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0);
env->ReleaseByteArrayElements(outer_rows, (jbyte *) outer_rows_ptr, 0);
env->ReleaseByteArrayElements(inner_rows, (jbyte *) inner_rows_ptr, 0);

return ret;
}

JNIEXPORT jobject JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows, jboolean isPartial) {
Expand Down
4 changes: 4 additions & 0 deletions src/enclave/App/SGXEnclave.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ extern "C" {
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray);

JNIEXPORT jobject JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jboolean);
Expand Down
54 changes: 54 additions & 0 deletions src/enclave/Enclave/BroadcastNestedLoopJoin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "BroadcastNestedLoopJoin.h"

#include "ExpressionEvaluation.h"
#include "FlatbuffersReaders.h"
#include "FlatbuffersWriters.h"
#include "common.h"

/** C++ implementation of a broadcast nested loop join.
* Assumes outer_rows is streamed and inner_rows is broadcast.
* DOES NOT rely on rows to be tagged primary or secondary, and that
* assumption will break the implementation.
*/
void broadcast_nested_loop_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *outer_rows, size_t outer_rows_length,
uint8_t *inner_rows, size_t inner_rows_length,
uint8_t **output_rows, size_t *output_rows_length) {

FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length);
const tuix::JoinType join_type = join_expr_eval.get_join_type();

RowReader outer_r(BufferRefView<tuix::EncryptedBlocks>(outer_rows, outer_rows_length));
RowWriter w;

while (outer_r.has_next()) {
const tuix::Row *outer = outer_r.next();
bool o_i_match = false;

RowReader inner_r(BufferRefView<tuix::EncryptedBlocks>(inner_rows, inner_rows_length));
const tuix::Row *inner;
while (inner_r.has_next()) {
inner = inner_r.next();
o_i_match |= join_expr_eval.eval_condition(outer, inner);
}

switch(join_type) {
case tuix::JoinType_LeftAnti:
wzheng marked this conversation as resolved.
Show resolved Hide resolved
if (!o_i_match) {
w.append(outer);
}
break;
case tuix::JoinType_LeftSemi:
if (o_i_match) {
w.append(outer);
}
break;
default:
throw std::runtime_error(
std::string("Join type not supported: ")
+ std::string(to_string(join_type)));
}
}
w.output_buffer(output_rows, output_rows_length);
}
8 changes: 8 additions & 0 deletions src/enclave/Enclave/BroadcastNestedLoopJoin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <cstddef>
#include <cstdint>

void broadcast_nested_loop_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *outer_rows, size_t outer_rows_length,
uint8_t *inner_rows, size_t inner_rows_length,
uint8_t **output_rows, size_t *output_rows_length);
3 changes: 2 additions & 1 deletion src/enclave/Enclave/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ set(SOURCES
Flatbuffers.cpp
FlatbuffersReaders.cpp
FlatbuffersWriters.cpp
Join.cpp
NonObliviousSortMergeJoin.cpp
BroadcastNestedLoopJoin.cpp
Limit.cpp
Project.cpp
Sort.cpp
Expand Down
22 changes: 21 additions & 1 deletion src/enclave/Enclave/Enclave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include "Aggregate.h"
#include "Crypto.h"
#include "Filter.h"
#include "Join.h"
#include "NonObliviousSortMergeJoin.h"
#include "BroadcastNestedLoopJoin.h"
#include "Limit.h"
#include "Project.h"
#include "Sort.h"
Expand Down Expand Up @@ -161,6 +162,25 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le
}
}

void ecall_broadcast_nested_loop_join(uint8_t *join_expr, size_t join_expr_length,
uint8_t *outer_rows, size_t outer_rows_length,
uint8_t *inner_rows, size_t inner_rows_length,
uint8_t **output_rows, size_t *output_rows_length) {
// Guard against operating on arbitrary enclave memory
assert(oe_is_outside_enclave(outer_rows, outer_rows_length) == 1);
assert(oe_is_outside_enclave(inner_rows, inner_rows_length) == 1);
__builtin_ia32_lfence();

try {
broadcast_nested_loop_join(join_expr, join_expr_length,
outer_rows, outer_rows_length,
inner_rows, inner_rows_length,
output_rows, output_rows_length);
} catch (const std::runtime_error &e) {
ocall_throw(e.what());
}
}

void ecall_non_oblivious_aggregate(
uint8_t *agg_op, size_t agg_op_length,
uint8_t *input_rows, size_t input_rows_length,
Expand Down
6 changes: 6 additions & 0 deletions src/enclave/Enclave/Enclave.edl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ enclave {
[user_check] uint8_t *input_rows, size_t input_rows_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);

public void ecall_broadcast_nested_loop_join(
[in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length,
[user_check] uint8_t *outer_rows, size_t outer_rows_length,
[user_check] uint8_t *inner_rows, size_t inner_rows_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);

public void ecall_non_oblivious_aggregate(
[in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length,
[user_check] uint8_t *input_rows, size_t input_rows_length,
Expand Down
100 changes: 73 additions & 27 deletions src/enclave/Enclave/ExpressionEvaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1787,60 +1787,104 @@ class FlatbuffersJoinExprEvaluator {
}

const tuix::JoinExpr* join_expr = flatbuffers::GetRoot<tuix::JoinExpr>(buf);
join_type = join_expr->join_type();

if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
throw std::runtime_error("Mismatched join key lengths");
}
for (auto key_it = join_expr->left_keys()->begin();
key_it != join_expr->left_keys()->end(); ++key_it) {
left_key_evaluators.emplace_back(
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(*key_it)));
join_type = join_expr->join_type();
if (join_expr->condition() != NULL) {
condition_eval = std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(join_expr->condition()));
}
for (auto key_it = join_expr->right_keys()->begin();
key_it != join_expr->right_keys()->end(); ++key_it) {
right_key_evaluators.emplace_back(
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(*key_it)));
is_equi_join = false;

if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) {
is_equi_join = true;
if (join_expr->condition() != NULL) {
throw std::runtime_error("Equi join cannot have condition");
}
if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
throw std::runtime_error("Mismatched join key lengths");
}
for (auto key_it = join_expr->left_keys()->begin();
key_it != join_expr->left_keys()->end(); ++key_it) {
left_key_evaluators.emplace_back(
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(*key_it)));
}
for (auto key_it = join_expr->right_keys()->begin();
key_it != join_expr->right_keys()->end(); ++key_it) {
right_key_evaluators.emplace_back(
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(*key_it)));
}
}
}

/**
* Return true if the given row is from the primary table, indicated by its first field, which
* must be an IntegerField.
/** Return true if the given row is from the primary table, indicated by its first field, which
* must be an IntegerField.
* Rows MUST have been tagged in Scala.
*/
bool is_primary(const tuix::Row *row) {
return static_cast<const tuix::IntegerField *>(
row->field_values()->Get(0)->value())->value() == 0;
}

/** Return true if the two rows are from the same join group. */
bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) {
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
/** Returns the row evaluator corresponding to the primary row
* Rows MUST have been tagged in Scala.
*/
const tuix::Row *get_primary_row(
const tuix::Row *row1, const tuix::Row *row2) {
return is_primary(row1) ? row1 : row2;
}

/** Return true if the two rows satisfy the join condition. */
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
builder.Clear();
bool row1_equals_row2;

/** Check equality for equi joins. If it is a non-equi join,
* the key evaluators will be empty, so the code never enters the for loop.
*/
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
for (uint32_t i = 0; i < row1_evaluators.size(); i++) {
const tuix::Field *row1_eval_tmp = row1_evaluators[i]->eval(row1);
auto row1_eval_offset = flatbuffers_copy(row1_eval_tmp, builder);
auto row1_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row1_eval_offset);

const tuix::Field *row2_eval_tmp = row2_evaluators[i]->eval(row2);
auto row2_eval_offset = flatbuffers_copy(row2_eval_tmp, builder);
auto row2_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset);

bool row1_equals_row2 =
flatbuffers::Offset<tuix::Field> comparison = eval_binary_comparison<tuix::EqualTo, std::equal_to>(
builder,
row1_field,
row2_field);
row1_equals_row2 =
static_cast<const tuix::BooleanField *>(
flatbuffers::GetTemporaryPointer<tuix::Field>(
builder,
eval_binary_comparison<tuix::EqualTo, std::equal_to>(
builder,
flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row1_eval_offset),
flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset)))
->value())->value();
comparison)->value())->value();

if (!row1_equals_row2) {
return false;
}
}

/* Check condition for non-equi joins */
if (!is_equi_join) {
std::vector<flatbuffers::Offset<tuix::Field>> concat_fields;
for (auto field : *row1->field_values()) {
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
}
for (auto field : *row2->field_values()) {
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
}
flatbuffers::Offset<tuix::Row> concat = tuix::CreateRowDirect(builder, &concat_fields);
const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer<tuix::Row>(builder, concat);

const tuix::Field *condition_result = condition_eval->eval(concat_ptr);

return static_cast<const tuix::BooleanField *>(condition_result->value())->value();
}
return true;
}

Expand All @@ -1853,6 +1897,8 @@ class FlatbuffersJoinExprEvaluator {
tuix::JoinType join_type;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> left_key_evaluators;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> right_key_evaluators;
bool is_equi_join;
std::unique_ptr<FlatbuffersExpressionEvaluator> condition_eval;
};

class AggregateExpressionEvaluator {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#include "Join.h"
#include "NonObliviousSortMergeJoin.h"

#include "ExpressionEvaluation.h"
#include "FlatbuffersReaders.h"
#include "FlatbuffersWriters.h"
#include "common.h"

/** C++ implementation of a non-oblivious sort merge join.
* Rows MUST be tagged primary or secondary for this to work.
*/
void non_oblivious_sort_merge_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
Expand All @@ -25,7 +28,7 @@ void non_oblivious_sort_merge_join(

if (join_expr_eval.is_primary(current)) {
if (last_primary_of_group.get()
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
// Add this primary row to the current group
primary_group.append(current);
last_primary_of_group.set(current);
Expand All @@ -50,13 +53,13 @@ void non_oblivious_sort_merge_join(
} else {
// Output the joined rows resulting from this foreign row
if (last_primary_of_group.get()
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
auto primary_group_buffer = primary_group.output_buffer();
RowReader primary_group_reader(primary_group_buffer.view());
while (primary_group_reader.has_next()) {
const tuix::Row *primary = primary_group_reader.next();

if (!join_expr_eval.is_same_group(primary, current)) {
if (!join_expr_eval.eval_condition(primary, current)) {
throw std::runtime_error(
std::string("Invariant violation: rows of primary_group "
"are not of the same group: ")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
#include <cstddef>
#include <cstdint>

#ifndef JOIN_H
#define JOIN_H

void non_oblivious_sort_merge_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **output_rows, size_t *output_rows_length);

#endif
9 changes: 5 additions & 4 deletions src/flatbuffers/operators.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ enum JoinType : ubyte {
}
table JoinExpr {
join_type:JoinType;
// Currently only cross joins and equijoins are supported, so we store
// parallel arrays of key expressions and the join outputs pairs of rows
// where each expression from the left is equal to the matching expression
// from the right.
// In the case of equi joins, we store parallel arrays of key expressions and have the join output
// pairs of rows where each expression from the left is equal to the matching expression from the right.
left_keys:[Expr];
right_keys:[Expr];
// In the case of non-equi joins, we pass in a condition as an expression and evaluate that on each pair of rows.
// TODO: have equi joins use this condition rather than an additional filter operation.
condition:Expr;
}
Loading