Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow shapecheck of PixelCNN++ #2017

Merged
merged 38 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c19c3e9
Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.…
juliuskunze Jan 16, 2020
7e424c9
Fix dynamic slicing
juliuskunze Jan 17, 2020
7595316
Fix issue with float64.__index__()
juliuskunze Jan 17, 2020
eaa847d
Fix np.arange with float size, _try_canonicalize_shape
juliuskunze Jan 17, 2020
4abe17d
Cleanup: Make methods to create Poly internal (only use in Poly / sha…
juliuskunze Jan 17, 2020
79c1a7b
Fix testReshapeWithUnusualShapes (error message)
juliuskunze Jan 17, 2020
a45b879
Fix syntax for python 3.6
juliuskunze Jan 17, 2020
5d630e9
Remove Poly.__index__
juliuskunze Jan 20, 2020
212f241
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Jan 20, 2020
99455b5
Fix tests
juliuskunze Jan 20, 2020
293caf6
Split up masking.py
juliuskunze Jan 20, 2020
3283bba
Cleanup masking
juliuskunze Jan 20, 2020
97c2506
Cleanup
juliuskunze Jan 20, 2020
4433fb3
Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)
juliuskunze Jan 20, 2020
1d28237
Remove shape_rules, fix test
juliuskunze Jan 20, 2020
189e586
Remove shapes.py, move code to abstract_arrays.py / api.py
juliuskunze Jan 20, 2020
8a6f1fc
Remove safe_map/zip, is_instance from abstract_arrays, test + fix Pol…
juliuskunze Jan 20, 2020
72d8021
Add missing shapecheck_test.py
juliuskunze Jan 20, 2020
f6ba303
Cleanup, minimize changes
juliuskunze Jan 20, 2020
115581e
Minimize import diff
juliuskunze Jan 20, 2020
4e9e237
Minor
juliuskunze Jan 20, 2020
847d437
Allow shapecheck of np.where
juliuskunze Jan 21, 2020
cfb19fa
Fix np.where
juliuskunze Jan 21, 2020
732315b
Simplify gather to allow retightening type assertion in ConcreteArray
juliuskunze Jan 21, 2020
cb4696a
Remove unused imports
juliuskunze Jan 21, 2020
a760fa1
Make import style consistent
juliuskunze Jan 21, 2020
987d65c
Remove is_polymorphic, special cases in sampling, split, where.
juliuskunze Jan 21, 2020
c3c4588
Move back Poly, _parse_shape_spec into masking.py to simplify diff
juliuskunze Jan 21, 2020
bc4cdfa
Move back ShapeTest into masking_test.py to simplify diff
juliuskunze Jan 21, 2020
01404ed
Minor reverts to further simplify diff
juliuskunze Jan 21, 2020
026f933
Fix tests
juliuskunze Jan 21, 2020
5bb0730
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Jan 29, 2020
4ff7c3c
Merge remote-tracking branch 'main/master' into shapecheck-pcnn
juliuskunze Jan 30, 2020
aca2724
Minimize diff
juliuskunze Jan 30, 2020
d4a8bbb
Restore copyright, cleanup imports in masking.py
juliuskunze Jan 31, 2020
1d12235
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 6, 2020
dd4b8d2
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 7, 2020
02f1589
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from __future__ import division
from __future__ import print_function

from collections import Counter
from itertools import product
import operator as op
import numpy as onp

from . import core
Expand Down Expand Up @@ -248,3 +251,167 @@ def _make_concrete_python_scalar(x):
ad_util.jaxval_zeros_likers[t] = _zeros_like_python_scalar

core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

def to_index(x):
"""Like operator.index, but allowing polymorphic dimensions.
Not implemented as `Poly.__index__`, since operator.index only allows ints."""
return x if type(x) is Poly else op.index(x)

# TODO remove remaining usages:
def is_polymorphic(shape):
return any(map(lambda d: type(d) is Poly, shape))

def eval_polymorphic_shape(shape, values_dict):
return tuple(dim.evaluate(values_dict) if type(dim) is Poly else dim
for dim in shape)

def _ensure_poly(p):
if type(p) is Poly:
return p

return Poly({Mon(): p})

class Poly(dict):
"""Polynomial with integer coefficients,
usable as element in a polymorphic shape.

type Poly = Map Mon Int -- monomials to coeffs
type Mon = Map Str Int
"""

def __init__(self, coeffs):
# Makes sure Polynomials are always in canonical form to simplify operators:
coeffs = {mon: op.index(coeff) for mon, coeff in coeffs.items() if coeff != 0}
coeffs = {Mon(): 0} if len(coeffs) == 0 else coeffs
super().__init__(coeffs)

def __add__(self, other):
coeffs = self.copy()

for mon, coeff in _ensure_poly(other).items():
coeffs[mon] = coeffs.get(mon, 0) + coeff

return Poly(coeffs)

def __sub__(self, other):
return self + -other

def __neg__(self):
return Poly({mon: -coeff for mon, coeff in self.items()})

def __mul__(self, other):
coeffs = dict()
for (mon1, coeff1), (mon2, coeff2) \
in product(self.items(), _ensure_poly(other).items()):
mon = mon1 * mon2
coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2

return Poly(coeffs)

def __rmul__(self, other):
return self * other

def __radd__(self, other):
return self + other

def __rsub__(self, other):
return self + -other

def __floordiv__(self, divisor):
q, _ = divmod(self, divisor) # pytype: disable=wrong-arg-types
return q

def __mod__(self, divisor):
_, r = divmod(self, divisor) # pytype: disable=wrong-arg-types
return r

def __divmod__(self, divisor):
if self.is_constant:
return divmod(int(self), divisor)

def divided(count):
q, r = divmod(count, divisor)
if r != 0:
raise ValueError('shapecheck currently only supports strides '
'that exactly divide the strided axis length.')
return q

return Poly(
{k: coeff // divisor if k.degree == 0 else divided(coeff)
for k, coeff in self.items()}), self.get(Mon(), 0) % divisor

def __hash__(self):
return hash(tuple(sorted(self.items())))

def __eq__(self, other):
return dict.__eq__(self, _ensure_poly(other))

def __ne__(self, other):
return not self == other

def __ge__(self, other):
other = _ensure_poly(other)

if other.is_constant and self.is_constant:
return int(self) >= int(other)

if other.is_constant and int(other) <= 1:
# Assume polynomials > 0, allowing to use shape rules of binops, conv:
return True

if self.is_constant and int(self) <= 0:
return False # See above.

if self == other:
return True

raise ValueError('Polynomials comparison "{} >= {}" is inconclusive.'
.format(self, other))

def __le__(self, other):
return _ensure_poly(other) >= self

def __lt__(self, other):
return not (self >= other)

def __gt__(self, other):
return not (_ensure_poly(other) >= self)

def __str__(self):
return ' + '.join('{} {}'.format(v, k)
if (v != 1 or k.degree == 0) else str(k)
for k, v in sorted(self.items())).strip()

def __int__(self):
assert self.is_constant

return op.index(next(iter(self.values())))

def evaluate(self, values_dict):
return sum(coeff * prod([values_dict[id] ** deg for id, deg in mon.items()])
for mon, coeff in self.items())

@property
def is_constant(self):
return len(self) == 1 and next(iter(self)).degree == 0

class Mon(dict): # type Mon = Map Id Int -- ids to degrees
def __hash__(self):
return hash(tuple(self.items()))

def __str__(self):
return ' '.join('{}**{}'.format(k, v) if v != 1 else str(k)
for k, v in sorted(self.items()))

def __lt__(self, other):
# sort by total degree, then lexicographically on indets
self_key = self.degree, tuple(sorted(self))
other_key = other.degree, tuple(sorted(other))
return self_key < other_key

def __mul__(self, other):
return Mon(Counter(self) + Counter(other))

@property
def degree(self):
return sum(self.values())
117 changes: 96 additions & 21 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import collections
import functools
import itertools as it
import operator as op
import os
import string
import threading
from warnings import warn

Expand All @@ -54,15 +54,15 @@
from .lib import xla_bridge as xb
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped, \
eval_polymorphic_shape, Poly, Mon
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
from .interpreters import ad
from .interpreters import batching
from .interpreters import parallel
from .interpreters import masking
from .interpreters.masking import shapecheck, ensure_poly
from .config import flags, config

map = safe_map
Expand Down Expand Up @@ -1018,8 +1018,8 @@ def mask(fun, in_shapes, out_shape):
in_specs, in_shapes_tree = tree_flatten(in_shapes)
out_specs, out_shapes_tree = tree_flatten(out_shape)

in_specs = map(masking.parse_spec, in_specs)
out_specs = map(masking.parse_spec, out_specs)
in_specs = map(_parse_shape_spec, in_specs)
out_specs = map(_parse_shape_spec, out_specs)

unique_ids = collections.defaultdict(object)
in_specs = map(partial(_remap_ids, unique_ids), in_specs)
Expand All @@ -1029,56 +1029,131 @@ def wrapped_fun(args, logical_env):
args_flat, in_tree = tree_flatten(args)
if in_tree != in_shapes_tree: raise TypeError("pytree mismatch")
logical_env = {unique_ids[name] : val for name, val in logical_env.items()}
in_shapes = map(masking.finalize_spec, in_specs, map(onp.shape, args_flat))
in_shapes = map(_finalize_shape_spec, in_specs, map(onp.shape, args_flat))
padded_env = _bind_shapes(in_shapes, [x.shape for x in args_flat])
f = lu.wrap_init(fun)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
outs, out_shapes_ = masking.mask_fun(
flat_fun, logical_env, padded_env, args_flat, in_shapes)
if not out_tree() == out_shapes_tree: raise TypeError("pytree mismatch")
out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs))
out_shapes = map(_finalize_shape_spec, out_specs, map(onp.shape, outs))
if not out_shapes == list(out_shapes_):
raise masking.ShapeError
if not all(onp.shape(out) == masking.eval_shape_expr(padded_env, expr)
for out, expr in zip(outs, out_shapes)):
raise masking.ShapeError
raise ShapeError
if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env)
for out, shape in zip(outs, out_shapes)):
raise ShapeError
return tree_unflatten(out_tree(), outs)
return wrapped_fun

def _remap_ids(names, shape_spec):
ShapeSpec, Poly, Mon = masking.ShapeSpec, masking.Poly, masking.Mon
mdim = masking.monomorphic_dim
return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
: coeff for mon, coeff in poly.items()})
if poly is not mdim else mdim for poly in shape_spec)
if poly is not _monomorphic_dim else
_monomorphic_dim for poly in shape_spec)

def _bind_shapes(shape_exprs, shapes):
env = {}
for shape_expr, shape in zip(shape_exprs, shapes):
for poly, d in zip(shape_expr, shape):
if ensure_poly(poly).is_constant:
if type(poly) is not Poly or poly.is_constant:
continue
else:
(binder,), = poly # TODO generalize to handle striding
if env.setdefault(binder, d) != d: raise masking.ShapeError
if env.setdefault(binder, d) != d: raise ShapeError
return env


@curry
def shapecheck(in_shapes, out_shape, fun):
in_shapes, in_tree = tree_flatten(in_shapes)
in_shapes = map(masking.parse_spec, in_shapes)
in_shapes = map(_parse_shape_spec, in_shapes)
out_shapes, out_tree = tree_flatten(out_shape)
out_shapes = map(masking.parse_spec, out_shapes)
out_shapes = map(_parse_shape_spec, out_shapes)
flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
out_shapes_ = masking.shapecheck(flat_fun, in_shapes)
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
if out_tree != out_tree_(): raise TypeError("pytree mismatch")
if not all(map(_shape_spec_consistent, out_shapes, out_shapes_)):
raise masking.ShapeError
raise ShapeError
return fun

class ShapeError(Exception): pass

class ShapeSyntaxError(Exception): pass

# To denote some shape expressions (for annotations) we use a small language.
#
# data ShapeSpec = ShapeSpec [Dim]
# data Dim = Id PyObj
# | Lit Int
# | Mul Dim Dim
# | Add Dim Dim
# | MonomorphicDim
#
# We'll also make a simple concrete syntax for annotation. The grammar is
#
# shape_spec ::= '(' dims ')'
# dims ::= dim ',' dims | ''
# dim ::= str | int | dim '*' dim | dim '+' dim | '_'
#
# ShapeSpecs can have some monomorphic dims inside them,
# which must be replaced with concrete shapes when known.

class ShapeSpec(tuple):
def __str__(self):
return 'ShapeSpec({})'.format(', '.join(map(str, self)))

def _finalize_shape_spec(spec, shape):
return tuple(_parse_lit(d) if e is _monomorphic_dim else e
for e, d in zip(spec, shape))

def _parse_shape_spec(spec=''):
if not spec:
return ShapeSpec(())
if spec[0] == '(':
if spec[-1] != ')': raise ShapeSyntaxError(spec)
spec = spec[1:-1]
dims = map(_parse_dim, spec.replace(' ', '').strip(',').split(','))
return ShapeSpec(dims)

def _parse_dim(spec):
if '+' in spec:
return onp.sum(map(_parse_dim, spec.split('+')))
elif '*' in spec:
return prod(map(_parse_dim, spec.split('*')))
elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit():
return _parse_lit(spec)
elif spec in _identifiers:
return _parse_id(spec)
elif spec == '_':
return _monomorphic_dim
else:
raise ShapeSyntaxError(spec)

_identifiers = frozenset(string.ascii_lowercase)

def _parse_id(name): return Poly({Mon({name: 1}): 1})

def _parse_lit(val_str): return Poly({Mon(): int(val_str)})

class MonomorphicDim(object):
def __str__(self): return '_'

_monomorphic_dim = MonomorphicDim()

# Two convenient ways to provide shape annotations:
# 1. '(m, n)'
# 2. s_['m', 'n']

class S_(object):
def __getitem__(self, idx):
return _parse_shape_spec(('(' + ','.join(map(str, idx)) + ')')
if type(idx) is tuple else str(idx))

s_ = S_()

def _shape_spec_consistent(spec, expr):
return all(a == b for a, b in zip(spec, expr) if a is not masking.monomorphic_dim)
return all(a == b for a, b in zip(spec, expr) if a is not _monomorphic_dim)


def jvp(fun, primals, tangents):
Expand Down
Loading