Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Experiment to use pytypes to add support for python type hints #69

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ python:
install:
- pip install coverage
- pip install --upgrade pytest pytest-benchmark
- pip install pytypes

script:
- |
Expand Down
14 changes: 12 additions & 2 deletions multipledispatch/conflict.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
from .utils import _toposort, groupby
from pytypes import is_subtype, is_Union, get_Union_params


class AmbiguityWarning(Warning):
pass


def safe_subtype(a, b):
"""Union safe subclass"""
if is_Union(a):
return any(is_subtype(tp, b) for tp in get_Union_params(a))
else:
return is_subtype(a, b)


def supercedes(a, b):
""" A is consistent and strictly more specific than B """
return len(a) == len(b) and all(map(issubclass, a, b))
return len(a) == len(b) and all(map(safe_subtype, a, b))


def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
return (len(a) == len(b) and
all(issubclass(aa, bb) or issubclass(bb, aa)
all(safe_subtype(aa, bb) or safe_subtype(bb, aa)
for aa, bb in zip(a, b)))


Expand Down
81 changes: 60 additions & 21 deletions multipledispatch/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from warnings import warn
import inspect

import copy

from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
import itertools as itl

import itertools as itl
import pytypes
import typing


class MDNotImplementedError(NotImplementedError):
Expand Down Expand Up @@ -46,6 +51,7 @@ def restart_ordering(on_ambiguity=ambiguity_warn):
DeprecationWarning,
)


class Dispatcher(object):
""" Dispatch methods based on type signature

Expand Down Expand Up @@ -140,13 +146,17 @@ def add(self, signature, func):
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D.add((typing.Optional[str], ), lambda x: x)

>>> D(1, 2)
3
>>> D(1, 2.0)
>>> D('1', 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
NotImplementedError: Could not find signature for add: <str, float>
>>> D('s')
's'
>>> D(None)

When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
Expand All @@ -157,24 +167,35 @@ def add(self, signature, func):
annotations = self.get_func_annotations(func)
if annotations:
signature = annotations
# Make function annotation dict

def process_union(tp):
if isinstance(tp, tuple):
t = typing.Union[tuple(process_union(e) for e in tp)]
return t
else:
return tp

# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func)
return
signatures = expand_tuples(signature)
for signature in signatures:
signature = tuple(process_union(tp) for tp in signature)

for typ in signature:
if not isinstance(typ, type):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" %
(typ, str_sig, self.name))
# make a copy of the function (if needed) and apply the function annotations

self.funcs[signature] = func
self._cache.clear()
# TODO: MAKE THIS type or typevar
for typ in signature:
try:
typing.Union[typ]
except TypeError:
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" %
(typ, str_sig, self.name))

self.funcs[signature] = func
self._cache.clear()

try:
del self._ordering
Expand All @@ -196,7 +217,11 @@ def reorder(self, on_ambiguity=ambiguity_warn):
return od

def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
try:
types = tuple([pytypes.deep_type(arg, 1, max_sample=10) for arg in args])
except:
# some things dont deeptype welkl
types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
except KeyError:
Expand Down Expand Up @@ -259,12 +284,26 @@ def dispatch(self, *types):
except StopIteration:
return None

@staticmethod
def get_type_vars(x):
if isinstance(x, typing.TypeVar):
yield x
if isinstance(x, typing.GenericMeta):
for e in x.__parameters__:
yield e

def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
if len(signature) == n:
result = self.funcs[signature]
yield result
try:
typsig = typing.Tuple[signature]
typvars = list(self.get_type_vars(typsig))
if pytypes.is_subtype(typing.Tuple[types], typsig, bound_typevars={t.__name__: t for t in typvars}):
yield result
except pytypes.InputTypeError:
continue

def resolve(self, types):
""" Deterimine appropriate implementation for this type signature
Expand Down
6 changes: 3 additions & 3 deletions multipledispatch/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ def f(x):

def test_union_types():
@dispatch((A, C))
def f(x):
def hh(x):
return 1

assert f(A()) == 1
assert f(C()) == 1
assert hh(A()) == 1
assert hh(C()) == 1


def test_namespaces():
Expand Down
29 changes: 29 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from multipledispatch import dispatch
from multipledispatch.dispatcher import Dispatcher
from multipledispatch.utils import raises
import typing


def test_function_annotation_register():
Expand All @@ -30,8 +32,23 @@ def inc(x: int):
def inc(x: float):
return x - 1

@dispatch()
def inc(x: typing.Optional[str]):
return x

@dispatch()
def inc(x: typing.List[int]):
return x[0] * 4

@dispatch()
def inc(x: typing.List[str]):
return x[0] + 'b'

assert inc(1) == 2
assert inc(1.0) == 0.0
assert inc('a') == 'a'
assert inc([8]) == 32
assert inc(['a']) == 'ab'


def test_function_annotation_dispatch_custom_namespace():
Expand Down Expand Up @@ -68,6 +85,18 @@ def f(self, x: float):
assert foo.f(1.0) == 0.0


def test_diagonal_dispatch():
T = typing.TypeVar('T')
U = typing.TypeVar('U')

@dispatch()
def diag(x: T, y: T):
return 'same'

assert diag(1, 6) == 'same'
assert raises(NotImplementedError, lambda: diag(1, '1'))


def test_overlaps():
@dispatch(int)
def inc(x: int):
Expand Down
25 changes: 20 additions & 5 deletions multipledispatch/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@

import pytypes
import typing


def raises(err, lamda):
try:
lamda()
Expand All @@ -14,15 +19,25 @@ def expand_tuples(L):

>>> expand_tuples([1, 2])
[(1, 2)]

>>> expand_tuples([1, typing.Optional[str]]) #doctest: +ELLIPSIS
[(1, <... 'str'>), (1, <... 'NoneType'>)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
if pytypes.is_Union(L[0]):
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in pytypes.get_Union_params(L[0])]
elif not pytypes.is_of_type(L[0], tuple):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow this doesn't actually get hit when L[0] is not a tuple:

(Pdb) p L[0]
<type 'numpy.dtype'>
(Pdb) p pytypes.is_of_type(L[0], tuple)
True

This breaks importing datashape.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root cause of this is actually more concerning than this one bug. Why doesn't pytypes.is_of_type work correctly for type objects?

rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]


# Taken from theano/theano/gof/sched.py
Expand Down