Skip to content

Commit

Permalink
[query] Use valid globals reference in MWZJ and TABK
Browse files Browse the repository at this point in the history
CHANGELOG: Fix a bug, introduced in 0.2.114, in which `Table.multi_way_zip_join` and `Table.aggregate_by_key` could throw "NoSuchElementException: Ref with name __iruid_..." when one or more of the tables had a number of partitions substantially different from the desired number of output partitions.

Fixes #14245.

In both MultiWayZipJoin and TableAggregateByKey, we repartition the child but neglect to use the
new globals `Ref` from the repartitioned child. As long as `repartitionNoShuffle` does not create a
new TableStage with new globals, this is fine, but that is not, in general, true. It seems that
recently, in lowered backends, when the repartition cost is deemed "high" we generate a fresh
TableStage with a fresh globals ref.
  • Loading branch information
Dan King committed Feb 2, 2024
1 parent d261554 commit d82c34a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 31 deletions.
17 changes: 17 additions & 0 deletions hail/python/test/hail/table/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,23 @@ def test_multi_way_zip_join_key_downcast2(self):
ht = hl.Table.multi_way_zip_join(vcfs, 'data', 'new_globals')
assert exp_count == ht._force_count()

def test_multi_way_zip_join_highly_unbalanced_partitions__issue_14245(self):
def import_vcf(file: str, partitions: int):
return (
hl.import_vcf(file, force_bgz=True, reference_genome='GRCh38', min_partitions=partitions)
.rows()
.select()
)

hl.Table.multi_way_zip_join(
[
import_vcf(resource('gvcfs/HG00096.g.vcf.gz'), 100),
import_vcf(resource('gvcfs/HG00268.g.vcf.gz'), 1),
],
'data',
'new_globals',
).write(new_temp_file(extension='ht'))

def test_index_maintains_count(self):
t1 = hl.Table.parallelize(
[{'a': 'foo', 'b': 1}, {'a': 'bar', 'b': 2}, {'a': 'bar', 'b': 2}],
Expand Down
65 changes: 34 additions & 31 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1197,39 +1197,39 @@ object LowerTableIR {

case TableAggregateByKey(child, expr) =>
val loweredChild = lower(child)

loweredChild.repartitionNoShuffle(
val repartitioned = loweredChild.repartitionNoShuffle(
ctx,
loweredChild.partitioner.coarsen(child.typ.key.length).strictify(),
)
.mapPartition(Some(child.typ.key)) { partition =>
Let(
FastSeq("global" -> loweredChild.globals),
mapIR(StreamGroupByKey(partition, child.typ.key, missingEqual = true)) { groupRef =>
StreamAgg(
groupRef,
"row",
bindIRs(
ArrayRef(
ApplyAggOp(
FastSeq(I32(1)),
FastSeq(SelectFields(Ref("row", child.typ.rowType), child.typ.key)),
AggSignature(Take(), FastSeq(TInt32), FastSeq(child.typ.keyType)),
),
I32(0),
), // FIXME: would prefer a First() agg op
expr,
) { case Seq(key, value) =>
MakeStruct(child.typ.key.map(k =>
(k, GetField(key, k))
) ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f =>
(f, GetField(value, f))
})
},
)
},
)
}

repartitioned.mapPartition(Some(child.typ.key)) { partition =>
Let(
FastSeq("global" -> repartitioned.globals),
mapIR(StreamGroupByKey(partition, child.typ.key, missingEqual = true)) { groupRef =>
StreamAgg(
groupRef,
"row",
bindIRs(
ArrayRef(
ApplyAggOp(
FastSeq(I32(1)),
FastSeq(SelectFields(Ref("row", child.typ.rowType), child.typ.key)),
AggSignature(Take(), FastSeq(TInt32), FastSeq(child.typ.keyType)),
),
I32(0),
), // FIXME: would prefer a First() agg op
expr,
) { case Seq(key, value) =>
MakeStruct(child.typ.key.map(k =>
(k, GetField(key, k))
) ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f =>
(f, GetField(value, f))
})
},
)
},
)
}

case TableDistinct(child) =>
val loweredChild = lower(child)
Expand Down Expand Up @@ -2155,7 +2155,10 @@ object LowerTableIR {
)
val repartitioned = lowered.map(_.repartitionNoShuffle(ctx, newPartitioner))
val newGlobals = MakeStruct(FastSeq(
globalName -> MakeArray(lowered.map(_.globals), TArray(lowered.head.globalType))
globalName -> MakeArray(
repartitioned.map(_.globals),
TArray(repartitioned.head.globalType),
)
))
val globalsRef = Ref(genUID(), newGlobals.typ)

Expand Down

0 comments on commit d82c34a

Please sign in to comment.