Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] Redesign Repartition IR nodes to be naive coalesce only #12671

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TableKeyBy, TableMapRows, TableRead, MatrixEntriesTable, \
TableFilter, TableKeyByAndAggregate, \
TableAggregateByKey, MatrixColsTable, TableParallelize, TableHead, \
TableTail, TableOrderBy, TableDistinct, RepartitionStrategy, \
TableTail, TableOrderBy, TableDistinct, \
TableRepartition, CastMatrixToTable, TableRename, TableMultiWayZipJoin, \
TableFilterIntervals, TableToTableApply, MatrixToTableApply, \
BlockMatrixToTableApply, BlockMatrixToTable, JavaTable, TableMapPartitions
Expand Down Expand Up @@ -296,7 +296,6 @@
'TableTail',
'TableOrderBy',
'TableDistinct',
'RepartitionStrategy',
'TableRepartition',
'CastMatrixToTable',
'TableRename',
Expand Down
9 changes: 4 additions & 5 deletions hail/python/hail/ir/matrix_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,20 +801,19 @@ def _compute_type(self, deep_typecheck):


class MatrixRepartition(MatrixIR):
def __init__(self, child, n, strategy):
def __init__(self, child, n):
super().__init__(child)
self.child = child
self.n = n
self.strategy = strategy

def _handle_randomness(self, row_uid_field_name, col_uid_field_name):
return MatrixRepartition(self.child.handle_randomness(row_uid_field_name, col_uid_field_name), self.n, self.strategy)
return MatrixRepartition(self.child.handle_randomness(row_uid_field_name, col_uid_field_name), self.n)

def head_str(self):
return f'{self.n} {self.strategy}'
return f'{self.n}'

def _eq(self, other):
return self.n == other.n and self.strategy == other.strategy
return self.n == other.n

def _compute_type(self, deep_typecheck):
self.child.compute_type(deep_typecheck)
Expand Down
15 changes: 4 additions & 11 deletions hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,27 +735,20 @@ def _compute_type(self, deep_typecheck):
return self.child.typ


class RepartitionStrategy:
SHUFFLE = 0
COALESCE = 1
NAIVE_COALESCE = 2


class TableRepartition(TableIR):
def __init__(self, child, n, strategy):
def __init__(self, child, n):
super().__init__(child)
self.child = child
self.n = n
self.strategy = strategy

def _handle_randomness(self, uid_field_name):
return TableRepartition(self.child.handle_randomness(uid_field_name), self.n, self.strategy)
return TableRepartition(self.child.handle_randomness(uid_field_name), self.n)

def head_str(self):
return f'{self.n} {self.strategy}'
return f'{self.n}'

def _eq(self, other):
return self.n == other.n and self.strategy == other.strategy
return self.n == other.n

def _compute_type(self, deep_typecheck):
self.child.compute_type(deep_typecheck)
Expand Down
54 changes: 25 additions & 29 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3358,44 +3358,37 @@ def repartition(self, n_partitions: int, shuffle: bool = True) -> 'MatrixTable':
<http://spark.apache.org/docs/latest/programming-guide.html#resilient-distributed-datasets-rdds>`__
for details.

When ``shuffle=True``, Hail does a full shuffle of the data
and creates equal sized partitions. When ``shuffle=False``,
Hail combines existing partitions to avoid a full
shuffle. These algorithms correspond to the `repartition` and
`coalesce` commands in Spark, respectively. In particular,
when ``shuffle=False``, ``n_partitions`` cannot exceed current
number of partitions.
See Also
--------
:meth:`.naive_coalesce`

Parameters
----------
n_partitions : int
Desired number of partitions.
shuffle : bool
If ``True``, use full shuffle to repartition.

Returns
-------
:class:`.MatrixTable`
Repartitioned dataset.
"""
if hl.current_backend().requires_lowering:
tmp = hl.utils.new_temp_file()

if len(self.row_key) == 0:
uid = Env.get_uid()
tmp2 = hl.utils.new_temp_file()
self.checkpoint(tmp2)
ht = hl.read_matrix_table(tmp2).add_row_index(uid).key_rows_by(uid)
ht.checkpoint(tmp)
return hl.read_matrix_table(tmp, _n_partitions=n_partitions).drop(uid)
else:
# checkpoint rather than write to use fast codec
self.checkpoint(tmp)
return hl.read_matrix_table(tmp, _n_partitions=n_partitions)

return MatrixTable(ir.MatrixRepartition(
self._mir, n_partitions,
ir.RepartitionStrategy.SHUFFLE if shuffle else ir.RepartitionStrategy.COALESCE))
if shuffle:
warning("'repartition': the 'shuffle' flag is deprecated, and repartition always writes"
" data to disk in the temp dir and reads it back with requested partitioning.")

tmp = hl.utils.new_temp_file()

if len(self.row_key) == 0:
uid = Env.get_uid()
tmp2 = hl.utils.new_temp_file()
self.checkpoint(tmp2)
ht = hl.read_matrix_table(tmp2).add_row_index(uid).key_rows_by(uid)
ht.checkpoint(tmp)
return hl.read_matrix_table(tmp, _n_partitions=n_partitions).drop(uid)
else:
# checkpoint rather than write to use fast codec
self.checkpoint(tmp)
return hl.read_matrix_table(tmp, _n_partitions=n_partitions)

@typecheck_method(max_partitions=int)
def naive_coalesce(self, max_partitions: int) -> 'MatrixTable':
Expand All @@ -3415,6 +3408,10 @@ def naive_coalesce(self, max_partitions: int) -> 'MatrixTable':
unbalanced dataset can be inefficient to operate on because the work is
not evenly distributed across partitions.

See Also
--------
:meth:`.repartition`

Parameters
----------
max_partitions : int
Expand All @@ -3426,8 +3423,7 @@ def naive_coalesce(self, max_partitions: int) -> 'MatrixTable':
:class:`.MatrixTable`
Matrix table with at most `max_partitions` partitions.
"""
return MatrixTable(ir.MatrixRepartition(
self._mir, max_partitions, ir.RepartitionStrategy.NAIVE_COALESCE))
return MatrixTable(ir.MatrixRepartition(self._mir, max_partitions))

def cache(self) -> 'MatrixTable':
"""Persist the dataset in memory.
Expand Down
55 changes: 25 additions & 30 deletions hail/python/hail/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2440,43 +2440,37 @@ def repartition(self, n, shuffle=True) -> 'Table':
<http://spark.apache.org/docs/latest/programming-guide.html#resilient-distributed-datasets-rdds>`__
for details.

When ``shuffle=True``, Hail does a full shuffle of the data
and creates equal sized partitions. When ``shuffle=False``,
Hail combines existing partitions to avoid a full shuffle.
These algorithms correspond to the `repartition` and
`coalesce` commands in Spark, respectively. In particular,
when ``shuffle=False``, ``n_partitions`` cannot exceed current
number of partitions.
See Also
--------
:meth:`.naive_coalesce`

Parameters
----------
n : int
Desired number of partitions.
shuffle : bool
If ``True``, use full shuffle to repartition.

Returns
-------
:class:`.Table`
Repartitioned table.
"""
if hl.current_backend().requires_lowering:
tmp = hl.utils.new_temp_file()

if len(self.key) == 0:
uid = Env.get_uid()
tmp2 = hl.utils.new_temp_file()
self.checkpoint(tmp2)
ht = hl.read_table(tmp2).add_index(uid).key_by(uid)
ht.checkpoint(tmp)
return hl.read_table(tmp, _n_partitions=n).key_by().drop(uid)
else:
# checkpoint rather than write to use fast codec
self.checkpoint(tmp)
return hl.read_table(tmp, _n_partitions=n)

return Table(ir.TableRepartition(
self._tir, n, ir.RepartitionStrategy.SHUFFLE if shuffle else ir.RepartitionStrategy.COALESCE))
if shuffle:
warning("'repartition': the 'shuffle' flag is deprecated, and repartition always writes"
" data to disk in the temp dir and reads it back with requested partitioning.")

tmp = hl.utils.new_temp_file()

if len(self.key) == 0:
uid = Env.get_uid()
tmp2 = hl.utils.new_temp_file()
self.checkpoint(tmp2)
ht = hl.read_table(tmp2).add_index(uid).key_by(uid)
ht.checkpoint(tmp)
return hl.read_table(tmp, _n_partitions=n).key_by().drop(uid)
else:
# checkpoint rather than write to use fast codec
self.checkpoint(tmp)
return hl.read_table(tmp, _n_partitions=n)

@typecheck_method(max_partitions=int)
def naive_coalesce(self, max_partitions: int) -> 'Table':
Expand All @@ -2496,6 +2490,10 @@ def naive_coalesce(self, max_partitions: int) -> 'Table':
unbalanced dataset can be inefficient to operate on because the work is
not evenly distributed across partitions.

See Also
--------
:meth:`.repartition`

Parameters
----------
max_partitions : int
Expand All @@ -2507,11 +2505,8 @@ def naive_coalesce(self, max_partitions: int) -> 'Table':
:class:`.Table`
Table with at most `max_partitions` partitions.
"""
if hl.current_backend().requires_lowering:
return self.repartition(max_partitions)

return Table(ir.TableRepartition(
self._tir, max_partitions, ir.RepartitionStrategy.NAIVE_COALESCE))
self._tir, max_partitions))

@typecheck_method(other=table_type)
def semi_join(self, other: 'Table') -> 'Table':
Expand Down
7 changes: 0 additions & 7 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,13 +2011,6 @@ def assert_contains_node(t, node):
assert_contains_node(mt, ir.MatrixExplodeRows)
assert_unique_uids(mt)

# test MatrixRepartition
if not hl.current_backend().requires_lowering:
rmt = hl.utils.range_matrix_table(20, 10, 3)
mt = rmt.repartition(5)
assert_contains_node(mt, ir.MatrixRepartition)
assert_unique_uids(mt)

Comment on lines -2014 to -2020
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rewrite this and the TableRepartition one to use naive_coalesce, so the handle_randomness implementations still have test coverage?

# test MatrixUnionRows
r, c = 5, 5
mt = hl.utils.range_matrix_table(2*r, c)
Expand Down
8 changes: 0 additions & 8 deletions hail/python/test/hail/table/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,14 +2214,6 @@ def assert_contains_node(t, node):
assert_contains_node(t, ir.TableDistinct)
assert_unique_uids(t)

# test TableRepartition
if not hl.current_backend().requires_lowering:
rt = hl.utils.range_table(20, 3)
t = rt.repartition(5)
print(t._tir)
assert_contains_node(t, ir.TableRepartition)
assert_unique_uids(t)

# test CastMatrixToTable
mt = hl.utils.range_matrix_table(10, 10, 3)
t = mt._localize_entries("entries", "cols")
Expand Down
4 changes: 2 additions & 2 deletions hail/python/test/hail/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def table_irs(self):
ir.MakeStruct([
('foo', ir.NA(hl.tarray(hl.tint32)))])),
ir.TableRange(100, 10),
ir.TableRepartition(table_read, 10, ir.RepartitionStrategy.COALESCE),
ir.TableRepartition(table_read, 10),
ir.TableUnion(
[ir.TableRange(100, 10), ir.TableRange(50, 10)]),
ir.TableExplode(table_read, ['mset']),
Expand Down Expand Up @@ -240,7 +240,7 @@ def matrix_irs(self):

matrix_range = ir.MatrixRead(ir.MatrixRangeReader(1, 1, 10))
matrix_irs = [
ir.MatrixRepartition(matrix_range, 100, ir.RepartitionStrategy.SHUFFLE),
ir.MatrixRepartition(matrix_range, 100),
ir.MatrixUnionRows(matrix_range, matrix_range),
ir.MatrixDistinctByRow(matrix_range),
ir.MatrixRowsHead(matrix_read, 5),
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/DistinctlyKeyed.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object DistinctlyKeyed {
case TableFilter(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableHead(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableTail(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableRepartition(child, _, _) => basicChildrenCheck(IndexedSeq(child))
case TableRepartition(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableJoin(left, right, _, _) => basicChildrenCheck(IndexedSeq(left, right))
case TableIntervalJoin(left, right, _, _) => basicChildrenCheck(IndexedSeq(left, right))
case TableMultiWayZipJoin(children, _, _) => basicChildrenCheck(children)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ object LowerMatrixIR {
)))))
)

case MatrixRepartition(child, n, shuffle) => TableRepartition(lower(ctx, child, ab), n, shuffle)
case MatrixRepartition(child, n) => TableRepartition(lower(ctx, child, ab), n)

case MatrixFilterIntervals(child, intervals, keep) => TableFilterIntervals(lower(ctx, child, ab), intervals, keep)

Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -827,14 +827,14 @@ case class MatrixExplodeRows(child: MatrixIR, path: IndexedSeq[String]) extends
val typ: MatrixType = child.typ.copy(rowType = newRow.typ)
}

case class MatrixRepartition(child: MatrixIR, n: Int, strategy: Int) extends MatrixIR {
case class MatrixRepartition(child: MatrixIR, n: Int) extends MatrixIR {
val typ: MatrixType = child.typ

lazy val children: IndexedSeq[BaseIR] = FastIndexedSeq(child)

def copy(newChildren: IndexedSeq[BaseIR]): MatrixRepartition = {
val IndexedSeq(newChild: MatrixIR) = newChildren
MatrixRepartition(newChild, n, strategy)
MatrixRepartition(newChild, n)
}

override def columnCount: Option[Int] = child.columnCount
Expand Down
6 changes: 2 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1604,9 +1604,8 @@ object IRParser {
} yield TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize)
case "TableRepartition" =>
val n = int32_literal(it)
val strategy = int32_literal(it)
table_ir(env.onlyRelational)(it).map { child =>
TableRepartition(child, n, strategy)
TableRepartition(child, n)
}
case "TableHead" =>
val n = int64_literal(it)
Expand Down Expand Up @@ -1870,9 +1869,8 @@ object IRParser {
matrix_ir(env.onlyRelational)(it).map(MatrixCollectColsByKey)
case "MatrixRepartition" =>
val n = int32_literal(it)
val strategy = int32_literal(it)
matrix_ir(env.onlyRelational)(it).map { child =>
MatrixRepartition(child, n, strategy)
MatrixRepartition(child, n)
}
case "MatrixUnionRows" => matrix_ir_children(env.onlyRelational)(it).map(MatrixUnionRows(_))
case "MatrixDistinctByRow" => matrix_ir(env.onlyRelational)(it).map(MatrixDistinctByRow)
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
case MatrixAnnotateColsTable(_, _, uid) => single(prettyStringLiteral(uid))
case MatrixExplodeRows(_, path) => single(prettyIdentifiers(path))
case MatrixExplodeCols(_, path) => single(prettyIdentifiers(path))
case MatrixRepartition(_, n, strategy) => single(s"$n $strategy")
case MatrixRepartition(_, n) => single(n.toString)
case MatrixChooseCols(_, oldIndices) => single(prettyInts(oldIndices, elideLiterals))
case MatrixMapCols(_, _, newKey) => single(prettyStringsOpt(newKey))
case MatrixUnionCols(l, r, joinType) => single(joinType)
Expand All @@ -346,7 +346,7 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
case TableKeyBy(_, keys, isSorted) =>
FastSeq(prettyIdentifiers(keys), Pretty.prettyBooleanLiteral(isSorted))
case TableRange(n, nPartitions) => FastSeq(n.toString, nPartitions.toString)
case TableRepartition(_, n, strategy) => FastSeq(n.toString, strategy.toString)
case TableRepartition(_, n) => FastSeq(n.toString)
case TableHead(_, n) => single(n.toString)
case TableTail(_, n) => single(n.toString)
case TableJoin(_, _, joinType, joinKey) => FastSeq(joinType, joinKey.toString)
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ object PruneDeadFields {
case TableParallelize(rowsAndGlobal, _) =>
memoizeValueIR(ctx, rowsAndGlobal, TStruct("rows" -> TArray(requestedType.rowType), "global" -> requestedType.globalType), memo)
case TableRange(_, _) =>
case TableRepartition(child, _, _) => memoizeTableIR(ctx, child, requestedType, memo)
case TableRepartition(child, _) => memoizeTableIR(ctx, child, requestedType, memo)
case TableHead(child, _) => memoizeTableIR(ctx, child, TableType(
key = child.typ.key,
rowType = unify(child.typ.rowType, selectKey(child.typ.rowType, child.typ.key), requestedType.rowType),
Expand Down Expand Up @@ -794,7 +794,7 @@ object PruneDeadFields {
val dep = requestedType.copy(colType = unify(child.typ.colType,
requestedType.colType.insert(prunedPreExplosionFieldType, path.toList)._1.asInstanceOf[TStruct]))
memoizeMatrixIR(ctx, child, dep, memo)
case MatrixRepartition(child, _, _) =>
case MatrixRepartition(child, _) =>
memoizeMatrixIR(ctx, child, requestedType, memo)
case MatrixUnionRows(children) =>
memoizeMatrixIR(ctx, children.head, requestedType, memo)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Requiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
case TableFilter(child, _) => requiredness.unionFrom(lookup(child))
case TableHead(child, _) => requiredness.unionFrom(lookup(child))
case TableTail(child, _) => requiredness.unionFrom(lookup(child))
case TableRepartition(child, n, strategy) => requiredness.unionFrom(lookup(child))
case TableRepartition(child, n) => requiredness.unionFrom(lookup(child))
case TableDistinct(child) => requiredness.unionFrom(lookup(child))
case TableOrderBy(child, sortFields) => requiredness.unionFrom(lookup(child))
case TableRename(child, rMap, gMap) => requiredness.unionFrom(lookup(child))
Expand Down
Loading