Skip to content

Commit

Permalink
matching in strategies.scala (#159)
Browse files Browse the repository at this point in the history
This PR is the first of two parts towards making TPC-H 16 work: the other will be implementing `is_distinct` for aggregate operations.

`BroadcastNestedLoopJoin` is Spark's "catch all" for non-equi joins. It works by first picking a side to broadcast, then iterating through every possible row combination and checking the non-equi condition against the pair.
  • Loading branch information
octaviansima authored Feb 24, 2021
1 parent 3c28b5f commit 432eef8
Show file tree
Hide file tree
Showing 16 changed files with 418 additions and 63 deletions.
44 changes: 44 additions & 0 deletions src/enclave/App/App.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,50 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
return ret;
}

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin(
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray outer_rows, jbyteArray inner_rows) {
(void)obj;

jboolean if_copy;

uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr);
uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy);

uint32_t outer_rows_length = (uint32_t) env->GetArrayLength(outer_rows);
uint8_t *outer_rows_ptr = (uint8_t *) env->GetByteArrayElements(outer_rows, &if_copy);

uint32_t inner_rows_length = (uint32_t) env->GetArrayLength(inner_rows);
uint8_t *inner_rows_ptr = (uint8_t *) env->GetByteArrayElements(inner_rows, &if_copy);

uint8_t *output_rows = nullptr;
size_t output_rows_length = 0;

if (outer_rows_ptr == nullptr) {
ocall_throw("BroadcastNestedLoopJoin: JNI failed to get inner byte array.");
} else if (inner_rows_ptr == nullptr) {
ocall_throw("BroadcastNestedLoopJoin: JNI failed to get outer byte array.");
} else {
oe_check_and_time("Broadcast Nested Loop Join",
ecall_broadcast_nested_loop_join(
(oe_enclave_t*)eid,
join_expr_ptr, join_expr_length,
outer_rows_ptr, outer_rows_length,
inner_rows_ptr, inner_rows_length,
&output_rows, &output_rows_length));
}

jbyteArray ret = env->NewByteArray(output_rows_length);
env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows);
free(output_rows);

env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0);
env->ReleaseByteArrayElements(outer_rows, (jbyte *) outer_rows_ptr, 0);
env->ReleaseByteArrayElements(inner_rows, (jbyte *) inner_rows_ptr, 0);

return ret;
}

JNIEXPORT jobject JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows, jboolean isPartial) {
Expand Down
4 changes: 4 additions & 0 deletions src/enclave/App/SGXEnclave.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ extern "C" {
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray);

JNIEXPORT jobject JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jboolean);
Expand Down
54 changes: 54 additions & 0 deletions src/enclave/Enclave/BroadcastNestedLoopJoin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "BroadcastNestedLoopJoin.h"

#include "ExpressionEvaluation.h"
#include "FlatbuffersReaders.h"
#include "FlatbuffersWriters.h"
#include "common.h"

/** C++ implementation of a broadcast nested loop join.
* Assumes outer_rows is streamed and inner_rows is broadcast.
* DOES NOT rely on rows to be tagged primary or secondary, and that
* assumption will break the implementation.
*/
void broadcast_nested_loop_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *outer_rows, size_t outer_rows_length,
uint8_t *inner_rows, size_t inner_rows_length,
uint8_t **output_rows, size_t *output_rows_length) {

FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length);
const tuix::JoinType join_type = join_expr_eval.get_join_type();

RowReader outer_r(BufferRefView<tuix::EncryptedBlocks>(outer_rows, outer_rows_length));
RowWriter w;

while (outer_r.has_next()) {
const tuix::Row *outer = outer_r.next();
bool o_i_match = false;

RowReader inner_r(BufferRefView<tuix::EncryptedBlocks>(inner_rows, inner_rows_length));
const tuix::Row *inner;
while (inner_r.has_next()) {
inner = inner_r.next();
o_i_match |= join_expr_eval.eval_condition(outer, inner);
}

switch(join_type) {
case tuix::JoinType_LeftAnti:
if (!o_i_match) {
w.append(outer);
}
break;
case tuix::JoinType_LeftSemi:
if (o_i_match) {
w.append(outer);
}
break;
default:
throw std::runtime_error(
std::string("Join type not supported: ")
+ std::string(to_string(join_type)));
}
}
w.output_buffer(output_rows, output_rows_length);
}
8 changes: 8 additions & 0 deletions src/enclave/Enclave/BroadcastNestedLoopJoin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <cstddef>
#include <cstdint>

void broadcast_nested_loop_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *outer_rows, size_t outer_rows_length,
uint8_t *inner_rows, size_t inner_rows_length,
uint8_t **output_rows, size_t *output_rows_length);
3 changes: 2 additions & 1 deletion src/enclave/Enclave/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ set(SOURCES
Flatbuffers.cpp
FlatbuffersReaders.cpp
FlatbuffersWriters.cpp
Join.cpp
NonObliviousSortMergeJoin.cpp
BroadcastNestedLoopJoin.cpp
Limit.cpp
Project.cpp
Sort.cpp
Expand Down
22 changes: 21 additions & 1 deletion src/enclave/Enclave/Enclave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include "Aggregate.h"
#include "Crypto.h"
#include "Filter.h"
#include "Join.h"
#include "NonObliviousSortMergeJoin.h"
#include "BroadcastNestedLoopJoin.h"
#include "Limit.h"
#include "Project.h"
#include "Sort.h"
Expand Down Expand Up @@ -161,6 +162,25 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le
}
}

void ecall_broadcast_nested_loop_join(uint8_t *join_expr, size_t join_expr_length,
uint8_t *outer_rows, size_t outer_rows_length,
uint8_t *inner_rows, size_t inner_rows_length,
uint8_t **output_rows, size_t *output_rows_length) {
// Guard against operating on arbitrary enclave memory
assert(oe_is_outside_enclave(outer_rows, outer_rows_length) == 1);
assert(oe_is_outside_enclave(inner_rows, inner_rows_length) == 1);
__builtin_ia32_lfence();

try {
broadcast_nested_loop_join(join_expr, join_expr_length,
outer_rows, outer_rows_length,
inner_rows, inner_rows_length,
output_rows, output_rows_length);
} catch (const std::runtime_error &e) {
ocall_throw(e.what());
}
}

void ecall_non_oblivious_aggregate(
uint8_t *agg_op, size_t agg_op_length,
uint8_t *input_rows, size_t input_rows_length,
Expand Down
6 changes: 6 additions & 0 deletions src/enclave/Enclave/Enclave.edl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ enclave {
[user_check] uint8_t *input_rows, size_t input_rows_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);

public void ecall_broadcast_nested_loop_join(
[in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length,
[user_check] uint8_t *outer_rows, size_t outer_rows_length,
[user_check] uint8_t *inner_rows, size_t inner_rows_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);

public void ecall_non_oblivious_aggregate(
[in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length,
[user_check] uint8_t *input_rows, size_t input_rows_length,
Expand Down
100 changes: 73 additions & 27 deletions src/enclave/Enclave/ExpressionEvaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1787,60 +1787,104 @@ class FlatbuffersJoinExprEvaluator {
}

const tuix::JoinExpr* join_expr = flatbuffers::GetRoot<tuix::JoinExpr>(buf);
join_type = join_expr->join_type();

if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
throw std::runtime_error("Mismatched join key lengths");
}
for (auto key_it = join_expr->left_keys()->begin();
key_it != join_expr->left_keys()->end(); ++key_it) {
left_key_evaluators.emplace_back(
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(*key_it)));
join_type = join_expr->join_type();
if (join_expr->condition() != NULL) {
condition_eval = std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(join_expr->condition()));
}
for (auto key_it = join_expr->right_keys()->begin();
key_it != join_expr->right_keys()->end(); ++key_it) {
right_key_evaluators.emplace_back(
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(*key_it)));
is_equi_join = false;

if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) {
is_equi_join = true;
if (join_expr->condition() != NULL) {
throw std::runtime_error("Equi join cannot have condition");
}
if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
throw std::runtime_error("Mismatched join key lengths");
}
for (auto key_it = join_expr->left_keys()->begin();
key_it != join_expr->left_keys()->end(); ++key_it) {
left_key_evaluators.emplace_back(
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(*key_it)));
}
for (auto key_it = join_expr->right_keys()->begin();
key_it != join_expr->right_keys()->end(); ++key_it) {
right_key_evaluators.emplace_back(
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(*key_it)));
}
}
}

/**
* Return true if the given row is from the primary table, indicated by its first field, which
* must be an IntegerField.
/** Return true if the given row is from the primary table, indicated by its first field, which
* must be an IntegerField.
* Rows MUST have been tagged in Scala.
*/
bool is_primary(const tuix::Row *row) {
return static_cast<const tuix::IntegerField *>(
row->field_values()->Get(0)->value())->value() == 0;
}

/** Return true if the two rows are from the same join group. */
bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) {
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
/** Returns the row evaluator corresponding to the primary row
* Rows MUST have been tagged in Scala.
*/
const tuix::Row *get_primary_row(
const tuix::Row *row1, const tuix::Row *row2) {
return is_primary(row1) ? row1 : row2;
}

/** Return true if the two rows satisfy the join condition. */
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
builder.Clear();
bool row1_equals_row2;

/** Check equality for equi joins. If it is a non-equi join,
* the key evaluators will be empty, so the code never enters the for loop.
*/
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
for (uint32_t i = 0; i < row1_evaluators.size(); i++) {
const tuix::Field *row1_eval_tmp = row1_evaluators[i]->eval(row1);
auto row1_eval_offset = flatbuffers_copy(row1_eval_tmp, builder);
auto row1_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row1_eval_offset);

const tuix::Field *row2_eval_tmp = row2_evaluators[i]->eval(row2);
auto row2_eval_offset = flatbuffers_copy(row2_eval_tmp, builder);
auto row2_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset);

bool row1_equals_row2 =
flatbuffers::Offset<tuix::Field> comparison = eval_binary_comparison<tuix::EqualTo, std::equal_to>(
builder,
row1_field,
row2_field);
row1_equals_row2 =
static_cast<const tuix::BooleanField *>(
flatbuffers::GetTemporaryPointer<tuix::Field>(
builder,
eval_binary_comparison<tuix::EqualTo, std::equal_to>(
builder,
flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row1_eval_offset),
flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset)))
->value())->value();
comparison)->value())->value();

if (!row1_equals_row2) {
return false;
}
}

/* Check condition for non-equi joins */
if (!is_equi_join) {
std::vector<flatbuffers::Offset<tuix::Field>> concat_fields;
for (auto field : *row1->field_values()) {
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
}
for (auto field : *row2->field_values()) {
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
}
flatbuffers::Offset<tuix::Row> concat = tuix::CreateRowDirect(builder, &concat_fields);
const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer<tuix::Row>(builder, concat);

const tuix::Field *condition_result = condition_eval->eval(concat_ptr);

return static_cast<const tuix::BooleanField *>(condition_result->value())->value();
}
return true;
}

Expand All @@ -1853,6 +1897,8 @@ class FlatbuffersJoinExprEvaluator {
tuix::JoinType join_type;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> left_key_evaluators;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> right_key_evaluators;
bool is_equi_join;
std::unique_ptr<FlatbuffersExpressionEvaluator> condition_eval;
};

class AggregateExpressionEvaluator {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#include "Join.h"
#include "NonObliviousSortMergeJoin.h"

#include "ExpressionEvaluation.h"
#include "FlatbuffersReaders.h"
#include "FlatbuffersWriters.h"
#include "common.h"

/** C++ implementation of a non-oblivious sort merge join.
* Rows MUST be tagged primary or secondary for this to work.
*/
void non_oblivious_sort_merge_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
Expand All @@ -25,7 +28,7 @@ void non_oblivious_sort_merge_join(

if (join_expr_eval.is_primary(current)) {
if (last_primary_of_group.get()
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
// Add this primary row to the current group
primary_group.append(current);
last_primary_of_group.set(current);
Expand All @@ -50,13 +53,13 @@ void non_oblivious_sort_merge_join(
} else {
// Output the joined rows resulting from this foreign row
if (last_primary_of_group.get()
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
auto primary_group_buffer = primary_group.output_buffer();
RowReader primary_group_reader(primary_group_buffer.view());
while (primary_group_reader.has_next()) {
const tuix::Row *primary = primary_group_reader.next();

if (!join_expr_eval.is_same_group(primary, current)) {
if (!join_expr_eval.eval_condition(primary, current)) {
throw std::runtime_error(
std::string("Invariant violation: rows of primary_group "
"are not of the same group: ")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
#include <cstddef>
#include <cstdint>

#ifndef JOIN_H
#define JOIN_H

void non_oblivious_sort_merge_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **output_rows, size_t *output_rows_length);

#endif
9 changes: 5 additions & 4 deletions src/flatbuffers/operators.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ enum JoinType : ubyte {
}
table JoinExpr {
join_type:JoinType;
// Currently only cross joins and equijoins are supported, so we store
// parallel arrays of key expressions and the join outputs pairs of rows
// where each expression from the left is equal to the matching expression
// from the right.
// In the case of equi joins, we store parallel arrays of key expressions and have the join output
// pairs of rows where each expression from the left is equal to the matching expression from the right.
left_keys:[Expr];
right_keys:[Expr];
// In the case of non-equi joins, we pass in a condition as an expression and evaluate that on each pair of rows.
// TODO: have equi joins use this condition rather than an additional filter operation.
condition:Expr;
}
Loading

0 comments on commit 432eef8

Please sign in to comment.