Skip to content

Commit

Permalink
Implement basic rewrites for Unique
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 10, 2021
1 parent 93e2264 commit a341c88
Show file tree
Hide file tree
Showing 2 changed files with 468 additions and 2 deletions.
159 changes: 158 additions & 1 deletion aesara/tensor/basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
Expand Down Expand Up @@ -3495,3 +3495,160 @@ def local_Shape_i_of_broadcastable(fgraph, node):

if shape_arg.broadcastable[node.op.i]:
return [as_tensor_variable(1, dtype=np.int64)]


@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_scalar(fgraph, node):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if not isinstance(node.op, Unique):
return False

if node.op.return_index or node.op.return_inverse or node.op.return_counts:
return False

uniqued_var = node.inputs[0]

if uniqued_var.ndim != 0:
return False

old_out = node.outputs[0]
res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype)
return [res]


@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_Alloc_lift(fgraph, node):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False

if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False

alloc_var = node.inputs[0]

if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)):
return False

alloced_var, *alloc_shape = alloc_var.owner.inputs

new_unique, *_ = node.op.make_node(alloced_var).outputs

old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]


@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False

if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False

bcast_var = node.inputs[0]

if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False

bcasted_var, *bcast_shape = bcast_var.owner.inputs

new_unique, *_ = node.op.make_node(bcasted_var).outputs

old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]


@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_Repeat_lift(fgraph, node):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False

if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False

repeat_var = node.inputs[0]

if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)):
return False

repeated_var, *repeat_shape = repeat_var.owner.inputs

new_unique, *_ = node.op.make_node(repeated_var).outputs

old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]


@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_second(fgraph, node):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False

if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False

second_var = node.inputs[0]

if not (
second_var.owner
and isinstance(second_var.owner.op, Elemwise)
and isinstance(second_var.owner.op.scalar_op, aes.Second)
):
return False

shape_var, seconded_var = second_var.owner.inputs

new_unique, *_ = node.op.make_node(seconded_var).outputs

old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
Loading

0 comments on commit a341c88

Please sign in to comment.