Skip to content

Commit

Permalink
Nested loop join
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-spacetime committed Nov 1, 2024
1 parent 91bd62b commit e8d228a
Showing 1 changed file with 42 additions and 45 deletions.
87 changes: 42 additions & 45 deletions crates/execution/src/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ pub enum IterOp {
/// An index scan opcode takes 2 args:
/// 1. An [IndexId]
/// 2. A ptr to an [AlgebraicValue]
IndexScanEq(IndexId, u16),
IxScanEq(IndexId, u16),
/// An index range scan opcode takes 3 args:
/// 1. An [IndexId]
/// 2. A ptr to the lower bound
/// 3. A ptr to the upper bound
IndexScanRange(IndexId, Bound<u16>, Bound<u16>),
/// A cross join has 2 args, but its opcode has none
CrossJoin,
IxScanRange(IndexId, Bound<u16>, Bound<u16>),
/// Pops its 2 args from the stack
NLJoin,
/// An index join opcode takes 2 args:
/// 1. An [IndexId]
/// 2. An instruction ptr
Expand Down Expand Up @@ -162,21 +162,21 @@ impl CachedIterPlan {
// Push delta scan
stack.push(Iter::DeltaScan(tx.delta_scan_iter(table_id)));
}
IterOp::IndexScanEq(index_id, ptr) => {
IterOp::IxScanEq(index_id, ptr) => {
// Push index scan
stack.push(Iter::IndexScan(tx.index_scan_iter(index_id, &self.constant(ptr))));
}
IterOp::IndexScanRange(index_id, lower, upper) => {
IterOp::IxScanRange(index_id, lower, upper) => {
// Push range scan
let lower = lower.map(|ptr| self.constant(ptr));
let upper = upper.map(|ptr| self.constant(ptr));
stack.push(Iter::IndexScan(tx.index_scan_iter(index_id, &(lower, upper))));
}
IterOp::CrossJoin => {
// Pop args and push cross join
IterOp::NLJoin => {
// Pop args and push nested loop join
let rhs = stack.pop().unwrap();
let lhs = stack.pop().unwrap();
stack.push(Iter::CrossJoin(CrossJoinIter::new(lhs, rhs)));
stack.push(Iter::NLJoin(NestedLoopJoin::new(lhs, rhs)));
}
IterOp::IxJoin(index_id, i, n) => {
// Pop arg and push index join
Expand All @@ -187,7 +187,7 @@ impl CachedIterPlan {
let ops = &self.expr_ops[i..i + n as usize];
let program = ExprProgram::new(ops, &self.constants);
let projection = ProgramEvaluator::from(program);
stack.push(Iter::IxJoin(IxJoin::Eq(IndexJoin::new(
stack.push(Iter::IxJoin(LeftDeepJoin::Eq(IndexJoin::new(
input, index, table, blob_store, projection,
))));
}
Expand All @@ -200,7 +200,7 @@ impl CachedIterPlan {
let ops = &self.expr_ops[i..i + n as usize];
let program = ExprProgram::new(ops, &self.constants);
let projection = ProgramEvaluator::from(program);
stack.push(Iter::UniqueIxJoin(IxJoin::Eq(UniqueIndexJoin::new(
stack.push(Iter::UniqueIxJoin(LeftDeepJoin::Eq(UniqueIndexJoin::new(
input, index, table, blob_store, projection,
))));
}
Expand Down Expand Up @@ -230,12 +230,12 @@ pub enum Iter<'a> {
DeltaScan(DeltaScanIter<'a>),
/// A [RowRef] index iterator
IndexScan(IndexScanIter<'a>),
/// A cross product iterator
CrossJoin(CrossJoinIter<'a>),
/// A nested loop join iterator
NLJoin(NestedLoopJoin<'a>),
/// A non-unique (constraint) index join iterator
IxJoin(IxJoin<IndexJoin<'a>>),
IxJoin(LeftDeepJoin<IndexJoin<'a>>),
/// A unique (constraint) index join iterator
UniqueIxJoin(IxJoin<UniqueIndexJoin<'a>>),
UniqueIxJoin(LeftDeepJoin<UniqueIndexJoin<'a>>),
/// A tuple-at-a-time filter iterator
Filter(Filter<'a>),
}
Expand Down Expand Up @@ -269,7 +269,7 @@ impl<'a> Iterator for Iter<'a> {
// Filter is a passthru
iter.next()
}
Self::CrossJoin(iter) => {
Self::NLJoin(iter) => {
iter.next().map(|t| {
match t {
// A leaf join
Expand All @@ -287,8 +287,7 @@ impl<'a> Iterator for Iter<'a> {
// / \
// b c
(Tuple::Row(r), Tuple::Join(mut rows)) => {
// Returns (n+1)-tuples,
// if the rhs returns n-tuples.
// Returns an (n+1)-tuple
let mut pointers = vec![r];
pointers.append(&mut rows);
Tuple::Join(pointers)
Expand All @@ -300,8 +299,7 @@ impl<'a> Iterator for Iter<'a> {
// / \
// a b
(Tuple::Join(mut rows), Tuple::Row(r)) => {
// Returns (n+1)-tuples,
// if the lhs returns n-tuples.
// Returns an (n+1)-tuple
rows.push(r);
Tuple::Join(rows)
}
Expand All @@ -313,9 +311,7 @@ impl<'a> Iterator for Iter<'a> {
// / \ / \
// a b c d
(Tuple::Join(mut lhs), Tuple::Join(mut rhs)) => {
// Returns (n+m)-tuples,
// if the lhs returns n-tuples,
// if the rhs returns m-tuples.
// Returns an (n+m)-tuple
lhs.append(&mut rhs);
Tuple::Join(lhs)
}
Expand All @@ -326,8 +322,8 @@ impl<'a> Iterator for Iter<'a> {
}
}

/// An iterator for an index join
pub enum IxJoin<Iter> {
/// An iterator for a left deep join tree
pub enum LeftDeepJoin<Iter> {
/// A standard join
Eq(Iter),
/// A semijoin that returns the lhs
Expand All @@ -336,7 +332,7 @@ pub enum IxJoin<Iter> {
SemiRhs(Iter),
}

impl<'a, Iter> Iterator for IxJoin<Iter>
impl<'a, Iter> Iterator for LeftDeepJoin<Iter>
where
Iter: Iterator<Item = (Tuple<'a>, RowRef<'a>)>,
{
Expand Down Expand Up @@ -370,7 +366,7 @@ where
// / \
// a b
(Tuple::Join(mut rows), ptr) => {
// Returns an n+1 tuple
// Returns an (n+1)-tuple
rows.push(Row::Ptr(ptr));
Tuple::Join(rows)
}
Expand Down Expand Up @@ -496,9 +492,8 @@ impl<'a> Iterator for IndexJoin<'a> {
}
}

/// A cross join returns the cross product of its two inputs.
/// It materializes the rhs and streams the lhs.
pub struct CrossJoinIter<'a> {
/// A nested loop join returns the cross product of its inputs
pub struct NestedLoopJoin<'a> {
/// The lhs input
lhs: Box<Iter<'a>>,
/// The rhs input
Expand All @@ -511,7 +506,7 @@ pub struct CrossJoinIter<'a> {
rhs_ptr: usize,
}

impl<'a> CrossJoinIter<'a> {
impl<'a> NestedLoopJoin<'a> {
fn new(lhs: Iter<'a>, rhs: Iter<'a>) -> Self {
Self {
lhs: Box::new(lhs),
Expand All @@ -523,25 +518,27 @@ impl<'a> CrossJoinIter<'a> {
}
}

impl<'a> Iterator for CrossJoinIter<'a> {
impl<'a> Iterator for NestedLoopJoin<'a> {
type Item = (Tuple<'a>, Tuple<'a>);

fn next(&mut self) -> Option<Self::Item> {
// Materialize the rhs on the first call
if self.build.is_empty() {
self.build = self.rhs.as_mut().collect();
self.lhs_row = self.lhs.next();
self.rhs_ptr = 0;
for t in self.rhs.as_mut() {
self.build.push(t);
}
// Reset the rhs pointer
if self.rhs_ptr == self.build.len() {
self.lhs_row = self.lhs.next();
self.rhs_ptr = 0;
match self.build.get(self.rhs_ptr) {
Some(v) => {
self.rhs_ptr += 1;
self.lhs_row.as_ref().map(|u| (u.clone(), v.clone()))
}
None => {
self.rhs_ptr = 1;
self.lhs_row = self.lhs.next();
self.lhs_row
.as_ref()
.zip(self.build.first())
.map(|(u, v)| (u.clone(), v.clone()))
}
}
self.lhs_row.as_ref().map(|lhs_tuple| {
self.rhs_ptr += 1;
(lhs_tuple.clone(), self.build[self.rhs_ptr - 1].clone())
})
}
}

Expand Down

0 comments on commit e8d228a

Please sign in to comment.