Skip to content

Commit

Permalink
Revert "Allow shapecheck of PixelCNN++ (jax-ml#2017)"
Browse files Browse the repository at this point in the history
This reverts commit 8f538f4.

Issue: jax-ml#2245
  • Loading branch information
gnecula authored and srvasude committed May 5, 2020
1 parent 93983c5 commit 17f6333
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 225 deletions.
26 changes: 15 additions & 11 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import collections
import functools
import itertools as it
import operator as op
import os
import string
import threading
from warnings import warn

Expand All @@ -51,14 +51,14 @@
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .interpreters.masking import eval_polymorphic_shape, Poly, Mon
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
from .interpreters import ad
from .interpreters import batching
from .interpreters import parallel
from .interpreters import masking
from .interpreters.masking import shapecheck, ensure_poly
from .config import flags, config, bool_env

map = safe_map
Expand Down Expand Up @@ -1053,23 +1053,24 @@ def wrapped_fun(args, logical_env):
out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs))
if not out_shapes == list(out_shapes_):
raise masking.ShapeError
if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env)
for out, shape in zip(outs, out_shapes)):
if not all(onp.shape(out) == masking.eval_shape_expr(padded_env, expr)
for out, expr in zip(outs, out_shapes)):
raise masking.ShapeError
return tree_unflatten(out_tree(), outs)
return wrapped_fun

def _remap_ids(names, shape_spec):
return masking.ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
ShapeSpec, Poly, Mon = masking.ShapeSpec, masking.Poly, masking.Mon
mdim = masking.monomorphic_dim
return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
: coeff for mon, coeff in poly.items()})
if poly is not masking._monomorphic_dim else
masking._monomorphic_dim for poly in shape_spec)
if poly is not mdim else mdim for poly in shape_spec)

def _bind_shapes(shape_exprs, shapes):
env = {}
for shape_expr, shape in zip(shape_exprs, shapes):
for poly, d in zip(shape_expr, shape):
if type(poly) is not Poly or poly.is_constant:
if ensure_poly(poly).is_constant:
continue
else:
(binder,), = poly # TODO generalize to handle striding
Expand All @@ -1084,13 +1085,16 @@ def shapecheck(in_shapes, out_shape, fun):
out_shapes, out_tree = tree_flatten(out_shape)
out_shapes = map(masking.parse_spec, out_shapes)
flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
out_shapes_ = masking.shapecheck(flat_fun, in_shapes)
if out_tree != out_tree_(): raise TypeError("pytree mismatch")
if not all(map(masking._shape_spec_consistent, out_shapes, out_shapes_)):
if not all(map(_shape_spec_consistent, out_shapes, out_shapes_)):
raise masking.ShapeError
return fun

def _shape_spec_consistent(spec, expr):
return all(a == b for a, b in zip(spec, expr) if a is not masking.monomorphic_dim)


def jvp(fun, primals, tangents):
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
Expand Down
Loading

0 comments on commit 17f6333

Please sign in to comment.