From 2ccff72a4d48622798cc3fd281356a5050731213 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 14 Feb 2020 05:11:10 +0000 Subject: [PATCH 001/159] drop python 2 support --- .circleci/config.yml | 12 ++--- docs/conftest.py | 5 +- mush/compat.py | 70 --------------------------- mush/declarations.py | 36 ++++---------- mush/tests/configparser.py | 8 --- mush/tests/conftest.py | 6 --- mush/tests/example_with_mush_clone.py | 2 +- mush/tests/example_without_mush.py | 2 +- setup.cfg | 3 -- setup.py | 9 ++-- 10 files changed, 22 insertions(+), 131 deletions(-) delete mode 100644 mush/compat.py delete mode 100644 mush/tests/configparser.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 580b300..339a0d8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,17 +7,17 @@ common: &common jobs: - python/pip-run-tests: - name: python27 - image: circleci/python:2.7 + name: python36 + image: circleci/python:3.6 - python/pip-run-tests: - name: python37 - image: circleci/python:3.7 + name: python38 + image: circleci/python:3.8 - python/coverage: name: coverage requires: - - python27 - - python37 + - python36 + - python38 - python/release: name: release diff --git a/docs/conftest.py b/docs/conftest.py index 0cc9d3a..291ab9e 100644 --- a/docs/conftest.py +++ b/docs/conftest.py @@ -5,8 +5,6 @@ from sybil.parsers.codeblock import CodeBlockParser from sybil.parsers.doctest import DocTestParser -from mush.compat import PY2 - sybil_collector = Sybil( parsers=[ DocTestParser(optionflags=REPORT_NDIFF|ELLIPSIS), @@ -18,5 +16,4 @@ def pytest_collect_file(parent, path): - if not PY2: - return sybil_collector(parent, path) + return sybil_collector(parent, path) diff --git a/mush/compat.py b/mush/compat.py deleted file mode 100644 index 14aa1b5..0000000 --- a/mush/compat.py +++ /dev/null @@ -1,70 +0,0 @@ -# compatibility module for different python versions -import sys -from collections import OrderedDict -from .markers import Marker - - -if sys.version_info[:2] < (3, 0): - PY2 = True - from functools import partial - from inspect import getargspec, ismethod, isclass, isfunction - - class Parameter(object): - POSITIONAL_ONLY = Marker('POSITIONAL_ONLY') - POSITIONAL_OR_KEYWORD = kind = Marker('POSITIONAL_OR_KEYWORD') - KEYWORD_ONLY = Marker('KEYWORD_ONLY') - empty = default = Marker('empty') - - class Signature(object): - __slots__ = 'parameters' - - def signature(obj): - sig = Signature() - sig.parameters = params = OrderedDict() - - bound_args = 0 - extra_kw = {} - if isclass(obj): - obj = obj.__init__ - elif isinstance(obj, partial): - bound_args = len(obj.args) - extra_kw = obj.keywords - obj = obj.func - if not (isfunction(obj) or ismethod(obj)): - obj = obj.__call__ - if not (isfunction(obj) or ismethod(obj)): - return sig - spec = getargspec(obj) - spec_args = spec.args - if callable(obj) and not isfunction(obj): - bound_args += 1 - if bound_args: - spec_args = spec.args[bound_args:] - - defaults_count = 0 if spec.defaults is None else len(spec.defaults) - default_start = len(spec_args) - defaults_count - for i, arg in enumerate(spec_args): - params[arg] = p = Parameter() - p.name = arg - if i >= default_start: - p.default = True - - for name in extra_kw: - p = params[name] - p.default = True - p.kind = p.KEYWORD_ONLY - - seen_keyword_only = False - for p in params.values(): - if p.kind is p.KEYWORD_ONLY: - seen_keyword_only = True - elif seen_keyword_only: - p.kind = p.KEYWORD_ONLY - - return sig - -else: - PY2 = False - from inspect import signature - -NoneType = type(None) diff --git a/mush/declarations.py b/mush/declarations.py index 7d6ec7b..578e809 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,12 +1,11 @@ -import sys -import types from functools import ( WRAPPER_UPDATES, - WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS + WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, + update_wrapper as functools_update_wrapper, ) -from inspect import isclass, isfunction -from .compat import NoneType, signature -from .markers import missing, not_specified +from inspect import signature + +from .markers import missing def name_or_repr(obj): @@ -226,10 +225,7 @@ def process(self, o): return o -if sys.version_info[0] == 2: - ok_types = (type, types.ClassType, str, how) -else: - ok_types = (type, str, how) +ok_types = (type, str, how) def check_type(*objs): @@ -287,7 +283,7 @@ def extract_declarations(obj, explicit_requires, explicit_returns, guess=True): if isinstance(requires_, requires): pass - elif isinstance(requires_, NoneType): + elif requires_ is None: if guess: requires_ = guess_requirements(obj) elif isinstance(requires_, (list, tuple)): @@ -297,7 +293,7 @@ def extract_declarations(obj, explicit_requires, explicit_returns, guess=True): else: requires_ = requires(requires_) - if isinstance(returns_, (ReturnsType, NoneType)): + if returns_ is None or isinstance(returns_, ReturnsType): pass elif isinstance(returns_, (list, tuple)): returns_ = returns(*returns_) @@ -320,18 +316,4 @@ def update_wrapper(wrapper, An extended version of :func:`functools.update_wrapper` that also preserves Mush's annotations. """ - # copied here to backport bugfix from Python 3. - for attr in assigned: - try: - value = getattr(wrapped, attr) - except AttributeError: - pass - else: - setattr(wrapper, attr, value) - for attr in updated: - getattr(wrapper, attr).update(getattr(wrapped, attr, {})) - # Issue #17482: set __wrapped__ last so we don't inadvertently copy it - # from the wrapped function when updating __dict__ - wrapper.__wrapped__ = wrapped - # Return the wrapper so this can be used as a decorator via partial() - return wrapper + return functools_update_wrapper(wrapper, wrapped, assigned, updated) diff --git a/mush/tests/configparser.py b/mush/tests/configparser.py deleted file mode 100644 index 8bbd59b..0000000 --- a/mush/tests/configparser.py +++ /dev/null @@ -1,8 +0,0 @@ -import sys - -if sys.version_info[:2] > (3, 0): - from configparser import RawConfigParser -else: - from ConfigParser import RawConfigParser - - diff --git a/mush/tests/conftest.py b/mush/tests/conftest.py index d981b98..df9b0e0 100644 --- a/mush/tests/conftest.py +++ b/mush/tests/conftest.py @@ -4,12 +4,6 @@ from mush import returns, requires from mush.declarations import how -from ..compat import PY2 - - -def pytest_ignore_collect(path): - if 'py3' in path.basename and PY2: - return True @pytest.fixture() diff --git a/mush/tests/example_with_mush_clone.py b/mush/tests/example_with_mush_clone.py index 1d6d02e..d34331d 100644 --- a/mush/tests/example_with_mush_clone.py +++ b/mush/tests/example_with_mush_clone.py @@ -1,5 +1,5 @@ from argparse import ArgumentParser, Namespace -from .configparser import RawConfigParser +from configparser import RawConfigParser from mush import Runner, requires, attr, item import logging, os, sqlite3, sys diff --git a/mush/tests/example_without_mush.py b/mush/tests/example_without_mush.py index 387c2b9..df454ac 100644 --- a/mush/tests/example_without_mush.py +++ b/mush/tests/example_without_mush.py @@ -1,5 +1,5 @@ from argparse import ArgumentParser -from .configparser import RawConfigParser +from configparser import RawConfigParser import logging, os, sqlite3, sys log = logging.getLogger() diff --git a/setup.cfg b/setup.cfg index 3648c65..03b03d6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,3 @@ -[wheel] -universal = 1 - [tool:pytest] addopts = --verbose --strict norecursedirs=functional .git docs/_build diff --git a/setup.py b/setup.py index 994c11c..a030b0a 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='mush', - version='2.8.1', + version='3.0.0a1', author='Chris Withers', author_email='chris@simplistix.co.uk', license='MIT', @@ -21,16 +21,15 @@ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', ], packages=find_packages(), zip_safe=False, include_package_data=True, + python_requires='>=3.6', extras_require=dict( test=['pytest', 'pytest-cov', 'mock', 'sybil', 'testfixtures'], build=['sphinx', 'setuptools-git', 'wheel', 'twine'] From c11558308e931067610e7bc04a683e69803aea00 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 14 Feb 2020 05:23:21 +0000 Subject: [PATCH 002/159] Drop the mapping interface from Context --- mush/context.py | 15 +++++++++------ mush/tests/test_context.py | 28 ++++++++++++++-------------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/mush/context.py b/mush/context.py index 4070cc4..30f04a3 100644 --- a/mush/context.py +++ b/mush/context.py @@ -4,7 +4,7 @@ from .factory import Factory from .markers import missing -NONE_TYPE = None.__class__ +NONE_TYPE = type(None) class ContextError(Exception): @@ -58,9 +58,12 @@ def type_key(type_tuple): return type.__name__ -class Context(dict): +class Context: "Stores resources for a particular run." + def __init__(self): + self._store = {} + def add(self, it, type): """ Add a resource to the context. @@ -71,15 +74,15 @@ def add(self, it, type): if type is NONE_TYPE: raise ValueError('Cannot add None to context') - if type in self: + if type in self._store: raise ContextError('Context already contains %r' % ( type )) - self[type] = it + self._store[type] = it def __repr__(self): bits = [] - for type, value in sorted(self.items(), key=type_key): + for type, value in sorted(self._store.items(), key=type_key): bits.append('\n %r: %r' % (type, value)) if bits: bits.append('\n') @@ -108,7 +111,7 @@ def call(self, obj, requires): ops.appendleft(type.process) type = type.type - o = self.get(type, missing) + o = self._store.get(type, missing) if isinstance(o, Factory): o = self.call(o.__wrapped__, o.requires) self[type] = o diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index c109a6b..e80b144 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -23,7 +23,7 @@ def test_simple(self): context = Context() context.add(obj, TheType) - self.assertTrue(context[TheType] is obj) + self.assertTrue(context._store[TheType] is obj) expected = ( ": \n" @@ -40,7 +40,7 @@ def test_type_as_string(self): expected = ("\n" "}>") - self.assertTrue(context['my label'] is obj) + self.assertTrue(context._store['my label'] is obj) self.assertEqual(repr(context), expected) self.assertEqual(str(context), expected) @@ -49,7 +49,7 @@ class T2(object): pass obj = TheType() context = Context() context.add(obj, T2) - self.assertTrue(context[T2] is obj) + self.assertTrue(context._store[T2] is obj) expected = ("\n" "}>") @@ -80,7 +80,7 @@ def test_add_none(self): def test_add_none_with_type(self): context = Context() context.add(None, TheType) - self.assertTrue(context[TheType] is None) + self.assertTrue(context._store[TheType] is None) def test_call_basic(self): def foo(): @@ -96,7 +96,7 @@ def foo(obj): context.add('bar', 'baz') result = context.call(foo, requires('baz')) compare(result, 'bar') - compare({'baz': 'bar'}, context) + compare({'baz': 'bar'}, context._store) def test_call_requires_type(self): def foo(obj): @@ -105,7 +105,7 @@ def foo(obj): context.add('bar', TheType) result = context.call(foo, requires(TheType)) compare(result, 'bar') - compare({TheType: 'bar'}, context) + compare({TheType: 'bar'}, context._store) def test_call_requires_missing(self): def foo(obj): return obj @@ -144,7 +144,7 @@ def foo(x, y): compare(result, ('foo', 'bar')) compare({TheType: 'foo', 'baz': 'bar'}, - actual=context) + actual=context._store) def test_call_requires_optional_present(self): def foo(x=1): @@ -153,7 +153,7 @@ def foo(x=1): context.add(2, TheType) result = context.call(foo, requires(optional(TheType))) compare(result, 2) - compare({TheType: 2}, context) + compare({TheType: 2}, context._store) def test_call_requires_optional_ContextError(self): def foo(x=1): @@ -169,7 +169,7 @@ def foo(x=1): context.add(2, 'foo') result = context.call(foo, requires(optional('foo'))) compare(result, 2) - compare({'foo': 2}, context) + compare({'foo': 2}, context._store) def test_call_requires_item(self): def foo(x): @@ -234,7 +234,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns(TheType)) compare(result, 'bar') - compare({TheType: 'bar'}, context) + compare({TheType: 'bar'}, context._store) def test_returns_sequence(self): def foo(): @@ -242,7 +242,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns('foo', 'bar')) compare(result, (1, 2)) - compare({'foo': 1, 'bar': 2}, context) + compare({'foo': 1, 'bar': 2}, context._store) def test_returns_mapping(self): def foo(): @@ -250,7 +250,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns_mapping()) compare(result, {'foo': 1, 'bar': 2}) - compare({'foo': 1, 'bar': 2}, context) + compare({'foo': 1, 'bar': 2}, context._store) def test_ignore_return(self): def foo(): @@ -258,11 +258,11 @@ def foo(): context = Context() result = context.extract(foo, nothing, nothing) compare(result, 'bar') - compare({}, context) + compare({}, context._store) def test_ignore_non_iterable_return(self): def foo(): pass context = Context() result = context.extract(foo, nothing, nothing) compare(result, expected=None) - compare(context, expected={}) + compare(context._store, expected={}) From 3ef4825a133f8d51bcdcd60df22eb4593f8825a1 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 14 Feb 2020 05:35:14 +0000 Subject: [PATCH 003/159] Only store one annotation. --- mush/callpoints.py | 2 +- mush/declarations.py | 21 +++++++++++++-------- mush/plug.py | 7 +++++-- mush/tests/test_declarations.py | 10 +++++----- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index f1275aa..481b990 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -11,7 +11,7 @@ class CallPoint(object): def __init__(self, obj, requires=None, returns=None, lazy=None): requires, returns = extract_declarations(obj, requires, returns) - lazy = lazy or getattr(obj, '__mush_lazy__', False) + lazy = lazy or getattr(obj, '__mush__', {}).get('lazy') requires = requires or nothing returns = returns or result_type if lazy: diff --git a/mush/declarations.py b/mush/declarations.py index 578e809..c5c13df 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -12,6 +12,12 @@ def name_or_repr(obj): return getattr(obj, '__name__', None) or repr(obj) +def set_mush(obj, key, value): + if not hasattr(obj, '__mush__'): + obj.__mush__ = {} + obj.__mush__[key] = value + + class requires(object): """ Represents requirements for a particular callable. @@ -54,14 +60,14 @@ def __repr__(self): return txt def __call__(self, obj): - obj.__mush_requires__ = self + set_mush(obj, 'requires', self) return obj class ReturnsType(object): def __call__(self, obj): - obj.__mush_returns__ = self + set_mush(obj, 'returns', self) return obj def __repr__(self): @@ -138,7 +144,7 @@ def lazy(obj): Declaration that specifies the callable should only be called the first time it is required. """ - obj.__mush_lazy__ = True + set_mush(obj, 'lazy', True) return obj @@ -271,8 +277,9 @@ def guess_requirements(obj): def extract_declarations(obj, explicit_requires, explicit_returns, guess=True): - mush_requires = getattr(obj, '__mush_requires__', None) - mush_returns = getattr(obj, '__mush_returns__', None) + mush_declarations = getattr(obj, '__mush__', {}) + mush_requires = mush_declarations.get('requires', None) + mush_returns = mush_declarations.get('returns', None) annotations = getattr(obj, '__annotations__', None) annotations = {} if annotations is None else annotations.copy() annotation_returns = annotations.pop('return', None) @@ -303,9 +310,7 @@ def extract_declarations(obj, explicit_requires, explicit_returns, guess=True): return requires_, returns_ -WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ( - '__mush__requires__', '__mush_returns__' -) +WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) def update_wrapper(wrapper, diff --git a/mush/plug.py b/mush/plug.py index 0b59d8f..dfe1950 100644 --- a/mush/plug.py +++ b/mush/plug.py @@ -1,10 +1,13 @@ +from .declarations import set_mush + + class ignore(object): """ A decorator to explicitly mark that a method of a :class:`~mush.Plug` should not be added to a runner by :meth:`~mush.Plug.add_to` """ def __call__(self, method): - method.__mush_plug__ = self + set_mush(method, 'plug', self) return method def apply(self, runner, obj): @@ -64,5 +67,5 @@ def add_to(self, runner): if not name.startswith('_'): obj = getattr(self, name) if callable(obj): - action = getattr(obj, '__mush_plug__', default_action) + action = getattr(obj, '__mush__', {}).get('plug', default_action) action.apply(runner, obj) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index c5d6421..613d770 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -63,7 +63,7 @@ def test_decorator_paranoid(self): def foo(): return 'bar' - compare(set(foo.__mush_requires__), {(None, Type1)}) + compare(set(foo.__mush__['requires']), {(None, Type1)}) compare(foo(), 'bar') @@ -163,7 +163,7 @@ def test_decorator(self): @returns(Type1) def foo(): return 'foo' - r = foo.__mush_returns__ + r = foo.__mush__['returns'] compare(repr(r), 'returns(Type1)') compare(dict(r.process(foo())), {Type1: 'foo'}) @@ -180,7 +180,7 @@ def test_it(self): @returns_mapping() def foo(): return {Type1: 'foo', 'bar': 'baz'} - r = foo.__mush_returns__ + r = foo.__mush__['returns'] compare(repr(r), 'returns_mapping()') compare(dict(r.process(foo())), {Type1: 'foo', 'bar': 'baz'}) @@ -194,7 +194,7 @@ def test_it(self): @returns_sequence() def foo(): return t1, t2 - r = foo.__mush_returns__ + r = foo.__mush__['returns'] compare(repr(r), 'returns_sequence()') compare(dict(r.process(foo())), {Type1: t1, Type2: t2}) @@ -206,7 +206,7 @@ def test_basic(self): @returns_result_type() def foo(): return 'foo' - r = foo.__mush_returns__ + r = foo.__mush__['returns'] compare(repr(r), 'returns_result_type()') compare(dict(r.process(foo())), {str: 'foo'}) From 121311350343d0321c17e43a8dca9658b1c4e271 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 14 Feb 2020 08:53:09 +0000 Subject: [PATCH 004/159] factor out a method to get a single requirement. --- mush/context.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/mush/context.py b/mush/context.py index 30f04a3..ddfcd3d 100644 --- a/mush/context.py +++ b/mush/context.py @@ -103,31 +103,35 @@ def call(self, obj, requires): args = [] kw = {} - for name, required in requires: - - type = required - ops = deque() - while isinstance(type, how): - ops.appendleft(type.process) - type = type.type - - o = self._store.get(type, missing) - if isinstance(o, Factory): - o = self.call(o.__wrapped__, o.requires) - self[type] = o - - for op in ops: - o = op(o) - if o is nothing: - break - + for name, requirement in requires: + o = self.get(requirement) if o is nothing: pass - elif o is missing: - raise ContextError('No %s in context' % repr(required)) elif name is None: args.append(o) else: kw[name] = o return obj(*args, **kw) + + def get(self, requirement): + spec = requirement + ops = deque() + + while isinstance(spec, how): + ops.appendleft(spec.process) + spec = spec.type + + o = self._store.get(spec, missing) + if isinstance(o, Factory): + o = o(self) + + for op in ops: + o = op(o) + if o is nothing: + break + + if o is missing: + raise ContextError('No %s in context' % repr(requirement)) + + return o From 8c2ba7b19ae57b39912da91b75305fc79988a4cf Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 14 Feb 2020 11:49:26 +0000 Subject: [PATCH 005/159] make context importable from the top level mush package. --- mush/__init__.py | 2 ++ mush/tests/test_context.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mush/__init__.py b/mush/__init__.py index 60fc10e..919f59c 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -5,8 +5,10 @@ optional, attr, item, nothing ) from .plug import Plug +from .context import Context, ContextError __all__ = [ + 'Context', 'ContextError', 'Runner', 'requires', 'optional', 'returns_result_type', 'returns_mapping', 'returns_sequence', 'returns', diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index e80b144..1074561 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -3,7 +3,7 @@ from testfixtures import ShouldRaise, compare -from mush.context import Context, ContextError +from mush import Context, ContextError from mush.declarations import ( nothing, requires, optional, item, From 65800e59b4ef35ddb27c6c7315ac7832fc9acfe1 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 14 Feb 2020 18:17:19 +0000 Subject: [PATCH 006/159] Split declaration extraction between requires and returns. Also make it a little faster when we have explicit declarations. --- mush/callpoints.py | 7 +++---- mush/declarations.py | 33 +++++++++++++++++++++------------ mush/runner.py | 7 ++++--- mush/tests/test_declarations.py | 5 +++-- 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index 481b990..6b74d51 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,4 +1,4 @@ -from .declarations import result_type, nothing, extract_declarations +from .declarations import result_type, nothing, extract_requires, extract_returns from .factory import Factory @@ -6,11 +6,10 @@ class CallPoint(object): next = None previous = None - requires = nothing - returns = result_type def __init__(self, obj, requires=None, returns=None, lazy=None): - requires, returns = extract_declarations(obj, requires, returns) + requires = extract_requires(obj, requires) + returns = extract_returns(obj, returns) lazy = lazy or getattr(obj, '__mush__', {}).get('lazy') requires = requires or nothing returns = returns or result_type diff --git a/mush/declarations.py b/mush/declarations.py index c5c13df..0549926 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -276,17 +276,15 @@ def guess_requirements(obj): return requires(*args, **kw) -def extract_declarations(obj, explicit_requires, explicit_returns, guess=True): - mush_declarations = getattr(obj, '__mush__', {}) - mush_requires = mush_declarations.get('requires', None) - mush_returns = mush_declarations.get('returns', None) - annotations = getattr(obj, '__annotations__', None) - annotations = {} if annotations is None else annotations.copy() - annotation_returns = annotations.pop('return', None) - annotation_requires = annotations or None - - requires_ = explicit_requires or mush_requires or annotation_requires - returns_ = explicit_returns or mush_returns or annotation_returns +def extract_requires(obj, requires_, guess=True): + if requires_ is None: + mush_declarations = getattr(obj, '__mush__', {}) + requires_ = mush_declarations.get('requires', None) + if requires_ is None: + annotations = getattr(obj, '__annotations__', None) + annotations = {} if annotations is None else annotations.copy() + annotations.pop('return', None) + requires_ = annotations or None if isinstance(requires_, requires): pass @@ -300,6 +298,17 @@ def extract_declarations(obj, explicit_requires, explicit_returns, guess=True): else: requires_ = requires(requires_) + return requires_ + + +def extract_returns(obj, returns_): + if returns_ is None: + mush_declarations = getattr(obj, '__mush__', {}) + returns_ = mush_declarations.get('returns', None) + if returns_ is None: + annotations = getattr(obj, '__annotations__', {}) + returns_ = annotations.get('return') + if returns_ is None or isinstance(returns_, ReturnsType): pass elif isinstance(returns_, (list, tuple)): @@ -307,7 +316,7 @@ def extract_declarations(obj, explicit_requires, explicit_returns, guess=True): else: returns_ = returns(returns_) - return requires_, returns_ + return returns_ WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) diff --git a/mush/runner.py b/mush/runner.py index 0ca0740..c88953f 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -1,6 +1,6 @@ from .callpoints import CallPoint from .context import Context, ContextError -from .declarations import extract_declarations +from .declarations import extract_requires, extract_returns from .markers import not_specified from .modifier import Modifier from .plug import Plug @@ -170,8 +170,9 @@ def replace(self, original, replacement, requires=None, returns=None): while point: if point.obj is original: - new_requirements = extract_declarations( - replacement, requires, returns, guess=False + new_requirements = ( + extract_requires(replacement, requires, guess=False), + extract_returns(replacement, returns) ) if any(new_requirements): diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 613d770..bc3d27f 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -7,12 +7,13 @@ requires, optional, returns, returns_mapping, returns_sequence, returns_result_type, how, item, attr, nothing, - extract_declarations + extract_requires, extract_returns ) def check_extract(obj, expected_rq, expected_rt): - rq, rt = extract_declarations(obj, None, None) + rq = extract_requires(obj, None) + rt = extract_returns(obj, None) compare(rq, expected=expected_rq, strict=True) compare(rt, expected=expected_rt, strict=True) From de63995205347fa644c781d7cc385e97226745b8 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 14 Feb 2020 18:23:19 +0000 Subject: [PATCH 007/159] Allow Context.call to extract requirements. --- mush/context.py | 5 +++-- mush/tests/test_context.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mush/context.py b/mush/context.py index ddfcd3d..7a8a5b9 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,6 +1,6 @@ from collections import deque -from .declarations import how, nothing +from .declarations import how, nothing, extract_requires from .factory import Factory from .markers import missing @@ -94,7 +94,8 @@ def extract(self, obj, requires, returns): self.add(obj, type) return result - def call(self, obj, requires): + def call(self, obj, requires=None): + requires = extract_requires(obj, requires) if isinstance(obj, Factory): self.add(obj, obj.returns.args[0]) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 1074561..76d45b1 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -228,6 +228,21 @@ def foo(x=1): result = context.call(foo, requires(item(optional('foo'), 'bar'))) compare(result, 'baz') + def test_call_extract_requirements(self): + def foo(param): + return param + context = Context() + context.add('bar', 'param') + result = context.call(foo) + compare(result, 'bar') + + def test_call_extract_no_requirements(self): + def foo(): + pass + context = Context() + result = context.call(foo) + compare(result, expected=None) + def test_returns_single(self): def foo(): return 'bar' From 2c4e90d62a59c6c3e22431ba6b243cf2e15147bb Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 09:29:28 +0000 Subject: [PATCH 008/159] Rename stuff in Context.add to closer match current intentions. --- mush/context.py | 13 ++++++------- mush/tests/test_context.py | 6 +++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mush/context.py b/mush/context.py index 7a8a5b9..33203f2 100644 --- a/mush/context.py +++ b/mush/context.py @@ -64,21 +64,20 @@ class Context: def __init__(self): self._store = {} - def add(self, it, type): + def add(self, resource, provides): """ Add a resource to the context. - Optionally specify the type to use for the object rather than - the type of the object itself. + Optionally specify what the resource provides. """ - if type is NONE_TYPE: + if provides is NONE_TYPE: raise ValueError('Cannot add None to context') - if type in self._store: + if provides in self._store: raise ContextError('Context already contains %r' % ( - type + provides )) - self._store[type] = it + self._store[provides] = resource def __repr__(self): bits = [] diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 76d45b1..527dc0e 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -35,7 +35,7 @@ def test_simple(self): def test_type_as_string(self): obj = TheType() context = Context() - context.add(obj, type='my label') + context.add(obj, provides='my label') expected = ("\n" @@ -68,9 +68,9 @@ def test_clash_string_type(self): obj1 = TheType() obj2 = TheType() context = Context() - context.add(obj1, type='my label') + context.add(obj1, provides='my label') with ShouldRaise(ContextError("Context already contains 'my label'")): - context.add(obj2, type='my label') + context.add(obj2, provides='my label') def test_add_none(self): context = Context() From 7ae6fbe7286d3ccdde9eedc1fe91d353318f1b50 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 09:29:51 +0000 Subject: [PATCH 009/159] move the default requires/returns down into the context from the callpoint --- mush/callpoints.py | 2 -- mush/declarations.py | 10 ++++---- mush/runner.py | 4 ++-- mush/tests/conftest.py | 36 ----------------------------- mush/tests/test_declarations.py | 30 ++++++++++++------------ mush/tests/test_declarations_py3.py | 16 ++++++------- 6 files changed, 30 insertions(+), 68 deletions(-) delete mode 100644 mush/tests/conftest.py diff --git a/mush/callpoints.py b/mush/callpoints.py index 6b74d51..21677e3 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -11,8 +11,6 @@ def __init__(self, obj, requires=None, returns=None, lazy=None): requires = extract_requires(obj, requires) returns = extract_returns(obj, returns) lazy = lazy or getattr(obj, '__mush__', {}).get('lazy') - requires = requires or nothing - returns = returns or result_type if lazy: obj = Factory(obj, requires, returns) requires = returns = nothing diff --git a/mush/declarations.py b/mush/declarations.py index 0549926..26312b5 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -276,7 +276,7 @@ def guess_requirements(obj): return requires(*args, **kw) -def extract_requires(obj, requires_, guess=True): +def extract_requires(obj, requires_, default=nothing): if requires_ is None: mush_declarations = getattr(obj, '__mush__', {}) requires_ = mush_declarations.get('requires', None) @@ -289,7 +289,7 @@ def extract_requires(obj, requires_, guess=True): if isinstance(requires_, requires): pass elif requires_ is None: - if guess: + if default is not None: requires_ = guess_requirements(obj) elif isinstance(requires_, (list, tuple)): requires_ = requires(*requires_) @@ -298,10 +298,10 @@ def extract_requires(obj, requires_, guess=True): else: requires_ = requires(requires_) - return requires_ + return requires_ or default -def extract_returns(obj, returns_): +def extract_returns(obj, returns_, default=result_type): if returns_ is None: mush_declarations = getattr(obj, '__mush__', {}) returns_ = mush_declarations.get('returns', None) @@ -316,7 +316,7 @@ def extract_returns(obj, returns_): else: returns_ = returns(returns_) - return returns_ + return returns_ or default WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) diff --git a/mush/runner.py b/mush/runner.py index c88953f..469d0bb 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -171,8 +171,8 @@ def replace(self, original, replacement, requires=None, returns=None): if point.obj is original: new_requirements = ( - extract_requires(replacement, requires, guess=False), - extract_returns(replacement, returns) + extract_requires(replacement, requires, default=None), + extract_returns(replacement, returns, default=None) ) if any(new_requirements): diff --git a/mush/tests/conftest.py b/mush/tests/conftest.py deleted file mode 100644 index df9b0e0..0000000 --- a/mush/tests/conftest.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest -from mock import Mock -from testfixtures.comparison import register, compare_simple - -from mush import returns, requires -from mush.declarations import how - - -@pytest.fixture() -def mock(): - return Mock() - - -def compare_requires(x, y, context): - diff_args = context.different(x.args, y.args, '.args') - diff_kw = context.different(x.kw, y.kw, '.args') - if diff_args or diff_kw: # pragma: no cover - return compare_simple(x, y, context) - - -def compare_returns(x, y, context): - diff_args = context.different(x.args, y.args, '.args') - if diff_args: # pragma: no cover - return compare_simple(x, y, context) - - -def compare_how(x, y, context): - diff_args = context.different(x.type, y.type, '.type') - diff_names = context.different(x.type, y.type, '.names') - if diff_args or diff_names: # pragma: no cover - return compare_simple(x, y, context) - - -register(requires, compare_requires) -register(returns, compare_returns) -register(how, compare_how) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index bc3d27f..0f98024 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -7,8 +7,8 @@ requires, optional, returns, returns_mapping, returns_sequence, returns_result_type, how, item, attr, nothing, - extract_requires, extract_returns -) + extract_requires, extract_returns, + result_type) def check_extract(obj, expected_rq, expected_rt): @@ -230,14 +230,14 @@ def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, expected_rq=requires('a', optional('b')), - expected_rt=None) + expected_rt=result_type) def test_default_requirements_for_class(self): class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, expected_rq=requires('a', optional('b')), - expected_rt=None) + expected_rt=result_type) def test_extract_from_partial(self): def foo(x, y, z, a=None): pass @@ -245,7 +245,7 @@ def foo(x, y, z, a=None): pass check_extract( p, expected_rq=requires(z='z', a=optional('a'), y=optional('y')), - expected_rt=None + expected_rt=result_type ) def test_extract_from_partial_default_not_in_partial(self): @@ -254,7 +254,7 @@ def foo(a=None): pass check_extract( p, expected_rq=requires(optional('a')), - expected_rt=None + expected_rt=result_type ) def test_extract_from_partial_default_in_partial_arg(self): @@ -263,8 +263,8 @@ def foo(a=None): pass check_extract( p, # since a is already bound by the partial: - expected_rq=None, - expected_rt=None + expected_rq=nothing, + expected_rt=result_type ) def test_extract_from_partial_default_in_partial_kw(self): @@ -273,7 +273,7 @@ def foo(a=None): pass check_extract( p, expected_rq=requires(a=optional('a')), - expected_rt=None + expected_rt=result_type ) def test_extract_from_partial_required_in_partial_arg(self): @@ -282,8 +282,8 @@ def foo(a): pass check_extract( p, # since a is already bound by the partial: - expected_rq=None, - expected_rt=None + expected_rq=nothing, + expected_rt=result_type ) def test_extract_from_partial_required_in_partial_kw(self): @@ -292,7 +292,7 @@ def foo(a): pass check_extract( p, expected_rq=requires(a=optional('a')), - expected_rt=None + expected_rt=result_type ) def test_extract_from_partial_plus_one_default_not_in_partial(self): @@ -301,7 +301,7 @@ def foo(b, a=None): pass check_extract( p, expected_rq=requires('b', optional('a')), - expected_rt=None + expected_rt=result_type ) def test_extract_from_partial_plus_one_required_in_partial_arg(self): @@ -311,7 +311,7 @@ def foo(b, a): pass p, # since b is already bound: expected_rq=requires('a'), - expected_rt=None + expected_rt=result_type ) def test_extract_from_partial_plus_one_required_in_partial_kw(self): @@ -320,5 +320,5 @@ def foo(b, a): pass check_extract( p, expected_rq=requires('b', a=optional('a')), - expected_rt=None + expected_rt=result_type ) diff --git a/mush/tests/test_declarations_py3.py b/mush/tests/test_declarations_py3.py index 12d3cec..82a3265 100644 --- a/mush/tests/test_declarations_py3.py +++ b/mush/tests/test_declarations_py3.py @@ -2,8 +2,8 @@ from mush.declarations import ( requires, returns, returns_mapping, returns_sequence, item, update_wrapper, - optional -) + optional, + nothing, result_type) from mush.tests.test_declarations import check_extract @@ -19,12 +19,12 @@ def test_requires_only(self): def foo(a: 'foo'): pass check_extract(foo, expected_rq=requires(a='foo'), - expected_rt=None) + expected_rt=result_type) def test_returns_only(self): def foo() -> 'bar': pass check_extract(foo, - expected_rq=None, + expected_rq=nothing, expected_rt=returns('bar')) def test_extract_from_decorated_class(self, mock): @@ -59,14 +59,14 @@ def test_returns_mapping(self): rt = returns_mapping() def foo() -> rt: pass check_extract(foo, - expected_rq=None, + expected_rq=nothing, expected_rt=rt) def test_returns_sequence(self): rt = returns_sequence() def foo() -> rt: pass check_extract(foo, - expected_rq=None, + expected_rq=nothing, expected_rt=rt) def test_how_instance_in_annotations(self): @@ -74,7 +74,7 @@ def test_how_instance_in_annotations(self): def foo(a: how): pass check_extract(foo, expected_rq=requires(a=how), - expected_rt=None) + expected_rt=result_type) def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass @@ -83,4 +83,4 @@ def foo(a, b=1, *, c, d=None): pass optional('b'), c='c', d=optional('d')), - expected_rt=None) + expected_rt=result_type) From 670767d17d0e8928448dcb0e5cbf897809ae0245 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sat, 15 Feb 2020 16:30:45 +0000 Subject: [PATCH 010/159] unused import --- mush/tests/test_callpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 955bb9e..8dd0995 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -1,7 +1,7 @@ from functools import update_wrapper from unittest import TestCase -from mock import Mock, call +from mock import Mock from testfixtures import compare from mush.callpoints import CallPoint From 29b42c6d5cd6ce7d3c2a18080aee4b7108c71b16 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 07:50:54 +0000 Subject: [PATCH 011/159] No longer needed with latest testfixtures --- mush/tests/test_declarations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 0f98024..f762c8c 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -14,7 +14,7 @@ def check_extract(obj, expected_rq, expected_rt): rq = extract_requires(obj, None) rt = extract_returns(obj, None) - compare(rq, expected=expected_rq, strict=True) + compare(rq, expected=expected_rq, strict=True, ignore_attributes={Requirement: ['ops']}) compare(rt, expected=expected_rt, strict=True) From 490e409a04df0f7783d7219dc1615f4731985823 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 07:51:59 +0000 Subject: [PATCH 012/159] unused --- mush/tests/test_declarations_py3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mush/tests/test_declarations_py3.py b/mush/tests/test_declarations_py3.py index 82a3265..e1438a6 100644 --- a/mush/tests/test_declarations_py3.py +++ b/mush/tests/test_declarations_py3.py @@ -27,7 +27,7 @@ def foo() -> 'bar': pass expected_rq=nothing, expected_rt=returns('bar')) - def test_extract_from_decorated_class(self, mock): + def test_extract_from_decorated_class(self): class Wrapper(object): def __init__(self, func): From 9537b95e9f549c562a8fa922bc771689ceedf742 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 07:57:48 +0000 Subject: [PATCH 013/159] Factor out a requirement class as a step to more flexible requirement extraction and processing. --- mush/context.py | 20 +++++-------- mush/declarations.py | 50 ++++++++++++++++++--------------- mush/tests/test_declarations.py | 34 +++++++++++----------- 3 files changed, 51 insertions(+), 53 deletions(-) diff --git a/mush/context.py b/mush/context.py index 33203f2..12e66d1 100644 --- a/mush/context.py +++ b/mush/context.py @@ -103,35 +103,29 @@ def call(self, obj, requires=None): args = [] kw = {} - for name, requirement in requires: + for requirement in requires.resolvers: o = self.get(requirement) if o is nothing: pass - elif name is None: + elif requirement.target is None: args.append(o) else: - kw[name] = o + kw[requirement.target] = o return obj(*args, **kw) def get(self, requirement): - spec = requirement - ops = deque() - - while isinstance(spec, how): - ops.appendleft(spec.process) - spec = spec.type - - o = self._store.get(spec, missing) + # extract requirement? + o = self._store.get(requirement.base, missing) if isinstance(o, Factory): o = o(self) - for op in ops: + for op in requirement.ops: o = op(o) if o is nothing: break if o is missing: - raise ContextError('No %s in context' % repr(requirement)) + raise ContextError('No %s in context' % repr(requirement.spec)) return o diff --git a/mush/declarations.py b/mush/declarations.py index 26312b5..308bdeb 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,9 +1,11 @@ +from collections import deque from functools import ( WRAPPER_UPDATES, WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, update_wrapper as functools_update_wrapper, ) from inspect import signature +from typing import List from .markers import missing @@ -18,6 +20,25 @@ def set_mush(obj, key, value): obj.__mush__[key] = value +class Requirement: + + def __init__(self, source, target=None): + self.target = target + self.spec = source + self.ops = deque() + while isinstance(source, how): + self.ops.appendleft(source.process) + source = source.type + self.base = source + + def __repr__(self): + requirement_repr = name_or_repr(self.spec) + if self.target is None: + return requirement_repr + else: + return f'{self.target}={requirement_repr}' + + class requires(object): """ Represents requirements for a particular callable. @@ -33,31 +54,14 @@ class requires(object): def __init__(self, *args, **kw): check_type(*args) check_type(*kw.values()) - self.args = args - self.kw = kw - - def __iter__(self): - """ - When iterated over, yields tuples representing individual - types required by arguments or keyword parameters in the form - ``(keyword_name, decorated_type)``. - - If the keyword name is ``None``, then the type is for - a positional argument. - """ - for arg in self.args: - yield None, arg - for k, v in self.kw.items(): - yield k, v + self.resolvers = [] + for arg in args: + self.resolvers.append(Requirement(arg)) + for k, v in kw.items(): + self.resolvers.append(Requirement(v, target=k)) def __repr__(self): - bits = [] - for arg in self.args: - bits.append(name_or_repr(arg)) - for k, v in sorted(self.kw.items()): - bits.append('%s=%s' % (k, name_or_repr(v))) - txt = 'requires(%s)' % ', '.join(bits) - return txt + return f"requires({', '.join(repr(r) for r in self.resolvers)})" def __call__(self, obj): set_mush(obj, 'requires', self) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index f762c8c..34590a9 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -1,14 +1,14 @@ from functools import partial from unittest import TestCase from mock import Mock -from testfixtures import compare, generator, ShouldRaise +from testfixtures import compare, ShouldRaise from mush.markers import missing from mush.declarations import ( requires, optional, returns, returns_mapping, returns_sequence, returns_result_type, how, item, attr, nothing, extract_requires, extract_returns, - result_type) + result_type, Requirement) def check_extract(obj, expected_rq, expected_rt): @@ -29,27 +29,27 @@ class TestRequires(TestCase): def test_empty(self): r = requires() compare(repr(r), 'requires()') - compare(generator(), r) + compare(r.resolvers, []) def test_types(self): r = requires(Type1, Type2, x=Type3, y=Type4) compare(repr(r), 'requires(Type1, Type2, x=Type3, y=Type4)') - compare({ - (None, Type1), - (None, Type2), - ('x', Type3), - ('y', Type4), - }, set(r)) + compare(r.resolvers, expected=[ + Requirement(Type1), + Requirement(Type2), + Requirement(Type3, target='x'), + Requirement(Type4, target='y'), + ]) def test_strings(self): r = requires('1', '2', x='3', y='4') compare(repr(r), "requires('1', '2', x='3', y='4')") - compare({ - (None, '1'), - (None, '2'), - ('x', '3'), - ('y', '4'), - }, set(r)) + compare(r.resolvers, expected=[ + Requirement('1'), + Requirement('2'), + Requirement('3', target='x'), + Requirement('4', target='y'), + ]) def test_tuple_arg(self): with ShouldRaise(TypeError("('1', '2') is not a type or label")): @@ -64,7 +64,7 @@ def test_decorator_paranoid(self): def foo(): return 'bar' - compare(set(foo.__mush__['requires']), {(None, Type1)}) + compare(foo.__mush__['requires'].resolvers, expected=[Requirement(Type1)]) compare(foo(), 'bar') @@ -244,7 +244,7 @@ def foo(x, y, z, a=None): pass p = partial(foo, 1, y=2) check_extract( p, - expected_rq=requires(z='z', a=optional('a'), y=optional('y')), + expected_rq=requires(y=optional('y'), z='z', a=optional('a'), ), expected_rt=result_type ) From c982e238e302c9428d89b2effa4ffe8a226bd55f Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 09:25:35 +0000 Subject: [PATCH 014/159] simplify requires class to a list subclass --- mush/context.py | 2 +- mush/declarations.py | 10 ++++++---- mush/tests/test_declarations.py | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mush/context.py b/mush/context.py index 12e66d1..06e2a74 100644 --- a/mush/context.py +++ b/mush/context.py @@ -103,7 +103,7 @@ def call(self, obj, requires=None): args = [] kw = {} - for requirement in requires.resolvers: + for requirement in requires: o = self.get(requirement) if o is nothing: pass diff --git a/mush/declarations.py b/mush/declarations.py index 308bdeb..8b9ae0a 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -39,7 +39,7 @@ def __repr__(self): return f'{self.target}={requirement_repr}' -class requires(object): +class requires(list): """ Represents requirements for a particular callable. @@ -52,16 +52,17 @@ class requires(object): """ def __init__(self, *args, **kw): + super(requires, self).__init__() check_type(*args) check_type(*kw.values()) self.resolvers = [] for arg in args: - self.resolvers.append(Requirement(arg)) + self.append(Requirement(arg)) for k, v in kw.items(): - self.resolvers.append(Requirement(v, target=k)) + self.append(Requirement(v, target=k)) def __repr__(self): - return f"requires({', '.join(repr(r) for r in self.resolvers)})" + return f"requires({', '.join(repr(r) for r in self)})" def __call__(self, obj): set_mush(obj, 'requires', self) @@ -251,6 +252,7 @@ class Nothing(requires, returns): def process(self, result): return () + #: A singleton that be used as a :class:`~mush.requires` to indicate that a #: callable has no required arguments or as a :class:`~mush.returns` to indicate #: that anything returned from a callable should be ignored. diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 34590a9..261b0c2 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -34,7 +34,7 @@ def test_empty(self): def test_types(self): r = requires(Type1, Type2, x=Type3, y=Type4) compare(repr(r), 'requires(Type1, Type2, x=Type3, y=Type4)') - compare(r.resolvers, expected=[ + compare(r, expected=[ Requirement(Type1), Requirement(Type2), Requirement(Type3, target='x'), @@ -44,7 +44,7 @@ def test_types(self): def test_strings(self): r = requires('1', '2', x='3', y='4') compare(repr(r), "requires('1', '2', x='3', y='4')") - compare(r.resolvers, expected=[ + compare(r, expected=[ Requirement('1'), Requirement('2'), Requirement('3', target='x'), @@ -64,7 +64,7 @@ def test_decorator_paranoid(self): def foo(): return 'bar' - compare(foo.__mush__['requires'].resolvers, expected=[Requirement(Type1)]) + compare(foo.__mush__['requires'], expected=[Requirement(Type1)]) compare(foo(), 'bar') From 503074e225054cdda7c7d0483868b83ae7ed0b80 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 09:37:50 +0000 Subject: [PATCH 015/159] move annotation declarations tests back into test_declarations --- mush/tests/test_declarations.py | 83 +++++++++++++++++++++++++++- mush/tests/test_declarations_py3.py | 86 ----------------------------- 2 files changed, 82 insertions(+), 87 deletions(-) delete mode 100644 mush/tests/test_declarations_py3.py diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 261b0c2..437183e 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -8,7 +8,9 @@ returns_mapping, returns_sequence, returns_result_type, how, item, attr, nothing, extract_requires, extract_returns, - result_type, Requirement) + result_type, Requirement, + update_wrapper +) def check_extract(obj, expected_rq, expected_rt): @@ -322,3 +324,82 @@ def foo(b, a): pass expected_rq=requires('b', a=optional('a')), expected_rt=result_type ) + + +class TestExtractDeclarationsFromTypeAnnotations(object): + + def test_extract_from_annotations(self): + def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass + check_extract(foo, + expected_rq=requires(a='foo', c='bar'), + expected_rt=returns('bar')) + + def test_requires_only(self): + def foo(a: 'foo'): pass + check_extract(foo, + expected_rq=requires(a='foo'), + expected_rt=result_type) + + def test_returns_only(self): + def foo() -> 'bar': pass + check_extract(foo, + expected_rq=nothing, + expected_rt=returns('bar')) + + def test_extract_from_decorated_class(self): + + class Wrapper(object): + def __init__(self, func): + self.func = func + def __call__(self): + return 'the '+self.func() + + def my_dec(func): + return update_wrapper(Wrapper(func), func) + + @my_dec + def foo(a: 'foo'=None) -> 'bar': + return 'answer' + + compare(foo(), expected='the answer') + check_extract(foo, + expected_rq=requires(a='foo'), + expected_rt=returns('bar')) + + def test_decorator_trumps_annotations(self): + @requires('foo') + @returns('bar') + def foo(a: 'x') -> 'y': pass + check_extract(foo, + expected_rq=requires('foo'), + expected_rt=returns('bar')) + + def test_returns_mapping(self): + rt = returns_mapping() + def foo() -> rt: pass + check_extract(foo, + expected_rq=nothing, + expected_rt=rt) + + def test_returns_sequence(self): + rt = returns_sequence() + def foo() -> rt: pass + check_extract(foo, + expected_rq=nothing, + expected_rt=rt) + + def test_how_instance_in_annotations(self): + how = item('config', 'db_url') + def foo(a: how): pass + check_extract(foo, + expected_rq=requires(a=how), + expected_rt=result_type) + + def test_default_requirements(self): + def foo(a, b=1, *, c, d=None): pass + check_extract(foo, + expected_rq=requires('a', + optional('b'), + c='c', + d=optional('d')), + expected_rt=result_type) diff --git a/mush/tests/test_declarations_py3.py b/mush/tests/test_declarations_py3.py deleted file mode 100644 index e1438a6..0000000 --- a/mush/tests/test_declarations_py3.py +++ /dev/null @@ -1,86 +0,0 @@ -from testfixtures import compare - -from mush.declarations import ( - requires, returns, returns_mapping, returns_sequence, item, update_wrapper, - optional, - nothing, result_type) -from mush.tests.test_declarations import check_extract - - -class TestExtractDeclarations(object): - - def test_extract_from_annotations(self): - def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass - check_extract(foo, - expected_rq=requires(a='foo', c='bar'), - expected_rt=returns('bar')) - - def test_requires_only(self): - def foo(a: 'foo'): pass - check_extract(foo, - expected_rq=requires(a='foo'), - expected_rt=result_type) - - def test_returns_only(self): - def foo() -> 'bar': pass - check_extract(foo, - expected_rq=nothing, - expected_rt=returns('bar')) - - def test_extract_from_decorated_class(self): - - class Wrapper(object): - def __init__(self, func): - self.func = func - def __call__(self): - return 'the '+self.func() - - def my_dec(func): - return update_wrapper(Wrapper(func), func) - - @my_dec - def foo(a: 'foo'=None) -> 'bar': - return 'answer' - - compare(foo(), expected='the answer') - check_extract(foo, - expected_rq=requires(a='foo'), - expected_rt=returns('bar')) - - def test_decorator_trumps_annotations(self): - @requires('foo') - @returns('bar') - def foo(a: 'x') -> 'y': pass - check_extract(foo, - expected_rq=requires('foo'), - expected_rt=returns('bar')) - - def test_returns_mapping(self): - rt = returns_mapping() - def foo() -> rt: pass - check_extract(foo, - expected_rq=nothing, - expected_rt=rt) - - def test_returns_sequence(self): - rt = returns_sequence() - def foo() -> rt: pass - check_extract(foo, - expected_rq=nothing, - expected_rt=rt) - - def test_how_instance_in_annotations(self): - how = item('config', 'db_url') - def foo(a: how): pass - check_extract(foo, - expected_rq=requires(a=how), - expected_rt=result_type) - - def test_default_requirements(self): - def foo(a, b=1, *, c, d=None): pass - check_extract(foo, - expected_rq=requires('a', - optional('b'), - c='c', - d=optional('d')), - expected_rt=result_type) From f7bc6f9f347ee9725f3646c7f8d1ad131be5032d Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 10:57:34 +0000 Subject: [PATCH 016/159] Make provides an optional argument to Context.add --- mush/context.py | 11 +++++++---- mush/tests/test_context.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mush/context.py b/mush/context.py index 06e2a74..d35b5ec 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,6 +1,6 @@ -from collections import deque +from typing import Optional, Any, Union, Type -from .declarations import how, nothing, extract_requires +from .declarations import nothing, extract_requires from .factory import Factory from .markers import missing @@ -64,13 +64,16 @@ class Context: def __init__(self): self._store = {} - def add(self, resource, provides): + def add(self, + resource: Any, + provides: Optional[Union[Type, str]] = None): """ Add a resource to the context. Optionally specify what the resource provides. """ - + if provides is None: + provides = type(resource) if provides is NONE_TYPE: raise ValueError('Cannot add None to context') if provides in self._store: diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 527dc0e..8c68487 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -21,7 +21,7 @@ class TestContext(TestCase): def test_simple(self): obj = TheType() context = Context() - context.add(obj, TheType) + context.add(obj) self.assertTrue(context._store[TheType] is obj) expected = ( @@ -48,7 +48,7 @@ def test_explicit_type(self): class T2(object): pass obj = TheType() context = Context() - context.add(obj, T2) + context.add(obj, provides=T2) self.assertTrue(context._store[T2] is obj) expected = ("\n" From 6e4a333b8efe85338e27d97098d35c44f55c4af6 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 14:29:13 +0000 Subject: [PATCH 017/159] Add support for context-side resolvers --- mush/callpoints.py | 2 +- mush/context.py | 38 ++++++++++++++------ mush/{factory.py => resolvers.py} | 14 ++++++++ mush/tests/test_context.py | 58 +++++++++++++++++++++++-------- mush/tests/test_factory.py | 12 ------- mush/tests/test_resolver.py | 21 +++++++++++ 6 files changed, 107 insertions(+), 38 deletions(-) rename mush/{factory.py => resolvers.py} (66%) delete mode 100644 mush/tests/test_factory.py create mode 100644 mush/tests/test_resolver.py diff --git a/mush/callpoints.py b/mush/callpoints.py index 21677e3..cbc2381 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,5 +1,5 @@ from .declarations import result_type, nothing, extract_requires, extract_returns -from .factory import Factory +from .resolvers import Factory class CallPoint(object): diff --git a/mush/context.py b/mush/context.py index d35b5ec..aac0fa7 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,8 +1,9 @@ -from typing import Optional, Any, Union, Type +from typing import Optional, Any, Union, Type, Callable, NewType -from .declarations import nothing, extract_requires -from .factory import Factory +from mush.resolvers import ValueResolver +from .declarations import nothing, extract_requires, Requirement from .markers import missing +from .resolvers import Factory NONE_TYPE = type(None) @@ -58,6 +59,11 @@ def type_key(type_tuple): return type.__name__ +ResourceKey = NewType('ResourceKey', Union[Type, str]) +ResourceValue = NewType('ResourceValue', Any) +Resolver = Callable[['Context'], ResourceValue] + + class Context: "Stores resources for a particular run." @@ -65,13 +71,18 @@ def __init__(self): self._store = {} def add(self, - resource: Any, - provides: Optional[Union[Type, str]] = None): + resource: Optional[ResourceValue] = None, + provides: Optional[ResourceKey] = None, + resolver: Optional[Resolver] = None): """ Add a resource to the context. Optionally specify what the resource provides. """ + if resolver is not None and (provides is None or resource is not None): + if resource is not None: + raise TypeError('resource cannot be supplied when using a resolver') + raise TypeError('Both provides and resolver must be supplied') if provides is None: provides = type(resource) if provides is NONE_TYPE: @@ -80,7 +91,9 @@ def add(self, raise ContextError('Context already contains %r' % ( provides )) - self._store[provides] = resource + if resolver is None: + resolver = ValueResolver(resource) + self._store[provides] = resolver def __repr__(self): bits = [] @@ -117,11 +130,14 @@ def call(self, obj, requires=None): return obj(*args, **kw) - def get(self, requirement): - # extract requirement? - o = self._store.get(requirement.base, missing) - if isinstance(o, Factory): - o = o(self) + def get(self, requirement: Requirement): + resolver = self._store.get(requirement.base, missing) + if resolver is missing: + o = missing + else: + o = resolver(self) + if isinstance(o, Factory): + o = o(self) for op in requirement.ops: o = op(o) diff --git a/mush/factory.py b/mush/resolvers.py similarity index 66% rename from mush/factory.py rename to mush/resolvers.py index 5b9d03c..d6fe5b2 100644 --- a/mush/factory.py +++ b/mush/resolvers.py @@ -1,6 +1,20 @@ from .declarations import returns as returns_declaration +class ValueResolver: + + __slots__ = ['value'] + + def __init__(self, value): + self.value = value + + def __call__(self, context): + return self.value + + def __repr__(self): + return repr(self.value) + + class Factory(object): value = None diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 8c68487..d896e37 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -9,6 +9,7 @@ nothing, requires, optional, item, attr, returns, returns_mapping ) +from mush.resolvers import ValueResolver class TheType(object): @@ -23,7 +24,7 @@ def test_simple(self): context = Context() context.add(obj) - self.assertTrue(context._store[TheType] is obj) + compare(context._store, expected={TheType: ValueResolver(obj)}) expected = ( ": \n" @@ -40,7 +41,7 @@ def test_type_as_string(self): expected = ("\n" "}>") - self.assertTrue(context._store['my label'] is obj) + compare(context._store, expected={'my label': ValueResolver(obj)}) self.assertEqual(repr(context), expected) self.assertEqual(str(context), expected) @@ -49,13 +50,40 @@ class T2(object): pass obj = TheType() context = Context() context.add(obj, provides=T2) - self.assertTrue(context._store[T2] is obj) + compare(context._store, expected={T2: ValueResolver(obj)}) expected = ("\n" "}>") compare(repr(context), expected) compare(str(context), expected) + + def test_no_resolver_or_provides(self): + context = Context() + with ShouldRaise(ValueError('Cannot add None to context')): + context.add() + compare(context._store, expected={}) + def test_resolver_but_no_provides(self): + context = Context() + with ShouldRaise(TypeError('Both provides and resolver must be supplied')): + context.add(resolver=lambda: None) + compare(context._store, expected={}) + + def test_resolver(self): + m = Mock() + context = Context() + context.add(provides='foo', resolver=m) + m.assert_not_called() + assert context.get(requires('foo')[0]) is m.return_value + m.assert_called_with(context) + + def test_resolver_and_resource(self): + m = Mock() + context = Context() + with ShouldRaise(TypeError('resource cannot be supplied when using a resolver')): + context.add('bar', provides='foo', resolver=m) + compare(context._store, expected={}) + def test_clash(self): obj1 = TheType() obj2 = TheType() @@ -75,12 +103,12 @@ def test_clash_string_type(self): def test_add_none(self): context = Context() with ShouldRaise(ValueError('Cannot add None to context')): - context.add(None, None.__class__) + context.add(None, type(None)) def test_add_none_with_type(self): context = Context() context.add(None, TheType) - self.assertTrue(context._store[TheType] is None) + compare(context._store, expected={TheType: ValueResolver(None)}) def test_call_basic(self): def foo(): @@ -96,7 +124,7 @@ def foo(obj): context.add('bar', 'baz') result = context.call(foo, requires('baz')) compare(result, 'bar') - compare({'baz': 'bar'}, context._store) + compare({'baz': ValueResolver('bar')}, actual=context._store) def test_call_requires_type(self): def foo(obj): @@ -105,7 +133,7 @@ def foo(obj): context.add('bar', TheType) result = context.call(foo, requires(TheType)) compare(result, 'bar') - compare({TheType: 'bar'}, context._store) + compare({TheType: ValueResolver('bar')}, actual=context._store) def test_call_requires_missing(self): def foo(obj): return obj @@ -142,8 +170,8 @@ def foo(x, y): context.add('bar', 'baz') result = context.call(foo, requires(y='baz', x=TheType)) compare(result, ('foo', 'bar')) - compare({TheType: 'foo', - 'baz': 'bar'}, + compare({TheType: ValueResolver('foo'), + 'baz': ValueResolver('bar')}, actual=context._store) def test_call_requires_optional_present(self): @@ -153,7 +181,7 @@ def foo(x=1): context.add(2, TheType) result = context.call(foo, requires(optional(TheType))) compare(result, 2) - compare({TheType: 2}, context._store) + compare({TheType: ValueResolver(2)}, actual=context._store) def test_call_requires_optional_ContextError(self): def foo(x=1): @@ -169,7 +197,7 @@ def foo(x=1): context.add(2, 'foo') result = context.call(foo, requires(optional('foo'))) compare(result, 2) - compare({'foo': 2}, context._store) + compare({'foo': ValueResolver(2)}, actual=context._store) def test_call_requires_item(self): def foo(x): @@ -249,7 +277,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns(TheType)) compare(result, 'bar') - compare({TheType: 'bar'}, context._store) + compare({TheType: ValueResolver('bar')}, actual=context._store) def test_returns_sequence(self): def foo(): @@ -257,7 +285,8 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns('foo', 'bar')) compare(result, (1, 2)) - compare({'foo': 1, 'bar': 2}, context._store) + compare({'foo': ValueResolver(1), 'bar': ValueResolver(2)}, + actual=context._store) def test_returns_mapping(self): def foo(): @@ -265,7 +294,8 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns_mapping()) compare(result, {'foo': 1, 'bar': 2}) - compare({'foo': 1, 'bar': 2}, context._store) + compare({'foo': ValueResolver(1), 'bar': ValueResolver(2)}, + actual=context._store) def test_ignore_return(self): def foo(): diff --git a/mush/tests/test_factory.py b/mush/tests/test_factory.py deleted file mode 100644 index 3657132..0000000 --- a/mush/tests/test_factory.py +++ /dev/null @@ -1,12 +0,0 @@ -from testfixtures import compare - -from mush import returns -from mush.factory import Factory -from mush.markers import Marker - -foo = Marker('foo') - - -def test_repr(): - f = Factory(foo, None, returns('foo')) - compare(repr(f), expected='>') diff --git a/mush/tests/test_resolver.py b/mush/tests/test_resolver.py new file mode 100644 index 0000000..53cf561 --- /dev/null +++ b/mush/tests/test_resolver.py @@ -0,0 +1,21 @@ +from testfixtures import compare + +from mush import returns +from mush.resolvers import Factory, ValueResolver +from mush.markers import Marker + +foo = Marker('foo') + + +class TestValueResolver: + + def test_repr(self): + f = ValueResolver(foo) + compare(repr(f), expected='') + + +class TestFactory: + + def test_repr(self): + f = Factory(foo, None, returns('foo')) + compare(repr(f), expected='>') From 23cd986bd4d806f74f90a77f3a6e3b4da4f68b9e Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 14:40:57 +0000 Subject: [PATCH 018/159] Add support for explicit removal from a context. --- mush/context.py | 15 ++++++++++++--- mush/tests/test_context.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/mush/context.py b/mush/context.py index aac0fa7..d67d2bd 100644 --- a/mush/context.py +++ b/mush/context.py @@ -88,13 +88,22 @@ def add(self, if provides is NONE_TYPE: raise ValueError('Cannot add None to context') if provides in self._store: - raise ContextError('Context already contains %r' % ( - provides - )) + raise ContextError(f'Context already contains {provides!r}') if resolver is None: resolver = ValueResolver(resource) self._store[provides] = resolver + def remove(self, provides: ResourceKey, *, strict: bool = True): + """ + Remove the specified resource key from the context. + + If ``strict``, then a :class:`ContextError` will be raised if the + specified resource is not present in the context. + """ + if strict and provides not in self._store: + raise ContextError(f'Context does not contain {provides!r}') + self._store.pop(provides, None) + def __repr__(self): bits = [] for type, value in sorted(self._store.items(), key=type_key): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index d896e37..1643362 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -311,3 +311,20 @@ def foo(): pass result = context.extract(foo, nothing, nothing) compare(result, expected=None) compare(context._store, expected={}) + + def test_remove(self): + context = Context() + context.add('foo') + context.remove(str) + compare(context._store, expected={}) + + def test_remove_not_there_strict(self): + context = Context() + with ShouldRaise(ContextError("Context does not contain 'foo'")): + context.remove('foo') + compare(context._store, expected={}) + + def test_remove_not_there_not_strict(self): + context = Context() + context.remove('foo', strict=False) + compare(context._store, expected={}) From f90c099f0295c30fa8c3df2b72ce7e8ea50ef042 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 15:11:29 +0000 Subject: [PATCH 019/159] Allow context to be obtained from itself. --- mush/context.py | 5 ++++- mush/tests/test_context.py | 6 ++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mush/context.py b/mush/context.py index d67d2bd..b8d9f79 100644 --- a/mush/context.py +++ b/mush/context.py @@ -154,6 +154,9 @@ def get(self, requirement: Requirement): break if o is missing: - raise ContextError('No %s in context' % repr(requirement.spec)) + if requirement.base is Context: + o = self + else: + raise ContextError('No %s in context' % repr(requirement.spec)) return o diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 1643362..3d93fb0 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -312,6 +312,12 @@ def foo(): pass compare(result, expected=None) compare(context._store, expected={}) + def test_context_contains_itself(self): + context = Context() + def return_context(context: Context): + return context + assert context.call(return_context) is context + def test_remove(self): context = Context() context.add('foo') From dfeb10a43ca1542afc1b7b782b290198f59dfa8b Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 15:20:26 +0000 Subject: [PATCH 020/159] Make lazy resolution no longer a special case in the Context. --- mush/callpoints.py | 10 ++++++---- mush/context.py | 9 +-------- mush/resolvers.py | 17 ++++++++++++----- mush/tests/test_resolver.py | 6 +++--- mush/tests/test_runner.py | 27 +++++++++++++++++++++++++++ 5 files changed, 49 insertions(+), 20 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index cbc2381..95324a0 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,5 +1,6 @@ -from .declarations import result_type, nothing, extract_requires, extract_returns -from .resolvers import Factory +from .context import Context +from .declarations import nothing, extract_requires, extract_returns +from .resolvers import Lazy class CallPoint(object): @@ -12,8 +13,9 @@ def __init__(self, obj, requires=None, returns=None, lazy=None): returns = extract_returns(obj, returns) lazy = lazy or getattr(obj, '__mush__', {}).get('lazy') if lazy: - obj = Factory(obj, requires, returns) - requires = returns = nothing + obj = Lazy(obj, requires, returns) + requires = requires(Context) + returns = nothing self.obj = obj self.requires = requires self.returns = returns diff --git a/mush/context.py b/mush/context.py index b8d9f79..a90e21a 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,9 +1,8 @@ from typing import Optional, Any, Union, Type, Callable, NewType -from mush.resolvers import ValueResolver from .declarations import nothing, extract_requires, Requirement from .markers import missing -from .resolvers import Factory +from .resolvers import ValueResolver NONE_TYPE = type(None) @@ -121,10 +120,6 @@ def extract(self, obj, requires, returns): def call(self, obj, requires=None): requires = extract_requires(obj, requires) - if isinstance(obj, Factory): - self.add(obj, obj.returns.args[0]) - return - args = [] kw = {} @@ -145,8 +140,6 @@ def get(self, requirement: Requirement): o = missing else: o = resolver(self) - if isinstance(o, Factory): - o = o(self) for op in requirement.ops: o = op(o) diff --git a/mush/resolvers.py b/mush/resolvers.py index d6fe5b2..2475595 100644 --- a/mush/resolvers.py +++ b/mush/resolvers.py @@ -15,16 +15,23 @@ def __repr__(self): return repr(self.value) -class Factory(object): - - value = None +class Lazy(object): def __init__(self, obj, requires, returns): if not (type(returns) is returns_declaration and len(returns.args) == 1): raise TypeError('a single return type must be explicitly specified') self.__wrapped__ = obj self.requires = requires - self.returns = returns + self.provides = returns.args[0] + + def __call__(self, context): + context.add(resolver=self.resolve, provides=self.provides) + + def resolve(self, context): + result = context.call(self.__wrapped__, self.requires) + context.remove(self.provides) + context.add(result, self.provides) + return result def __repr__(self): - return '' % self.__wrapped__ + return '' % self.__wrapped__ diff --git a/mush/tests/test_resolver.py b/mush/tests/test_resolver.py index 53cf561..006c445 100644 --- a/mush/tests/test_resolver.py +++ b/mush/tests/test_resolver.py @@ -1,7 +1,7 @@ from testfixtures import compare from mush import returns -from mush.resolvers import Factory, ValueResolver +from mush.resolvers import Lazy, ValueResolver from mush.markers import Marker foo = Marker('foo') @@ -17,5 +17,5 @@ def test_repr(self): class TestFactory: def test_repr(self): - f = Factory(foo, None, returns('foo')) - compare(repr(f), expected='>') + f = Lazy(foo, None, returns('foo')) + compare(repr(f), expected='>') diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 5591f4e..c540335 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -500,6 +500,33 @@ def job(obj): call.job(t), ], ) + def test_lazy_only_resolved_once(self): + m = Mock() + class T1(object): pass + t = T1() + + def lazy_used(): + m.lazy_used() + return t + + def job1(obj): + m.job1(obj) + + def job2(obj): + m.job2(obj) + + runner = Runner() + runner.add(lazy_used, returns=returns(T1), lazy=True) + runner.add(job1, requires(T1)) + runner.add(job2, requires(T1)) + runner() + + compare(m.mock_calls, expected=[ + call.lazy_used(), + call.job1(t), + call.job2(t), + ], ) + def test_missing_from_context_no_chain(self): class T(object): pass From cbbf31af98b28e49790d5aa6a037b64a371102d5 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 17:36:16 +0000 Subject: [PATCH 021/159] Drop the lazy decorator. Lazy is a property of the runner, not the callable! --- mush/callpoints.py | 3 +-- mush/declarations.py | 9 --------- mush/tests/test_runner.py | 33 ++------------------------------- 3 files changed, 3 insertions(+), 42 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index 95324a0..3782c10 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -8,10 +8,9 @@ class CallPoint(object): next = None previous = None - def __init__(self, obj, requires=None, returns=None, lazy=None): + def __init__(self, obj, requires=None, returns=None, lazy=False): requires = extract_requires(obj, requires) returns = extract_returns(obj, returns) - lazy = lazy or getattr(obj, '__mush__', {}).get('lazy') if lazy: obj = Lazy(obj, requires, returns) requires = requires(Context) diff --git a/mush/declarations.py b/mush/declarations.py index 8b9ae0a..051fe98 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -144,15 +144,6 @@ def __repr__(self): return self.__class__.__name__ + '(' + args_repr + ')' -def lazy(obj): - """ - Declaration that specifies the callable should only be called the first time - it is required. - """ - set_mush(obj, 'lazy', True) - return obj - - class how(object): """ The base class for type decorators that indicate which part of a diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index c540335..f9766ae 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -8,7 +8,7 @@ from mush.context import ContextError from mush.declarations import ( - requires, attr, item, nothing, returns, returns_mapping, lazy + requires, attr, item, nothing, returns, returns_mapping ) from mush.runner import Runner @@ -403,36 +403,7 @@ def job2(obj): call.job2(t), ], m.mock_calls) - def test_lazy_decorator(self): - m = Mock() - class T1(object): pass - class T2(object): pass - t = T1() - - @lazy - @returns(T1) - def lazy_used(): - m.lazy_used() - return t - - @lazy - @returns(T2) - def lazy_unused(): - raise AssertionError('should not be called') # pragma: no cover - - @requires(T1) - def job(obj): - m.job(obj) - - runner = Runner(lazy_used, lazy_unused, job) - runner() - - compare(m.mock_calls, expected=[ - call.lazy_used(), - call.job(t), - ], ) - - def test_lazy_imperative(self): + def test_lazy(self): m = Mock() class T1(object): pass class T2(object): pass From 4b366cb135d3ccaf6d95727e405b8129b1937416 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 18:34:39 +0000 Subject: [PATCH 022/159] optimisation for a very common path. --- mush/context.py | 5 +++-- mush/declarations.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mush/context.py b/mush/context.py index a90e21a..3a232e9 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,6 +1,6 @@ from typing import Optional, Any, Union, Type, Callable, NewType -from .declarations import nothing, extract_requires, Requirement +from .declarations import nothing, extract_requires, Requirement, RequiresType from .markers import missing from .resolvers import ValueResolver @@ -118,7 +118,8 @@ def extract(self, obj, requires, returns): return result def call(self, obj, requires=None): - requires = extract_requires(obj, requires) + if requires.__class__ is not RequiresType: + requires = extract_requires(obj, requires) args = [] kw = {} diff --git a/mush/declarations.py b/mush/declarations.py index 051fe98..47ef3aa 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -39,7 +39,7 @@ def __repr__(self): return f'{self.target}={requirement_repr}' -class requires(list): +class RequiresType(list): """ Represents requirements for a particular callable. @@ -69,6 +69,9 @@ def __call__(self, obj): return obj +requires = RequiresType + + class ReturnsType(object): def __call__(self, obj): From a09c906cefae6c88237add123ccac38ba04f66e3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 18:36:48 +0000 Subject: [PATCH 023/159] Make Context.get() simpler. --- mush/context.py | 37 +++++++++++++++++-------------------- mush/tests/test_context.py | 12 +++++++++++- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/mush/context.py b/mush/context.py index 3a232e9..c1c722d 100644 --- a/mush/context.py +++ b/mush/context.py @@ -125,7 +125,16 @@ def call(self, obj, requires=None): kw = {} for requirement in requires: - o = self.get(requirement) + o = self.get(requirement.base, missing) + + for op in requirement.ops: + o = op(o) + if o is nothing: + break + + if o is missing: + raise ContextError('No %s in context' % repr(requirement.base)) + if o is nothing: pass elif requirement.target is None: @@ -135,22 +144,10 @@ def call(self, obj, requires=None): return obj(*args, **kw) - def get(self, requirement: Requirement): - resolver = self._store.get(requirement.base, missing) - if resolver is missing: - o = missing - else: - o = resolver(self) - - for op in requirement.ops: - o = op(o) - if o is nothing: - break - - if o is missing: - if requirement.base is Context: - o = self - else: - raise ContextError('No %s in context' % repr(requirement.spec)) - - return o + def get(self, requirement: ResourceKey, default=None): + resolver = self._store.get(requirement, None) + if resolver is None: + if requirement is Context: + return self + return default + return resolver(self) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 3d93fb0..9a4faf2 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -74,7 +74,7 @@ def test_resolver(self): context = Context() context.add(provides='foo', resolver=m) m.assert_not_called() - assert context.get(requires('foo')[0]) is m.return_value + assert context.get('foo') is m.return_value m.assert_called_with(context) def test_resolver_and_resource(self): @@ -317,6 +317,7 @@ def test_context_contains_itself(self): def return_context(context: Context): return context assert context.call(return_context) is context + assert context.get(Context) is context def test_remove(self): context = Context() @@ -334,3 +335,12 @@ def test_remove_not_there_not_strict(self): context = Context() context.remove('foo', strict=False) compare(context._store, expected={}) + + def test_get_present(self): + context = Context() + context.add('bar', provides='foo') + compare(context.get('foo'), expected='bar') + + def test_get_missing(self): + context = Context() + compare(context.get('foo'), expected=None) From c57f733b253624dd8afbd26979a91376537afb9a Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 16 Feb 2020 20:19:34 +0000 Subject: [PATCH 024/159] standardise on 'key' for the name of parameters that take a ResourceKey --- mush/context.py | 30 ++++++++++++++---------------- mush/declarations.py | 9 +++++++-- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/mush/context.py b/mush/context.py index c1c722d..b17172b 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,6 +1,9 @@ -from typing import Optional, Any, Union, Type, Callable, NewType +from typing import Optional -from .declarations import nothing, extract_requires, Requirement, RequiresType +from .declarations import ( + nothing, extract_requires, RequiresType, + ResourceKey, ResourceValue, Resolver +) from .markers import missing from .resolvers import ValueResolver @@ -58,11 +61,6 @@ def type_key(type_tuple): return type.__name__ -ResourceKey = NewType('ResourceKey', Union[Type, str]) -ResourceValue = NewType('ResourceValue', Any) -Resolver = Callable[['Context'], ResourceValue] - - class Context: "Stores resources for a particular run." @@ -92,16 +90,16 @@ def add(self, resolver = ValueResolver(resource) self._store[provides] = resolver - def remove(self, provides: ResourceKey, *, strict: bool = True): + def remove(self, key: ResourceKey, *, strict: bool = True): """ Remove the specified resource key from the context. If ``strict``, then a :class:`ContextError` will be raised if the specified resource is not present in the context. """ - if strict and provides not in self._store: - raise ContextError(f'Context does not contain {provides!r}') - self._store.pop(provides, None) + if strict and key not in self._store: + raise ContextError(f'Context does not contain {key!r}') + self._store.pop(key, None) def __repr__(self): bits = [] @@ -125,7 +123,7 @@ def call(self, obj, requires=None): kw = {} for requirement in requires: - o = self.get(requirement.base, missing) + o = self.get(requirement.key, missing) for op in requirement.ops: o = op(o) @@ -133,7 +131,7 @@ def call(self, obj, requires=None): break if o is missing: - raise ContextError('No %s in context' % repr(requirement.base)) + raise ContextError('No %s in context' % repr(requirement.spec)) if o is nothing: pass @@ -144,10 +142,10 @@ def call(self, obj, requires=None): return obj(*args, **kw) - def get(self, requirement: ResourceKey, default=None): - resolver = self._store.get(requirement, None) + def get(self, key: ResourceKey, default=None): + resolver = self._store.get(key, None) if resolver is None: - if requirement is Context: + if key is Context: return self return default return resolver(self) diff --git a/mush/declarations.py b/mush/declarations.py index 47ef3aa..d70e09c 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -5,11 +5,16 @@ update_wrapper as functools_update_wrapper, ) from inspect import signature -from typing import List +from typing import List, Type, Optional, Callable, Sequence, NewType, Union, Any from .markers import missing +ResourceKey = NewType('ResourceKey', Union[Type, str]) +ResourceValue = NewType('ResourceValue', Any) +Resolver = Callable[['Context'], ResourceValue] + + def name_or_repr(obj): return getattr(obj, '__name__', None) or repr(obj) @@ -29,7 +34,7 @@ def __init__(self, source, target=None): while isinstance(source, how): self.ops.appendleft(source.process) source = source.type - self.base = source + self.key: ResourceKey = source def __repr__(self): requirement_repr = name_or_repr(self.spec) From 47a944b84f0190d73decb714e22381a86a73043b Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 17 Feb 2020 08:35:28 +0000 Subject: [PATCH 025/159] Resolvers should take a default. --- mush/context.py | 2 +- mush/declarations.py | 2 +- mush/resolvers.py | 4 ++-- mush/tests/test_context.py | 11 +++++++++-- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/mush/context.py b/mush/context.py index b17172b..92a06c6 100644 --- a/mush/context.py +++ b/mush/context.py @@ -148,4 +148,4 @@ def get(self, key: ResourceKey, default=None): if key is Context: return self return default - return resolver(self) + return resolver(self, default) diff --git a/mush/declarations.py b/mush/declarations.py index d70e09c..ff595d9 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -12,7 +12,7 @@ ResourceKey = NewType('ResourceKey', Union[Type, str]) ResourceValue = NewType('ResourceValue', Any) -Resolver = Callable[['Context'], ResourceValue] +Resolver = Callable[['Context', Any], ResourceValue] def name_or_repr(obj): diff --git a/mush/resolvers.py b/mush/resolvers.py index 2475595..29bb77c 100644 --- a/mush/resolvers.py +++ b/mush/resolvers.py @@ -8,7 +8,7 @@ class ValueResolver: def __init__(self, value): self.value = value - def __call__(self, context): + def __call__(self, context, default): return self.value def __repr__(self): @@ -27,7 +27,7 @@ def __init__(self, obj, requires, returns): def __call__(self, context): context.add(resolver=self.resolve, provides=self.provides) - def resolve(self, context): + def resolve(self, context, default): result = context.call(self.__wrapped__, self.requires) context.remove(self.provides) context.add(result, self.provides) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 9a4faf2..50c3df3 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -75,7 +75,7 @@ def test_resolver(self): context.add(provides='foo', resolver=m) m.assert_not_called() assert context.get('foo') is m.return_value - m.assert_called_with(context) + m.assert_called_with(context, None) def test_resolver_and_resource(self): m = Mock() @@ -83,7 +83,14 @@ def test_resolver_and_resource(self): with ShouldRaise(TypeError('resource cannot be supplied when using a resolver')): context.add('bar', provides='foo', resolver=m) compare(context._store, expected={}) - + + def test_resolver_with_default(self): + m = Mock() + context = Context() + context.add(provides='foo', + resolver=lambda context, default=None: context.get('foo-bar', default)) + assert context.get('foo', default=m) is m + def test_clash(self): obj1 = TheType() obj2 = TheType() From e0a5961a48c732622fcbc9da75862a579e23b2b3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 19 Feb 2020 06:56:39 +0000 Subject: [PATCH 026/159] tighten up sphinx builds and easy markup using :any: by default. --- docs/Makefile | 2 +- docs/conf.py | 1 + mush/modifier.py | 6 +++--- mush/plug.py | 2 +- mush/runner.py | 6 +++--- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index e8d21ce..d5d8d40 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -9,7 +9,7 @@ PAPER = # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d _build/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +ALLSPHINXOPTS = -W --keep-going -d _build/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml pickle json htmlhelp qthelp latex changes linkcheck doctest diff --git a/docs/conf.py b/docs/conf.py index 91eb7a8..3e4b286 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -37,3 +37,4 @@ 'Simplistix Ltd', 'manual'), ] +default_role = 'any' diff --git a/mush/modifier.py b/mush/modifier.py index 06807a6..4f0188a 100644 --- a/mush/modifier.py +++ b/mush/modifier.py @@ -24,17 +24,17 @@ def add(self, obj, requires=None, returns=None, label=None, lazy=False): :param obj: The callable to be added. :param requires: The resources to required as parameters when calling - `obj`. These can be specified by passing a single + ``obj``. These can be specified by passing a single type, a string name or a :class:`requires` object. - :param returns: The resources that `obj` will return. + :param returns: The resources that ``obj`` will return. These can be specified as a single type, a string name or a :class:`returns`, :class:`returns_mapping`, :class:`returns_sequence` object. :param label: If specified, this is a string that adds a label to the - point where `obj` is added that can later be retrieved + point where ``obj`` is added that can later be retrieved with :meth:`Runner.__getitem__`. :param lazy: If true, ``obj`` will only be called the first time it diff --git a/mush/plug.py b/mush/plug.py index dfe1950..64b6010 100644 --- a/mush/plug.py +++ b/mush/plug.py @@ -17,7 +17,7 @@ def apply(self, runner, obj): class insert(ignore): """ A decorator to explicitly mark that a method of a :class:`~mush.Plug` should - be added to a runner by :meth:`~mush.Plug.add_to`. The `label` parameter + be added to a runner by :meth:`~mush.Plug.add_to`. The ``label`` parameter can be used to indicate a different label at which to add the method, instead of using the name of the method. """ diff --git a/mush/runner.py b/mush/runner.py index 469d0bb..84cdc0f 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -26,17 +26,17 @@ def add(self, obj, requires=None, returns=None, label=None, lazy=False): :param obj: The callable to be added. :param requires: The resources to required as parameters when calling - `obj`. These can be specified by passing a single + ``obj``. These can be specified by passing a single type, a string name or a :class:`requires` object. - :param returns: The resources that `obj` will return. + :param returns: The resources that ``obj`` will return. These can be specified as a single type, a string name or a :class:`returns`, :class:`returns_mapping`, :class:`returns_sequence` object. :param label: If specified, this is a string that adds a label to the - point where `obj` is added that can later be retrieved + point where ``obj`` is added that can later be retrieved with :meth:`Runner.__getitem__`. :param lazy: If true, ``obj`` will only be called the first time it From e35ac13cab1d7a95e5aa9275402c80f4f5638129 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 19 Feb 2020 08:13:59 +0000 Subject: [PATCH 027/159] restructure how Runner.replace works to make extract_(requires|returns) simpler. --- docs/api.txt | 2 +- mush/declarations.py | 33 +++++++++++----- mush/runner.py | 80 ++++++++++++++++++++++----------------- mush/tests/test_runner.py | 29 ++++++++------ 4 files changed, 87 insertions(+), 57 deletions(-) diff --git a/docs/api.txt b/docs/api.txt index acbfa2b..8a5aff6 100644 --- a/docs/api.txt +++ b/docs/api.txt @@ -12,7 +12,7 @@ API Reference :members: Modifier .. automodule:: mush.declarations - :members: how, nothing, result_type, update_wrapper + :members: how, nothing, result_type, update_wrapper, DeclarationsFrom .. automodule:: mush.plug :members: insert, ignore, append, Plug diff --git a/mush/declarations.py b/mush/declarations.py index ff595d9..2a2845b 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,4 +1,5 @@ from collections import deque +from enum import Enum, auto from functools import ( WRAPPER_UPDATES, WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, @@ -152,6 +153,17 @@ def __repr__(self): return self.__class__.__name__ + '(' + args_repr + ')' +class DeclarationsFrom(Enum): + #: Use declarations from the original callable. + original = auto() + #: Use declarations from the replacement callable. + replacement = auto() + + +original = DeclarationsFrom.original +replacement = DeclarationsFrom.replacement + + class how(object): """ The base class for type decorators that indicate which part of a @@ -281,8 +293,8 @@ def guess_requirements(obj): return requires(*args, **kw) -def extract_requires(obj, requires_, default=nothing): - if requires_ is None: +def extract_requires(obj, explicit=None): + if explicit is None: mush_declarations = getattr(obj, '__mush__', {}) requires_ = mush_declarations.get('requires', None) if requires_ is None: @@ -290,29 +302,30 @@ def extract_requires(obj, requires_, default=nothing): annotations = {} if annotations is None else annotations.copy() annotations.pop('return', None) requires_ = annotations or None + else: + requires_ = explicit if isinstance(requires_, requires): pass elif requires_ is None: - if default is not None: - requires_ = guess_requirements(obj) + requires_ = guess_requirements(obj) elif isinstance(requires_, (list, tuple)): requires_ = requires(*requires_) - elif isinstance(requires_, dict): - requires_ = requires(**requires_) else: requires_ = requires(requires_) - return requires_ or default + return requires_ or nothing -def extract_returns(obj, returns_, default=result_type): - if returns_ is None: +def extract_returns(obj, explicit=None): + if explicit is None: mush_declarations = getattr(obj, '__mush__', {}) returns_ = mush_declarations.get('returns', None) if returns_ is None: annotations = getattr(obj, '__annotations__', {}) returns_ = annotations.get('return') + else: + returns_ = explicit if returns_ is None or isinstance(returns_, ReturnsType): pass @@ -321,7 +334,7 @@ def extract_returns(obj, returns_, default=result_type): else: returns_ = returns(returns_) - return returns_ or default + return returns_ or result_type WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) diff --git a/mush/runner.py b/mush/runner.py index 84cdc0f..b9e8d43 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -1,6 +1,8 @@ +from typing import Callable + from .callpoints import CallPoint from .context import Context, ContextError -from .declarations import extract_requires, extract_returns +from .declarations import extract_requires, extract_returns, DeclarationsFrom from .markers import not_specified from .modifier import Modifier from .plug import Plug @@ -148,52 +150,62 @@ def clone(self, runner._copy_from(start, end, added_using) return runner - def replace(self, original, replacement, requires=None, returns=None): + def replace(self, + original: Callable, + replacement: Callable, + requires_from: DeclarationsFrom = DeclarationsFrom.replacement, + returns_from: DeclarationsFrom = DeclarationsFrom.original): """ Replace all instances of one callable with another. - No changes in requirements or call ordering will be made unless the - replacements have been decorated with requirements, or either - ``requires`` or ``returns`` have been specified. + :param original: The callable to replaced. - :param requires: The resources to required as parameters when calling - `obj`. These can be specified by passing a single - type, a string name or a :class:`requires` object. + :param replacement: The callable use instead. - :param returns: The resources that `obj` will return. - These can be specified as a single - type, a string name or a :class:`returns`, - :class:`returns_mapping`, :class:`returns_sequence` - object. + :param requires_from: + + Which :class:`requires` to use. + If :attr:`~mush.declarations.DeclarationsFrom.original`, + the existing ones will be used. + If :attr:`~mush.declarations.DeclarationsFrom.replacement`, + they will be extracted from the supplied replacements. + + :param returns_from: + + Which :class:`returns` to use. + If :attr:`~mush.declarations.DeclarationsFrom.original`, + the existing ones will be used. + If :attr:`~mush.declarations.DeclarationsFrom.replacement`, + they will be extracted from the supplied replacements. """ point = self.start while point: if point.obj is original: + if requires_from is DeclarationsFrom.replacement: + requires = extract_requires(replacement) + else: + requires = point.requires + if returns_from is DeclarationsFrom.replacement: + returns = extract_returns(replacement) + else: + returns = point.returns - new_requirements = ( - extract_requires(replacement, requires, default=None), - extract_returns(replacement, returns, default=None) - ) - - if any(new_requirements): - new_point = CallPoint(replacement, *new_requirements) - if point.previous is None: - self.start = new_point - else: - point.previous.next = new_point - if point.next is None: - self.end = new_point - else: - point.next.previous = new_point - new_point.next = point.next - for label in point.labels: - self.labels[label] = new_point - new_point.labels.add(label) - new_point.added_using = set(point.added_using) + new_point = CallPoint(replacement, requires, returns) + if point.previous is None: + self.start = new_point + else: + point.previous.next = new_point + if point.next is None: + self.end = new_point else: + point.next.previous = new_point + new_point.next = point.next - point.obj = replacement + for label in point.labels: + self.labels[label] = new_point + new_point.labels.add(label) + new_point.added_using = set(point.added_using) point = point.next diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index f9766ae..5d6041e 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -8,8 +8,8 @@ from mush.context import ContextError from mush.declarations import ( - requires, attr, item, nothing, returns, returns_mapping -) + requires, attr, item, nothing, returns, returns_mapping, + replacement, original) from mush.runner import Runner @@ -1090,9 +1090,9 @@ def job3(obj): runner = Runner(job1, job2, job3) runner.replace(job1, m.job1) m.job1.return_value = t1 - runner.replace(job2, m.job2) + runner.replace(job2, m.job2, requires_from=original) m.job2.return_value = t2 - runner.replace(job3, m.job3) + runner.replace(job3, m.job3, requires_from=original) runner() compare([ @@ -1120,7 +1120,9 @@ def job1(obj): job2 = requires(T4)(m.job2) runner = Runner(job0, job1, job2) - runner.replace(job1, requires(T2)(returns(T4)(m.job1))) + runner.replace(job1, + requires(T2)(returns(T4)(m.job1)), + returns_from=replacement) runner() compare([ @@ -1147,7 +1149,8 @@ def job1(obj): job2 = requires(T4)(m.job2) runner = Runner(job0, job1, job2) - runner.replace(job1, m.job1, requires=T2, returns=T4) + runner.replace(job1, requires(T2)(returns(T4)(m.job1)), + returns_from=replacement) runner() compare([ @@ -1163,7 +1166,9 @@ def test_replace_explicit_with_labels(self): runner['foo'].add(m.job1) runner['foo'].add(m.job2) - runner.replace(m.job2, m.jobnew, returns='mock') + runner.replace(m.job2, + returns('mock')(m.jobnew), + returns_from=replacement) runner() @@ -1196,10 +1201,10 @@ def test_replace_explicit_with_labels(self): ], actual=m.mock_calls) def test_replace_explicit_at_start(self): - m = Mock() + m = returns('mock')(Mock()) runner = Runner(m.job1, m.job2) - runner.replace(m.job1, m.jobnew, returns='mock') + runner.replace(m.job1, m.jobnew, returns_from=replacement) runner() compare([ @@ -1208,10 +1213,10 @@ def test_replace_explicit_at_start(self): ], actual=m.mock_calls) def test_replace_explicit_at_end(self): - m = Mock() + m = returns('mock')(Mock()) runner = Runner(m.job1, m.job2) - runner.replace(m.job2, m.jobnew, returns='mock') + runner.replace(m.job2, m.jobnew, returns_from=replacement) runner.add(m.jobnew2) runner() @@ -1232,7 +1237,7 @@ def barbar(sheep): runner.add(barbar, requires='flossy') compare(runner(), expected='barbar') - runner.replace(barbar, lambda dog: None) + runner.replace(barbar, lambda dog: None, requires_from=original) compare(runner(), expected=None) def test_modifier_changes_endpoint(self): From e6be88c1f584bcabe744f2ad5dcf16d484abe88c Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 19 Feb 2020 08:22:28 +0000 Subject: [PATCH 028/159] implement a comparer for Requirements, so we can start refactoring them. --- mush/declarations.py | 6 +++--- mush/tests/conftest.py | 13 +++++++++++++ mush/tests/test_declarations.py | 2 +- setup.py | 2 +- 4 files changed, 18 insertions(+), 5 deletions(-) create mode 100644 mush/tests/conftest.py diff --git a/mush/declarations.py b/mush/declarations.py index 2a2845b..533dab0 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -31,6 +31,7 @@ class Requirement: def __init__(self, source, target=None): self.target = target self.spec = source + self.repr = name_or_repr(source) self.ops = deque() while isinstance(source, how): self.ops.appendleft(source.process) @@ -38,11 +39,10 @@ def __init__(self, source, target=None): self.key: ResourceKey = source def __repr__(self): - requirement_repr = name_or_repr(self.spec) if self.target is None: - return requirement_repr + return self.repr else: - return f'{self.target}={requirement_repr}' + return f'{self.target}={self.repr}' class RequiresType(list): diff --git a/mush/tests/conftest.py b/mush/tests/conftest.py new file mode 100644 index 0000000..fede896 --- /dev/null +++ b/mush/tests/conftest.py @@ -0,0 +1,13 @@ +from testfixtures.comparison import register, compare_object + +from mush.declarations import Requirement + + +def compare_requirement(x, y, context): + # make sure this doesn't get refactored away, since we're using it + # as a proxy to check .ops: + assert hasattr(x, 'repr') + return compare_object(x, y, context, ignore_attributes=['ops']) + + +register(Requirement, compare_requirement) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 437183e..341e239 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -16,7 +16,7 @@ def check_extract(obj, expected_rq, expected_rt): rq = extract_requires(obj, None) rt = extract_returns(obj, None) - compare(rq, expected=expected_rq, strict=True, ignore_attributes={Requirement: ['ops']}) + compare(rq, expected=expected_rq, strict=True) compare(rt, expected=expected_rt, strict=True) diff --git a/setup.py b/setup.py index a030b0a..8cea6ff 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,6 @@ include_package_data=True, python_requires='>=3.6', extras_require=dict( - test=['pytest', 'pytest-cov', 'mock', 'sybil', 'testfixtures'], + test=['pytest', 'pytest-cov', 'mock', 'sybil', 'testfixtures>=6.13'], build=['sphinx', 'setuptools-git', 'wheel', 'twine'] )) From 460b96681e043bb49f563deab143bf209188b25a Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 19 Feb 2020 09:02:53 +0000 Subject: [PATCH 029/159] Move target back into the requires object. Having it on the requirement felt awkward. --- mush/context.py | 6 +++--- mush/declarations.py | 25 ++++++++++++------------- mush/tests/test_declarations.py | 20 ++++++++++---------- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/mush/context.py b/mush/context.py index 92a06c6..802949e 100644 --- a/mush/context.py +++ b/mush/context.py @@ -122,7 +122,7 @@ def call(self, obj, requires=None): args = [] kw = {} - for requirement in requires: + for target, requirement in requires: o = self.get(requirement.key, missing) for op in requirement.ops: @@ -135,10 +135,10 @@ def call(self, obj, requires=None): if o is nothing: pass - elif requirement.target is None: + elif target is None: args.append(o) else: - kw[requirement.target] = o + kw[target] = o return obj(*args, **kw) diff --git a/mush/declarations.py b/mush/declarations.py index 533dab0..fbb32d8 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -28,10 +28,10 @@ def set_mush(obj, key, value): class Requirement: - def __init__(self, source, target=None): - self.target = target - self.spec = source + def __init__(self, source): self.repr = name_or_repr(source) + + self.spec = source self.ops = deque() while isinstance(source, how): self.ops.appendleft(source.process) @@ -39,10 +39,7 @@ def __init__(self, source, target=None): self.key: ResourceKey = source def __repr__(self): - if self.target is None: - return self.repr - else: - return f'{self.target}={self.repr}' + return self.repr class RequiresType(list): @@ -61,14 +58,16 @@ def __init__(self, *args, **kw): super(requires, self).__init__() check_type(*args) check_type(*kw.values()) - self.resolvers = [] - for arg in args: - self.append(Requirement(arg)) - for k, v in kw.items(): - self.append(Requirement(v, target=k)) + for target, source in chain( + ((None, arg) for arg in args), + kw.items(), + ): + self.append((target, Requirement(source))) def __repr__(self): - return f"requires({', '.join(repr(r) for r in self)})" + parts = (repr(r) if t is None else f'{t}={r!r}' + for (t, r) in self) + return f"requires({', '.join(parts)})" def __call__(self, obj): set_mush(obj, 'requires', self) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 341e239..ae3de5b 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -31,26 +31,26 @@ class TestRequires(TestCase): def test_empty(self): r = requires() compare(repr(r), 'requires()') - compare(r.resolvers, []) + compare(r, expected=[]) def test_types(self): r = requires(Type1, Type2, x=Type3, y=Type4) compare(repr(r), 'requires(Type1, Type2, x=Type3, y=Type4)') compare(r, expected=[ - Requirement(Type1), - Requirement(Type2), - Requirement(Type3, target='x'), - Requirement(Type4, target='y'), + (None, Requirement(Type1)), + (None, Requirement(Type2)), + ('x', Requirement(Type3)), + ('y', Requirement(Type4)), ]) def test_strings(self): r = requires('1', '2', x='3', y='4') compare(repr(r), "requires('1', '2', x='3', y='4')") compare(r, expected=[ - Requirement('1'), - Requirement('2'), - Requirement('3', target='x'), - Requirement('4', target='y'), + (None, Requirement('1')), + (None, Requirement('2')), + ('x', Requirement('3')), + ('y', Requirement('4')), ]) def test_tuple_arg(self): @@ -66,7 +66,7 @@ def test_decorator_paranoid(self): def foo(): return 'bar' - compare(foo.__mush__['requires'], expected=[Requirement(Type1)]) + compare(foo.__mush__['requires'], expected=[(None, Requirement(Type1))]) compare(foo(), 'bar') From b0ca09450acb8a47b2e29d43f1a3809fa743cf2e Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 19 Feb 2020 09:11:10 +0000 Subject: [PATCH 030/159] First step into allowing explicit requirement objects. --- mush/context.py | 2 +- mush/declarations.py | 15 ++++++++++----- mush/tests/test_context.py | 8 ++++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/mush/context.py b/mush/context.py index 802949e..99233d6 100644 --- a/mush/context.py +++ b/mush/context.py @@ -123,7 +123,7 @@ def call(self, obj, requires=None): kw = {} for target, requirement in requires: - o = self.get(requirement.key, missing) + o = self.get(requirement.key, requirement.default) for op in requirement.ops: o = op(o) diff --git a/mush/declarations.py b/mush/declarations.py index fbb32d8..d4013a0 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -6,6 +6,7 @@ update_wrapper as functools_update_wrapper, ) from inspect import signature +from itertools import chain from typing import List, Type, Optional, Callable, Sequence, NewType, Union, Any from .markers import missing @@ -28,10 +29,12 @@ def set_mush(obj, key, value): class Requirement: - def __init__(self, source): + def __init__(self, source, default=missing): self.repr = name_or_repr(source) self.spec = source + self.default = default + self.ops = deque() while isinstance(source, how): self.ops.appendleft(source.process) @@ -55,14 +58,16 @@ class RequiresType(list): """ def __init__(self, *args, **kw): - super(requires, self).__init__() + super().__init__() check_type(*args) check_type(*kw.values()) - for target, source in chain( + for target, requirement in chain( ((None, arg) for arg in args), kw.items(), ): - self.append((target, Requirement(source))) + if not isinstance(requirement, Requirement): + requirement = Requirement(requirement) + self.append((target, requirement)) def __repr__(self): parts = (repr(r) if t is None else f'{t}={r!r}' @@ -246,7 +251,7 @@ def process(self, o): return o -ok_types = (type, str, how) +ok_types = (type, str, how, Requirement) def check_type(*objs): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 50c3df3..203964e 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -197,6 +197,14 @@ def foo(x=1): result = context.call(foo, requires(optional(TheType))) compare(result, 1) + def test_call_requires_optional_override_source_and_default(self): + def foo(x=1): + return x + context = Context() + context.add(2, provides='x') + result = context.call(foo, requires(x=Requirement('y', default=3))) + compare(result, expected=3) + def test_call_requires_optional_string(self): def foo(x=1): return x From 7de2b3f6e63584e80bf197cf111462d223e51131 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 19 Feb 2020 09:31:45 +0000 Subject: [PATCH 031/159] lose vestigial `spec` attribute and give Requirement a proper repr. --- mush/context.py | 2 +- mush/declarations.py | 8 +++----- mush/tests/test_context.py | 2 +- mush/tests/test_declarations.py | 7 +++++++ mush/tests/test_runner.py | 12 ++++++------ 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/mush/context.py b/mush/context.py index 99233d6..0a34fb6 100644 --- a/mush/context.py +++ b/mush/context.py @@ -131,7 +131,7 @@ def call(self, obj, requires=None): break if o is missing: - raise ContextError('No %s in context' % repr(requirement.spec)) + raise ContextError('No %s in context' % requirement.repr) if o is nothing: pass diff --git a/mush/declarations.py b/mush/declarations.py index d4013a0..bf144c9 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -7,11 +7,10 @@ ) from inspect import signature from itertools import chain -from typing import List, Type, Optional, Callable, Sequence, NewType, Union, Any +from typing import Type, Callable, NewType, Union, Any from .markers import missing - ResourceKey = NewType('ResourceKey', Union[Type, str]) ResourceValue = NewType('ResourceValue', Any) Resolver = Callable[['Context', Any], ResourceValue] @@ -32,7 +31,6 @@ class Requirement: def __init__(self, source, default=missing): self.repr = name_or_repr(source) - self.spec = source self.default = default self.ops = deque() @@ -42,7 +40,7 @@ def __init__(self, source, default=missing): self.key: ResourceKey = source def __repr__(self): - return self.repr + return f'Requirement({self.repr}, default={self.default})' class RequiresType(list): @@ -70,7 +68,7 @@ def __init__(self, *args, **kw): self.append((target, requirement)) def __repr__(self): - parts = (repr(r) if t is None else f'{t}={r!r}' + parts = (r.repr if t is None else f'{t}={r.repr}' for (t, r) in self) return f"requires({', '.join(parts)})" diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 203964e..4bd83b9 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -146,7 +146,7 @@ def test_call_requires_missing(self): def foo(obj): return obj context = Context() with ShouldRaise(ContextError( - "No in context" + "No TheType in context" )): context.call(foo, requires(TheType)) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index ae3de5b..70db0ca 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -70,6 +70,13 @@ def foo(): compare(foo(), 'bar') +class TestRequirement: + + def test_repr(self): + compare(repr(Requirement('foo', default=None)), + expected="Requirement('foo', default=None)") + + class TestItem(TestCase): def test_single(self): diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 5d6041e..1b8222a 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -514,10 +514,10 @@ def job(arg): 'While calling: '+repr(job)+' requires(T) returns_result_type()', 'with :', '', - 'No '+repr(T)+' in context', + 'No T in context', )) - compare(text, repr(s.raised)) - compare(text, str(s.raised)) + compare(text, actual=repr(s.raised)) + compare(text, actual=str(s.raised)) def test_missing_from_context_with_chain(self): class T(object): pass @@ -552,14 +552,14 @@ def job5(): pass 'While calling: '+repr(job3)+' requires(T) returns_result_type()', 'with :', '', - 'No '+repr(T)+' in context', + 'No T in context', '', 'Still to call:', repr(job4)+' requires() returns_result_type() <-- 4', repr(job5)+" requires('foo', bar='baz') returns('bob')", )) - compare(text, repr(s.raised)) - compare(text, str(s.raised)) + compare(text, actual=repr(s.raised)) + compare(text, actual=str(s.raised)) def test_job_called_badly(self): def job(arg): From 31caa665acb135654267c2448d8ebcdf597ae867 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 20 Feb 2020 08:28:55 +0000 Subject: [PATCH 032/159] Remove the optional how in favour of just using defaults. This allows resolution of an individual requirement to be moved to the Requirement object. Also removed an obscure "marker" requirement use case, which I think was a hangover from mush 1.x's attempts to do DAG-ish resolution. --- docs/use.txt | 19 +++++------ mush/__init__.py | 4 +-- mush/context.py | 19 +++-------- mush/declarations.py | 49 ++++++++++------------------- mush/markers.py | 5 ++- mush/tests/test_context.py | 41 +++++++----------------- mush/tests/test_declarations.py | 56 ++++++++++++++------------------- mush/tests/test_runner.py | 25 --------------- 8 files changed, 70 insertions(+), 148 deletions(-) diff --git a/docs/use.txt b/docs/use.txt index 9aed676..4ac7300 100755 --- a/docs/use.txt +++ b/docs/use.txt @@ -198,8 +198,6 @@ I made an apple I turned an apple into an orange I made juice out of an apple and an orange -.. _optional-resources: - Optional requirements ~~~~~~~~~~~~~~~~~~~~~ @@ -209,17 +207,17 @@ take this into account. Take the following function: .. code-block:: python - def greet(name='stranger'): + def greet(name: str = 'stranger'): print('Hello ' + name + '!') -If a name is not always be available, it can be added to a runner as follows: +If a name is not always be available, the callable's default will be used: .. code-block:: python - from mush import Runner, optional + from mush import Runner runner = Runner() - runner.add(greet, requires=optional(str)) + runner.add(greet) Now, when this runner is called, the default will be used: @@ -231,13 +229,13 @@ available: .. code-block:: python - from mush import Runner, optional + from mush import Runner def my_name_is(): return 'Slim Shady' runner = Runner(my_name_is) - runner.add(greet, requires=optional(str)) + runner.add(greet) In this case, the string returned will be used: @@ -538,9 +536,6 @@ I turned an apple into an orange I made juice out of an apple and an orange a refreshing fruit beverage -If an argument has a default, then the requirement will be made -:ref:`optional `. - Configuration Precedence ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1040,7 +1035,7 @@ To see how the configuration panned out, we would look at the :func:`repr`: >>> runner requires() returns('config') - requires(foo='config'['foo']) returns_result_type() <-- config + requires('config'['foo']) returns_result_type() <-- config requires('connection') returns_result_type() diff --git a/mush/__init__.py b/mush/__init__.py index 919f59c..cb5079b 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -2,7 +2,7 @@ from .declarations import ( requires, returns_result_type, returns_mapping, returns_sequence, returns, - optional, attr, item, nothing + attr, item, nothing ) from .plug import Plug from .context import Context, ContextError @@ -10,7 +10,7 @@ __all__ = [ 'Context', 'ContextError', 'Runner', - 'requires', 'optional', + 'requires', 'returns_result_type', 'returns_mapping', 'returns_sequence', 'returns', 'attr', 'item', 'Plug', 'nothing' ] diff --git a/mush/context.py b/mush/context.py index 0a34fb6..b4060c2 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,8 +1,7 @@ from typing import Optional from .declarations import ( - nothing, extract_requires, RequiresType, - ResourceKey, ResourceValue, Resolver + extract_requires, RequiresType, ResourceKey, ResourceValue, Resolver ) from .markers import missing from .resolvers import ValueResolver @@ -123,19 +122,11 @@ def call(self, obj, requires=None): kw = {} for target, requirement in requires: - o = self.get(requirement.key, requirement.default) - - for op in requirement.ops: - o = op(o) - if o is nothing: - break - + o = requirement.resolve(self) if o is missing: - raise ContextError('No %s in context' % requirement.repr) - - if o is nothing: - pass - elif target is None: + if requirement.default is missing: + raise ContextError('No %s in context' % requirement.repr) + if target is None: args.append(o) else: kw[target] = o diff --git a/mush/declarations.py b/mush/declarations.py index bf144c9..eeed764 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -5,7 +5,7 @@ WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, update_wrapper as functools_update_wrapper, ) -from inspect import signature +from inspect import Signature from itertools import chain from typing import Type, Callable, NewType, Union, Any @@ -39,6 +39,14 @@ def __init__(self, source, default=missing): source = source.type self.key: ResourceKey = source + def resolve(self, context): + o = context.get(self.key, missing) + if o is missing: + return self.default + for op in self.ops: + o = op(o) + return o + def __repr__(self): return f'Requirement({self.repr}, default={self.default})' @@ -198,18 +206,6 @@ def process(self, o): """ return missing -class optional(how): - """ - A :class:`~.declarations.how` that indicates the callable requires the - wrapped requirement only if it's present in the :class:`~.context.Context`. - """ - type_pattern = 'optional(%(type)s)' - - def process(self, o): - if o is missing: - return nothing - return o - class attr(how): """ @@ -276,23 +272,17 @@ def process(self, result): result_type = returns_result_type() -def maybe_optional(p): - value = p.name - if p.default is not p.empty: - value = optional(value) - return value - - def guess_requirements(obj): args = [] kw = {} - for name, p in signature(obj).parameters.items(): + for name, p in Signature.from_callable(obj).parameters.items(): + key = p.name if p.annotation is missing else p.annotation + requirement = Requirement(key, default=p.default) if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}: - args.append(maybe_optional(p)) + args.append(requirement) elif p.kind is p.KEYWORD_ONLY: - kw[name] = maybe_optional(p) - if args or kw: - return requires(*args, **kw) + kw[name] = requirement + return requires(*args, **kw) def extract_requires(obj, explicit=None): @@ -300,17 +290,12 @@ def extract_requires(obj, explicit=None): mush_declarations = getattr(obj, '__mush__', {}) requires_ = mush_declarations.get('requires', None) if requires_ is None: - annotations = getattr(obj, '__annotations__', None) - annotations = {} if annotations is None else annotations.copy() - annotations.pop('return', None) - requires_ = annotations or None + requires_ = guess_requirements(obj) else: requires_ = explicit if isinstance(requires_, requires): pass - elif requires_ is None: - requires_ = guess_requirements(obj) elif isinstance(requires_, (list, tuple)): requires_ = requires(*requires_) else: @@ -319,7 +304,7 @@ def extract_requires(obj, explicit=None): return requires_ or nothing -def extract_returns(obj, explicit=None): +def extract_returns(obj: Callable, explicit: ReturnsType = None): if explicit is None: mush_declarations = getattr(obj, '__mush__', {}) returns_ = mush_declarations.get('returns', None) diff --git a/mush/markers.py b/mush/markers.py index e45b138..9ac04f6 100644 --- a/mush/markers.py +++ b/mush/markers.py @@ -1,3 +1,6 @@ +from inspect import Parameter + + class Marker(object): def __init__(self, name): @@ -8,4 +11,4 @@ def __repr__(self): not_specified = Marker('not_specified') -missing = Marker('missing') \ No newline at end of file +missing = Parameter.empty diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 4bd83b9..b05581d 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,13 +1,11 @@ from unittest import TestCase -from mock import Mock +from mock import Mock from testfixtures import ShouldRaise, compare from mush import Context, ContextError - from mush.declarations import ( - nothing, requires, optional, item, - attr, returns, returns_mapping + nothing, requires, item, attr, returns, returns_mapping, Requirement ) from mush.resolvers import ValueResolver @@ -186,15 +184,15 @@ def foo(x=1): return x context = Context() context.add(2, TheType) - result = context.call(foo, requires(optional(TheType))) + result = context.call(foo, requires(TheType)) compare(result, 2) compare({TheType: ValueResolver(2)}, actual=context._store) - def test_call_requires_optional_ContextError(self): - def foo(x=1): + def test_call_requires_optional_missing(self): + def foo(x: TheType = 1): return x context = Context() - result = context.call(foo, requires(optional(TheType))) + result = context.call(foo) compare(result, 1) def test_call_requires_optional_override_source_and_default(self): @@ -206,11 +204,11 @@ def foo(x=1): compare(result, expected=3) def test_call_requires_optional_string(self): - def foo(x=1): + def foo(x:'foo'=1): return x context = Context() context.add(2, 'foo') - result = context.call(foo, requires(optional('foo'))) + result = context.call(foo) compare(result, 2) compare({'foo': ValueResolver(2)}, actual=context._store) @@ -241,11 +239,11 @@ def foo(x): result = context.call(foo, requires(item(attr('foo', 'bar'), 'baz'))) compare(result, 'bob') - def test_call_requires_optional_item_ContextError(self): - def foo(x=1): + def test_call_requires_optional_item_missing(self): + def foo(x: item('foo', 'bar') = 1): return x context = Context() - result = context.call(foo, requires(optional(item('foo', 'bar')))) + result = context.call(foo) compare(result, 1) def test_call_requires_optional_item_present(self): @@ -253,22 +251,7 @@ def foo(x=1): return x context = Context() context.add(dict(bar='baz'), 'foo') - result = context.call(foo, requires(optional(item('foo', 'bar')))) - compare(result, 'baz') - - def test_call_requires_item_optional_ContextError(self): - def foo(x=1): - return x - context = Context() - result = context.call(foo, requires(item(optional('foo'), 'bar'))) - compare(result, 1) - - def test_call_requires_item_optional_present(self): - def foo(x=1): - return x - context = Context() - context.add(dict(bar='baz'), 'foo') - result = context.call(foo, requires(item(optional('foo'), 'bar'))) + result = context.call(foo, requires((item('foo', 'bar')))) compare(result, 'baz') def test_call_extract_requirements(self): diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 70db0ca..0cd988f 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -4,7 +4,7 @@ from testfixtures import compare, ShouldRaise from mush.markers import missing from mush.declarations import ( - requires, optional, returns, + requires, returns, returns_mapping, returns_sequence, returns_result_type, how, item, attr, nothing, extract_requires, extract_returns, @@ -136,21 +136,6 @@ def test_passed_missing(self): compare(h.process(missing), missing) -class TestOptional(TestCase): - - def test_type(self): - compare(repr(optional(Type1)), "optional(Type1)") - - def test_string(self): - compare(repr(optional('1')), "optional('1')") - - def test_present(self): - compare(optional(Type1).process(1), 1) - - def test_missing(self): - compare(optional(Type1).process(missing), nothing) - - class TestReturns(TestCase): def test_type(self): @@ -238,14 +223,14 @@ class TestExtractDeclarations(object): def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, - expected_rq=requires('a', optional('b')), + expected_rq=requires('a', Requirement('b', default=None)), expected_rt=result_type) def test_default_requirements_for_class(self): class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, - expected_rq=requires('a', optional('b')), + expected_rq=requires('a', Requirement('b', default=None)), expected_rt=result_type) def test_extract_from_partial(self): @@ -253,7 +238,9 @@ def foo(x, y, z, a=None): pass p = partial(foo, 1, y=2) check_extract( p, - expected_rq=requires(y=optional('y'), z='z', a=optional('a'), ), + expected_rq=requires(y=Requirement('y', default=2), + z='z', + a=Requirement('a', default=None)), expected_rt=result_type ) @@ -262,7 +249,7 @@ def foo(a=None): pass p = partial(foo) check_extract( p, - expected_rq=requires(optional('a')), + expected_rq=requires(Requirement('a', default=None)), expected_rt=result_type ) @@ -281,7 +268,7 @@ def foo(a=None): pass p = partial(foo, a=1) check_extract( p, - expected_rq=requires(a=optional('a')), + expected_rq=requires(a=Requirement('a', default=1)), expected_rt=result_type ) @@ -300,7 +287,7 @@ def foo(a): pass p = partial(foo, a=1) check_extract( p, - expected_rq=requires(a=optional('a')), + expected_rq=requires(a=Requirement('a', default=1)), expected_rt=result_type ) @@ -309,7 +296,7 @@ def foo(b, a=None): pass p = partial(foo) check_extract( p, - expected_rq=requires('b', optional('a')), + expected_rq=requires('b', Requirement('a', default=None)), expected_rt=result_type ) @@ -328,7 +315,7 @@ def foo(b, a): pass p = partial(foo, a=1) check_extract( p, - expected_rq=requires('b', a=optional('a')), + expected_rq=requires('b', a=Requirement('a', default=1)), expected_rt=result_type ) @@ -338,13 +325,16 @@ class TestExtractDeclarationsFromTypeAnnotations(object): def test_extract_from_annotations(self): def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass check_extract(foo, - expected_rq=requires(a='foo', c='bar'), + expected_rq=requires('foo', + 'b', + Requirement('bar', default=1), + Requirement('d', default=2)), expected_rt=returns('bar')) def test_requires_only(self): def foo(a: 'foo'): pass check_extract(foo, - expected_rq=requires(a='foo'), + expected_rq=requires('foo'), expected_rt=result_type) def test_returns_only(self): @@ -365,12 +355,12 @@ def my_dec(func): return update_wrapper(Wrapper(func), func) @my_dec - def foo(a: 'foo'=None) -> 'bar': + def foo(a: 'foo' = None) -> 'bar': return 'answer' compare(foo(), expected='the answer') check_extract(foo, - expected_rq=requires(a='foo'), + expected_rq=requires(Requirement('foo', default=None)), expected_rt=returns('bar')) def test_decorator_trumps_annotations(self): @@ -396,17 +386,17 @@ def foo() -> rt: pass expected_rt=rt) def test_how_instance_in_annotations(self): - how = item('config', 'db_url') - def foo(a: how): pass + how_instance = item('config', 'db_url') + def foo(a: how_instance): pass check_extract(foo, - expected_rq=requires(a=how), + expected_rq=requires(how_instance), expected_rt=result_type) def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, expected_rq=requires('a', - optional('b'), + Requirement('b', default=1), c='c', - d=optional('d')), + d=Requirement('d', default=None)), expected_rt=result_type) diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 1b8222a..015df31 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -767,31 +767,6 @@ def func2(obj1, obj2, obj3): call.cm1.exit(Exception, e) ], m.mock_calls) - def test_marker_interfaces(self): - # return {Type:None} - # don't pass when a requirement is for a type but value is None - class Marker(object): pass - - m = Mock() - - def setup(): - m.setup() - return {Marker: nothing} - - @requires(Marker) - def use(): - m.use() - - runner = Runner() - runner.add(setup, returns=returns_mapping(), label='setup') - runner['setup'].add(use) - runner() - - compare([ - call.setup(), - call.use(), - ], m.mock_calls) - def test_clone(self): m = Mock() class T1(object): pass From 89583e14c4c0cdda1299d24ffa134acab4bbce5f Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 21 Feb 2020 08:19:31 +0000 Subject: [PATCH 033/159] Add support for nested contexts. --- mush/context.py | 19 ++++++++++++++++++- mush/tests/test_context.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/mush/context.py b/mush/context.py index b4060c2..db13292 100644 --- a/mush/context.py +++ b/mush/context.py @@ -63,6 +63,8 @@ def type_key(type_tuple): class Context: "Stores resources for a particular run." + _parent = None + def __init__(self): self._store = {} @@ -134,9 +136,24 @@ def call(self, obj, requires=None): return obj(*args, **kw) def get(self, key: ResourceKey, default=None): - resolver = self._store.get(key, None) + context = self + resolver = None + + while resolver is None and context is not None: + resolver = context._store.get(key, None) + if resolver is None: + context = context._parent + elif context is not self: + self._store[key] = resolver + if resolver is None: if key is Context: return self return default + return resolver(self, default) + + def nest(self): + nested = type(self)() + nested._parent = self + return nested diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index b05581d..eebb69c 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -342,3 +342,17 @@ def test_get_present(self): def test_get_missing(self): context = Context() compare(context.get('foo'), expected=None) + + def test_nest(self): + c1 = Context() + c1.add('a', provides='a') + c1.add('c', provides='c') + c2 = c1.nest() + c2.add('b', provides='b') + c2.add('d', provides='c') + compare(c2.get('a'), expected='a') + compare(c2.get('b'), expected='b') + compare(c2.get('c'), expected='d') + compare(c1.get('a'), expected='a') + compare(c1.get('b'), expected=None) + compare(c1.get('c'), expected='c') From ff3e77aac6b1d526449e4c8403c5c7103d588729 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sat, 22 Feb 2020 18:35:39 +0000 Subject: [PATCH 034/159] Switch to storing values for the common case. This stops the async infection from spreading further... --- mush/context.py | 44 +++++++++++++++++++++++++------------ mush/resolvers.py | 14 ------------ mush/tests/test_context.py | 28 +++++++++++------------ mush/tests/test_resolver.py | 17 +++++++------- 4 files changed, 53 insertions(+), 50 deletions(-) diff --git a/mush/context.py b/mush/context.py index db13292..039d6f7 100644 --- a/mush/context.py +++ b/mush/context.py @@ -4,7 +4,6 @@ extract_requires, RequiresType, ResourceKey, ResourceValue, Resolver ) from .markers import missing -from .resolvers import ValueResolver NONE_TYPE = type(None) @@ -60,6 +59,19 @@ def type_key(type_tuple): return type.__name__ +class ResolvableValue: + __slots__ = ('value', 'resolver') + + def __init__(self, value, resolver=None): + self.value = value + self.resolver = resolver + + def __repr__(self): + if self.resolver is None: + return repr(self.value) + return repr(self.resolver) + + class Context: "Stores resources for a particular run." @@ -87,9 +99,7 @@ def add(self, raise ValueError('Cannot add None to context') if provides in self._store: raise ContextError(f'Context already contains {provides!r}') - if resolver is None: - resolver = ValueResolver(resource) - self._store[provides] = resolver + self._store[provides] = ResolvableValue(resource, resolver) def remove(self, key: ResourceKey, *, strict: bool = True): """ @@ -135,23 +145,29 @@ def call(self, obj, requires=None): return obj(*args, **kw) - def get(self, key: ResourceKey, default=None): + def _get(self, key, default): context = self - resolver = None + resolvable = None - while resolver is None and context is not None: - resolver = context._store.get(key, None) - if resolver is None: + while resolvable is None and context is not None: + resolvable = context._store.get(key, None) + if resolvable is None: context = context._parent elif context is not self: - self._store[key] = resolver + self._store[key] = resolvable - if resolver is None: + if resolvable is None: if key is Context: - return self - return default + return ResolvableValue(self) + return ResolvableValue(default) - return resolver(self, default) + return resolvable + + def get(self, key: ResourceKey, default=None): + resolvable = self._get(key, default) + if resolvable.resolver is not None: + return resolvable.resolver(self, default) + return resolvable.value def nest(self): nested = type(self)() diff --git a/mush/resolvers.py b/mush/resolvers.py index 29bb77c..7220b8a 100644 --- a/mush/resolvers.py +++ b/mush/resolvers.py @@ -1,20 +1,6 @@ from .declarations import returns as returns_declaration -class ValueResolver: - - __slots__ = ['value'] - - def __init__(self, value): - self.value = value - - def __call__(self, context, default): - return self.value - - def __repr__(self): - return repr(self.value) - - class Lazy(object): def __init__(self, obj, requires, returns): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index eebb69c..2e8d668 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,13 +1,13 @@ from unittest import TestCase from mock import Mock +from mush.context import ResolvableValue from testfixtures import ShouldRaise, compare from mush import Context, ContextError from mush.declarations import ( nothing, requires, item, attr, returns, returns_mapping, Requirement ) -from mush.resolvers import ValueResolver class TheType(object): @@ -22,7 +22,7 @@ def test_simple(self): context = Context() context.add(obj) - compare(context._store, expected={TheType: ValueResolver(obj)}) + compare(context._store, expected={TheType: ResolvableValue(obj)}) expected = ( ": \n" @@ -39,7 +39,7 @@ def test_type_as_string(self): expected = ("\n" "}>") - compare(context._store, expected={'my label': ValueResolver(obj)}) + compare(context._store, expected={'my label': ResolvableValue(obj)}) self.assertEqual(repr(context), expected) self.assertEqual(str(context), expected) @@ -48,7 +48,7 @@ class T2(object): pass obj = TheType() context = Context() context.add(obj, provides=T2) - compare(context._store, expected={T2: ValueResolver(obj)}) + compare(context._store, expected={T2: ResolvableValue(obj)}) expected = ("\n" "}>") @@ -113,7 +113,7 @@ def test_add_none(self): def test_add_none_with_type(self): context = Context() context.add(None, TheType) - compare(context._store, expected={TheType: ValueResolver(None)}) + compare(context._store, expected={TheType: ResolvableValue(None)}) def test_call_basic(self): def foo(): @@ -129,7 +129,7 @@ def foo(obj): context.add('bar', 'baz') result = context.call(foo, requires('baz')) compare(result, 'bar') - compare({'baz': ValueResolver('bar')}, actual=context._store) + compare({'baz': ResolvableValue('bar')}, actual=context._store) def test_call_requires_type(self): def foo(obj): @@ -138,7 +138,7 @@ def foo(obj): context.add('bar', TheType) result = context.call(foo, requires(TheType)) compare(result, 'bar') - compare({TheType: ValueResolver('bar')}, actual=context._store) + compare({TheType: ResolvableValue('bar')}, actual=context._store) def test_call_requires_missing(self): def foo(obj): return obj @@ -175,8 +175,8 @@ def foo(x, y): context.add('bar', 'baz') result = context.call(foo, requires(y='baz', x=TheType)) compare(result, ('foo', 'bar')) - compare({TheType: ValueResolver('foo'), - 'baz': ValueResolver('bar')}, + compare({TheType: ResolvableValue('foo'), + 'baz': ResolvableValue('bar')}, actual=context._store) def test_call_requires_optional_present(self): @@ -186,7 +186,7 @@ def foo(x=1): context.add(2, TheType) result = context.call(foo, requires(TheType)) compare(result, 2) - compare({TheType: ValueResolver(2)}, actual=context._store) + compare({TheType: ResolvableValue(2)}, actual=context._store) def test_call_requires_optional_missing(self): def foo(x: TheType = 1): @@ -210,7 +210,7 @@ def foo(x:'foo'=1): context.add(2, 'foo') result = context.call(foo) compare(result, 2) - compare({'foo': ValueResolver(2)}, actual=context._store) + compare({'foo': ResolvableValue(2)}, actual=context._store) def test_call_requires_item(self): def foo(x): @@ -275,7 +275,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns(TheType)) compare(result, 'bar') - compare({TheType: ValueResolver('bar')}, actual=context._store) + compare({TheType: ResolvableValue('bar')}, actual=context._store) def test_returns_sequence(self): def foo(): @@ -283,7 +283,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns('foo', 'bar')) compare(result, (1, 2)) - compare({'foo': ValueResolver(1), 'bar': ValueResolver(2)}, + compare({'foo': ResolvableValue(1), 'bar': ResolvableValue(2)}, actual=context._store) def test_returns_mapping(self): @@ -292,7 +292,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns_mapping()) compare(result, {'foo': 1, 'bar': 2}) - compare({'foo': ValueResolver(1), 'bar': ValueResolver(2)}, + compare({'foo': ResolvableValue(1), 'bar': ResolvableValue(2)}, actual=context._store) def test_ignore_return(self): diff --git a/mush/tests/test_resolver.py b/mush/tests/test_resolver.py index 006c445..b30564a 100644 --- a/mush/tests/test_resolver.py +++ b/mush/tests/test_resolver.py @@ -1,21 +1,22 @@ +from mush.context import ResolvableValue from testfixtures import compare from mush import returns -from mush.resolvers import Lazy, ValueResolver +from mush.resolvers import Lazy from mush.markers import Marker foo = Marker('foo') -class TestValueResolver: +class TestLazy: def test_repr(self): - f = ValueResolver(foo) - compare(repr(f), expected='') + f = Lazy(foo, None, returns('foo')) + compare(repr(f), expected='>') -class TestFactory: +class TestResolvableValue: - def test_repr(self): - f = Lazy(foo, None, returns('foo')) - compare(repr(f), expected='>') + def test_repr_with_resolver(self): + compare(repr(ResolvableValue(None, foo)), + expected='') From e5897cdbcbb5faf87b5fba2a942f751ad7813354 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sat, 22 Feb 2020 18:35:51 +0000 Subject: [PATCH 035/159] Add pytest-asyncio for testing. --- setup.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8cea6ff..a44bb0c 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,13 @@ include_package_data=True, python_requires='>=3.6', extras_require=dict( - test=['pytest', 'pytest-cov', 'mock', 'sybil', 'testfixtures>=6.13'], + test=[ + 'mock', + 'pytest', + 'pytest-asyncio', + 'pytest-cov', + 'sybil', + 'testfixtures>=6.13' + ], build=['sphinx', 'setuptools-git', 'wheel', 'twine'] )) From 0333af2dcdadab19c7912fe6b258a2f9551f3e21 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sat, 22 Feb 2020 22:39:24 +0000 Subject: [PATCH 036/159] Using another module's marker is probably not going to end well in the long run. --- mush/declarations.py | 5 +++-- mush/markers.py | 5 +---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index eeed764..0f23561 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -276,8 +276,9 @@ def guess_requirements(obj): args = [] kw = {} for name, p in Signature.from_callable(obj).parameters.items(): - key = p.name if p.annotation is missing else p.annotation - requirement = Requirement(key, default=p.default) + key = p.name if p.annotation is p.empty else p.annotation + default = missing if p.default is p.empty else p.default + requirement = Requirement(key, default=default) if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}: args.append(requirement) elif p.kind is p.KEYWORD_ONLY: diff --git a/mush/markers.py b/mush/markers.py index 9ac04f6..77d71e7 100644 --- a/mush/markers.py +++ b/mush/markers.py @@ -1,6 +1,3 @@ -from inspect import Parameter - - class Marker(object): def __init__(self, name): @@ -11,4 +8,4 @@ def __repr__(self): not_specified = Marker('not_specified') -missing = Parameter.empty +missing = Marker('missing') From 44dd97c604d4390d205268ca241bcb2c7f52e442 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 23 Feb 2020 11:42:14 +0000 Subject: [PATCH 037/159] Drop support for getting context from itself. This really makes no sense: you already have the context object! --- mush/context.py | 7 ++++--- mush/tests/test_context.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mush/context.py b/mush/context.py index 039d6f7..8635a4d 100644 --- a/mush/context.py +++ b/mush/context.py @@ -136,7 +136,10 @@ def call(self, obj, requires=None): for target, requirement in requires: o = requirement.resolve(self) if o is missing: - if requirement.default is missing: + key = requirement.key + if isinstance(key, type) and issubclass(key, Context): + o = self + elif requirement.default is missing: raise ContextError('No %s in context' % requirement.repr) if target is None: args.append(o) @@ -157,8 +160,6 @@ def _get(self, key, default): self._store[key] = resolvable if resolvable is None: - if key is Context: - return ResolvableValue(self) return ResolvableValue(default) return resolvable diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 2e8d668..9bb7a05 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -315,7 +315,6 @@ def test_context_contains_itself(self): def return_context(context: Context): return context assert context.call(return_context) is context - assert context.get(Context) is context def test_remove(self): context = Context() From 8a98888f7b16ddd66fd24b7ad46b2e62a6c96420 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 23 Feb 2020 12:00:58 +0000 Subject: [PATCH 038/159] Implementation of an asyncio-compatible Context. --- mush/__init__.py | 3 +- mush/asyncio.py | 62 +++++++++++++ mush/context.py | 28 ++++-- mush/declarations.py | 8 -- mush/tests/test_async_context.py | 148 +++++++++++++++++++++++++++++++ 5 files changed, 233 insertions(+), 16 deletions(-) create mode 100644 mush/asyncio.py create mode 100644 mush/tests/test_async_context.py diff --git a/mush/__init__.py b/mush/__init__.py index cb5079b..7f75fda 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -6,9 +6,10 @@ ) from .plug import Plug from .context import Context, ContextError +from .asyncio import AsyncContext __all__ = [ - 'Context', 'ContextError', + 'Context', 'AsyncContext', 'ContextError', 'Runner', 'requires', 'returns_result_type', 'returns_mapping', 'returns_sequence', 'returns', diff --git a/mush/asyncio.py b/mush/asyncio.py new file mode 100644 index 0000000..9ef33c5 --- /dev/null +++ b/mush/asyncio.py @@ -0,0 +1,62 @@ +import asyncio +from functools import partial + +from mush import Context, ContextError +from mush.declarations import ResourceKey, RequiresType, extract_requires +from mush.markers import missing + + +async def ensure_async(func, *args, **kw): + if asyncio.iscoroutinefunction(func): + return await func(*args, **kw) + if kw: + func = partial(func, **kw) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, func, *args) + + +class SyncContext: + + def __init__(self, context, loop): + self.context = context + self.loop = loop + + def get(self, key: ResourceKey, default=None): + coro = self.context.get(key, default) + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + return future.result() + + +class AsyncContext(Context): + + def __init__(self): + super().__init__() + self._sync_context = SyncContext(self, asyncio.get_event_loop()) + + async def get(self, key: ResourceKey, default=None): + resolvable = self._get(key, default) + if resolvable.resolver is not None: + if asyncio.iscoroutinefunction(resolvable.resolver): + context = self + else: + context = self._sync_context + return await ensure_async(resolvable.resolver, context, default) + return resolvable.value + + async def call(self, obj, requires=None): + args = [] + kw = {} + resolving = self._resolve( + obj, requires, args, kw, + self if asyncio.iscoroutinefunction(obj) else self._sync_context + ) + for requirement in resolving: + o = await self.get(requirement.key, requirement.default) + resolving.send(o) + return await ensure_async(obj, *args, **kw) + + async def extract(self, obj, requires, returns): + result = await self.call(obj, requires) + for type, obj in returns.process(result): + self.add(obj, type) + return result diff --git a/mush/context.py b/mush/context.py index 8635a4d..88a0d41 100644 --- a/mush/context.py +++ b/mush/context.py @@ -126,26 +126,40 @@ def extract(self, obj, requires, returns): self.add(obj, type) return result - def call(self, obj, requires=None): + @staticmethod + def _resolve(obj, requires, args, kw, context): + if requires.__class__ is not RequiresType: requires = extract_requires(obj, requires) - args = [] - kw = {} - for target, requirement in requires: - o = requirement.resolve(self) + o = yield requirement + + if o is not requirement.default: + for op in requirement.ops: + o = op(o) + if o is missing: key = requirement.key if isinstance(key, type) and issubclass(key, Context): - o = self - elif requirement.default is missing: + o = context + else: raise ContextError('No %s in context' % requirement.repr) + if target is None: args.append(o) else: kw[target] = o + yield + + def call(self, obj, requires=None): + args = [] + kw = {} + resolving = self._resolve(obj, requires, args, kw, self) + for requirement in resolving: + o = self.get(requirement.key, requirement.default) + resolving.send(o) return obj(*args, **kw) def _get(self, key, default): diff --git a/mush/declarations.py b/mush/declarations.py index 0f23561..fa920c6 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -39,14 +39,6 @@ def __init__(self, source, default=missing): source = source.type self.key: ResourceKey = source - def resolve(self, context): - o = context.get(self.key, missing) - if o is missing: - return self.default - for op in self.ops: - o = op(o) - return o - def __repr__(self): return f'Requirement({self.repr}, default={self.default})' diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py new file mode 100644 index 0000000..6c88559 --- /dev/null +++ b/mush/tests/test_async_context.py @@ -0,0 +1,148 @@ +import asyncio +import pytest + +from mush import AsyncContext, Context, requires, returns +from testfixtures import compare + + +@pytest.mark.asyncio +async def test_get_is_async(): + context = AsyncContext() + result = context.get('foo', default='bar') + assert asyncio.iscoroutine(result) + compare(await result, expected='bar') + + +@pytest.mark.asyncio +async def test_get_async_resolver(): + async def resolver(*args): + return 'bar' + context = AsyncContext() + context.add(provides='foo', resolver=resolver) + compare(await context.get('foo'), expected='bar') + + +@pytest.mark.asyncio +async def test_get_async_resolver_calls_back_into_async(): + async def resolver(context, default): + return await context.get('baz') + context = AsyncContext() + context.add('bar', provides='baz') + context.add(provides='foo', resolver=resolver) + compare(await context.get('foo'), expected='bar') + + +@pytest.mark.asyncio +async def test_get_sync_resolver(): + def resolver(*args): + return 'bar' + context = AsyncContext() + context.add(provides='foo', resolver=resolver) + compare(await context.get('foo'), expected='bar') + + +@pytest.mark.asyncio +async def test_get_sync_resolver_calls_back_into_async(): + def resolver(context, default): + return context.get('baz') + context = AsyncContext() + context.add('bar', provides='baz') + context.add(provides='foo', resolver=resolver) + compare(await context.get('foo'), expected='bar') + + +@pytest.mark.asyncio +async def test_call_is_async(): + context = AsyncContext() + def it(): + return 'bar' + result = context.call(it) + assert asyncio.iscoroutine(result) + compare(await result, expected='bar') + + +@pytest.mark.asyncio +async def test_call_async(): + context = AsyncContext() + context.add('1', provides='a') + async def it(a, b='2'): + return a+b + compare(await context.call(it), expected='12') + + +@pytest.mark.asyncio +async def test_call_async_requires_context(): + context = AsyncContext() + context.add('bar', provides='baz') + async def it(context: Context): + return await context.get('baz') + compare(await context.call(it), expected='bar') + + +@pytest.mark.asyncio +async def test_call_async_requires_async_context(): + context = AsyncContext() + context.add('bar', provides='baz') + async def it(context: AsyncContext): + return await context.get('baz') + compare(await context.call(it), expected='bar') + + +@pytest.mark.asyncio +async def test_call_sync(): + context = AsyncContext() + context.add('foo', provides='baz') + def it(*, baz): + return baz+'bar' + compare(await context.call(it), expected='foobar') + + +@pytest.mark.asyncio +async def test_call_sync_requires_context(): + context = AsyncContext() + context.add('bar', provides='baz') + def it(context: Context): + return context.get('baz') + compare(await context.call(it), expected='bar') + + +@pytest.mark.asyncio +async def test_call_sync_requires_async_context(): + context = AsyncContext() + context.add('bar', provides='baz') + def it(context: AsyncContext): + return context.get('baz') + compare(await context.call(it), expected='bar') + + +@pytest.mark.asyncio +async def test_extract_is_async(): + context = AsyncContext() + def it(): + return 'bar' + result = context.extract(it, requires(), returns('baz')) + assert asyncio.iscoroutine(result) + compare(await result, expected='bar') + compare(await context.get('baz'), expected='bar') + + +@pytest.mark.asyncio +async def test_extract_async(): + context = AsyncContext() + context.add('foo', provides='bar') + async def it(context): + return await context.get('bar')+'bar' + result = context.extract(it, requires(Context), returns('baz')) + compare(await result, expected='foobar') + compare(await context.get('baz'), expected='foobar') + + +@pytest.mark.asyncio +async def test_extract_sync(): + context = AsyncContext() + context.add('foo', provides='bar') + def it(context): + return context.get('bar')+'bar' + result = context.extract(it, requires(Context), returns('baz')) + compare(await result, expected='foobar') + compare(await context.get('baz'), expected='foobar') From 00ee2083f6a09371325327fc01445809d012e3d0 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 23 Feb 2020 17:53:19 +0000 Subject: [PATCH 039/159] Add support for custom requirement classes. --- mush/asyncio.py | 23 ++++++++++++----------- mush/context.py | 5 ++++- mush/declarations.py | 9 +++++++-- mush/tests/test_async_context.py | 31 +++++++++++++++++++++++++++++++ mush/tests/test_context.py | 29 ++++++++++++++++++++++++++++- 5 files changed, 82 insertions(+), 15 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 9ef33c5..5e0c8fc 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -33,25 +33,26 @@ def __init__(self): super().__init__() self._sync_context = SyncContext(self, asyncio.get_event_loop()) + def _context_for(self, obj): + return self if asyncio.iscoroutinefunction(obj) else self._sync_context + async def get(self, key: ResourceKey, default=None): resolvable = self._get(key, default) - if resolvable.resolver is not None: - if asyncio.iscoroutinefunction(resolvable.resolver): - context = self - else: - context = self._sync_context - return await ensure_async(resolvable.resolver, context, default) + r = resolvable.resolver + if r is not None: + return await ensure_async(r, self._context_for(r), default) return resolvable.value async def call(self, obj, requires=None): args = [] kw = {} - resolving = self._resolve( - obj, requires, args, kw, - self if asyncio.iscoroutinefunction(obj) else self._sync_context - ) + resolving = self._resolve(obj, requires, args, kw, self._context_for(obj)) for requirement in resolving: - o = await self.get(requirement.key, requirement.default) + r = requirement.resolve + if r is not None: + o = await ensure_async(r, self._context_for(r)) + else: + o = await self.get(requirement.key, requirement.default) resolving.send(o) return await ensure_async(obj, *args, **kw) diff --git a/mush/context.py b/mush/context.py index 88a0d41..7ed70d2 100644 --- a/mush/context.py +++ b/mush/context.py @@ -158,7 +158,10 @@ def call(self, obj, requires=None): kw = {} resolving = self._resolve(obj, requires, args, kw, self) for requirement in resolving: - o = self.get(requirement.key, requirement.default) + if requirement.resolve: + o = requirement.resolve(self) + else: + o = self.get(requirement.key, requirement.default) resolving.send(o) return obj(*args, **kw) diff --git a/mush/declarations.py b/mush/declarations.py index fa920c6..e96db35 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -28,6 +28,8 @@ def set_mush(obj, key, value): class Requirement: + resolve = None + def __init__(self, source, default=missing): self.repr = name_or_repr(source) @@ -269,8 +271,11 @@ def guess_requirements(obj): kw = {} for name, p in Signature.from_callable(obj).parameters.items(): key = p.name if p.annotation is p.empty else p.annotation - default = missing if p.default is p.empty else p.default - requirement = Requirement(key, default=default) + if isinstance(p.annotation, Requirement): + requirement = p.annotation + else: + default = missing if p.default is p.empty else p.default + requirement = Requirement(key, default=default) if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}: args.append(requirement) elif p.kind is p.KEYWORD_ONLY: diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 6c88559..6b337a7 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -2,6 +2,7 @@ import pytest from mush import AsyncContext, Context, requires, returns +from mush.declarations import Requirement from testfixtures import compare @@ -146,3 +147,33 @@ def it(context): result = context.extract(it, requires(Context), returns('baz')) compare(await result, expected='foobar') compare(await context.get('baz'), expected='foobar') + + +@pytest.mark.asyncio +async def test_custom_requirement_async_resolve(): + + class FromRequest(Requirement): + async def resolve(self, context): + return (await context.get('request'))[self.key] + + def foo(bar: FromRequest('bar')): + return bar + + context = AsyncContext() + context.add({'bar': 'foo'}, provides='request') + compare(await context.call(foo), expected='foo') + + +@pytest.mark.asyncio +async def test_custom_requirement_sync_resolve(): + + class FromRequest(Requirement): + def resolve(self, context): + return context.get('request')[self.key] + + def foo(bar: FromRequest('bar')): + return bar + + context = AsyncContext() + context.add({'bar': 'foo'}, provides='request') + compare(await context.call(foo), expected='foo') diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 9bb7a05..c0c6d1f 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -6,7 +6,7 @@ from mush import Context, ContextError from mush.declarations import ( - nothing, requires, item, attr, returns, returns_mapping, Requirement + nothing, requires, item, attr, returns, returns_mapping, Requirement, missing ) @@ -355,3 +355,30 @@ def test_nest(self): compare(c1.get('a'), expected='a') compare(c1.get('b'), expected=None) compare(c1.get('c'), expected='c') + + def test_custom_requirement(self): + + class FromRequest(Requirement): + def resolve(self, context): + return context.get('request')[self.key] + + def foo(bar: FromRequest('bar')): + return bar + + context = Context() + context.add({'bar': 'foo'}, provides='request') + compare(context.call(foo), expected='foo') + + def test_custom_requirement_returns_missing(self): + + class FromRequest(Requirement): + def resolve(self, context): + return context.get('request').get(self.key, missing) + + def foo(bar: FromRequest('bar')): + pass + + context = Context(default_requirement_type=FromRequest) + context.add({}, provides='request') + with ShouldRaise(ContextError("No 'bar' in context")): + compare(context.call(foo)) From fb955bf511fc06f788a49fb0edf6eaf7c9e8740b Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 24 Feb 2020 18:41:01 +0000 Subject: [PATCH 040/159] Add support for requirements being explicitly specified and combined. Different parameters can have their requirements explicitly specified. --- docs/use.txt | 2 +- mush/callpoints.py | 4 +- mush/context.py | 6 +- mush/declarations.py | 103 +++++++++++++++++++++----------- mush/tests/test_callpoints.py | 24 ++++---- mush/tests/test_declarations.py | 47 ++++++++++++--- mush/tests/test_runner.py | 4 +- 7 files changed, 126 insertions(+), 64 deletions(-) diff --git a/docs/use.txt b/docs/use.txt index 4ac7300..23d60c7 100755 --- a/docs/use.txt +++ b/docs/use.txt @@ -951,7 +951,7 @@ If you have a base runner such as this: def parse_args(parser): return parser.parse_args() - def load_config(): + def load_config(config_url): return json.loads(urllib2.urlopen('...').read()) def finalise_things(): diff --git a/mush/callpoints.py b/mush/callpoints.py index 3782c10..2fda5ee 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,5 +1,5 @@ from .context import Context -from .declarations import nothing, extract_requires, extract_returns +from .declarations import nothing, extract_requires, extract_returns, RequiresType from .resolvers import Lazy @@ -13,7 +13,7 @@ def __init__(self, obj, requires=None, returns=None, lazy=False): returns = extract_returns(obj, returns) if lazy: obj = Lazy(obj, requires, returns) - requires = requires(Context) + requires = RequiresType(Context) returns = nothing self.obj = obj self.requires = requires diff --git a/mush/context.py b/mush/context.py index 7ed70d2..5835f61 100644 --- a/mush/context.py +++ b/mush/context.py @@ -132,7 +132,7 @@ def _resolve(obj, requires, args, kw, context): if requires.__class__ is not RequiresType: requires = extract_requires(obj, requires) - for target, requirement in requires: + for requirement in requires: o = yield requirement if o is not requirement.default: @@ -146,10 +146,10 @@ def _resolve(obj, requires, args, kw, context): else: raise ContextError('No %s in context' % requirement.repr) - if target is None: + if requirement.target is None: args.append(o) else: - kw[target] = o + kw[requirement.target] = o yield diff --git a/mush/declarations.py b/mush/declarations.py index e96db35..74db73a 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -5,7 +5,7 @@ WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, update_wrapper as functools_update_wrapper, ) -from inspect import Signature +from inspect import signature from itertools import chain from typing import Type, Callable, NewType, Union, Any @@ -30,9 +30,9 @@ class Requirement: resolve = None - def __init__(self, source, default=missing): + def __init__(self, source, default=missing, target=None): self.repr = name_or_repr(source) - + self.target = target self.default = default self.ops = deque() @@ -42,7 +42,7 @@ def __init__(self, source, default=missing): self.key: ResourceKey = source def __repr__(self): - return f'Requirement({self.repr}, default={self.default})' + return f'{type(self).__name__}({self.repr}, default={self.default})' class RequiresType(list): @@ -65,13 +65,15 @@ def __init__(self, *args, **kw): ((None, arg) for arg in args), kw.items(), ): - if not isinstance(requirement, Requirement): - requirement = Requirement(requirement) - self.append((target, requirement)) + if isinstance(requirement, Requirement): + requirement.target = target + else: + requirement = Requirement(requirement, target=target) + self.append(requirement) def __repr__(self): - parts = (r.repr if t is None else f'{t}={r.repr}' - for (t, r) in self) + parts = (r.repr if r.target is None else f'{r.target}={r.repr}' + for r in self) return f"requires({', '.join(parts)})" def __call__(self, obj): @@ -266,40 +268,71 @@ def process(self, result): result_type = returns_result_type() -def guess_requirements(obj): - args = [] - kw = {} - for name, p in Signature.from_callable(obj).parameters.items(): - key = p.name if p.annotation is p.empty else p.annotation - if isinstance(p.annotation, Requirement): +def _unpack_requires(by_name, by_index, requires_): + + for i, requirement in enumerate(requires_): + if requirement.target is None: + try: + arg = by_index[i] + except IndexError: + # case where something takes *args + arg = i + else: + arg = requirement.target + by_name[arg] = requirement + + +def extract_requires(obj, explicit=None): + # from annotations + by_name = {} + for name, p in signature(obj).parameters.items(): + if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): + continue + + if isinstance(p.default, Requirement): + requirement = p.default + elif isinstance(p.annotation, Requirement): requirement = p.annotation else: + key = p.name if p.annotation is p.empty else p.annotation default = missing if p.default is p.empty else p.default requirement = Requirement(key, default=default) - if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}: - args.append(requirement) - elif p.kind is p.KEYWORD_ONLY: - kw[name] = requirement - return requires(*args, **kw) + if p.kind is p.KEYWORD_ONLY: + requirement.target = p.name + by_name[name] = requirement + + by_index = list(by_name) + + # from declarations + mush_declarations = getattr(obj, '__mush__', None) + if mush_declarations is not None: + requires_ = mush_declarations.get('requires') + if requires_ is not None: + _unpack_requires(by_name, by_index, requires_) + + # explicit + if explicit is not None: + if isinstance(explicit, (list, tuple)): + requires_ = requires(*explicit) + elif not isinstance(explicit, requires): + requires_ = requires(explicit) + else: + requires_ = explicit + _unpack_requires(by_name, by_index, requires_) -def extract_requires(obj, explicit=None): - if explicit is None: - mush_declarations = getattr(obj, '__mush__', {}) - requires_ = mush_declarations.get('requires', None) - if requires_ is None: - requires_ = guess_requirements(obj) - else: - requires_ = explicit + if not by_name: + return nothing - if isinstance(requires_, requires): - pass - elif isinstance(requires_, (list, tuple)): - requires_ = requires(*requires_) - else: - requires_ = requires(requires_) + args = [] + kw = {} + for requirement in by_name.values(): + if requirement.target is None: + args.append(requirement) + else: + kw[requirement.target] = requirement - return requires_ or nothing + return requires(*args, **kw) def extract_returns(obj: Callable, explicit: ReturnsType = None): diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 8dd0995..1b739bc 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -21,12 +21,12 @@ def test_passive_attributes(self): compare(point.labels, set()) def test_supplied_explicitly(self): - obj = object() + def foo(a1): pass rq = requires('foo') rt = returns('bar') - result = CallPoint(obj, rq, rt)(self.context) + result = CallPoint(foo, rq, rt)(self.context) compare(result, self.context.extract.return_value) - self.context.extract.assert_called_with(obj, rq, rt) + self.context.extract.assert_called_with(foo, rq, rt) def test_extract_from_decorations(self): rq = requires('foo') @@ -34,7 +34,7 @@ def test_extract_from_decorations(self): @rq @rt - def foo(): pass + def foo(a1): pass result = CallPoint(foo)(self.context) compare(result, self.context.extract.return_value) @@ -49,7 +49,7 @@ class Wrapper(object): def __init__(self, func): self.func = func def __call__(self): - return 'the '+self.func() + return self.func('the ') def my_dec(func): return update_wrapper(Wrapper(func), func) @@ -57,8 +57,8 @@ def my_dec(func): @my_dec @rq @rt - def foo(): - return 'answer' + def foo(prefix): + return prefix+'answer' self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) result = CallPoint(foo)(self.context) @@ -67,7 +67,7 @@ def foo(): def test_explicit_trumps_decorators(self): @requires('foo') @returns('bar') - def foo(): pass + def foo(a1): pass rq = requires('baz') rt = returns('bob') @@ -82,7 +82,7 @@ def foo(): pass compare(repr(foo)+" requires() returns_result_type()", repr(point)) def test_repr_maximal(self): - def foo(): pass + def foo(a1): pass point = CallPoint(foo, requires('foo'), returns('bar')) point.labels.add('baz') point.labels.add('bob') @@ -90,7 +90,7 @@ def foo(): pass repr(point)) def test_convert_to_requires_and_returns(self): - def foo(): pass + def foo(baz): pass point = CallPoint(foo, requires='foo', returns='bar') self.assertTrue(isinstance(point.requires, requires)) self.assertTrue(isinstance(point.returns, returns)) @@ -98,7 +98,7 @@ def foo(): pass repr(point)) def test_convert_to_requires_and_returns_tuple(self): - def foo(): pass + def foo(a1, a2): pass point = CallPoint(foo, requires=('foo', 'bar'), returns=('baz', 'bob')) @@ -108,7 +108,7 @@ def foo(): pass repr(point)) def test_convert_to_requires_and_returns_list(self): - def foo(): pass + def foo(a1, a2): pass point = CallPoint(foo, requires=['foo', 'bar'], returns=['baz', 'bob']) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 0cd988f..36290fa 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -37,20 +37,20 @@ def test_types(self): r = requires(Type1, Type2, x=Type3, y=Type4) compare(repr(r), 'requires(Type1, Type2, x=Type3, y=Type4)') compare(r, expected=[ - (None, Requirement(Type1)), - (None, Requirement(Type2)), - ('x', Requirement(Type3)), - ('y', Requirement(Type4)), + Requirement(Type1), + Requirement(Type2), + Requirement(Type3, target='x'), + Requirement(Type4, target='y'), ]) def test_strings(self): r = requires('1', '2', x='3', y='4') compare(repr(r), "requires('1', '2', x='3', y='4')") compare(r, expected=[ - (None, Requirement('1')), - (None, Requirement('2')), - ('x', Requirement('3')), - ('y', Requirement('4')), + Requirement('1'), + Requirement('2'), + Requirement('3', target='x'), + Requirement('4', target='y'), ]) def test_tuple_arg(self): @@ -66,7 +66,7 @@ def test_decorator_paranoid(self): def foo(): return 'bar' - compare(foo.__mush__['requires'], expected=[(None, Requirement(Type1))]) + compare(foo.__mush__['requires'], expected=[Requirement(Type1)]) compare(foo(), 'bar') @@ -400,3 +400,32 @@ def foo(a, b=1, *, c, d=None): pass c='c', d=Requirement('d', default=None)), expected_rt=result_type) + + +class TestDeclarationsFromMultipleSources: + + def test_declarations_from_different_sources(self): + r1 = Requirement('a') + r2 = Requirement('b') + r3 = Requirement('c') + + @requires(b=r2) + def foo(a: r1, b, c=r3): + pass + + check_extract(foo, + expected_rq=requires(r1, b=r2, c=r3), + expected_rt=result_type) + + def test_declaration_priorities(self): + r1 = Requirement('a') + r2 = Requirement('b') + r3 = Requirement('c') + + @requires(a=r1) + def foo(a: r2 = r3, b: str = r2, c = r3): + pass + + check_extract(foo, + expected_rq=requires(r1, b=r2, c=r3), + expected_rt=result_type) diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 015df31..7d6ccc5 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -530,7 +530,7 @@ def job3(arg): pass # pragma: nocover def job4(): pass - def job5(): pass + def job5(foo, bar): pass runner = Runner() runner.add(job1, label='1') @@ -556,7 +556,7 @@ def job5(): pass '', 'Still to call:', repr(job4)+' requires() returns_result_type() <-- 4', - repr(job5)+" requires('foo', bar='baz') returns('bob')", + repr(job5)+" requires('foo', 'baz') returns('bob')", )) compare(text, actual=repr(s.raised)) compare(text, actual=str(s.raised)) From 2af74b800ae426f76fc0ffd8df48bd3a639ca932 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 25 Feb 2020 07:48:54 +0000 Subject: [PATCH 041/159] Move logic out of RequiresType constructor. This simplified use in extract_requires. --- mush/callpoints.py | 7 ++-- mush/declarations.py | 66 +++++++++++++++-------------------- mush/tests/test_callpoints.py | 8 ++--- 3 files changed, 37 insertions(+), 44 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index 2fda5ee..25c683f 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,5 +1,8 @@ from .context import Context -from .declarations import nothing, extract_requires, extract_returns, RequiresType +from .declarations import ( + nothing, extract_requires, extract_returns, + requires as requires_function +) from .resolvers import Lazy @@ -13,7 +16,7 @@ def __init__(self, obj, requires=None, returns=None, lazy=False): returns = extract_returns(obj, returns) if lazy: obj = Lazy(obj, requires, returns) - requires = RequiresType(Context) + requires = requires_function(Context) returns = nothing self.obj = obj self.requires = requires diff --git a/mush/declarations.py b/mush/declarations.py index 74db73a..020d112 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -46,31 +46,6 @@ def __repr__(self): class RequiresType(list): - """ - Represents requirements for a particular callable. - - The passed in `args` and `kw` should map to the types, including - any required :class:`~.declarations.how`, for the matching - arguments or keyword parameters the callable requires. - - String names for resources must be used instead of types where the callable - returning those resources is configured to return the named resource. - """ - - def __init__(self, *args, **kw): - super().__init__() - check_type(*args) - check_type(*kw.values()) - for target, requirement in chain( - ((None, arg) for arg in args), - kw.items(), - ): - if isinstance(requirement, Requirement): - requirement.target = target - else: - requirement = Requirement(requirement, target=target) - self.append(requirement) - def __repr__(self): parts = (r.repr if r.target is None else f'{r.target}={r.repr}' for r in self) @@ -81,7 +56,30 @@ def __call__(self, obj): return obj -requires = RequiresType +def requires(*args, **kw): + """ + Represents requirements for a particular callable. + + The passed in ``args`` and ``kw`` should map to the types, including + any required :class:`~.declarations.how`, for the matching + arguments or keyword parameters the callable requires. + + String names for resources must be used instead of types where the callable + returning those resources is configured to return the named resource. + """ + requires_ = RequiresType() + check_type(*args) + check_type(*kw.values()) + for target, requirement in chain( + ((None, arg) for arg in args), + kw.items(), + ): + if isinstance(requirement, Requirement): + requirement.target = target + else: + requirement = Requirement(requirement, target=target) + requires_.append(requirement) + return requires_ class ReturnsType(object): @@ -252,7 +250,7 @@ def check_type(*objs): ) -class Nothing(requires, returns): +class Nothing(RequiresType, returns): def process(self, result): return () @@ -282,7 +280,7 @@ def _unpack_requires(by_name, by_index, requires_): by_name[arg] = requirement -def extract_requires(obj, explicit=None): +def extract_requires(obj: Callable, explicit=None): # from annotations by_name = {} for name, p in signature(obj).parameters.items(): @@ -315,7 +313,7 @@ def extract_requires(obj, explicit=None): if explicit is not None: if isinstance(explicit, (list, tuple)): requires_ = requires(*explicit) - elif not isinstance(explicit, requires): + elif not isinstance(explicit, RequiresType): requires_ = requires(explicit) else: requires_ = explicit @@ -324,15 +322,7 @@ def extract_requires(obj, explicit=None): if not by_name: return nothing - args = [] - kw = {} - for requirement in by_name.values(): - if requirement.target is None: - args.append(requirement) - else: - kw[requirement.target] = requirement - - return requires(*args, **kw) + return RequiresType(by_name.values()) def extract_returns(obj: Callable, explicit: ReturnsType = None): diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 1b739bc..f6bec52 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -5,7 +5,7 @@ from testfixtures import compare from mush.callpoints import CallPoint -from mush.declarations import requires, returns, update_wrapper +from mush.declarations import requires, returns, update_wrapper, RequiresType class TestCallPoints(TestCase): @@ -92,7 +92,7 @@ def foo(a1): pass def test_convert_to_requires_and_returns(self): def foo(baz): pass point = CallPoint(foo, requires='foo', returns='bar') - self.assertTrue(isinstance(point.requires, requires)) + self.assertTrue(isinstance(point.requires, RequiresType)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires('foo') returns('bar')", repr(point)) @@ -102,7 +102,7 @@ def foo(a1, a2): pass point = CallPoint(foo, requires=('foo', 'bar'), returns=('baz', 'bob')) - self.assertTrue(isinstance(point.requires, requires)) + self.assertTrue(isinstance(point.requires, RequiresType)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires('foo', 'bar') returns('baz', 'bob')", repr(point)) @@ -112,7 +112,7 @@ def foo(a1, a2): pass point = CallPoint(foo, requires=['foo', 'bar'], returns=['baz', 'bob']) - self.assertTrue(isinstance(point.requires, requires)) + self.assertTrue(isinstance(point.requires, RequiresType)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires('foo', 'bar') returns('baz', 'bob')", repr(point)) From a322e9273497121fd6b53c597ed8e03794f2d94a Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 25 Feb 2020 08:10:25 +0000 Subject: [PATCH 042/159] clarify types of resolver --- mush/context.py | 4 ++-- mush/declarations.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mush/context.py b/mush/context.py index 5835f61..af35e34 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,7 +1,7 @@ from typing import Optional from .declarations import ( - extract_requires, RequiresType, ResourceKey, ResourceValue, Resolver + extract_requires, RequiresType, ResourceKey, ResourceValue, ResourceResolver ) from .markers import missing @@ -83,7 +83,7 @@ def __init__(self): def add(self, resource: Optional[ResourceValue] = None, provides: Optional[ResourceKey] = None, - resolver: Optional[Resolver] = None): + resolver: Optional[ResourceResolver] = None): """ Add a resource to the context. diff --git a/mush/declarations.py b/mush/declarations.py index 020d112..a106e0c 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -13,7 +13,8 @@ ResourceKey = NewType('ResourceKey', Union[Type, str]) ResourceValue = NewType('ResourceValue', Any) -Resolver = Callable[['Context', Any], ResourceValue] +ResourceResolver = Callable[['Context', Any], ResourceValue] +RequirementResolver = Callable[['Context'], ResourceValue] def name_or_repr(obj): @@ -28,7 +29,7 @@ def set_mush(obj, key, value): class Requirement: - resolve = None + resolve: RequirementResolver = None def __init__(self, source, default=missing, target=None): self.repr = name_or_repr(source) From 70fa1adba20c1418c9537b59b19a4e1d183e3cb4 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 25 Feb 2020 13:21:42 +0000 Subject: [PATCH 043/159] Replace all the hows with a more flexible Value. This also allows Requirement to be simplified down even further. --- docs/use.txt | 27 ++-- mush/__init__.py | 5 +- mush/context.py | 5 +- mush/declarations.py | 156 +++++++++++++----------- mush/tests/conftest.py | 13 -- mush/tests/example_with_mush_clone.py | 12 +- mush/tests/example_with_mush_factory.py | 12 +- mush/tests/test_context.py | 26 ++-- mush/tests/test_declarations.py | 117 ++++++++++++------ mush/tests/test_runner.py | 22 ++-- 10 files changed, 217 insertions(+), 178 deletions(-) delete mode 100644 mush/tests/conftest.py diff --git a/docs/use.txt b/docs/use.txt index 23d60c7..0db95ea 100755 --- a/docs/use.txt +++ b/docs/use.txt @@ -274,12 +274,12 @@ of the returns resources through to the :func:`pick` function: .. code-block:: python - from mush import Runner, attr, item, requires + from mush import Runner, requires, Value runner = Runner(some_attributes, some_items) - runner.add(pick, requires(fruit1=attr(Stuff, 'fruit'), - fruit2=item(dict, 'fruit'), - fruit3=item(attr(Stuff, 'tree'), 'fruit'))) + runner.add(pick, requires(fruit1=Value(Stuff).fruit, + fruit2=Value(dict)['fruit'], + fruit3=Value(Stuff).tree['fruit'])) So now we can pick fruit from some interesting places: @@ -919,11 +919,13 @@ a remote web service: .. code-block:: python + from mush import Runner, Value + def load_config() -> 'config': return json.loads(urllib2.urlopen('...').read()) - def do_stuff(username: item('config', 'username'), - password: item('config', 'password')): + def do_stuff(username: Value('config')['username'], + password: Value('config')['password']): print('doing stuff as ' + username + ' with '+ password) runner = Runner(load_config, do_stuff) @@ -944,6 +946,7 @@ If you have a base runner such as this: .. code-block:: python from argparse import ArgumentParser, Namespace + from mush import Runner, Value def base_args(parser): parser.add_argument('config_url') @@ -960,7 +963,7 @@ If you have a base runner such as this: base_runner = Runner(ArgumentParser) base_runner.add(base_args, requires=ArgumentParser, label='args') base_runner.add(parse_args, requires=ArgumentParser) - point = base_runner.add(load_config, requires=attr(Namespace, 'config_url'), + point = base_runner.add(load_config, requires=Value(Namespace).config_url, returns='config') point.add_label('body') base_runner.add(finalise_things, label='ending') @@ -972,8 +975,8 @@ That runner might be used for a specific script as follows: def job_args(parser: ArgumentParser): parser.add_argument('--colour') - def do_stuff(username: item('config', 'username'), - colour: attr(Namespace, 'colour')): + def do_stuff(username: Value('config')['username'], + colour: Value(Namespace).colour): print(username + ' is '+ colour) runner = base_runner.clone() @@ -1014,12 +1017,12 @@ For example, consider this runner: .. code-block:: python - from mush import Runner + from mush import Runner, Value def make_config() -> 'config': return {'foo': 'bar'} - def connect(foo: item('config', 'foo')): + def connect(foo = Value('config')['foo']): return 'connection' def process(connection): @@ -1035,7 +1038,7 @@ To see how the configuration panned out, we would look at the :func:`repr`: >>> runner requires() returns('config') - requires('config'['foo']) returns_result_type() <-- config + requires(Value('config')['foo']) returns_result_type() <-- config requires('connection') returns_result_type() diff --git a/mush/__init__.py b/mush/__init__.py index 7f75fda..1d4c6c5 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -2,7 +2,7 @@ from .declarations import ( requires, returns_result_type, returns_mapping, returns_sequence, returns, - attr, item, nothing + Value, nothing ) from .plug import Plug from .context import Context, ContextError @@ -13,5 +13,6 @@ 'Runner', 'requires', 'returns_result_type', 'returns_mapping', 'returns_sequence', 'returns', - 'attr', 'item', 'Plug', 'nothing' + 'Value', + 'Plug', 'nothing' ] diff --git a/mush/context.py b/mush/context.py index af35e34..ad22006 100644 --- a/mush/context.py +++ b/mush/context.py @@ -138,13 +138,16 @@ def _resolve(obj, requires, args, kw, context): if o is not requirement.default: for op in requirement.ops: o = op(o) + if o is missing: + o = requirement.default + break if o is missing: key = requirement.key if isinstance(key, type) and issubclass(key, Context): o = context else: - raise ContextError('No %s in context' % requirement.repr) + raise ContextError('No %s in context' % requirement.value_repr()) if requirement.target is None: args.append(o) diff --git a/mush/declarations.py b/mush/declarations.py index a106e0c..e80104f 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -7,7 +7,7 @@ ) from inspect import signature from itertools import chain -from typing import Type, Callable, NewType, Union, Any +from typing import Type, Callable, NewType, Union, Any, Sequence, List, Optional from .markers import missing @@ -28,27 +28,43 @@ def set_mush(obj, key, value): class Requirement: + """ + The requirement for an individual parameter of a callable. + """ resolve: RequirementResolver = None - def __init__(self, source, default=missing, target=None): - self.repr = name_or_repr(source) - self.target = target - self.default = default - - self.ops = deque() - while isinstance(source, how): - self.ops.appendleft(source.process) - source = source.type - self.key: ResourceKey = source + def __init__(self, key, name=None, type_=None, default=missing, target=None): + self.key: ResourceKey = key + self.name: str = (key if isinstance(key, str) else None) if name is None else name + self.type: type = (key if not isinstance(key, str) else None) if type_ is None else type_ + self.target: Optional[str] = target + self.default: Any = default + #: Any operations to be performed on the resource after it + #: has been obtained. + self.ops: List['Op'] = [] + + def value_repr(self): + key = name_or_repr(self.key) + if self.ops or self.default is not missing: + default = '' if self.default is missing else f', default={self.default!r}' + ops = ''.join(repr(o) for o in self.ops) + return f'Value({key}{default}){ops}' + return key def __repr__(self): - return f'{type(self).__name__}({self.repr}, default={self.default})' + attrs = [] + for a in 'name', 'type_', 'target': + value = getattr(self, a.rstrip('_')) + if value is not None: + attrs.append(f", {a}={value!r}") + return f"{type(self).__name__}({self.value_repr()}{''.join(attrs)})" class RequiresType(list): + def __repr__(self): - parts = (r.repr if r.target is None else f'{r.target}={r.repr}' + parts = (r.value_repr() if r.target is None else f'{r.target}={r.value_repr()}' for r in self) return f"requires({', '.join(parts)})" @@ -71,14 +87,17 @@ def requires(*args, **kw): requires_ = RequiresType() check_type(*args) check_type(*kw.values()) - for target, requirement in chain( + for target, possible in chain( ((None, arg) for arg in args), kw.items(), ): - if isinstance(requirement, Requirement): - requirement.target = target + if isinstance(possible, Value): + possible = possible.requirement + if isinstance(possible, Requirement): + possible.target = target + requirement = possible else: - requirement = Requirement(requirement, target=target) + requirement = Requirement(possible, target=target) requires_.append(requirement) return requires_ @@ -169,78 +188,63 @@ class DeclarationsFrom(Enum): replacement = DeclarationsFrom.replacement -class how(object): - """ - The base class for type decorators that indicate which part of a - resource is required by a particular callable. +class Op: - :param type: The resource type to be decorated. - :param names: Used to identify the part of the resource to extract. - """ - type_pattern = '%(type)s' - name_pattern = '' + def __init__(self, name): + self.name = name - def __init__(self, type, *names): - check_type(type) - self.type = type - self.names = names - def __repr__(self): - txt = self.type_pattern % dict(type=name_or_repr(self.type)) - for name in self.names: - txt += self.name_pattern % dict(name=name) - return txt - - def process(self, o): - """ - Extract the required part of the object passed in. - :obj:`missing` should be returned if the required part - cannot be extracted. - :obj:`missing` may be passed in and is usually be handled - by returning :obj:`missing` immediately. - """ - return missing - - -class attr(how): - """ - A :class:`~.declarations.how` that indicates the callable requires the named - attribute from the decorated type. - """ - name_pattern = '.%(name)s' +class AttrOp(Op): - def process(self, o): - if o is missing: - return o + def __call__(self, o): try: - for name in self.names: - o = getattr(o, name) + return getattr(o, self.name) except AttributeError: return missing - else: - return o + def __repr__(self): + return f'.{self.name}' -class item(how): - """ - A :class:`~.declarations.how` that indicates the callable requires the named - item from the decorated type. - """ - name_pattern = '[%(name)r]' - def process(self, o): - if o is missing: - return o +class ItemOp(Op): + + def __call__(self, o): try: - for name in self.names: - o = o[name] + return o[self.name] except KeyError: return missing - else: - return o + + def __repr__(self): + return f'[{self.name!r}]' + + +class Value: + """ + Declaration indicating that the specified resource key is required. + + Values are generative, so they can be used to indicate attributes or + items from a resource are required. + + A default may be specified, which will be used if the specified + resource is not available. + """ + + def __init__(self, key: ResourceKey, *, default: Any = missing): + self.requirement = Requirement(key, default=default) + + def __getattr__(self, name): + self.requirement.ops.append(AttrOp(name)) + return self + + def __getitem__(self, name): + self.requirement.ops.append(ItemOp(name)) + return self + + def __repr__(self): + return self.requirement.value_repr() -ok_types = (type, str, how, Requirement) +ok_types = (type, str, Value, Requirement) def check_type(*objs): @@ -290,8 +294,12 @@ def extract_requires(obj: Callable, explicit=None): if isinstance(p.default, Requirement): requirement = p.default + elif isinstance(p.default, Value): + requirement = p.default.requirement elif isinstance(p.annotation, Requirement): requirement = p.annotation + elif isinstance(p.annotation, Value): + requirement = p.annotation.requirement else: key = p.name if p.annotation is p.empty else p.annotation default = missing if p.default is p.empty else p.default diff --git a/mush/tests/conftest.py b/mush/tests/conftest.py deleted file mode 100644 index fede896..0000000 --- a/mush/tests/conftest.py +++ /dev/null @@ -1,13 +0,0 @@ -from testfixtures.comparison import register, compare_object - -from mush.declarations import Requirement - - -def compare_requirement(x, y, context): - # make sure this doesn't get refactored away, since we're using it - # as a proxy to check .ops: - assert hasattr(x, 'repr') - return compare_object(x, y, context, ignore_attributes=['ops']) - - -register(Requirement, compare_requirement) diff --git a/mush/tests/example_with_mush_clone.py b/mush/tests/example_with_mush_clone.py index d34331d..d8fbae3 100644 --- a/mush/tests/example_with_mush_clone.py +++ b/mush/tests/example_with_mush_clone.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser, Namespace from configparser import RawConfigParser -from mush import Runner, requires, attr, item +from mush import Runner, requires, Value import logging, os, sqlite3, sys log = logging.getLogger() @@ -43,9 +43,9 @@ def __exit__(self, type, obj, tb): base_runner.add(base_options, label='args') base_runner.extend(parse_args, parse_config) base_runner.add(setup_logging, requires( - log_path = item('config', 'log'), - quiet = attr(Namespace, 'quiet'), - verbose = attr(Namespace, 'verbose') + log_path = Value('config')['log'], + quiet = Value(Namespace).quiet, + verbose = Value(Namespace).verbose, )) @@ -62,9 +62,9 @@ def do(conn, path): main = base_runner.clone() main['args'].add(args, requires=ArgumentParser) -main.add(DatabaseHandler, requires=item('config', 'db')) +main.add(DatabaseHandler, requires=Value('config')['db']) main.add(do, - requires(attr(DatabaseHandler, 'conn'), attr(Namespace, 'path'))) + requires(Value(DatabaseHandler).conn, Value(Namespace).path)) if __name__ == '__main__': main() diff --git a/mush/tests/example_with_mush_factory.py b/mush/tests/example_with_mush_factory.py index 98c36ba..2757faa 100644 --- a/mush/tests/example_with_mush_factory.py +++ b/mush/tests/example_with_mush_factory.py @@ -1,4 +1,4 @@ -from mush import Runner, attr, item, requires +from mush import Runner, requires, Value from argparse import ArgumentParser, Namespace from .example_with_mush_clone import ( @@ -21,14 +21,14 @@ def make_runner(do): runner.add(parse_args, requires=ArgumentParser) runner.add(parse_config, requires=Namespace) runner.add(setup_logging, requires( - log_path = item('config', 'log'), - quiet = attr(Namespace, 'quiet'), - verbose = attr(Namespace, 'verbose') + log_path=Value('config')['log'], + quiet=Value(Namespace).quiet, + verbose=Value(Namespace).verbose, )) - runner.add(DatabaseHandler, requires=item('config', 'db')) + runner.add(DatabaseHandler, requires=Value('config')['db']) runner.add( do, - requires(attr(DatabaseHandler, 'conn'), attr(Namespace, 'path')) + requires(Value(DatabaseHandler).conn, Value(Namespace).path) ) return runner diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index c0c6d1f..cdf3f4b 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -4,10 +4,8 @@ from mush.context import ResolvableValue from testfixtures import ShouldRaise, compare -from mush import Context, ContextError -from mush.declarations import ( - nothing, requires, item, attr, returns, returns_mapping, Requirement, missing -) +from mush import Context, ContextError, requires, returns, nothing, returns_mapping +from mush.declarations import Requirement, Value, missing class TheType(object): @@ -153,9 +151,9 @@ def foo(obj): return obj context = Context() context.add({}, TheType) with ShouldRaise(ContextError( - "No TheType['foo'] in context" + "No Value(TheType)['foo'] in context" )): - context.call(foo, requires(item(TheType, 'foo'))) + context.call(foo, requires(Value(TheType)['foo'])) def test_call_requires_accidental_tuple(self): def foo(obj): return obj @@ -217,7 +215,7 @@ def foo(x): return x context = Context() context.add(dict(bar='baz'), 'foo') - result = context.call(foo, requires(item('foo', 'bar'))) + result = context.call(foo, requires(Value('foo')['bar'])) compare(result, 'baz') def test_call_requires_attr(self): @@ -226,7 +224,7 @@ def foo(x): m = Mock() context = Context() context.add(m, 'foo') - result = context.call(foo, requires(attr('foo', 'bar'))) + result = context.call(foo, requires(Value('foo').bar)) compare(result, m.bar) def test_call_requires_item_attr(self): @@ -235,23 +233,23 @@ def foo(x): m = Mock() m.bar= dict(baz='bob') context = Context() - context.add(m, 'foo') - result = context.call(foo, requires(item(attr('foo', 'bar'), 'baz'))) + context.add(m, provides='foo') + result = context.call(foo, requires(Value('foo').bar['baz'])) compare(result, 'bob') def test_call_requires_optional_item_missing(self): - def foo(x: item('foo', 'bar') = 1): + def foo(x: str = Value('foo', default=1)['bar']): return x context = Context() result = context.call(foo) compare(result, 1) def test_call_requires_optional_item_present(self): - def foo(x=1): + def foo(x: str = Value('foo', default=1)['bar']): return x context = Context() - context.add(dict(bar='baz'), 'foo') - result = context.call(foo, requires((item('foo', 'bar')))) + context.add(dict(bar='baz'), provides='foo') + result = context.call(foo) compare(result, 'baz') def test_call_extract_requirements(self): diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 36290fa..f01b7d8 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -1,16 +1,20 @@ from functools import partial from unittest import TestCase + +import pytest from mock import Mock from testfixtures import compare, ShouldRaise + +from mush import Context from mush.markers import missing from mush.declarations import ( requires, returns, returns_mapping, returns_sequence, returns_result_type, - how, item, attr, nothing, + nothing, extract_requires, extract_returns, result_type, Requirement, - update_wrapper -) + update_wrapper, + Value, AttrOp) def check_extract(obj, expected_rq, expected_rt): @@ -72,68 +76,85 @@ def foo(): class TestRequirement: - def test_repr(self): - compare(repr(Requirement('foo', default=None)), - expected="Requirement('foo', default=None)") + def test_repr_minimal_name(self): + compare(repr(Requirement('foo')), + expected="Requirement('foo', name='foo')") + + def test_repr_minimal_type(self): + compare(repr(Requirement(str)), + expected="Requirement(str, type_=)") + + def test_repr_maximal(self): + r = Requirement('foo', name='n', type_='ty', default=None, target='ta') + r.ops.append(AttrOp('bar')) + compare(repr(r), + expected="Requirement(Value('foo', default=None).bar, " + "name='n', type_='ty', target='ta')") + +def check_ops(value, data, *, expected): + for op in value.requirement.ops: + data = op(data) + compare(expected, actual=data) -class TestItem(TestCase): + +class TestItem: def test_single(self): - h = item(Type1, 'foo') - compare(repr(h), "Type1['foo']") - compare(h.process(dict(foo=1)), 1) + h = Value(Type1)['foo'] + compare(repr(h), "Value(Type1)['foo']") + check_ops(h, {'foo': 1}, expected=1) def test_multiple(self): - h = item(Type1, 'foo', 'bar') - compare(repr(h), "Type1['foo']['bar']") - compare(h.process(dict(foo=dict(bar=1))), 1) + h = Value(Type1)['foo']['bar'] + compare(repr(h), "Value(Type1)['foo']['bar']") + check_ops(h, {'foo': {'bar': 1}}, expected=1) def test_missing_obj(self): - h = item(Type1, 'foo', 'bar') + h = Value(Type1)['foo']['bar'] with ShouldRaise(TypeError): - h.process(object()) + check_ops(h, object(), expected=None) def test_missing_key(self): - h = item(Type1, 'foo', 'bar') - compare(h.process({}), missing) + h = Value(Type1)['foo'] + check_ops(h, {}, expected=missing) def test_passed_missing(self): - h = item(Type1, 'foo', 'bar') - compare(h.process(missing), missing) + c = Context() + c.add({}, provides='key') + compare(c.call(lambda x: x, requires=Value('key', default=1)['foo']['bar']), + expected=1) def test_bad_type(self): + h = Value(Type1)['foo']['bar'] with ShouldRaise(TypeError): - item([], 'foo', 'bar') - - -class TestHow(TestCase): - - def test_process_on_base(self): - compare(how('foo').process('bar'), missing) + check_ops(h, [], expected=None) class TestAttr(TestCase): def test_single(self): - h = attr(Type1, 'foo') - compare(repr(h), "Type1.foo") + h = Value(Type1).foo + compare(repr(h), "Value(Type1).foo") m = Mock() - compare(h.process(m), m.foo) + check_ops(h, m, expected=m.foo) def test_multiple(self): - h = attr(Type1, 'foo', 'bar') - compare(repr(h), "Type1.foo.bar") + h = Value(Type1).foo.bar + compare(repr(h), "Value(Type1).foo.bar") m = Mock() - compare(h.process(m), m.foo.bar) + check_ops(h, m, expected=m.foo.bar) def test_missing(self): - h = attr(Type1, 'foo', 'bar') - compare(h.process(object()), missing) + h = Value(Type1).foo + compare(repr(h), "Value(Type1).foo") + check_ops(h, object(), expected=missing) def test_passed_missing(self): - h = attr(Type1, 'foo', 'bar') - compare(h.process(missing), missing) + c = Context() + c.add(object(), provides='key') + compare(c.call(lambda x: x, requires=Value('key', default=1).foo.bar), + expected=1) class TestReturns(TestCase): @@ -386,10 +407,9 @@ def foo() -> rt: pass expected_rt=rt) def test_how_instance_in_annotations(self): - how_instance = item('config', 'db_url') - def foo(a: how_instance): pass + def foo(a: Value('config')['db_url']): pass check_extract(foo, - expected_rq=requires(how_instance), + expected_rq=requires(Value('config')['db_url']), expected_rt=result_type) def test_default_requirements(self): @@ -401,6 +421,25 @@ def foo(a, b=1, *, c, d=None): pass d=Requirement('d', default=None)), expected_rt=result_type) + def test_type_only(self): + class T: pass + def foo(a: T): pass + check_extract(foo, + expected_rq=requires(Requirement(T)), + expected_rt=result_type) + + def test_type_plus_value(self): + def foo(a: str = Value('b')): pass + check_extract(foo, + expected_rq=requires(Requirement('b')), + expected_rt=result_type) + + def test_type_plus_value_with_default(self): + def foo(a: str = Value('b', default=1)): pass + check_extract(foo, + expected_rq=requires(Requirement('b', default=1)), + expected_rt=result_type) + class TestDeclarationsFromMultipleSources: diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 7d6ccc5..643b6ac 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -1,17 +1,17 @@ from unittest import TestCase from mock import Mock, call +from mush.context import ContextError +from mush.declarations import ( + requires, returns, returns_mapping, + replacement, original, + Value) +from mush.runner import Runner from testfixtures import ( ShouldRaise, compare ) -from mush.context import ContextError -from mush.declarations import ( - requires, attr, item, nothing, returns, returns_mapping, - replacement, original) -from mush.runner import Runner - def verify(runner, *expected): seen_labels = set() @@ -612,7 +612,7 @@ def job2(obj): m.job2(obj) runner = Runner() runner.add(job1) - runner.add(job2, requires(attr(T, 'foo'))) + runner.add(job2, requires(Value(T).foo)) runner() compare([ @@ -634,7 +634,7 @@ def job2(obj): m.job2(obj) runner = Runner() runner.add(job1) - runner.add(job2, requires(attr(T, 'foo', 'bar'))) + runner.add(job2, requires(Value(T).foo.bar)) runner() compare([ @@ -654,7 +654,7 @@ def job2(obj): m.job2(obj) runner = Runner() runner.add(job1) - runner.add(job2, requires(item(MyDict, 'the_thing'))) + runner.add(job2, requires(Value(MyDict)['the_thing'])) runner() compare([ call.job1(), @@ -673,7 +673,7 @@ def job2(obj): m.job2(obj) runner = Runner() runner.add(job1) - runner.add(job2, requires(item(MyDict, 'the_thing', 'other_thing'))) + runner.add(job2, requires(Value(MyDict)['the_thing']['other_thing'])) runner() compare([ call.job1(), @@ -691,7 +691,7 @@ def job2(obj): m.job2(obj) runner = Runner() runner.add(job1) - runner.add(job2, requires(item(attr(T, 'foo'), 'baz'))) + runner.add(job2, requires(Value(T).foo['baz'])) runner() compare([ From 520c7908942de593151da91093ca2af5d18b29a5 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 25 Feb 2020 13:46:31 +0000 Subject: [PATCH 044/159] move code around to make it easier to read through --- mush/__init__.py | 4 +- mush/asyncio.py | 5 +- mush/callpoints.py | 4 +- mush/context.py | 3 +- mush/declarations.py | 253 ++++++++++---------------------- mush/extraction.py | 112 ++++++++++++++ mush/runner.py | 3 +- mush/tests/test_callpoints.py | 3 +- mush/tests/test_declarations.py | 22 ++- 9 files changed, 215 insertions(+), 194 deletions(-) create mode 100644 mush/extraction.py diff --git a/mush/__init__.py b/mush/__init__.py index 1d4c6c5..587b848 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -4,6 +4,7 @@ returns_result_type, returns_mapping, returns_sequence, returns, Value, nothing ) +from .extraction import extract_requires, extract_returns, update_wrapper from .plug import Plug from .context import Context, ContextError from .asyncio import AsyncContext @@ -14,5 +15,6 @@ 'requires', 'returns_result_type', 'returns_mapping', 'returns_sequence', 'returns', 'Value', - 'Plug', 'nothing' + 'Plug', 'nothing', + 'update_wrapper', ] diff --git a/mush/asyncio.py b/mush/asyncio.py index 5e0c8fc..08ac0da 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -1,9 +1,8 @@ import asyncio from functools import partial -from mush import Context, ContextError -from mush.declarations import ResourceKey, RequiresType, extract_requires -from mush.markers import missing +from mush import Context +from mush.declarations import ResourceKey async def ensure_async(func, *args, **kw): diff --git a/mush/callpoints.py b/mush/callpoints.py index 25c683f..6a211bf 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,8 +1,8 @@ from .context import Context from .declarations import ( - nothing, extract_requires, extract_returns, - requires as requires_function + nothing, requires as requires_function ) +from .extraction import extract_requires, extract_returns from .resolvers import Lazy diff --git a/mush/context.py b/mush/context.py index ad22006..02e449d 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,8 +1,9 @@ from typing import Optional from .declarations import ( - extract_requires, RequiresType, ResourceKey, ResourceValue, ResourceResolver + RequiresType, ResourceKey, ResourceValue, ResourceResolver ) +from .extraction import extract_requires from .markers import missing NONE_TYPE = type(None) diff --git a/mush/declarations.py b/mush/declarations.py index e80104f..54ea3ac 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,13 +1,6 @@ -from collections import deque from enum import Enum, auto -from functools import ( - WRAPPER_UPDATES, - WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, - update_wrapper as functools_update_wrapper, -) -from inspect import signature from itertools import chain -from typing import Type, Callable, NewType, Union, Any, Sequence, List, Optional +from typing import Type, Callable, NewType, Union, Any, List, Optional from .markers import missing @@ -42,7 +35,7 @@ def __init__(self, key, name=None, type_=None, default=missing, target=None): self.default: Any = default #: Any operations to be performed on the resource after it #: has been obtained. - self.ops: List['Op'] = [] + self.ops: List['ValueOp'] = [] def value_repr(self): key = name_or_repr(self.key) @@ -61,6 +54,62 @@ def __repr__(self): return f"{type(self).__name__}({self.value_repr()}{''.join(attrs)})" +class Value: + """ + Declaration indicating that the specified resource key is required. + + Values are generative, so they can be used to indicate attributes or + items from a resource are required. + + A default may be specified, which will be used if the specified + resource is not available. + """ + + def __init__(self, key: ResourceKey, *, default: Any = missing): + self.requirement = Requirement(key, default=default) + + def __getattr__(self, name): + self.requirement.ops.append(ValueAttrOp(name)) + return self + + def __getitem__(self, name): + self.requirement.ops.append(ValueItemOp(name)) + return self + + def __repr__(self): + return self.requirement.value_repr() + + +class ValueOp: + + def __init__(self, name): + self.name = name + + +class ValueAttrOp(ValueOp): + + def __call__(self, o): + try: + return getattr(o, self.name) + except AttributeError: + return missing + + def __repr__(self): + return f'.{self.name}' + + +class ValueItemOp(ValueOp): + + def __call__(self, o): + try: + return o[self.name] + except KeyError: + return missing + + def __repr__(self): + return f'[{self.name!r}]' + + class RequiresType(list): def __repr__(self): @@ -177,82 +226,9 @@ def __repr__(self): return self.__class__.__name__ + '(' + args_repr + ')' -class DeclarationsFrom(Enum): - #: Use declarations from the original callable. - original = auto() - #: Use declarations from the replacement callable. - replacement = auto() - - -original = DeclarationsFrom.original -replacement = DeclarationsFrom.replacement - - -class Op: - - def __init__(self, name): - self.name = name - - -class AttrOp(Op): - - def __call__(self, o): - try: - return getattr(o, self.name) - except AttributeError: - return missing - - def __repr__(self): - return f'.{self.name}' - - -class ItemOp(Op): - - def __call__(self, o): - try: - return o[self.name] - except KeyError: - return missing - - def __repr__(self): - return f'[{self.name!r}]' - - -class Value: - """ - Declaration indicating that the specified resource key is required. - - Values are generative, so they can be used to indicate attributes or - items from a resource are required. - - A default may be specified, which will be used if the specified - resource is not available. - """ - - def __init__(self, key: ResourceKey, *, default: Any = missing): - self.requirement = Requirement(key, default=default) - - def __getattr__(self, name): - self.requirement.ops.append(AttrOp(name)) - return self - - def __getitem__(self, name): - self.requirement.ops.append(ItemOp(name)) - return self - - def __repr__(self): - return self.requirement.value_repr() - - -ok_types = (type, str, Value, Requirement) - - -def check_type(*objs): - for obj in objs: - if not isinstance(obj, ok_types): - raise TypeError( - repr(obj)+" is not a type or label" - ) +#: A singleton indicating that a callable's return value should be +#: stored based on the type of that return value. +result_type = returns_result_type() class Nothing(RequiresType, returns): @@ -266,103 +242,24 @@ def process(self, result): #: that anything returned from a callable should be ignored. nothing = Nothing() -#: A singleton indicating that a callable's return value should be -#: stored based on the type of that return value. -result_type = returns_result_type() - -def _unpack_requires(by_name, by_index, requires_): - - for i, requirement in enumerate(requires_): - if requirement.target is None: - try: - arg = by_index[i] - except IndexError: - # case where something takes *args - arg = i - else: - arg = requirement.target - by_name[arg] = requirement - - -def extract_requires(obj: Callable, explicit=None): - # from annotations - by_name = {} - for name, p in signature(obj).parameters.items(): - if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): - continue - - if isinstance(p.default, Requirement): - requirement = p.default - elif isinstance(p.default, Value): - requirement = p.default.requirement - elif isinstance(p.annotation, Requirement): - requirement = p.annotation - elif isinstance(p.annotation, Value): - requirement = p.annotation.requirement - else: - key = p.name if p.annotation is p.empty else p.annotation - default = missing if p.default is p.empty else p.default - requirement = Requirement(key, default=default) - - if p.kind is p.KEYWORD_ONLY: - requirement.target = p.name - by_name[name] = requirement - - by_index = list(by_name) - - # from declarations - mush_declarations = getattr(obj, '__mush__', None) - if mush_declarations is not None: - requires_ = mush_declarations.get('requires') - if requires_ is not None: - _unpack_requires(by_name, by_index, requires_) - - # explicit - if explicit is not None: - if isinstance(explicit, (list, tuple)): - requires_ = requires(*explicit) - elif not isinstance(explicit, RequiresType): - requires_ = requires(explicit) - else: - requires_ = explicit - _unpack_requires(by_name, by_index, requires_) - - if not by_name: - return nothing - - return RequiresType(by_name.values()) - - -def extract_returns(obj: Callable, explicit: ReturnsType = None): - if explicit is None: - mush_declarations = getattr(obj, '__mush__', {}) - returns_ = mush_declarations.get('returns', None) - if returns_ is None: - annotations = getattr(obj, '__annotations__', {}) - returns_ = annotations.get('return') - else: - returns_ = explicit +class DeclarationsFrom(Enum): + original = auto() + replacement = auto() - if returns_ is None or isinstance(returns_, ReturnsType): - pass - elif isinstance(returns_, (list, tuple)): - returns_ = returns(*returns_) - else: - returns_ = returns(returns_) - return returns_ or result_type +#: Use declarations from the original callable. +original = DeclarationsFrom.original +#: Use declarations from the replacement callable. +replacement = DeclarationsFrom.replacement -WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) +ok_types = (type, str, Value, Requirement) -def update_wrapper(wrapper, - wrapped, - assigned=WRAPPER_ASSIGNMENTS, - updated=WRAPPER_UPDATES): - """ - An extended version of :func:`functools.update_wrapper` that - also preserves Mush's annotations. - """ - return functools_update_wrapper(wrapper, wrapped, assigned, updated) +def check_type(*objs): + for obj in objs: + if not isinstance(obj, ok_types): + raise TypeError( + repr(obj)+" is not a type or label" + ) diff --git a/mush/extraction.py b/mush/extraction.py new file mode 100644 index 0000000..725e3c9 --- /dev/null +++ b/mush/extraction.py @@ -0,0 +1,112 @@ +from functools import ( + WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, + WRAPPER_UPDATES, + update_wrapper as functools_update_wrapper +) +from inspect import signature +from typing import Callable + +from .declarations import ( + Value, + requires, Requirement, RequiresType, ReturnsType, + returns, result_type, + nothing +) +from .markers import missing + + +def _unpack_requires(by_name, by_index, requires_): + + for i, requirement in enumerate(requires_): + if requirement.target is None: + try: + arg = by_index[i] + except IndexError: + # case where something takes *args + arg = i + else: + arg = requirement.target + by_name[arg] = requirement + + +def extract_requires(obj: Callable, explicit=None): + # from annotations + by_name = {} + for name, p in signature(obj).parameters.items(): + if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): + continue + + if isinstance(p.default, Requirement): + requirement = p.default + elif isinstance(p.default, Value): + requirement = p.default.requirement + elif isinstance(p.annotation, Requirement): + requirement = p.annotation + elif isinstance(p.annotation, Value): + requirement = p.annotation.requirement + else: + key = p.name if p.annotation is p.empty else p.annotation + default = missing if p.default is p.empty else p.default + requirement = Requirement(key, default=default) + + if p.kind is p.KEYWORD_ONLY: + requirement.target = p.name + by_name[name] = requirement + + by_index = list(by_name) + + # from declarations + mush_declarations = getattr(obj, '__mush__', None) + if mush_declarations is not None: + requires_ = mush_declarations.get('requires') + if requires_ is not None: + _unpack_requires(by_name, by_index, requires_) + + # explicit + if explicit is not None: + if isinstance(explicit, (list, tuple)): + requires_ = requires(*explicit) + elif not isinstance(explicit, RequiresType): + requires_ = requires(explicit) + else: + requires_ = explicit + _unpack_requires(by_name, by_index, requires_) + + if not by_name: + return nothing + + return RequiresType(by_name.values()) + + +def extract_returns(obj: Callable, explicit: ReturnsType = None): + if explicit is None: + mush_declarations = getattr(obj, '__mush__', {}) + returns_ = mush_declarations.get('returns', None) + if returns_ is None: + annotations = getattr(obj, '__annotations__', {}) + returns_ = annotations.get('return') + else: + returns_ = explicit + + if returns_ is None or isinstance(returns_, ReturnsType): + pass + elif isinstance(returns_, (list, tuple)): + returns_ = returns(*returns_) + else: + returns_ = returns(returns_) + + return returns_ or result_type + + +WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) + + +def update_wrapper(wrapper, + wrapped, + assigned=WRAPPER_ASSIGNMENTS, + updated=WRAPPER_UPDATES): + """ + An extended version of :func:`functools.update_wrapper` that + also preserves Mush's annotations. + """ + return functools_update_wrapper(wrapper, wrapped, assigned, updated) diff --git a/mush/runner.py b/mush/runner.py index b9e8d43..093970b 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -2,7 +2,8 @@ from .callpoints import CallPoint from .context import Context, ContextError -from .declarations import extract_requires, extract_returns, DeclarationsFrom +from .declarations import DeclarationsFrom +from .extraction import extract_requires, extract_returns from .markers import not_specified from .modifier import Modifier from .plug import Plug diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index f6bec52..18ddf29 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -5,7 +5,8 @@ from testfixtures import compare from mush.callpoints import CallPoint -from mush.declarations import requires, returns, update_wrapper, RequiresType +from mush.declarations import requires, returns, RequiresType +from mush.extraction import update_wrapper class TestCallPoints(TestCase): diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index f01b7d8..a160441 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -6,15 +6,16 @@ from testfixtures import compare, ShouldRaise from mush import Context -from mush.markers import missing from mush.declarations import ( requires, returns, returns_mapping, returns_sequence, returns_result_type, nothing, - extract_requires, extract_returns, result_type, Requirement, - update_wrapper, - Value, AttrOp) + Value, + ValueAttrOp +) +from mush.extraction import extract_requires, extract_returns, update_wrapper +from mush.markers import missing def check_extract(obj, expected_rq, expected_rt): @@ -86,7 +87,7 @@ def test_repr_minimal_type(self): def test_repr_maximal(self): r = Requirement('foo', name='n', type_='ty', default=None, target='ta') - r.ops.append(AttrOp('bar')) + r.ops.append(ValueAttrOp('bar')) compare(repr(r), expected="Requirement(Value('foo', default=None).bar, " "name='n', type_='ty', target='ta')") @@ -428,16 +429,23 @@ def foo(a: T): pass expected_rq=requires(Requirement(T)), expected_rt=result_type) + @pytest.mark.parametrize("type_", [str, int, dict, list]) + def test_simple_type_only(self, type_): + def foo(a: type_): pass + check_extract(foo, + expected_rq=requires(Requirement('a', type_=type_)), + expected_rt=result_type) + def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=requires(Requirement('b')), + expected_rq=requires(Requirement('b', name='b', type_=str)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, - expected_rq=requires(Requirement('b', default=1)), + expected_rq=requires(Requirement('b', name='b', type_=str, default=1)), expected_rt=result_type) From d7c0d6a4a2520624c5d40da477bf423a6bf116db Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 25 Feb 2020 13:52:09 +0000 Subject: [PATCH 045/159] better name for this helper --- mush/declarations.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index 54ea3ac..d68996a 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -134,8 +134,8 @@ def requires(*args, **kw): returning those resources is configured to return the named resource. """ requires_ = RequiresType() - check_type(*args) - check_type(*kw.values()) + valid_decoration_types(*args) + valid_decoration_types(*kw.values()) for target, possible in chain( ((None, arg) for arg in args), kw.items(), @@ -211,7 +211,7 @@ class returns(returns_result_type): """ def __init__(self, *args): - check_type(*args) + valid_decoration_types(*args) self.args = args def process(self, obj): @@ -254,12 +254,12 @@ class DeclarationsFrom(Enum): replacement = DeclarationsFrom.replacement -ok_types = (type, str, Value, Requirement) +VALID_DECORATION_TYPES = (type, str, Value, Requirement) -def check_type(*objs): +def valid_decoration_types(*objs): for obj in objs: - if not isinstance(obj, ok_types): + if not isinstance(obj, VALID_DECORATION_TYPES): raise TypeError( repr(obj)+" is not a type or label" ) From eaa1cce6748b2bd341f040b278902f497fc0bb4f Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 25 Feb 2020 21:54:37 +0000 Subject: [PATCH 046/159] Correct declaration tests to use more explicit expectations. Fix two bugs highlighted by this. --- mush/extraction.py | 15 ++++- mush/tests/test_declarations.py | 99 +++++++++++++++++++++------------ 2 files changed, 76 insertions(+), 38 deletions(-) diff --git a/mush/extraction.py b/mush/extraction.py index 725e3c9..2f82ee5 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -1,7 +1,8 @@ from functools import ( WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, WRAPPER_UPDATES, - update_wrapper as functools_update_wrapper + update_wrapper as functools_update_wrapper, + partial ) from inspect import signature from typing import Callable @@ -31,11 +32,16 @@ def _unpack_requires(by_name, by_index, requires_): def extract_requires(obj: Callable, explicit=None): # from annotations + is_partial = isinstance(obj, partial) by_name = {} for name, p in signature(obj).parameters.items(): if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): continue + # https://bugs.python.org/issue39753: + if is_partial and p.name in obj.keywords: + continue + if isinstance(p.default, Requirement): requirement = p.default elif isinstance(p.default, Value): @@ -75,6 +81,13 @@ def extract_requires(obj: Callable, explicit=None): if not by_name: return nothing + needs_target = False + for requirement in by_name.values(): + if requirement.target is not None: + needs_target = True + elif needs_target: + requirement.target = requirement.name + return RequiresType(by_name.values()) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index a160441..a2fe61e 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -1,7 +1,6 @@ from functools import partial from unittest import TestCase -import pytest from mock import Mock from testfixtures import compare, ShouldRaise @@ -12,7 +11,9 @@ nothing, result_type, Requirement, Value, - ValueAttrOp + ValueAttrOp, + RequiresType, + ValueItemOp ) from mush.extraction import extract_requires, extract_returns, update_wrapper from mush.markers import missing @@ -191,6 +192,7 @@ def test_bad_type(self): @returns([]) def foo(): pass + class TestReturnsMapping(TestCase): def test_it(self): @@ -245,14 +247,20 @@ class TestExtractDeclarations(object): def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, - expected_rq=requires('a', Requirement('b', default=None)), + expected_rq=RequiresType(( + Requirement('a'), + Requirement('b', default=None) + )), expected_rt=result_type) def test_default_requirements_for_class(self): class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, - expected_rq=requires('a', Requirement('b', default=None)), + expected_rq=RequiresType(( + Requirement('a'), + Requirement('b', default=None) + )), expected_rt=result_type) def test_extract_from_partial(self): @@ -260,9 +268,10 @@ def foo(x, y, z, a=None): pass p = partial(foo, 1, y=2) check_extract( p, - expected_rq=requires(y=Requirement('y', default=2), - z='z', - a=Requirement('a', default=None)), + expected_rq=RequiresType(( + Requirement('z', target='z'), + Requirement('a', target='a', default=None) + )), expected_rt=result_type ) @@ -271,7 +280,9 @@ def foo(a=None): pass p = partial(foo) check_extract( p, - expected_rq=requires(Requirement('a', default=None)), + expected_rq=RequiresType(( + Requirement('a', default=None), + )), expected_rt=result_type ) @@ -290,7 +301,7 @@ def foo(a=None): pass p = partial(foo, a=1) check_extract( p, - expected_rq=requires(a=Requirement('a', default=1)), + expected_rq=nothing, expected_rt=result_type ) @@ -309,7 +320,7 @@ def foo(a): pass p = partial(foo, a=1) check_extract( p, - expected_rq=requires(a=Requirement('a', default=1)), + expected_rq=nothing, expected_rt=result_type ) @@ -318,7 +329,10 @@ def foo(b, a=None): pass p = partial(foo) check_extract( p, - expected_rq=requires('b', Requirement('a', default=None)), + expected_rq=RequiresType(( + Requirement('b'), + Requirement('a', default=None) + )), expected_rt=result_type ) @@ -328,7 +342,9 @@ def foo(b, a): pass check_extract( p, # since b is already bound: - expected_rq=requires('a'), + expected_rq=RequiresType(( + Requirement('a'), + )), expected_rt=result_type ) @@ -337,7 +353,9 @@ def foo(b, a): pass p = partial(foo, a=1) check_extract( p, - expected_rq=requires('b', a=Requirement('a', default=1)), + expected_rq=RequiresType(( + Requirement('b'), + )), expected_rt=result_type ) @@ -347,16 +365,18 @@ class TestExtractDeclarationsFromTypeAnnotations(object): def test_extract_from_annotations(self): def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass check_extract(foo, - expected_rq=requires('foo', - 'b', - Requirement('bar', default=1), - Requirement('d', default=2)), + expected_rq=RequiresType(( + Requirement('foo'), + Requirement('b'), + Requirement('bar', default=1), + Requirement('d', default=2) + )), expected_rt=returns('bar')) def test_requires_only(self): def foo(a: 'foo'): pass check_extract(foo, - expected_rq=requires('foo'), + expected_rq=RequiresType((Requirement('foo'),)), expected_rt=result_type) def test_returns_only(self): @@ -382,7 +402,7 @@ def foo(a: 'foo' = None) -> 'bar': compare(foo(), expected='the answer') check_extract(foo, - expected_rq=requires(Requirement('foo', default=None)), + expected_rq=RequiresType((Requirement('foo', default=None),)), expected_rt=returns('bar')) def test_decorator_trumps_annotations(self): @@ -390,7 +410,7 @@ def test_decorator_trumps_annotations(self): @returns('bar') def foo(a: 'x') -> 'y': pass check_extract(foo, - expected_rq=requires('foo'), + expected_rq=RequiresType((Requirement('foo'),)), expected_rt=returns('bar')) def test_returns_mapping(self): @@ -409,43 +429,40 @@ def foo() -> rt: pass def test_how_instance_in_annotations(self): def foo(a: Value('config')['db_url']): pass + requirement = Requirement('config') + requirement.ops.append(ValueItemOp('db_url')) check_extract(foo, - expected_rq=requires(Value('config')['db_url']), + expected_rq=RequiresType((requirement,)), expected_rt=result_type) def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, - expected_rq=requires('a', - Requirement('b', default=1), - c='c', - d=Requirement('d', default=None)), + expected_rq=RequiresType(( + Requirement('a'), + Requirement('b', default=1), + Requirement('c', target='c'), + Requirement('d', target='d', default=None) + )), expected_rt=result_type) def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=requires(Requirement(T)), - expected_rt=result_type) - - @pytest.mark.parametrize("type_", [str, int, dict, list]) - def test_simple_type_only(self, type_): - def foo(a: type_): pass - check_extract(foo, - expected_rq=requires(Requirement('a', type_=type_)), + expected_rq=RequiresType((Requirement(T),)), expected_rt=result_type) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=requires(Requirement('b', name='b', type_=str)), + expected_rq=RequiresType((Requirement('b', name='b'),)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, - expected_rq=requires(Requirement('b', name='b', type_=str, default=1)), + expected_rq=RequiresType((Requirement('b', name='b', default=1),)), expected_rt=result_type) @@ -461,7 +478,11 @@ def foo(a: r1, b, c=r3): pass check_extract(foo, - expected_rq=requires(r1, b=r2, c=r3), + expected_rq=RequiresType(( + Requirement('a'), + Requirement('b', target='b'), + Requirement('c', target='c'), + )), expected_rt=result_type) def test_declaration_priorities(self): @@ -474,5 +495,9 @@ def foo(a: r2 = r3, b: str = r2, c = r3): pass check_extract(foo, - expected_rq=requires(r1, b=r2, c=r3), + expected_rq=RequiresType(( + Requirement('a', target='a'), + Requirement('b', target='b'), + Requirement('c', target='c'), + )), expected_rt=result_type) From d190666f91b3f4e8ab62321a4f481de34db33e6b Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 26 Feb 2020 07:28:46 +0000 Subject: [PATCH 047/159] Allow Value to support getting attributes with internal names. --- mush/declarations.py | 10 +++++++++- mush/tests/test_declarations.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/mush/declarations.py b/mush/declarations.py index d68996a..4101be6 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -68,10 +68,18 @@ class Value: def __init__(self, key: ResourceKey, *, default: Any = missing): self.requirement = Requirement(key, default=default) - def __getattr__(self, name): + def attr(self, name): + """ + If you need to get an attribute called either ``attr`` or ``item`` + then you will need to call this method instead of using the + generating behaviour. + """ self.requirement.ops.append(ValueAttrOp(name)) return self + def __getattr__(self, name): + return self.attr(name) + def __getitem__(self, name): self.requirement.ops.append(ValueItemOp(name)) return self diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index a2fe61e..3145fa9 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -100,6 +100,21 @@ def check_ops(value, data, *, expected): compare(expected, actual=data) +class TestValue: + + @pytest.mark.parametrize("name", ['attr', 'requirement']) + def test_attr_special_name(self, name): + v = Value('foo') + assert v.attr(name) is v + compare(v.requirement.ops, [ValueAttrOp(name)]) + + @pytest.mark.parametrize("name", ['attr', 'requirement']) + def test_item_special_name(self, name): + v = Value('foo') + assert v[name] is v + compare(v.requirement.ops, [ValueItemOp(name)]) + + class TestItem: def test_single(self): From 6cdc4ffd7e3b6f82717c11b3ac98e6bb931d5ca6 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 26 Feb 2020 07:54:24 +0000 Subject: [PATCH 048/159] Try not to ever mutate declarations that are passed in. --- mush/declarations.py | 10 ++++++++++ mush/extraction.py | 5 ++++- mush/tests/test_callpoints.py | 8 +++++--- mush/tests/test_declarations.py | 8 ++++++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index 4101be6..d766e72 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,3 +1,4 @@ +from copy import copy from enum import Enum, auto from itertools import chain from typing import Type, Callable, NewType, Union, Any, List, Optional @@ -37,6 +38,14 @@ def __init__(self, key, name=None, type_=None, default=missing, target=None): #: has been obtained. self.ops: List['ValueOp'] = [] + def clone(self): + """ + Create a copy of this requirement, so it can be mutated + """ + obj = copy(self) + obj.ops = list(self.ops) + return obj + def value_repr(self): key = name_or_repr(self.key) if self.ops or self.default is not missing: @@ -151,6 +160,7 @@ def requires(*args, **kw): if isinstance(possible, Value): possible = possible.requirement if isinstance(possible, Requirement): + possible = possible.clone() possible.target = target requirement = possible else: diff --git a/mush/extraction.py b/mush/extraction.py index 2f82ee5..05745d4 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -56,6 +56,7 @@ def extract_requires(obj: Callable, explicit=None): requirement = Requirement(key, default=default) if p.kind is p.KEYWORD_ONLY: + requirement = requirement.clone() requirement.target = p.name by_name[name] = requirement @@ -82,11 +83,13 @@ def extract_requires(obj: Callable, explicit=None): return nothing needs_target = False - for requirement in by_name.values(): + for name, requirement in by_name.items(): if requirement.target is not None: needs_target = True elif needs_target: + requirement = requirement.clone() requirement.target = requirement.name + by_name[name] = requirement return RequiresType(by_name.values()) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 18ddf29..e8f14de 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -1,7 +1,7 @@ from functools import update_wrapper from unittest import TestCase -from mock import Mock +from mock import Mock, call from testfixtures import compare from mush.callpoints import CallPoint @@ -27,7 +27,8 @@ def foo(a1): pass rt = returns('bar') result = CallPoint(foo, rq, rt)(self.context) compare(result, self.context.extract.return_value) - self.context.extract.assert_called_with(foo, rq, rt) + compare(tuple(self.context.extract.mock_calls[0].args), + expected=(foo, rq, rt)) def test_extract_from_decorations(self): rq = requires('foo') @@ -75,7 +76,8 @@ def foo(a1): pass result = CallPoint(foo, requires=rq, returns=rt)(self.context) compare(result, self.context.extract.return_value) - self.context.extract.assert_called_with(foo, rq, rt) + compare(tuple(self.context.extract.mock_calls[0].args), + expected=(foo, rq, rt)) def test_repr_minimal(self): def foo(): pass diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 3145fa9..f98bd51 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -1,6 +1,7 @@ from functools import partial from unittest import TestCase +import pytest from mock import Mock from testfixtures import compare, ShouldRaise @@ -93,6 +94,13 @@ def test_repr_maximal(self): expected="Requirement(Value('foo', default=None).bar, " "name='n', type_='ty', target='ta')") + def test_clone(self): + r = Value('foo').bar.requirement + r_ = r.clone() + assert r_ is not r + assert r_.ops is not r.ops + compare(r_, expected=r) + def check_ops(value, data, *, expected): for op in value.requirement.ops: From ead652b0de199b87171e4b4804b49978ce9a5adb Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 26 Feb 2020 08:17:14 +0000 Subject: [PATCH 049/159] extract type and default information even when a declaration is present --- mush/declarations.py | 7 +++++-- mush/extraction.py | 36 ++++++++++++++++++++++++--------- mush/tests/test_declarations.py | 11 +++++++--- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index d766e72..ba49ba7 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -72,10 +72,13 @@ class Value: A default may be specified, which will be used if the specified resource is not available. + + A type may also be explicitly specified, but you probably shouldn't + ever use this. """ - def __init__(self, key: ResourceKey, *, default: Any = missing): - self.requirement = Requirement(key, default=default) + def __init__(self, key: ResourceKey, *, type_: type = None, default: Any = missing): + self.requirement = Requirement(key, type_=type_, default=default) def attr(self, name): """ diff --git a/mush/extraction.py b/mush/extraction.py index 05745d4..0315a9f 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -4,7 +4,7 @@ update_wrapper as functools_update_wrapper, partial ) -from inspect import signature +from inspect import signature, Parameter from typing import Callable from .declarations import ( @@ -15,6 +15,10 @@ ) from .markers import missing +EMPTY = Parameter.empty +#: For these types, prefer the name instead of the type. +SIMPLE_TYPES = (str, int, dict, list) + def _unpack_requires(by_name, by_index, requires_): @@ -42,21 +46,33 @@ def extract_requires(obj: Callable, explicit=None): if is_partial and p.name in obj.keywords: continue - if isinstance(p.default, Requirement): - requirement = p.default - elif isinstance(p.default, Value): - requirement = p.default.requirement + type_ = p.annotation + default = p.default + if isinstance(default, Requirement): + requirement = default + default = EMPTY + elif isinstance(default, Value): + requirement = default.requirement + default = EMPTY elif isinstance(p.annotation, Requirement): requirement = p.annotation + type_ = requirement.type elif isinstance(p.annotation, Value): requirement = p.annotation.requirement + type_ = requirement.type else: - key = p.name if p.annotation is p.empty else p.annotation - default = missing if p.default is p.empty else p.default - requirement = Requirement(key, default=default) - + if not p.annotation is EMPTY: + key = p.annotation + else: + key = p.name + requirement = Requirement(key) + + requirement = requirement.clone() + if requirement.type is None and type_ is not EMPTY and isinstance(type_, type): + requirement.type = type_ + if requirement.default is missing and default is not EMPTY: + requirement.default = default if p.kind is p.KEYWORD_ONLY: - requirement = requirement.clone() requirement.target = p.name by_name[name] = requirement diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index f98bd51..a4f77b6 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -479,15 +479,20 @@ def foo(a: T): pass def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=RequiresType((Requirement('b', name='b'),)), + expected_rq=RequiresType((Requirement('b', name='b', type_=str),)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, - expected_rq=RequiresType((Requirement('b', name='b', default=1),)), + expected_rq=RequiresType((Requirement('b', name='b', type_=str, default=1),)), expected_rt=result_type) + def test_value_annotation_plus_default(self): + def foo(a: Value('b', type_=str) = 1): pass + check_extract(foo, + expected_rq=RequiresType((Requirement('b', name='b', type_=str, default=1),)), + expected_rt=result_type) class TestDeclarationsFromMultipleSources: @@ -520,7 +525,7 @@ def foo(a: r2 = r3, b: str = r2, c = r3): check_extract(foo, expected_rq=RequiresType(( Requirement('a', target='a'), - Requirement('b', target='b'), + Requirement('b', target='b', type_=str), Requirement('c', target='c'), )), expected_rt=result_type) From 634879475601e586d4889654e3268d9be8e110dc Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 26 Feb 2020 08:18:26 +0000 Subject: [PATCH 050/159] Use the parameter name in favour of the type for simple types. These simple types are unlikely to be useful. --- docs/use.txt | 3 ++- mush/extraction.py | 2 +- mush/tests/test_declarations.py | 7 +++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/use.txt b/docs/use.txt index 0db95ea..06dea5f 100755 --- a/docs/use.txt +++ b/docs/use.txt @@ -229,8 +229,9 @@ available: .. code-block:: python - from mush import Runner + from mush import Runner, returns + @returns('name') def my_name_is(): return 'Slim Shady' diff --git a/mush/extraction.py b/mush/extraction.py index 0315a9f..84cb8f3 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -61,7 +61,7 @@ def extract_requires(obj: Callable, explicit=None): requirement = p.annotation.requirement type_ = requirement.type else: - if not p.annotation is EMPTY: + if not (p.annotation is EMPTY or p.annotation in SIMPLE_TYPES): key = p.annotation else: key = p.name diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index a4f77b6..e0da4fd 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -476,6 +476,13 @@ def foo(a: T): pass expected_rq=RequiresType((Requirement(T),)), expected_rt=result_type) + @pytest.mark.parametrize("type_", [str, int, dict, list]) + def test_simple_type_only(self, type_): + def foo(a: type_): pass + check_extract(foo, + expected_rq=RequiresType((Requirement('a', type_=type_),)), + expected_rt=result_type) + def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, From 123482c85d19601dc59f15d5752889f2b1fcdf10 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 26 Feb 2020 08:52:57 +0000 Subject: [PATCH 051/159] remove unused code path --- mush/extraction.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mush/extraction.py b/mush/extraction.py index 84cb8f3..d7d749f 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -87,12 +87,9 @@ def extract_requires(obj: Callable, explicit=None): # explicit if explicit is not None: - if isinstance(explicit, (list, tuple)): - requires_ = requires(*explicit) - elif not isinstance(explicit, RequiresType): - requires_ = requires(explicit) - else: - requires_ = explicit + if not isinstance(explicit, (list, tuple)): + explicit = (explicit,) + requires_ = requires(*explicit) _unpack_requires(by_name, by_index, requires_) if not by_name: From 4c3a0b11c7240fe701b19fb7663c6cf18e03c502 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 27 Feb 2020 07:27:41 +0000 Subject: [PATCH 052/159] Don't do generative getattr with special names. --- mush/declarations.py | 2 ++ mush/tests/test_declarations.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/mush/declarations.py b/mush/declarations.py index ba49ba7..65f1015 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -90,6 +90,8 @@ def attr(self, name): return self def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) return self.attr(name) def __getitem__(self, name): diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index e0da4fd..bbe92c6 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -122,6 +122,12 @@ def test_item_special_name(self, name): assert v[name] is v compare(v.requirement.ops, [ValueItemOp(name)]) + def test_no_special_name_via_getattr(self): + v = Value('foo') + with ShouldRaise(AttributeError): + assert v.__len__ + compare(v.requirement.ops, []) + class TestItem: From f92c7aaf7a0cc3b80fdd302d051331bfbfaf1269 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 27 Feb 2020 08:59:45 +0000 Subject: [PATCH 053/159] Rework declaration extraction to layer all sources. Also document Requirement attributes and lose the inference in the constructor. --- mush/declarations.py | 19 +++-- mush/extraction.py | 81 ++++++++++++------- mush/tests/test_callpoints.py | 26 +++--- mush/tests/test_declarations.py | 138 +++++++++++++++++++++----------- 4 files changed, 175 insertions(+), 89 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index 65f1015..3df287a 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -29,14 +29,18 @@ class Requirement: resolve: RequirementResolver = None def __init__(self, key, name=None, type_=None, default=missing, target=None): + #: The resource key needed for this parameter. self.key: ResourceKey = key - self.name: str = (key if isinstance(key, str) else None) if name is None else name - self.type: type = (key if not isinstance(key, str) else None) if type_ is None else type_ - self.target: Optional[str] = target + #: The name of this parameter in the callable's signature. + self.name: str = name + #: The type required for this parameter. + self.type: type = type_ + #: The default for this parameter, should the required resource be unavailable. self.default: Any = default #: Any operations to be performed on the resource after it #: has been obtained. self.ops: List['ValueOp'] = [] + self.target: Optional[str] = target def clone(self): """ @@ -77,7 +81,11 @@ class Value: ever use this. """ - def __init__(self, key: ResourceKey, *, type_: type = None, default: Any = missing): + def __init__(self, key: ResourceKey=None, *, type_: type = None, default: Any = missing): + if isinstance(key, type): + if type_ is not None: + raise TypeError('type_ cannot be specified if key is a type') + type_ = key self.requirement = Requirement(key, type_=type_, default=default) def attr(self, name): @@ -169,7 +177,8 @@ def requires(*args, **kw): possible.target = target requirement = possible else: - requirement = Requirement(possible, target=target) + type_ = None if isinstance(possible, str) else possible + requirement = Requirement(possible, name=target, type_=type_, target=target) requires_.append(requirement) return requires_ diff --git a/mush/extraction.py b/mush/extraction.py index d7d749f..9ea4613 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -20,18 +20,25 @@ SIMPLE_TYPES = (str, int, dict, list) -def _unpack_requires(by_name, by_index, requires_): +def _apply_requires(by_name, by_index, requires_): - for i, requirement in enumerate(requires_): - if requirement.target is None: + for i, r in enumerate(requires_): + if r.target is None: try: - arg = by_index[i] + name = by_index[i] except IndexError: # case where something takes *args - arg = i + by_name[i] = r.clone() + continue else: - arg = requirement.target - by_name[arg] = requirement + name = r.target + + existing = by_name[name] + existing.key = existing.key if r.key is None else r.key + existing.type = existing.type if r.type is None else r.type + existing.default = existing.default if r.default is missing else r.default + existing.ops = existing.ops if not r.ops else r.ops + existing.target = existing.target if r.target is None else r.target def extract_requires(obj: Callable, explicit=None): @@ -46,34 +53,56 @@ def extract_requires(obj: Callable, explicit=None): if is_partial and p.name in obj.keywords: continue - type_ = p.annotation - default = p.default + name = p.name + if isinstance(p.annotation, type) and not p.annotation is EMPTY: + type_ = p.annotation + else: + type_ = None + key = None + default = missing if p.default is EMPTY else p.default + ops = [] + + requirement = None if isinstance(default, Requirement): requirement = default - default = EMPTY + default = missing elif isinstance(default, Value): requirement = default.requirement - default = EMPTY + default = missing elif isinstance(p.annotation, Requirement): requirement = p.annotation - type_ = requirement.type elif isinstance(p.annotation, Value): requirement = p.annotation.requirement - type_ = requirement.type - else: - if not (p.annotation is EMPTY or p.annotation in SIMPLE_TYPES): + + if requirement is None: + requirement = Requirement(key) + if isinstance(p.annotation, str): key = p.annotation + elif type_ is None or issubclass(type_, SIMPLE_TYPES): + key = name else: - key = p.name - requirement = Requirement(key) + key = type_ + else: + requirement = requirement.clone() + type_ = type_ if requirement.type is None else requirement.type + if requirement.key is not None: + key = requirement.key + elif type_ is None or issubclass(type_, SIMPLE_TYPES): + key = name + else: + key = type_ + default = requirement.default if requirement.default is not missing else default + ops = requirement.ops + + requirement.key = key + requirement.name = name + requirement.type = type_ + requirement.default = default + requirement.ops = ops - requirement = requirement.clone() - if requirement.type is None and type_ is not EMPTY and isinstance(type_, type): - requirement.type = type_ - if requirement.default is missing and default is not EMPTY: - requirement.default = default if p.kind is p.KEYWORD_ONLY: requirement.target = p.name + by_name[name] = requirement by_index = list(by_name) @@ -83,26 +112,24 @@ def extract_requires(obj: Callable, explicit=None): if mush_declarations is not None: requires_ = mush_declarations.get('requires') if requires_ is not None: - _unpack_requires(by_name, by_index, requires_) + _apply_requires(by_name, by_index, requires_) # explicit if explicit is not None: if not isinstance(explicit, (list, tuple)): explicit = (explicit,) requires_ = requires(*explicit) - _unpack_requires(by_name, by_index, requires_) + _apply_requires(by_name, by_index, requires_) if not by_name: return nothing needs_target = False - for name, requirement in by_name.items(): + for requirement in by_name.values(): if requirement.target is not None: needs_target = True elif needs_target: - requirement = requirement.clone() requirement.target = requirement.name - by_name[name] = requirement return RequiresType(by_name.values()) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index e8f14de..fdd4a91 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -1,11 +1,11 @@ from functools import update_wrapper from unittest import TestCase -from mock import Mock, call +from mock import Mock from testfixtures import compare from mush.callpoints import CallPoint -from mush.declarations import requires, returns, RequiresType +from mush.declarations import requires, returns, RequiresType, Requirement from mush.extraction import update_wrapper @@ -28,7 +28,9 @@ def foo(a1): pass result = CallPoint(foo, rq, rt)(self.context) compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), - expected=(foo, rq, rt)) + expected=(foo, + RequiresType((Requirement('foo', name='a1'),)), + rt)) def test_extract_from_decorations(self): rq = requires('foo') @@ -40,7 +42,10 @@ def foo(a1): pass result = CallPoint(foo)(self.context) compare(result, self.context.extract.return_value) - self.context.extract.assert_called_with(foo, rq, rt) + compare(tuple(self.context.extract.mock_calls[0].args), + expected=(foo, + RequiresType((Requirement('foo', name='a1'),)), + returns('bar'))) def test_extract_from_decorated_class(self): @@ -64,20 +69,21 @@ def foo(prefix): self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) result = CallPoint(foo)(self.context) - compare(result, expected=('the answer', rq, rt)) + compare(result, expected=('the answer', + RequiresType((Requirement('foo', name='prefix'),)), + rt)) def test_explicit_trumps_decorators(self): @requires('foo') @returns('bar') def foo(a1): pass - rq = requires('baz') - rt = returns('bob') - - result = CallPoint(foo, requires=rq, returns=rt)(self.context) + result = CallPoint(foo, requires('baz'), returns('bob'))(self.context) compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), - expected=(foo, rq, rt)) + expected=(foo, + RequiresType((Requirement('baz', name='a1'),)), + returns('bob'))) def test_repr_minimal(self): def foo(): pass diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index bbe92c6..bd406a6 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -44,10 +44,10 @@ def test_types(self): r = requires(Type1, Type2, x=Type3, y=Type4) compare(repr(r), 'requires(Type1, Type2, x=Type3, y=Type4)') compare(r, expected=[ - Requirement(Type1), - Requirement(Type2), - Requirement(Type3, target='x'), - Requirement(Type4, target='y'), + Requirement(Type1, type_=Type1), + Requirement(Type2, type_=Type2), + Requirement(Type3, name='x', type_=Type3, target='x'), + Requirement(Type4, name='y', type_=Type4, target='y'), ]) def test_strings(self): @@ -56,8 +56,8 @@ def test_strings(self): compare(r, expected=[ Requirement('1'), Requirement('2'), - Requirement('3', target='x'), - Requirement('4', target='y'), + Requirement('3', name='x', target='x'), + Requirement('4', name='y', target='y'), ]) def test_tuple_arg(self): @@ -73,19 +73,15 @@ def test_decorator_paranoid(self): def foo(): return 'bar' - compare(foo.__mush__['requires'], expected=[Requirement(Type1)]) + compare(foo.__mush__['requires'], expected=[Requirement(Type1, type_=Type1)]) compare(foo(), 'bar') class TestRequirement: - def test_repr_minimal_name(self): + def test_repr_minimal(self): compare(repr(Requirement('foo')), - expected="Requirement('foo', name='foo')") - - def test_repr_minimal_type(self): - compare(repr(Requirement(str)), - expected="Requirement(str, type_=)") + expected="Requirement('foo')") def test_repr_maximal(self): r = Requirement('foo', name='n', type_='ty', default=None, target='ta') @@ -128,6 +124,14 @@ def test_no_special_name_via_getattr(self): assert v.__len__ compare(v.requirement.ops, []) + def test_type_from_key(self): + v = Value(str) + compare(v.requirement.type, expected=str) + + def test_key_and_type_cannot_disagree(self): + with ShouldRaise(TypeError('type_ cannot be specified if key is a type')): + Value(key=str, type_=int) + class TestItem: @@ -277,8 +281,8 @@ def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, expected_rq=RequiresType(( - Requirement('a'), - Requirement('b', default=None) + Requirement('a', name='a'), + Requirement('b', name='b', default=None) )), expected_rt=result_type) @@ -287,8 +291,8 @@ class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, expected_rq=RequiresType(( - Requirement('a'), - Requirement('b', default=None) + Requirement('a', name='a'), + Requirement('b', name='b', default=None) )), expected_rt=result_type) @@ -298,8 +302,8 @@ def foo(x, y, z, a=None): pass check_extract( p, expected_rq=RequiresType(( - Requirement('z', target='z'), - Requirement('a', target='a', default=None) + Requirement('z', name='z', target='z'), + Requirement('a', name='a', target='a', default=None) )), expected_rt=result_type ) @@ -310,7 +314,7 @@ def foo(a=None): pass check_extract( p, expected_rq=RequiresType(( - Requirement('a', default=None), + Requirement('a', name='a', default=None), )), expected_rt=result_type ) @@ -359,8 +363,8 @@ def foo(b, a=None): pass check_extract( p, expected_rq=RequiresType(( - Requirement('b'), - Requirement('a', default=None) + Requirement('b', name='b'), + Requirement('a', name='a', default=None) )), expected_rt=result_type ) @@ -372,7 +376,7 @@ def foo(b, a): pass p, # since b is already bound: expected_rq=RequiresType(( - Requirement('a'), + Requirement('a', name='a'), )), expected_rt=result_type ) @@ -383,7 +387,7 @@ def foo(b, a): pass check_extract( p, expected_rq=RequiresType(( - Requirement('b'), + Requirement('b', name='b'), )), expected_rt=result_type ) @@ -395,17 +399,17 @@ def test_extract_from_annotations(self): def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass check_extract(foo, expected_rq=RequiresType(( - Requirement('foo'), - Requirement('b'), - Requirement('bar', default=1), - Requirement('d', default=2) + Requirement('foo', name='a'), + Requirement('b', name='b'), + Requirement('bar', name='c', default=1), + Requirement('d', name='d', default=2) )), expected_rt=returns('bar')) def test_requires_only(self): def foo(a: 'foo'): pass check_extract(foo, - expected_rq=RequiresType((Requirement('foo'),)), + expected_rq=RequiresType((Requirement('foo', name='a'),)), expected_rt=result_type) def test_returns_only(self): @@ -431,7 +435,7 @@ def foo(a: 'foo' = None) -> 'bar': compare(foo(), expected='the answer') check_extract(foo, - expected_rq=RequiresType((Requirement('foo', default=None),)), + expected_rq=RequiresType((Requirement('foo', name='a', default=None),)), expected_rt=returns('bar')) def test_decorator_trumps_annotations(self): @@ -439,7 +443,7 @@ def test_decorator_trumps_annotations(self): @returns('bar') def foo(a: 'x') -> 'y': pass check_extract(foo, - expected_rq=RequiresType((Requirement('foo'),)), + expected_rq=RequiresType((Requirement('foo', name='a'),)), expected_rt=returns('bar')) def test_returns_mapping(self): @@ -458,7 +462,7 @@ def foo() -> rt: pass def test_how_instance_in_annotations(self): def foo(a: Value('config')['db_url']): pass - requirement = Requirement('config') + requirement = Requirement('config', name='a') requirement.ops.append(ValueItemOp('db_url')) check_extract(foo, expected_rq=RequiresType((requirement,)), @@ -468,10 +472,10 @@ def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, expected_rq=RequiresType(( - Requirement('a'), - Requirement('b', default=1), - Requirement('c', target='c'), - Requirement('d', target='d', default=None) + Requirement('a', name='a'), + Requirement('b', name='b', default=1), + Requirement('c', name='c', target='c'), + Requirement('d', name='d', target='d', default=None) )), expected_rt=result_type) @@ -479,34 +483,74 @@ def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=RequiresType((Requirement(T),)), + expected_rq=RequiresType((Requirement(T, name='a', type_=T),)), expected_rt=result_type) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=RequiresType((Requirement('a', type_=type_),)), + expected_rq=RequiresType((Requirement('a', name='a', type_=type_),)), expected_rt=result_type) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=RequiresType((Requirement('b', name='b', type_=str),)), + expected_rq=RequiresType((Requirement('b', name='a', type_=str),)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, - expected_rq=RequiresType((Requirement('b', name='b', type_=str, default=1),)), + expected_rq=RequiresType(( + Requirement('b', name='a', type_=str, default=1), + )), expected_rt=result_type) def test_value_annotation_plus_default(self): def foo(a: Value('b', type_=str) = 1): pass check_extract(foo, - expected_rq=RequiresType((Requirement('b', name='b', type_=str, default=1),)), + expected_rq=RequiresType(( + Requirement('b', name='a', type_=str, default=1), + )), + expected_rt=result_type) + + def test_value_annotation_just_type_in_value_key_plus_default(self): + def foo(a: Value(str) = 1): pass + check_extract(foo, + expected_rq=RequiresType(( + Requirement(key=str, name='a', type_=str, default=1), + )), expected_rt=result_type) + def test_value_annotation_just_type_plus_default(self): + def foo(a: Value(type_=str) = 1): pass + check_extract(foo, + expected_rq=RequiresType(( + Requirement(key='a', name='a', type_=str, default=1), + )), + expected_rt=result_type) + + def test_value_unspecified_with_type(self): + class T1: pass + def foo(a: T1 = Value()): pass + check_extract(foo, + expected_rq=RequiresType((Requirement(key=T1, name='a', type_=T1),)), + expected_rt=result_type) + + def test_value_unspecified_with_simple_type(self): + def foo(a: str = Value()): pass + check_extract(foo, + expected_rq=RequiresType((Requirement(key='a', name='a', type_=str),)), + expected_rt=result_type) + + def test_value_unspecified(self): + def foo(a = Value()): pass + check_extract(foo, + expected_rq=RequiresType((Requirement(key='a', name='a'),)), + expected_rt=result_type) + + class TestDeclarationsFromMultipleSources: def test_declarations_from_different_sources(self): @@ -520,9 +564,9 @@ def foo(a: r1, b, c=r3): check_extract(foo, expected_rq=RequiresType(( - Requirement('a'), - Requirement('b', target='b'), - Requirement('c', target='c'), + Requirement('a', name='a'), + Requirement('b', name='b', target='b'), + Requirement('c', name='c', target='c'), )), expected_rt=result_type) @@ -537,8 +581,8 @@ def foo(a: r2 = r3, b: str = r2, c = r3): check_extract(foo, expected_rq=RequiresType(( - Requirement('a', target='a'), - Requirement('b', target='b', type_=str), - Requirement('c', target='c'), + Requirement('a', name='a', target='a'), + Requirement('b', name='b', target='b', type_=str), + Requirement('c', name='c', target='c'), )), expected_rt=result_type) From 7f647f7a431182eca74e154078518716ff76f797 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 27 Feb 2020 10:07:45 +0000 Subject: [PATCH 054/159] Add support for overriding the default type used for individual parameter requirements. --- mush/asyncio.py | 7 ++++--- mush/context.py | 20 +++++++++++--------- mush/extraction.py | 8 +++++--- mush/tests/test_async_context.py | 16 ++++++++++++++++ mush/tests/test_context.py | 26 ++++++++++++++++++++++++++ mush/tests/test_declarations.py | 20 ++++++++++++++++++++ 6 files changed, 82 insertions(+), 15 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 08ac0da..5dbd995 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -1,8 +1,9 @@ import asyncio from functools import partial +from typing import Type from mush import Context -from mush.declarations import ResourceKey +from mush.declarations import ResourceKey, Requirement async def ensure_async(func, *args, **kw): @@ -28,8 +29,8 @@ def get(self, key: ResourceKey, default=None): class AsyncContext(Context): - def __init__(self): - super().__init__() + def __init__(self, default_requirement_type: Type[Requirement] = Requirement): + super().__init__(default_requirement_type) self._sync_context = SyncContext(self, asyncio.get_event_loop()) def _context_for(self, obj): diff --git a/mush/context.py b/mush/context.py index 02e449d..13045bb 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,8 +1,8 @@ -from typing import Optional +from typing import Optional, Type from .declarations import ( - RequiresType, ResourceKey, ResourceValue, ResourceResolver -) + RequiresType, ResourceKey, ResourceValue, ResourceResolver, + Requirement) from .extraction import extract_requires from .markers import missing @@ -78,7 +78,8 @@ class Context: _parent = None - def __init__(self): + def __init__(self, default_requirement_type: Type[Requirement] = Requirement): + self.default_requirement_type = default_requirement_type self._store = {} def add(self, @@ -127,11 +128,10 @@ def extract(self, obj, requires, returns): self.add(obj, type) return result - @staticmethod - def _resolve(obj, requires, args, kw, context): + def _resolve(self, obj, requires, args, kw, context): if requires.__class__ is not RequiresType: - requires = extract_requires(obj, requires) + requires = extract_requires(obj, requires, self.default_requirement_type) for requirement in requires: o = yield requirement @@ -191,7 +191,9 @@ def get(self, key: ResourceKey, default=None): return resolvable.resolver(self, default) return resolvable.value - def nest(self): - nested = type(self)() + def nest(self, default_requirement_type: Type[Requirement] = None): + if default_requirement_type is None: + default_requirement_type = self.default_requirement_type + nested = self.__class__(default_requirement_type) nested._parent = self return nested diff --git a/mush/extraction.py b/mush/extraction.py index 9ea4613..998f5ee 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -5,7 +5,7 @@ partial ) from inspect import signature, Parameter -from typing import Callable +from typing import Callable, Type from .declarations import ( Value, @@ -41,7 +41,9 @@ def _apply_requires(by_name, by_index, requires_): existing.target = existing.target if r.target is None else r.target -def extract_requires(obj: Callable, explicit=None): +def extract_requires(obj: Callable, + explicit: RequiresType=None, + default_requirement_type: Type[Requirement] = Requirement): # from annotations is_partial = isinstance(obj, partial) by_name = {} @@ -75,7 +77,7 @@ def extract_requires(obj: Callable, explicit=None): requirement = p.annotation.requirement if requirement is None: - requirement = Requirement(key) + requirement = default_requirement_type(key) if isinstance(p.annotation, str): key = p.annotation elif type_ is None or issubclass(type_, SIMPLE_TYPES): diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 6b337a7..7f0afdf 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -177,3 +177,19 @@ def foo(bar: FromRequest('bar')): context = AsyncContext() context.add({'bar': 'foo'}, provides='request') compare(await context.call(foo), expected='foo') + + +@pytest.mark.asyncio +async def test_default_custom_requirement(): + + + class FromRequest(Requirement): + async def resolve(self, context): + return (await context.get('request'))[self.key] + + def foo(bar): + return bar + + context = AsyncContext(FromRequest) + context.add({'bar': 'foo'}, provides='request') + compare(await context.call(foo), expected='foo') diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index cdf3f4b..dc9be10 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -354,6 +354,19 @@ def test_nest(self): compare(c1.get('b'), expected=None) compare(c1.get('c'), expected='c') + def test_nest_with_overridden_default_requirement_type(self): + class FromRequest(Requirement): pass + c1 = Context(default_requirement_type=FromRequest) + c2 = c1.nest() + assert c2.default_requirement_type is FromRequest + + def test_nest_with_explicit_default_requirement_type(self): + class Requirement1(Requirement): pass + class Requirement2(Requirement): pass + c1 = Context(default_requirement_type=Requirement1) + c2 = c1.nest(default_requirement_type=Requirement2) + assert c2.default_requirement_type is Requirement2 + def test_custom_requirement(self): class FromRequest(Requirement): @@ -380,3 +393,16 @@ def foo(bar: FromRequest('bar')): context.add({}, provides='request') with ShouldRaise(ContextError("No 'bar' in context")): compare(context.call(foo)) + + def test_default_custom_requirement(self): + + class FromRequest(Requirement): + def resolve(self, context): + return context.get('request')[self.key] + + def foo(bar): + return bar + + context = Context(default_requirement_type=FromRequest) + context.add({'bar': 'foo'}, provides='request') + compare(context.call(foo), expected='foo') diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index bd406a6..162c73c 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -550,6 +550,26 @@ def foo(a = Value()): pass expected_rq=RequiresType((Requirement(key='a', name='a'),)), expected_rt=result_type) + def test_default_requirement_type(self): + def foo(x: str = None): pass + + class FromRequest(Requirement): pass + + rq = extract_requires(foo, default_requirement_type=FromRequest) + compare(rq, strict=True, expected=RequiresType(( + FromRequest(key='x', name='x', type_=str, default=None), + ))) + + def test_default_requirement_not_used(self): + def foo(x: str = Value(default=None)): pass + + class FromRequest(Requirement): pass + + rq = extract_requires(foo, default_requirement_type=FromRequest) + compare(rq, strict=True, expected=RequiresType(( + Requirement(key='x', name='x', type_=str, default=None), + ))) + class TestDeclarationsFromMultipleSources: From 9c6d2ed112d8587fd4804861ed67bb2a35be76f0 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 27 Feb 2020 11:36:51 +0000 Subject: [PATCH 055/159] Add more stuff to the public api. --- mush/__init__.py | 6 ++++-- mush/markers.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index 587b848..69cf5c6 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -2,9 +2,10 @@ from .declarations import ( requires, returns_result_type, returns_mapping, returns_sequence, returns, - Value, nothing + Value, Requirement, nothing ) from .extraction import extract_requires, extract_returns, update_wrapper +from .markers import missing from .plug import Plug from .context import Context, ContextError from .asyncio import AsyncContext @@ -14,7 +15,8 @@ 'Runner', 'requires', 'returns_result_type', 'returns_mapping', 'returns_sequence', 'returns', - 'Value', + 'Value', 'Requirement', 'Plug', 'nothing', 'update_wrapper', + 'missing', ] diff --git a/mush/markers.py b/mush/markers.py index 77d71e7..0b87c8d 100644 --- a/mush/markers.py +++ b/mush/markers.py @@ -8,4 +8,6 @@ def __repr__(self): not_specified = Marker('not_specified') + +#: A sentinel object to indicate that a value is missing. missing = Marker('missing') From 693bf8e11c68098f679e1958cfe056e99006ce03 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 28 Feb 2020 07:28:47 +0000 Subject: [PATCH 056/159] introduce support for using types as resource keys. --- mush/declarations.py | 17 ++++++++++++----- mush/tests/test_context.py | 11 ++++++++++- mush/tests/test_declarations.py | 17 ++++++++++++++--- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index 3df287a..fd393a5 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,7 +1,7 @@ from copy import copy from enum import Enum, auto from itertools import chain -from typing import Type, Callable, NewType, Union, Any, List, Optional +from typing import Type, Callable, NewType, Union, Any, List, Optional, _type_check from .markers import missing @@ -291,7 +291,14 @@ class DeclarationsFrom(Enum): def valid_decoration_types(*objs): for obj in objs: - if not isinstance(obj, VALID_DECORATION_TYPES): - raise TypeError( - repr(obj)+" is not a type or label" - ) + if isinstance(obj, VALID_DECORATION_TYPES): + continue + try: + _type_check(obj, '') + except TypeError: + pass + else: + continue + raise TypeError( + repr(obj)+" is not a valid decoration type" + ) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index dc9be10..a6b12ba 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -161,7 +161,7 @@ def foo(obj): return obj with ShouldRaise(TypeError( "(, " ") " - "is not a type or label" + "is not a valid decoration type" )): context.call(foo, requires((TheType, TheType))) @@ -336,6 +336,15 @@ def test_get_present(self): context.add('bar', provides='foo') compare(context.get('foo'), expected='bar') + def test_get_type(self): + context = Context() + context.add(['bar'], provides=List[str]) + compare(context.get(List[str]), expected=['bar']) + compare(context.get(List[int]), expected=None) + compare(context.get(List), expected=None) + # nb: this might be surprising: + compare(context.get(list), expected=None) + def test_get_missing(self): context = Context() compare(context.get('foo'), expected=None) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 162c73c..0ebe386 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -1,4 +1,5 @@ from functools import partial +from typing import Tuple from unittest import TestCase import pytest @@ -60,12 +61,17 @@ def test_strings(self): Requirement('4', name='y', target='y'), ]) + def test_typing(self): + r = requires(Tuple[str]) + compare(repr(r), "requires(typing.Tuple[str])") + compare(r, expected=[Requirement(Tuple[str], type_=Tuple[str])]) + def test_tuple_arg(self): - with ShouldRaise(TypeError("('1', '2') is not a type or label")): + with ShouldRaise(TypeError("('1', '2') is not a valid decoration type")): requires(('1', '2')) def test_tuple_kw(self): - with ShouldRaise(TypeError("('1', '2') is not a type or label")): + with ShouldRaise(TypeError("('1', '2') is not a valid decoration type")): requires(foo=('1', '2')) def test_decorator_paranoid(self): @@ -204,6 +210,11 @@ def test_string(self): compare(repr(r), "returns('bar')") compare(dict(r.process('foo')), {'bar': 'foo'}) + def test_typing(self): + r = returns(Tuple[str]) + compare(repr(r), 'returns(typing.Tuple[str])') + compare(dict(r.process('foo')), {Tuple[str]: 'foo'}) + def test_sequence(self): r = returns(Type1, 'bar') compare(repr(r), "returns(Type1, 'bar')") @@ -220,7 +231,7 @@ def foo(): def test_bad_type(self): with ShouldRaise(TypeError( - '[] is not a type or label' + '[] is not a valid decoration type' )): @returns([]) def foo(): pass From 71bb958b600d978352498a88ced8aaab4d16c762 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 28 Feb 2020 08:07:23 +0000 Subject: [PATCH 057/159] Aggressively cache the results of extracting declarations. It's an expensive process, and there's the mush=False option if it can't be done. This commit also types up and makes more flexible the Context.extract() method. --- mush/asyncio.py | 17 ++++++---- mush/context.py | 40 +++++++++++++++++------ mush/tests/test_async_context.py | 51 ++++++++++++++++++++++++++++- mush/tests/test_context.py | 55 +++++++++++++++++++++++++++++++- mush/tests/test_declarations.py | 4 +-- 5 files changed, 146 insertions(+), 21 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 5dbd995..bf99ed5 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -1,9 +1,9 @@ import asyncio from functools import partial -from typing import Type +from typing import Type, Callable from mush import Context -from mush.declarations import ResourceKey, Requirement +from mush.declarations import ResourceKey, Requirement, RequiresType, ReturnsType async def ensure_async(func, *args, **kw): @@ -43,10 +43,10 @@ async def get(self, key: ResourceKey, default=None): return await ensure_async(r, self._context_for(r), default) return resolvable.value - async def call(self, obj, requires=None): + async def call(self, obj: Callable, requires: RequiresType = None, *, mush: bool = True): args = [] kw = {} - resolving = self._resolve(obj, requires, args, kw, self._context_for(obj)) + resolving = self._resolve(obj, requires, args, kw, self._context_for(obj), mush) for requirement in resolving: r = requirement.resolve if r is not None: @@ -56,8 +56,11 @@ async def call(self, obj, requires=None): resolving.send(o) return await ensure_async(obj, *args, **kw) - async def extract(self, obj, requires, returns): + async def extract(self, + obj: Callable, + requires: RequiresType = None, + returns: ReturnsType = None, + mush: bool = True): result = await self.call(obj, requires) - for type, obj in returns.process(result): - self.add(obj, type) + self._process(obj, result, returns, mush) return result diff --git a/mush/context.py b/mush/context.py index 13045bb..bf2d259 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,9 +1,9 @@ -from typing import Optional, Type +from typing import Optional, Type, Callable from .declarations import ( RequiresType, ResourceKey, ResourceValue, ResourceResolver, - Requirement) -from .extraction import extract_requires + Requirement, set_mush, ReturnsType) +from .extraction import extract_requires, extract_returns from .markers import missing NONE_TYPE = type(None) @@ -122,16 +122,36 @@ def __repr__(self): bits.append('\n') return '' % ''.join(bits) - def extract(self, obj, requires, returns): - result = self.call(obj, requires) + def _process(self, obj, result, returns, mush): + if returns is None: + returns = getattr(obj, '__mush__', {}).get('returns_final') + if returns is None: + returns = extract_returns(obj, explicit=None) + if mush: + set_mush(obj, 'returns_final', returns) + for type, obj in returns.process(result): self.add(obj, type) + + def extract(self, + obj: Callable, + requires: RequiresType = None, + returns: ReturnsType = None, + mush: bool = True): + result = self.call(obj, requires) + self._process(obj, result, returns, mush) return result - def _resolve(self, obj, requires, args, kw, context): + def _resolve(self, obj, requires, args, kw, context, mush): - if requires.__class__ is not RequiresType: - requires = extract_requires(obj, requires, self.default_requirement_type) + if requires is None: + requires = getattr(obj, '__mush__', {}).get('requires_final') + if requires is None: + requires = extract_requires(obj, + explicit=None, + default_requirement_type=self.default_requirement_type) + if mush: + set_mush(obj, 'requires_final', requires) for requirement in requires: o = yield requirement @@ -157,10 +177,10 @@ def _resolve(self, obj, requires, args, kw, context): yield - def call(self, obj, requires=None): + def call(self, obj: Callable, requires: RequiresType = None, *, mush: bool = True): args = [] kw = {} - resolving = self._resolve(obj, requires, args, kw, self) + resolving = self._resolve(obj, requires, args, kw, self, mush) for requirement in resolving: if requirement.resolve: o = requirement.resolve(self) diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 7f0afdf..6087f0b 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -1,10 +1,15 @@ import asyncio +from typing import Tuple + import pytest from mush import AsyncContext, Context, requires, returns -from mush.declarations import Requirement +from mush.context import ResolvableValue +from mush.declarations import Requirement, RequiresType from testfixtures import compare +from mush.tests.test_context import TheType + @pytest.mark.asyncio async def test_get_is_async(): @@ -116,6 +121,23 @@ def it(context: AsyncContext): compare(await context.call(it), expected='bar') +@pytest.mark.asyncio +async def test_call_default_mush(): + context = AsyncContext() + def foo(): pass + await context.call(foo) + compare(foo.__mush__['requires_final'], expected=RequiresType()) + + +@pytest.mark.asyncio +async def test_call_no_mush(): + context = AsyncContext() + def foo(): + pass + await context.call(foo, mush=False) + assert not hasattr(foo, '__mush__') + + @pytest.mark.asyncio async def test_extract_is_async(): context = AsyncContext() @@ -149,6 +171,33 @@ def it(context): compare(await context.get('baz'), expected='foobar') +@pytest.mark.asyncio +async def test_extract_minimal(): + o = TheType() + def foo() -> TheType: + return o + context = AsyncContext() + result = await context.extract(foo) + assert result is o + compare({TheType: ResolvableValue(o)}, actual=context._store) + compare(foo.__mush__['returns_final'], expected=returns(TheType)) + + +@pytest.mark.asyncio +async def test_extract_maximal(): + def foo(*args): + return args + context = AsyncContext() + context.add('a') + result = await context.extract(foo, requires(str), returns(Tuple[str]), mush=False) + compare(result, expected=('a',)) + compare({ + str: ResolvableValue('a'), + Tuple[str]: ResolvableValue(('a',)), + }, actual=context._store) + assert not hasattr(foo, '__mush__') + + @pytest.mark.asyncio async def test_custom_requirement_async_resolve(): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index a6b12ba..f64c13d 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,3 +1,4 @@ +from typing import Tuple, List from unittest import TestCase from mock import Mock @@ -5,7 +6,7 @@ from testfixtures import ShouldRaise, compare from mush import Context, ContextError, requires, returns, nothing, returns_mapping -from mush.declarations import Requirement, Value, missing +from mush.declarations import Requirement, Value, missing, RequiresType class TheType(object): @@ -267,6 +268,58 @@ def foo(): result = context.call(foo) compare(result, expected=None) + def test_call_default_mush(self): + context = Context() + def foo(): pass + context.call(foo) + compare(foo.__mush__['requires_final'], expected=RequiresType()) + + def test_call_explict_mush(self): + context = Context() + def foo(): + pass + context.call(foo, mush=True) + compare(foo.__mush__['requires_final'], expected=RequiresType()) + + def test_call_explict_mush_plus_explicit_requires(self): + context = Context() + context.add('a') + def foo(*args): + return args + result = context.call(foo, requires(str), mush=True) + compare(result, ('a',)) + assert not hasattr(foo, '__mush__') + + def test_call_no_mush(self): + context = Context() + def foo(): + pass + context.call(foo, mush=False) + assert not hasattr(foo, '__mush__') + + def test_extract_minimal(self): + o = TheType() + def foo() -> TheType: + return o + context = Context() + result = context.extract(foo) + assert result is o + compare({TheType: ResolvableValue(o)}, actual=context._store) + compare(foo.__mush__['returns_final'], expected=returns(TheType)) + + def test_extract_maximal(self): + def foo(*args): + return args + context = Context() + context.add('a') + result = context.extract(foo, requires(str), returns(Tuple[str]), mush=False) + compare(result, expected=('a',)) + compare({ + str: ResolvableValue('a'), + Tuple[str]: ResolvableValue(('a',)), + }, actual=context._store) + assert not hasattr(foo, '__mush__') + def test_returns_single(self): def foo(): return 'bar' diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 0ebe386..b2e9509 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -163,7 +163,7 @@ def test_missing_key(self): def test_passed_missing(self): c = Context() c.add({}, provides='key') - compare(c.call(lambda x: x, requires=Value('key', default=1)['foo']['bar']), + compare(c.call(lambda x: x, requires(Value('key', default=1)['foo']['bar'])), expected=1) def test_bad_type(self): @@ -194,7 +194,7 @@ def test_missing(self): def test_passed_missing(self): c = Context() c.add(object(), provides='key') - compare(c.call(lambda x: x, requires=Value('key', default=1).foo.bar), + compare(c.call(lambda x: x, requires(Value('key', default=1).foo.bar)), expected=1) From 834eb28adaa0db99ec8b30e313e2feece88055d6 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 28 Feb 2020 17:59:31 +0000 Subject: [PATCH 058/159] Pivot to caching declarations in the context. The working assumption is that the callable an application will need will be finite and relatively small, so the cache should not end up being huge. This also opens possibilities of plugging in different extraction at the context level. --- mush/asyncio.py | 9 ++++---- mush/context.py | 33 ++++++++++++++------------- mush/tests/test_async_context.py | 23 +++++++------------ mush/tests/test_context.py | 38 ++++++++++++++------------------ 4 files changed, 44 insertions(+), 59 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index bf99ed5..fedee1d 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -43,10 +43,10 @@ async def get(self, key: ResourceKey, default=None): return await ensure_async(r, self._context_for(r), default) return resolvable.value - async def call(self, obj: Callable, requires: RequiresType = None, *, mush: bool = True): + async def call(self, obj: Callable, requires: RequiresType = None): args = [] kw = {} - resolving = self._resolve(obj, requires, args, kw, self._context_for(obj), mush) + resolving = self._resolve(obj, requires, args, kw, self._context_for(obj)) for requirement in resolving: r = requirement.resolve if r is not None: @@ -59,8 +59,7 @@ async def call(self, obj: Callable, requires: RequiresType = None, *, mush: bool async def extract(self, obj: Callable, requires: RequiresType = None, - returns: ReturnsType = None, - mush: bool = True): + returns: ReturnsType = None): result = await self.call(obj, requires) - self._process(obj, result, returns, mush) + self._process(obj, result, returns) return result diff --git a/mush/context.py b/mush/context.py index bf2d259..a729463 100644 --- a/mush/context.py +++ b/mush/context.py @@ -2,7 +2,8 @@ from .declarations import ( RequiresType, ResourceKey, ResourceValue, ResourceResolver, - Requirement, set_mush, ReturnsType) + Requirement, ReturnsType +) from .extraction import extract_requires, extract_returns from .markers import missing @@ -81,6 +82,8 @@ class Context: def __init__(self, default_requirement_type: Type[Requirement] = Requirement): self.default_requirement_type = default_requirement_type self._store = {} + self._requires_cache = {} + self._returns_cache = {} def add(self, resource: Optional[ResourceValue] = None, @@ -122,36 +125,30 @@ def __repr__(self): bits.append('\n') return '' % ''.join(bits) - def _process(self, obj, result, returns, mush): + def _process(self, obj, result, returns): if returns is None: - returns = getattr(obj, '__mush__', {}).get('returns_final') + returns = self._returns_cache.get(obj) if returns is None: returns = extract_returns(obj, explicit=None) - if mush: - set_mush(obj, 'returns_final', returns) + self._returns_cache[obj] = returns for type, obj in returns.process(result): self.add(obj, type) - def extract(self, - obj: Callable, - requires: RequiresType = None, - returns: ReturnsType = None, - mush: bool = True): + def extract(self, obj: Callable, requires: RequiresType = None, returns: ReturnsType = None): result = self.call(obj, requires) - self._process(obj, result, returns, mush) + self._process(obj, result, returns) return result - def _resolve(self, obj, requires, args, kw, context, mush): + def _resolve(self, obj, requires, args, kw, context): if requires is None: - requires = getattr(obj, '__mush__', {}).get('requires_final') + requires = self._requires_cache.get(obj) if requires is None: requires = extract_requires(obj, explicit=None, default_requirement_type=self.default_requirement_type) - if mush: - set_mush(obj, 'requires_final', requires) + self._requires_cache[obj] = requires for requirement in requires: o = yield requirement @@ -177,10 +174,10 @@ def _resolve(self, obj, requires, args, kw, context, mush): yield - def call(self, obj: Callable, requires: RequiresType = None, *, mush: bool = True): + def call(self, obj: Callable, requires: RequiresType = None): args = [] kw = {} - resolving = self._resolve(obj, requires, args, kw, self, mush) + resolving = self._resolve(obj, requires, args, kw, self) for requirement in resolving: if requirement.resolve: o = requirement.resolve(self) @@ -216,4 +213,6 @@ def nest(self, default_requirement_type: Type[Requirement] = None): default_requirement_type = self.default_requirement_type nested = self.__class__(default_requirement_type) nested._parent = self + nested._requires_cache = self._requires_cache + nested._returns_cache = self._returns_cache return nested diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 6087f0b..4c8e119 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -5,7 +5,7 @@ from mush import AsyncContext, Context, requires, returns from mush.context import ResolvableValue -from mush.declarations import Requirement, RequiresType +from mush.declarations import Requirement, RequiresType, returns_result_type from testfixtures import compare from mush.tests.test_context import TheType @@ -122,20 +122,11 @@ def it(context: AsyncContext): @pytest.mark.asyncio -async def test_call_default_mush(): +async def test_call_cache_requires(): context = AsyncContext() def foo(): pass await context.call(foo) - compare(foo.__mush__['requires_final'], expected=RequiresType()) - - -@pytest.mark.asyncio -async def test_call_no_mush(): - context = AsyncContext() - def foo(): - pass - await context.call(foo, mush=False) - assert not hasattr(foo, '__mush__') + compare(context._requires_cache[foo], expected=RequiresType()) @pytest.mark.asyncio @@ -180,7 +171,8 @@ def foo() -> TheType: result = await context.extract(foo) assert result is o compare({TheType: ResolvableValue(o)}, actual=context._store) - compare(foo.__mush__['returns_final'], expected=returns(TheType)) + compare(context._requires_cache[foo], expected=RequiresType()) + compare(context._returns_cache[foo], expected=returns(TheType)) @pytest.mark.asyncio @@ -189,13 +181,14 @@ def foo(*args): return args context = AsyncContext() context.add('a') - result = await context.extract(foo, requires(str), returns(Tuple[str]), mush=False) + result = await context.extract(foo, requires(str), returns(Tuple[str])) compare(result, expected=('a',)) compare({ str: ResolvableValue('a'), Tuple[str]: ResolvableValue(('a',)), }, actual=context._store) - assert not hasattr(foo, '__mush__') + compare(context._requires_cache, expected={}) + compare(context._returns_cache, expected={}) @pytest.mark.asyncio diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index f64c13d..cfa06a3 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -268,34 +268,20 @@ def foo(): result = context.call(foo) compare(result, expected=None) - def test_call_default_mush(self): + def test_call_caches_requires(self): context = Context() def foo(): pass context.call(foo) - compare(foo.__mush__['requires_final'], expected=RequiresType()) + compare(context._requires_cache[foo], expected=RequiresType()) - def test_call_explict_mush(self): - context = Context() - def foo(): - pass - context.call(foo, mush=True) - compare(foo.__mush__['requires_final'], expected=RequiresType()) - - def test_call_explict_mush_plus_explicit_requires(self): + def test_call_explict_explicit_requires_no_cache(self): context = Context() context.add('a') def foo(*args): return args - result = context.call(foo, requires(str), mush=True) + result = context.call(foo, requires(str)) compare(result, ('a',)) - assert not hasattr(foo, '__mush__') - - def test_call_no_mush(self): - context = Context() - def foo(): - pass - context.call(foo, mush=False) - assert not hasattr(foo, '__mush__') + compare(context._requires_cache, expected={}) def test_extract_minimal(self): o = TheType() @@ -305,20 +291,22 @@ def foo() -> TheType: result = context.extract(foo) assert result is o compare({TheType: ResolvableValue(o)}, actual=context._store) - compare(foo.__mush__['returns_final'], expected=returns(TheType)) + compare(context._requires_cache[foo], expected=RequiresType()) + compare(context._returns_cache[foo], expected=returns(TheType)) def test_extract_maximal(self): def foo(*args): return args context = Context() context.add('a') - result = context.extract(foo, requires(str), returns(Tuple[str]), mush=False) + result = context.extract(foo, requires(str), returns(Tuple[str])) compare(result, expected=('a',)) compare({ str: ResolvableValue('a'), Tuple[str]: ResolvableValue(('a',)), }, actual=context._store) - assert not hasattr(foo, '__mush__') + compare(context._requires_cache, expected={}) + compare(context._returns_cache, expected={}) def test_returns_single(self): def foo(): @@ -429,6 +417,12 @@ class Requirement2(Requirement): pass c2 = c1.nest(default_requirement_type=Requirement2) assert c2.default_requirement_type is Requirement2 + def test_nest_keeps_declarations_cache(self): + c1 = Context() + c2 = c1.nest() + assert c2._requires_cache is c1._requires_cache + assert c2._returns_cache is c1._returns_cache + def test_custom_requirement(self): class FromRequest(Requirement): From 4ed16855f2117cd3efbbd3086f22e126733c4d5b Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sat, 29 Feb 2020 08:13:07 +0000 Subject: [PATCH 059/159] Split asyncio stuff into its own module, but with identical names. This is since we're going to need mirrors of more components... --- mush/__init__.py | 5 ++-- mush/asyncio.py | 8 ++--- mush/tests/test_async_context.py | 51 ++++++++++++++++---------------- 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index 69cf5c6..0b2d15f 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -7,11 +7,11 @@ from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing from .plug import Plug +from .resolvers import Call from .context import Context, ContextError -from .asyncio import AsyncContext __all__ = [ - 'Context', 'AsyncContext', 'ContextError', + 'Context', 'ContextError', 'Runner', 'requires', 'returns_result_type', 'returns_mapping', 'returns_sequence', 'returns', @@ -19,4 +19,5 @@ 'Plug', 'nothing', 'update_wrapper', 'missing', + 'Call' ] diff --git a/mush/asyncio.py b/mush/asyncio.py index fedee1d..d11212c 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -2,7 +2,7 @@ from functools import partial from typing import Type, Callable -from mush import Context +from mush import Context as SyncContext from mush.declarations import ResourceKey, Requirement, RequiresType, ReturnsType @@ -15,7 +15,7 @@ async def ensure_async(func, *args, **kw): return await loop.run_in_executor(None, func, *args) -class SyncContext: +class SyncFromAsyncContext: def __init__(self, context, loop): self.context = context @@ -27,11 +27,11 @@ def get(self, key: ResourceKey, default=None): return future.result() -class AsyncContext(Context): +class Context(SyncContext): def __init__(self, default_requirement_type: Type[Requirement] = Requirement): super().__init__(default_requirement_type) - self._sync_context = SyncContext(self, asyncio.get_event_loop()) + self._sync_context = SyncFromAsyncContext(self, asyncio.get_event_loop()) def _context_for(self, obj): return self if asyncio.iscoroutinefunction(obj) else self._sync_context diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 4c8e119..19ed486 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -3,9 +3,10 @@ import pytest -from mush import AsyncContext, Context, requires, returns +from mush import Context, requires, returns +from mush.asyncio import Context from mush.context import ResolvableValue -from mush.declarations import Requirement, RequiresType, returns_result_type +from mush.declarations import Requirement, RequiresType from testfixtures import compare from mush.tests.test_context import TheType @@ -13,7 +14,7 @@ @pytest.mark.asyncio async def test_get_is_async(): - context = AsyncContext() + context = Context() result = context.get('foo', default='bar') assert asyncio.iscoroutine(result) compare(await result, expected='bar') @@ -23,7 +24,7 @@ async def test_get_is_async(): async def test_get_async_resolver(): async def resolver(*args): return 'bar' - context = AsyncContext() + context = Context() context.add(provides='foo', resolver=resolver) compare(await context.get('foo'), expected='bar') @@ -32,7 +33,7 @@ async def resolver(*args): async def test_get_async_resolver_calls_back_into_async(): async def resolver(context, default): return await context.get('baz') - context = AsyncContext() + context = Context() context.add('bar', provides='baz') context.add(provides='foo', resolver=resolver) compare(await context.get('foo'), expected='bar') @@ -42,7 +43,7 @@ async def resolver(context, default): async def test_get_sync_resolver(): def resolver(*args): return 'bar' - context = AsyncContext() + context = Context() context.add(provides='foo', resolver=resolver) compare(await context.get('foo'), expected='bar') @@ -51,7 +52,7 @@ def resolver(*args): async def test_get_sync_resolver_calls_back_into_async(): def resolver(context, default): return context.get('baz') - context = AsyncContext() + context = Context() context.add('bar', provides='baz') context.add(provides='foo', resolver=resolver) compare(await context.get('foo'), expected='bar') @@ -59,7 +60,7 @@ def resolver(context, default): @pytest.mark.asyncio async def test_call_is_async(): - context = AsyncContext() + context = Context() def it(): return 'bar' result = context.call(it) @@ -69,7 +70,7 @@ def it(): @pytest.mark.asyncio async def test_call_async(): - context = AsyncContext() + context = Context() context.add('1', provides='a') async def it(a, b='2'): return a+b @@ -78,7 +79,7 @@ async def it(a, b='2'): @pytest.mark.asyncio async def test_call_async_requires_context(): - context = AsyncContext() + context = Context() context.add('bar', provides='baz') async def it(context: Context): return await context.get('baz') @@ -87,16 +88,16 @@ async def it(context: Context): @pytest.mark.asyncio async def test_call_async_requires_async_context(): - context = AsyncContext() + context = Context() context.add('bar', provides='baz') - async def it(context: AsyncContext): + async def it(context: Context): return await context.get('baz') compare(await context.call(it), expected='bar') @pytest.mark.asyncio async def test_call_sync(): - context = AsyncContext() + context = Context() context.add('foo', provides='baz') def it(*, baz): return baz+'bar' @@ -105,7 +106,7 @@ def it(*, baz): @pytest.mark.asyncio async def test_call_sync_requires_context(): - context = AsyncContext() + context = Context() context.add('bar', provides='baz') def it(context: Context): return context.get('baz') @@ -114,16 +115,16 @@ def it(context: Context): @pytest.mark.asyncio async def test_call_sync_requires_async_context(): - context = AsyncContext() + context = Context() context.add('bar', provides='baz') - def it(context: AsyncContext): + def it(context: Context): return context.get('baz') compare(await context.call(it), expected='bar') @pytest.mark.asyncio async def test_call_cache_requires(): - context = AsyncContext() + context = Context() def foo(): pass await context.call(foo) compare(context._requires_cache[foo], expected=RequiresType()) @@ -131,7 +132,7 @@ def foo(): pass @pytest.mark.asyncio async def test_extract_is_async(): - context = AsyncContext() + context = Context() def it(): return 'bar' result = context.extract(it, requires(), returns('baz')) @@ -142,7 +143,7 @@ def it(): @pytest.mark.asyncio async def test_extract_async(): - context = AsyncContext() + context = Context() context.add('foo', provides='bar') async def it(context): return await context.get('bar')+'bar' @@ -153,7 +154,7 @@ async def it(context): @pytest.mark.asyncio async def test_extract_sync(): - context = AsyncContext() + context = Context() context.add('foo', provides='bar') def it(context): return context.get('bar')+'bar' @@ -167,7 +168,7 @@ async def test_extract_minimal(): o = TheType() def foo() -> TheType: return o - context = AsyncContext() + context = Context() result = await context.extract(foo) assert result is o compare({TheType: ResolvableValue(o)}, actual=context._store) @@ -179,7 +180,7 @@ def foo() -> TheType: async def test_extract_maximal(): def foo(*args): return args - context = AsyncContext() + context = Context() context.add('a') result = await context.extract(foo, requires(str), returns(Tuple[str])) compare(result, expected=('a',)) @@ -201,7 +202,7 @@ async def resolve(self, context): def foo(bar: FromRequest('bar')): return bar - context = AsyncContext() + context = Context() context.add({'bar': 'foo'}, provides='request') compare(await context.call(foo), expected='foo') @@ -216,7 +217,7 @@ def resolve(self, context): def foo(bar: FromRequest('bar')): return bar - context = AsyncContext() + context = Context() context.add({'bar': 'foo'}, provides='request') compare(await context.call(foo), expected='foo') @@ -232,6 +233,6 @@ async def resolve(self, context): def foo(bar): return bar - context = AsyncContext(FromRequest) + context = Context(FromRequest) context.add({'bar': 'foo'}, provides='request') compare(await context.call(foo), expected='foo') From c001170dba19345f466e8b3d54541e4cd2ecd188 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sat, 29 Feb 2020 09:00:17 +0000 Subject: [PATCH 060/159] Round out the API available to synchronous callable from an asynchronous context. --- mush/asyncio.py | 12 +++++++ mush/tests/test_async_context.py | 62 ++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index d11212c..88cb87c 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -20,12 +20,24 @@ class SyncFromAsyncContext: def __init__(self, context, loop): self.context = context self.loop = loop + self.remove = context.remove + self.add = context.add def get(self, key: ResourceKey, default=None): coro = self.context.get(key, default) future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() + def call(self, obj: Callable, requires: RequiresType = None): + coro = self.context.call(obj, requires) + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + return future.result() + + def extract(self, obj: Callable, requires: RequiresType = None, returns: ReturnsType = None): + coro = self.context.extract(obj, requires, returns) + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + return future.result() + class Context(SyncContext): diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 19ed486..cf9c718 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -3,7 +3,7 @@ import pytest -from mush import Context, requires, returns +from mush import Context, Value, requires, returns from mush.asyncio import Context from mush.context import ResolvableValue from mush.declarations import Requirement, RequiresType @@ -208,7 +208,7 @@ def foo(bar: FromRequest('bar')): @pytest.mark.asyncio -async def test_custom_requirement_sync_resolve(): +async def test_custom_requirement_sync_resolve_get(): class FromRequest(Requirement): def resolve(self, context): @@ -222,6 +222,64 @@ def foo(bar: FromRequest('bar')): compare(await context.call(foo), expected='foo') +@pytest.mark.asyncio +async def test_custom_requirement_sync_resolve_call(): + + async def baz(request: dict = Value('request')): + return request['bar'] + + class Syncer(Requirement): + def resolve(self, context): + return context.call(self.key) + + def foo(bar: Syncer(baz)): + return bar + + context = Context() + context.add({'bar': 'foo'}, provides='request') + compare(await context.call(foo), expected='foo') + + +@pytest.mark.asyncio +async def test_custom_requirement_sync_resolve_extract(): + + @returns('response') + async def baz(request: dict = Value('request')): + return request['bar'] + + class Syncer(Requirement): + def resolve(self, context): + return context.extract(self.key) + + def foo(bar: Syncer(baz)): + return bar + + context = Context() + context.add({'bar': 'foo'}, provides='request') + compare(await context.call(foo), expected='foo') + compare(await context.get('response'), expected='foo') + + +@pytest.mark.asyncio +async def test_custom_requirement_sync_resolve_add_remove(): + + class Syncer(Requirement): + def resolve(self, context): + request = context.get('request') + context.remove('request') + context.add(request['bar'], provides='response') + return request['bar'] + + def foo(bar: Syncer('request')): + return bar + + context = Context() + context.add({'bar': 'foo'}, provides='request') + compare(await context.call(foo), expected='foo') + compare(await context.get('request'), expected=None) + compare(await context.get('response'), expected='foo') + + @pytest.mark.asyncio async def test_default_custom_requirement(): From 3cbe9efb10c8657ae682b967876ced827b648133 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 9 Mar 2020 07:09:15 +0000 Subject: [PATCH 061/159] Collapse Value down into Requirement. This also means that requirements always have a resolve step. Also split requirements into their own module as more are coming. --- .coveragerc | 3 + mush/__init__.py | 26 ++--- mush/asyncio.py | 19 ++-- mush/context.py | 15 +-- mush/declarations.py | 143 ++----------------------- mush/extraction.py | 22 ++-- mush/requirements.py | 137 ++++++++++++++++++++++++ mush/tests/helpers.py | 6 ++ mush/tests/test_async_context.py | 3 +- mush/tests/test_callpoints.py | 12 ++- mush/tests/test_context.py | 11 +- mush/tests/test_declarations.py | 175 ++++++++++++++++--------------- mush/tests/test_runner.py | 4 +- mush/types.py | 6 ++ 14 files changed, 307 insertions(+), 275 deletions(-) create mode 100644 mush/requirements.py create mode 100644 mush/tests/helpers.py create mode 100644 mush/types.py diff --git a/.coveragerc b/.coveragerc index 542e548..1afc40e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -10,3 +10,6 @@ exclude_lines = # stuff that we don't worry about pass __name__ == '__main__' + + # circular references needed for type checking: + if TYPE_CHECKING: diff --git a/mush/__init__.py b/mush/__init__.py index 0b2d15f..4d9b313 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -1,23 +1,25 @@ -from .runner import Runner +from .context import Context, ContextError from .declarations import ( - requires, - returns_result_type, returns_mapping, returns_sequence, returns, - Value, Requirement, nothing + requires, returns_result_type, returns_mapping, returns_sequence, returns, nothing ) from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing from .plug import Plug -from .resolvers import Call -from .context import Context, ContextError +from .requirements import Value +from .runner import Runner __all__ = [ - 'Context', 'ContextError', + 'Context', + 'ContextError', + 'Plug', 'Runner', + 'Value', + 'missing', + 'nothing', 'requires', - 'returns_result_type', 'returns_mapping', 'returns_sequence', 'returns', - 'Value', 'Requirement', - 'Plug', 'nothing', + 'returns', + 'returns_mapping', + 'returns_result_type', + 'returns_sequence', 'update_wrapper', - 'missing', - 'Call' ] diff --git a/mush/asyncio.py b/mush/asyncio.py index 88cb87c..ff57da6 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -2,8 +2,10 @@ from functools import partial from typing import Type, Callable -from mush import Context as SyncContext -from mush.declarations import ResourceKey, Requirement, RequiresType, ReturnsType +from . import Context as SyncContext, Value as SyncValue +from .declarations import RequiresType, ReturnsType +from .requirements import Requirement +from .types import ResourceKey async def ensure_async(func, *args, **kw): @@ -15,6 +17,12 @@ async def ensure_async(func, *args, **kw): return await loop.run_in_executor(None, func, *args) +class Value(SyncValue): + + async def resolve(self, context): + return await context.get(self.key, self.default) + + class SyncFromAsyncContext: def __init__(self, context, loop): @@ -41,7 +49,7 @@ def extract(self, obj: Callable, requires: RequiresType = None, returns: Returns class Context(SyncContext): - def __init__(self, default_requirement_type: Type[Requirement] = Requirement): + def __init__(self, default_requirement_type: Type[Requirement] = Value): super().__init__(default_requirement_type) self._sync_context = SyncFromAsyncContext(self, asyncio.get_event_loop()) @@ -61,10 +69,7 @@ async def call(self, obj: Callable, requires: RequiresType = None): resolving = self._resolve(obj, requires, args, kw, self._context_for(obj)) for requirement in resolving: r = requirement.resolve - if r is not None: - o = await ensure_async(r, self._context_for(r)) - else: - o = await self.get(requirement.key, requirement.default) + o = await ensure_async(r, self._context_for(r)) resolving.send(o) return await ensure_async(obj, *args, **kw) diff --git a/mush/context.py b/mush/context.py index a729463..bf4e71b 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,11 +1,10 @@ from typing import Optional, Type, Callable -from .declarations import ( - RequiresType, ResourceKey, ResourceValue, ResourceResolver, - Requirement, ReturnsType -) +from .declarations import RequiresType, ReturnsType from .extraction import extract_requires, extract_returns from .markers import missing +from .requirements import Requirement, Value +from .types import ResourceKey, ResourceValue, ResourceResolver NONE_TYPE = type(None) @@ -79,7 +78,7 @@ class Context: _parent = None - def __init__(self, default_requirement_type: Type[Requirement] = Requirement): + def __init__(self, default_requirement_type: Type[Requirement] = Value): self.default_requirement_type = default_requirement_type self._store = {} self._requires_cache = {} @@ -179,11 +178,7 @@ def call(self, obj: Callable, requires: RequiresType = None): kw = {} resolving = self._resolve(obj, requires, args, kw, self) for requirement in resolving: - if requirement.resolve: - o = requirement.resolve(self) - else: - o = self.get(requirement.key, requirement.default) - resolving.send(o) + resolving.send(requirement.resolve(self)) return obj(*args, **kw) def _get(self, key, default): diff --git a/mush/declarations.py b/mush/declarations.py index fd393a5..e04fa47 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,18 +1,8 @@ -from copy import copy from enum import Enum, auto from itertools import chain -from typing import Type, Callable, NewType, Union, Any, List, Optional, _type_check +from typing import _type_check -from .markers import missing - -ResourceKey = NewType('ResourceKey', Union[Type, str]) -ResourceValue = NewType('ResourceValue', Any) -ResourceResolver = Callable[['Context', Any], ResourceValue] -RequirementResolver = Callable[['Context'], ResourceValue] - - -def name_or_repr(obj): - return getattr(obj, '__name__', None) or repr(obj) +from .requirements import Requirement, Value, name_or_repr def set_mush(obj, key, value): @@ -21,125 +11,6 @@ def set_mush(obj, key, value): obj.__mush__[key] = value -class Requirement: - """ - The requirement for an individual parameter of a callable. - """ - - resolve: RequirementResolver = None - - def __init__(self, key, name=None, type_=None, default=missing, target=None): - #: The resource key needed for this parameter. - self.key: ResourceKey = key - #: The name of this parameter in the callable's signature. - self.name: str = name - #: The type required for this parameter. - self.type: type = type_ - #: The default for this parameter, should the required resource be unavailable. - self.default: Any = default - #: Any operations to be performed on the resource after it - #: has been obtained. - self.ops: List['ValueOp'] = [] - self.target: Optional[str] = target - - def clone(self): - """ - Create a copy of this requirement, so it can be mutated - """ - obj = copy(self) - obj.ops = list(self.ops) - return obj - - def value_repr(self): - key = name_or_repr(self.key) - if self.ops or self.default is not missing: - default = '' if self.default is missing else f', default={self.default!r}' - ops = ''.join(repr(o) for o in self.ops) - return f'Value({key}{default}){ops}' - return key - - def __repr__(self): - attrs = [] - for a in 'name', 'type_', 'target': - value = getattr(self, a.rstrip('_')) - if value is not None: - attrs.append(f", {a}={value!r}") - return f"{type(self).__name__}({self.value_repr()}{''.join(attrs)})" - - -class Value: - """ - Declaration indicating that the specified resource key is required. - - Values are generative, so they can be used to indicate attributes or - items from a resource are required. - - A default may be specified, which will be used if the specified - resource is not available. - - A type may also be explicitly specified, but you probably shouldn't - ever use this. - """ - - def __init__(self, key: ResourceKey=None, *, type_: type = None, default: Any = missing): - if isinstance(key, type): - if type_ is not None: - raise TypeError('type_ cannot be specified if key is a type') - type_ = key - self.requirement = Requirement(key, type_=type_, default=default) - - def attr(self, name): - """ - If you need to get an attribute called either ``attr`` or ``item`` - then you will need to call this method instead of using the - generating behaviour. - """ - self.requirement.ops.append(ValueAttrOp(name)) - return self - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - return self.attr(name) - - def __getitem__(self, name): - self.requirement.ops.append(ValueItemOp(name)) - return self - - def __repr__(self): - return self.requirement.value_repr() - - -class ValueOp: - - def __init__(self, name): - self.name = name - - -class ValueAttrOp(ValueOp): - - def __call__(self, o): - try: - return getattr(o, self.name) - except AttributeError: - return missing - - def __repr__(self): - return f'.{self.name}' - - -class ValueItemOp(ValueOp): - - def __call__(self, o): - try: - return o[self.name] - except KeyError: - return missing - - def __repr__(self): - return f'[{self.name!r}]' - - class RequiresType(list): def __repr__(self): @@ -170,15 +41,15 @@ def requires(*args, **kw): ((None, arg) for arg in args), kw.items(), ): - if isinstance(possible, Value): - possible = possible.requirement if isinstance(possible, Requirement): possible = possible.clone() possible.target = target requirement = possible else: - type_ = None if isinstance(possible, str) else possible - requirement = Requirement(possible, name=target, type_=type_, target=target) + requirement = Value(possible) + requirement.type = None if isinstance(possible, str) else possible + requirement.name = target + requirement.target = target requires_.append(requirement) return requires_ @@ -286,7 +157,7 @@ class DeclarationsFrom(Enum): replacement = DeclarationsFrom.replacement -VALID_DECORATION_TYPES = (type, str, Value, Requirement) +VALID_DECORATION_TYPES = (type, str, Requirement) def valid_decoration_types(*objs): diff --git a/mush/extraction.py b/mush/extraction.py index 998f5ee..8930158 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -8,11 +8,11 @@ from typing import Callable, Type from .declarations import ( - Value, - requires, Requirement, RequiresType, ReturnsType, + requires, RequiresType, ReturnsType, returns, result_type, nothing ) +from .requirements import Requirement, Value from .markers import missing EMPTY = Parameter.empty @@ -42,8 +42,8 @@ def _apply_requires(by_name, by_index, requires_): def extract_requires(obj: Callable, - explicit: RequiresType=None, - default_requirement_type: Type[Requirement] = Requirement): + explicit: RequiresType = None, + default_requirement_type: Type[Requirement] = Value): # from annotations is_partial = isinstance(obj, partial) by_name = {} @@ -68,16 +68,11 @@ def extract_requires(obj: Callable, if isinstance(default, Requirement): requirement = default default = missing - elif isinstance(default, Value): - requirement = default.requirement - default = missing elif isinstance(p.annotation, Requirement): requirement = p.annotation - elif isinstance(p.annotation, Value): - requirement = p.annotation.requirement if requirement is None: - requirement = default_requirement_type(key) + requirement = Requirement(key) if isinstance(p.annotation, str): key = p.annotation elif type_ is None or issubclass(type_, SIMPLE_TYPES): @@ -127,7 +122,12 @@ def extract_requires(obj: Callable, return nothing needs_target = False - for requirement in by_name.values(): + for name, requirement in by_name.items(): + if requirement.__class__ is Requirement: + requirement_ = default_requirement_type() + requirement_.__dict__.update(requirement.__dict__) + requirement = requirement_ + by_name[name] = requirement if requirement.target is not None: needs_target = True elif needs_target: diff --git a/mush/requirements.py b/mush/requirements.py new file mode 100644 index 0000000..88d8de6 --- /dev/null +++ b/mush/requirements.py @@ -0,0 +1,137 @@ +from copy import copy +from typing import Any, Optional, List, TYPE_CHECKING + +from .types import ResourceKey +from .markers import missing + +if TYPE_CHECKING: + from .context import Context + + +def name_or_repr(obj): + return getattr(obj, '__name__', None) or repr(obj) + + +class Op: + + def __init__(self, name): + self.name = name + + +class AttrOp(Op): + + def __call__(self, o): + try: + return getattr(o, self.name) + except AttributeError: + return missing + + def __repr__(self): + return f'.{self.name}' + + +class ItemOp(Op): + + def __call__(self, o): + try: + return o[self.name] + except KeyError: + return missing + + def __repr__(self): + return f'[{self.name!r}]' + + +class Requirement: + """ + The requirement for an individual parameter of a callable. + """ + + def __init__(self, + key: ResourceKey = None, + name: str = None, + type_: type = None, + default: Any = missing, + target: str =None): + #: The resource key needed for this parameter. + self.key: Optional[ResourceKey] = key + #: The name of this parameter in the callable's signature. + self.name: Optional[str] = name + #: The type required for this parameter. + self.type: Optional[type] = type_ + #: The default for this parameter, should the required resource be unavailable. + self.default: Optional[Any] = default + #: Any operations to be performed on the resource after it + #: has been obtained. + self.ops: List['Op'] = [] + self.target: Optional[str] = target + + def resolve(self, context: 'Context'): + raise NotImplementedError() + + def clone(self): + """ + Create a copy of this requirement, so it can be mutated + """ + obj = copy(self) + obj.ops = list(self.ops) + return obj + + def value_repr(self, params='', *, from_repr=False): + key = name_or_repr(self.key) + if self.ops or self.default is not missing or from_repr: + default = '' if self.default is missing else f', default={self.default!r}' + ops = ''.join(repr(o) for o in self.ops) + return f"{type(self).__name__}({key}{default}{params}){ops}" + return key + + def __repr__(self): + attrs = [] + for a in 'name', 'type_', 'target': + value = getattr(self, a.rstrip('_')) + if value is not None and value != self.key: + attrs.append(f", {a}={value!r}") + return self.value_repr(''.join(attrs), from_repr=True) + + def attr(self, name): + """ + If you need to get an attribute called either ``attr`` or ``item`` + then you will need to call this method instead of using the + generating behaviour. + """ + self.ops.append(AttrOp(name)) + return self + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + return self.attr(name) + + def __getitem__(self, name): + self.ops.append(ItemOp(name)) + return self + + +class Value(Requirement): + """ + Declaration indicating that the specified resource key is required. + + Values are generative, so they can be used to indicate attributes or + items from a resource are required. + + A default may be specified, which will be used if the specified + resource is not available. + + A type may also be explicitly specified, but you probably shouldn't + ever use this. + """ + + def __init__(self, key: ResourceKey=None, *, type_: type = None, default: Any = missing): + if isinstance(key, type): + if type_ is not None: + raise TypeError('type_ cannot be specified if key is a type') + type_ = key + super().__init__(key, type_=type_, default=default) + + def resolve(self, context): + return context.get(self.key, self.default) diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py new file mode 100644 index 0000000..d0c0921 --- /dev/null +++ b/mush/tests/helpers.py @@ -0,0 +1,6 @@ +def r(base, **attrs): + """ + helper for returning Requirement subclasses with extra attributes + """ + base.__dict__.update(attrs) + return base diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index cf9c718..45106d8 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -6,7 +6,8 @@ from mush import Context, Value, requires, returns from mush.asyncio import Context from mush.context import ResolvableValue -from mush.declarations import Requirement, RequiresType +from mush.declarations import RequiresType +from mush.requirements import Requirement from testfixtures import compare from mush.tests.test_context import TheType diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index fdd4a91..aab7ba3 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -5,8 +5,10 @@ from testfixtures import compare from mush.callpoints import CallPoint -from mush.declarations import requires, returns, RequiresType, Requirement +from mush.declarations import requires, returns, RequiresType +from .. import Value from mush.extraction import update_wrapper +from .helpers import r class TestCallPoints(TestCase): @@ -29,7 +31,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType((Requirement('foo', name='a1'),)), + RequiresType([r(Value('foo'), name='a1')]), rt)) def test_extract_from_decorations(self): @@ -44,7 +46,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType((Requirement('foo', name='a1'),)), + RequiresType([r(Value('foo'), name='a1')]), returns('bar'))) def test_extract_from_decorated_class(self): @@ -70,7 +72,7 @@ def foo(prefix): self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) result = CallPoint(foo)(self.context) compare(result, expected=('the answer', - RequiresType((Requirement('foo', name='prefix'),)), + RequiresType([r(Value('foo'), name='prefix')]), rt)) def test_explicit_trumps_decorators(self): @@ -82,7 +84,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType((Requirement('baz', name='a1'),)), + RequiresType([r(Value('baz'), name='a1')]), returns('bob'))) def test_repr_minimal(self): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index cfa06a3..e92349a 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -5,9 +5,12 @@ from mush.context import ResolvableValue from testfixtures import ShouldRaise, compare -from mush import Context, ContextError, requires, returns, nothing, returns_mapping -from mush.declarations import Requirement, Value, missing, RequiresType - +from mush import ( + Context, ContextError, requires, returns, nothing, returns_mapping, + Value, missing +) +from mush.declarations import RequiresType +from mush.requirements import Requirement class TheType(object): def __repr__(self): @@ -199,7 +202,7 @@ def foo(x=1): return x context = Context() context.add(2, provides='x') - result = context.call(foo, requires(x=Requirement('y', default=3))) + result = context.call(foo, requires(x=Value('y', default=3))) compare(result, expected=3) def test_call_requires_optional_string(self): diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index b2e9509..2f1c6fd 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -6,19 +6,17 @@ from mock import Mock from testfixtures import compare, ShouldRaise -from mush import Context +from mush import Context, Value from mush.declarations import ( requires, returns, returns_mapping, returns_sequence, returns_result_type, nothing, - result_type, Requirement, - Value, - ValueAttrOp, - RequiresType, - ValueItemOp + result_type, RequiresType ) from mush.extraction import extract_requires, extract_returns, update_wrapper from mush.markers import missing +from mush.requirements import Requirement, AttrOp, ItemOp +from .helpers import r def check_extract(obj, expected_rq, expected_rt): @@ -42,29 +40,29 @@ def test_empty(self): compare(r, expected=[]) def test_types(self): - r = requires(Type1, Type2, x=Type3, y=Type4) - compare(repr(r), 'requires(Type1, Type2, x=Type3, y=Type4)') - compare(r, expected=[ - Requirement(Type1, type_=Type1), - Requirement(Type2, type_=Type2), - Requirement(Type3, name='x', type_=Type3, target='x'), - Requirement(Type4, name='y', type_=Type4, target='y'), + r_ = requires(Type1, Type2, x=Type3, y=Type4) + compare(repr(r_), 'requires(Type1, Type2, x=Type3, y=Type4)') + compare(r_, expected=[ + Value(Type1), + Value(Type2), + r(Value(Type3), name='x', target='x'), + r(Value(Type4), name='y', target='y'), ]) def test_strings(self): - r = requires('1', '2', x='3', y='4') - compare(repr(r), "requires('1', '2', x='3', y='4')") - compare(r, expected=[ - Requirement('1'), - Requirement('2'), - Requirement('3', name='x', target='x'), - Requirement('4', name='y', target='y'), + r_ = requires('1', '2', x='3', y='4') + compare(repr(r_), "requires('1', '2', x='3', y='4')") + compare(r_, expected=[ + Value('1'), + Value('2'), + r(Value('3'), name='x', target='x'), + r(Value('4'), name='y', target='y'), ]) def test_typing(self): - r = requires(Tuple[str]) - compare(repr(r), "requires(typing.Tuple[str])") - compare(r, expected=[Requirement(Tuple[str], type_=Tuple[str])]) + r_ = requires(Tuple[str]) + compare(repr(r_), "requires(typing.Tuple[str])") + compare(r_, expected=[r(Value(Tuple[str]), type=Tuple[str])]) def test_tuple_arg(self): with ShouldRaise(TypeError("('1', '2') is not a valid decoration type")): @@ -79,10 +77,16 @@ def test_decorator_paranoid(self): def foo(): return 'bar' - compare(foo.__mush__['requires'], expected=[Requirement(Type1, type_=Type1)]) + compare(foo.__mush__['requires'], expected=[Value(Type1)]) compare(foo(), 'bar') +def check_ops(value, data, *, expected): + for op in value.ops: + data = op(data) + compare(expected, actual=data) + + class TestRequirement: def test_repr_minimal(self): @@ -91,10 +95,10 @@ def test_repr_minimal(self): def test_repr_maximal(self): r = Requirement('foo', name='n', type_='ty', default=None, target='ta') - r.ops.append(ValueAttrOp('bar')) + r.ops.append(AttrOp('bar')) compare(repr(r), - expected="Requirement(Value('foo', default=None).bar, " - "name='n', type_='ty', target='ta')") + expected="Requirement('foo', default=None, " + "name='n', type_='ty', target='ta').bar") def test_clone(self): r = Value('foo').bar.requirement @@ -103,32 +107,29 @@ def test_clone(self): assert r_.ops is not r.ops compare(r_, expected=r) + special_names = ['attr', 'ops', 'target'] -def check_ops(value, data, *, expected): - for op in value.requirement.ops: - data = op(data) - compare(expected, actual=data) - - -class TestValue: - - @pytest.mark.parametrize("name", ['attr', 'requirement']) + @pytest.mark.parametrize("name", special_names) def test_attr_special_name(self, name): - v = Value('foo') + v = Requirement('foo') + assert getattr(v, name) is not self assert v.attr(name) is v - compare(v.requirement.ops, [ValueAttrOp(name)]) + compare(v.ops, expected=[AttrOp(name)]) - @pytest.mark.parametrize("name", ['attr', 'requirement']) + @pytest.mark.parametrize("name", special_names) def test_item_special_name(self, name): - v = Value('foo') + v = Requirement('foo') assert v[name] is v - compare(v.requirement.ops, [ValueItemOp(name)]) + compare(v.ops, expected=[ItemOp(name)]) def test_no_special_name_via_getattr(self): - v = Value('foo') + v = Requirement('foo') with ShouldRaise(AttributeError): assert v.__len__ - compare(v.requirement.ops, []) + compare(v.ops, []) + + +class TestValue: def test_type_from_key(self): v = Value(str) @@ -143,12 +144,12 @@ class TestItem: def test_single(self): h = Value(Type1)['foo'] - compare(repr(h), "Value(Type1)['foo']") + compare(repr(h), expected="Value(Type1)['foo']") check_ops(h, {'foo': 1}, expected=1) def test_multiple(self): h = Value(Type1)['foo']['bar'] - compare(repr(h), "Value(Type1)['foo']['bar']") + compare(repr(h), expected="Value(Type1)['foo']['bar']") check_ops(h, {'foo': {'bar': 1}}, expected=1) def test_missing_obj(self): @@ -292,8 +293,8 @@ def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, expected_rq=RequiresType(( - Requirement('a', name='a'), - Requirement('b', name='b', default=None) + r(Value('a'), name='a'), + r(Value('b'), default=None, name='b'), )), expected_rt=result_type) @@ -302,8 +303,8 @@ class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, expected_rq=RequiresType(( - Requirement('a', name='a'), - Requirement('b', name='b', default=None) + r(Value('a'), name='a'), + r(Value('b'), name='b', default=None), )), expected_rt=result_type) @@ -313,8 +314,8 @@ def foo(x, y, z, a=None): pass check_extract( p, expected_rq=RequiresType(( - Requirement('z', name='z', target='z'), - Requirement('a', name='a', target='a', default=None) + r(Value('z'), name='z', target='z'), + r(Value('a'), name='a', target='a', default=None), )), expected_rt=result_type ) @@ -325,7 +326,7 @@ def foo(a=None): pass check_extract( p, expected_rq=RequiresType(( - Requirement('a', name='a', default=None), + r(Value('a'), name='a', default=None), )), expected_rt=result_type ) @@ -374,8 +375,8 @@ def foo(b, a=None): pass check_extract( p, expected_rq=RequiresType(( - Requirement('b', name='b'), - Requirement('a', name='a', default=None) + r(Value('b'), name='b'), + r(Value('a'), name='a', default=None), )), expected_rt=result_type ) @@ -387,7 +388,7 @@ def foo(b, a): pass p, # since b is already bound: expected_rq=RequiresType(( - Requirement('a', name='a'), + r(Value('a'), name='a'), )), expected_rt=result_type ) @@ -398,7 +399,7 @@ def foo(b, a): pass check_extract( p, expected_rq=RequiresType(( - Requirement('b', name='b'), + r(Value('b'), name='b'), )), expected_rt=result_type ) @@ -410,17 +411,17 @@ def test_extract_from_annotations(self): def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass check_extract(foo, expected_rq=RequiresType(( - Requirement('foo', name='a'), - Requirement('b', name='b'), - Requirement('bar', name='c', default=1), - Requirement('d', name='d', default=2) + r(Value('foo'), name='a'), + r(Value('b'), name='b'), + r(Value('bar'), name='c', default=1), + r(Value('d'), name='d', default=2) )), expected_rt=returns('bar')) def test_requires_only(self): def foo(a: 'foo'): pass check_extract(foo, - expected_rq=RequiresType((Requirement('foo', name='a'),)), + expected_rq=RequiresType((r(Value('foo'), name='a'),)), expected_rt=result_type) def test_returns_only(self): @@ -446,7 +447,7 @@ def foo(a: 'foo' = None) -> 'bar': compare(foo(), expected='the answer') check_extract(foo, - expected_rq=RequiresType((Requirement('foo', name='a', default=None),)), + expected_rq=RequiresType((r(Value('foo'), name='a', default=None),)), expected_rt=returns('bar')) def test_decorator_trumps_annotations(self): @@ -454,7 +455,7 @@ def test_decorator_trumps_annotations(self): @returns('bar') def foo(a: 'x') -> 'y': pass check_extract(foo, - expected_rq=RequiresType((Requirement('foo', name='a'),)), + expected_rq=RequiresType((r(Value('foo'), name='a'),)), expected_rt=returns('bar')) def test_returns_mapping(self): @@ -473,20 +474,20 @@ def foo() -> rt: pass def test_how_instance_in_annotations(self): def foo(a: Value('config')['db_url']): pass - requirement = Requirement('config', name='a') - requirement.ops.append(ValueItemOp('db_url')) check_extract(foo, - expected_rq=RequiresType((requirement,)), + expected_rq=RequiresType(( + r(Value('config'), name='a', ops=[ItemOp('db_url')]), + )), expected_rt=result_type) def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, expected_rq=RequiresType(( - Requirement('a', name='a'), - Requirement('b', name='b', default=1), - Requirement('c', name='c', target='c'), - Requirement('d', name='d', target='d', default=None) + r(Value('a'), name='a'), + r(Value('b'), name='b', default=1), + r(Value('c'), name='c', target='c'), + r(Value('d'), name='d', target='d', default=None) )), expected_rt=result_type) @@ -494,27 +495,27 @@ def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=RequiresType((Requirement(T, name='a', type_=T),)), + expected_rq=RequiresType((r(Value(T), name='a', type=T),)), expected_rt=result_type) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=RequiresType((Requirement('a', name='a', type_=type_),)), + expected_rq=RequiresType((r(Value('a'), name='a', type=type_),)), expected_rt=result_type) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=RequiresType((Requirement('b', name='a', type_=str),)), + expected_rq=RequiresType((r(Value('b'), name='a', type=str),)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, expected_rq=RequiresType(( - Requirement('b', name='a', type_=str, default=1), + r(Value('b'), name='a', type=str, default=1), )), expected_rt=result_type) @@ -522,7 +523,7 @@ def test_value_annotation_plus_default(self): def foo(a: Value('b', type_=str) = 1): pass check_extract(foo, expected_rq=RequiresType(( - Requirement('b', name='a', type_=str, default=1), + r(Value('b'), name='a', type=str, default=1), )), expected_rt=result_type) @@ -530,7 +531,7 @@ def test_value_annotation_just_type_in_value_key_plus_default(self): def foo(a: Value(str) = 1): pass check_extract(foo, expected_rq=RequiresType(( - Requirement(key=str, name='a', type_=str, default=1), + r(Value(str), name='a', type=str, default=1), )), expected_rt=result_type) @@ -538,7 +539,7 @@ def test_value_annotation_just_type_plus_default(self): def foo(a: Value(type_=str) = 1): pass check_extract(foo, expected_rq=RequiresType(( - Requirement(key='a', name='a', type_=str, default=1), + r(Value(key='a'), name='a', type=str, default=1), )), expected_rt=result_type) @@ -546,19 +547,19 @@ def test_value_unspecified_with_type(self): class T1: pass def foo(a: T1 = Value()): pass check_extract(foo, - expected_rq=RequiresType((Requirement(key=T1, name='a', type_=T1),)), + expected_rq=RequiresType((r(Value(key=T1), name='a', type=T1),)), expected_rt=result_type) def test_value_unspecified_with_simple_type(self): def foo(a: str = Value()): pass check_extract(foo, - expected_rq=RequiresType((Requirement(key='a', name='a', type_=str),)), + expected_rq=RequiresType((r(Value(key='a'), name='a', type=str),)), expected_rt=result_type) def test_value_unspecified(self): def foo(a = Value()): pass check_extract(foo, - expected_rq=RequiresType((Requirement(key='a', name='a'),)), + expected_rq=RequiresType((r(Value(key='a'), name='a'),)), expected_rt=result_type) def test_default_requirement_type(self): @@ -578,7 +579,7 @@ class FromRequest(Requirement): pass rq = extract_requires(foo, default_requirement_type=FromRequest) compare(rq, strict=True, expected=RequiresType(( - Requirement(key='x', name='x', type_=str, default=None), + r(Value(key='x'), name='x', type=str, default=None), ))) @@ -595,9 +596,9 @@ def foo(a: r1, b, c=r3): check_extract(foo, expected_rq=RequiresType(( - Requirement('a', name='a'), - Requirement('b', name='b', target='b'), - Requirement('c', name='c', target='c'), + r(Value('a'), name='a'), + r(Value('b'), name='b', target='b'), + r(Value('c'), name='c', target='c'), )), expected_rt=result_type) @@ -612,8 +613,8 @@ def foo(a: r2 = r3, b: str = r2, c = r3): check_extract(foo, expected_rq=RequiresType(( - Requirement('a', name='a', target='a'), - Requirement('b', name='b', target='b', type_=str), - Requirement('c', name='c', target='c'), + r(Value('a'), name='a', target='a'), + r(Value('b'), name='b', target='b', type=str), + r(Value('c'), name='c', target='c'), )), expected_rt=result_type) diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 643b6ac..207f01f 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -4,8 +4,8 @@ from mush.context import ContextError from mush.declarations import ( requires, returns, returns_mapping, - replacement, original, - Value) + replacement, original) +from mush import Value from mush.runner import Runner from testfixtures import ( ShouldRaise, diff --git a/mush/types.py b/mush/types.py new file mode 100644 index 0000000..c1cd756 --- /dev/null +++ b/mush/types.py @@ -0,0 +1,6 @@ +from typing import NewType, Union, Hashable, Callable, Any + +ResourceKey = NewType('ResourceKey', Union[Hashable, Callable]) +ResourceValue = NewType('ResourceValue', Any) +ResourceResolver = Callable[['Context', Any], ResourceValue] +RequirementResolver = Callable[['Context'], ResourceValue] From e586723215ab6f058cbc6f3d9fa1da2273c9cb59 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 9 Mar 2020 09:46:02 +0000 Subject: [PATCH 062/159] Replace default requirement type with a requirement modifier callable. This allows the extraction phase of requirements to be customised. It also allows lazy to become a runner-specific thing. --- mush/asyncio.py | 10 ++++++++-- mush/callpoints.py | 21 ++++++++++++-------- mush/context.py | 18 ++++++++--------- mush/extraction.py | 15 +++++++++----- mush/modifier.py | 2 +- mush/requirements.py | 13 ++++++++++++ mush/resolvers.py | 23 --------------------- mush/runner.py | 16 ++++++++++++--- mush/tests/test_async_context.py | 8 ++++++-- mush/tests/test_callpoints.py | 34 +++++++++++++++++--------------- mush/tests/test_context.py | 25 +++++++++++++---------- mush/tests/test_declarations.py | 19 +++++++----------- mush/tests/test_resolver.py | 9 --------- mush/types.py | 7 ++++++- 14 files changed, 119 insertions(+), 101 deletions(-) delete mode 100644 mush/resolvers.py diff --git a/mush/asyncio.py b/mush/asyncio.py index ff57da6..b5db9e9 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -47,10 +47,16 @@ def extract(self, obj: Callable, requires: RequiresType = None, returns: Returns return future.result() +def default_requirement_type(requirement): + if requirement.__class__ is Requirement: + requirement.__class__ = Value + return requirement + + class Context(SyncContext): - def __init__(self, default_requirement_type: Type[Requirement] = Value): - super().__init__(default_requirement_type) + def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): + super().__init__(requirement_modifier) self._sync_context = SyncFromAsyncContext(self, asyncio.get_event_loop()) def _context_for(self, obj): diff --git a/mush/callpoints.py b/mush/callpoints.py index 6a211bf..e0f9d05 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,9 +1,12 @@ -from .context import Context from .declarations import ( - nothing, requires as requires_function + nothing, returns as returns_declaration + ) from .extraction import extract_requires, extract_returns -from .resolvers import Lazy + + +def do_nothing(): + pass class CallPoint(object): @@ -11,13 +14,15 @@ class CallPoint(object): next = None previous = None - def __init__(self, obj, requires=None, returns=None, lazy=False): - requires = extract_requires(obj, requires) + def __init__(self, runner, obj, requires=None, returns=None, lazy=False): + requires = extract_requires(obj, requires, runner.modify_requirement) returns = extract_returns(obj, returns) if lazy: - obj = Lazy(obj, requires, returns) - requires = requires_function(Context) - returns = nothing + if not (type(returns) is returns_declaration and len(returns.args) == 1): + raise TypeError('a single return type must be explicitly specified') + runner.lazy[returns.args[0]] = obj, requires + obj = do_nothing + requires = returns = nothing self.obj = obj self.requires = requires self.returns = returns diff --git a/mush/context.py b/mush/context.py index bf4e71b..b400747 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,10 +1,10 @@ from typing import Optional, Type, Callable from .declarations import RequiresType, ReturnsType -from .extraction import extract_requires, extract_returns +from .extraction import extract_requires, extract_returns, default_requirement_type from .markers import missing from .requirements import Requirement, Value -from .types import ResourceKey, ResourceValue, ResourceResolver +from .types import ResourceKey, ResourceValue, ResourceResolver, RequirementModifier NONE_TYPE = type(None) @@ -78,8 +78,8 @@ class Context: _parent = None - def __init__(self, default_requirement_type: Type[Requirement] = Value): - self.default_requirement_type = default_requirement_type + def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): + self._requirement_modifier = requirement_modifier self._store = {} self._requires_cache = {} self._returns_cache = {} @@ -146,7 +146,7 @@ def _resolve(self, obj, requires, args, kw, context): if requires is None: requires = extract_requires(obj, explicit=None, - default_requirement_type=self.default_requirement_type) + modifier=self._requirement_modifier) self._requires_cache[obj] = requires for requirement in requires: @@ -203,10 +203,10 @@ def get(self, key: ResourceKey, default=None): return resolvable.resolver(self, default) return resolvable.value - def nest(self, default_requirement_type: Type[Requirement] = None): - if default_requirement_type is None: - default_requirement_type = self.default_requirement_type - nested = self.__class__(default_requirement_type) + def nest(self, requirement_modifier: RequirementModifier = None): + if requirement_modifier is None: + requirement_modifier = self._requirement_modifier + nested = self.__class__(requirement_modifier) nested._parent = self nested._requires_cache = self._requires_cache nested._returns_cache = self._returns_cache diff --git a/mush/extraction.py b/mush/extraction.py index 8930158..d518812 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -14,6 +14,7 @@ ) from .requirements import Requirement, Value from .markers import missing +from .types import RequirementModifier EMPTY = Parameter.empty #: For these types, prefer the name instead of the type. @@ -41,9 +42,15 @@ def _apply_requires(by_name, by_index, requires_): existing.target = existing.target if r.target is None else r.target +def default_requirement_type(requirement): + if requirement.__class__ is Requirement: + requirement.__class__ = Value + return requirement + + def extract_requires(obj: Callable, explicit: RequiresType = None, - default_requirement_type: Type[Requirement] = Value): + modifier: RequirementModifier = default_requirement_type): # from annotations is_partial = isinstance(obj, partial) by_name = {} @@ -123,10 +130,8 @@ def extract_requires(obj: Callable, needs_target = False for name, requirement in by_name.items(): - if requirement.__class__ is Requirement: - requirement_ = default_requirement_type() - requirement_.__dict__.update(requirement.__dict__) - requirement = requirement_ + requirement_ = modifier(requirement) + if requirement_ is not requirement: by_name[name] = requirement if requirement.target is not None: needs_target = True diff --git a/mush/modifier.py b/mush/modifier.py index 4f0188a..691f2fa 100644 --- a/mush/modifier.py +++ b/mush/modifier.py @@ -48,7 +48,7 @@ def add(self, obj, requires=None, returns=None, label=None, lazy=False): raise ValueError('%r already points to %r' % ( label, self.runner.labels[label] )) - callpoint = CallPoint(obj, requires, returns, lazy) + callpoint = CallPoint(self.runner, obj, requires, returns, lazy) if label: self.add_label(label, callpoint) diff --git a/mush/requirements.py b/mush/requirements.py index 88d8de6..0940bec 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -135,3 +135,16 @@ def __init__(self, key: ResourceKey=None, *, type_: type = None, default: Any = def resolve(self, context): return context.get(self.key, self.default) + + +class Lazy(Requirement): + + runner = None + + def resolve(self, context): + result = context.get(self.key, missing) + if result is missing: + obj, requires = self.runner.lazy[self.key] + result = context.call(obj, requires) + context.add(result, provides=self.key) + return result diff --git a/mush/resolvers.py b/mush/resolvers.py deleted file mode 100644 index 7220b8a..0000000 --- a/mush/resolvers.py +++ /dev/null @@ -1,23 +0,0 @@ -from .declarations import returns as returns_declaration - - -class Lazy(object): - - def __init__(self, obj, requires, returns): - if not (type(returns) is returns_declaration and len(returns.args) == 1): - raise TypeError('a single return type must be explicitly specified') - self.__wrapped__ = obj - self.requires = requires - self.provides = returns.args[0] - - def __call__(self, context): - context.add(resolver=self.resolve, provides=self.provides) - - def resolve(self, context, default): - result = context.call(self.__wrapped__, self.requires) - context.remove(self.provides) - context.add(result, self.provides) - return result - - def __repr__(self): - return '' % self.__wrapped__ diff --git a/mush/runner.py b/mush/runner.py index 093970b..9701688 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -3,10 +3,11 @@ from .callpoints import CallPoint from .context import Context, ContextError from .declarations import DeclarationsFrom -from .extraction import extract_requires, extract_returns +from .extraction import extract_requires, extract_returns, default_requirement_type from .markers import not_specified from .modifier import Modifier from .plug import Plug +from .requirements import Lazy class Runner(object): @@ -20,8 +21,17 @@ class Runner(object): def __init__(self, *objects): self.labels = {} + self.lazy = {} self.extend(*objects) + def modify_requirement(self, requirement): + if requirement.key in self.lazy: + requirement.__class__ = Lazy + requirement.runner = self + else: + requirement = default_requirement_type(requirement) + return requirement + def add(self, obj, requires=None, returns=None, label=None, lazy=False): """ Add a callable to the runner. @@ -66,7 +76,7 @@ def _copy_from(self, start_point, end_point, added_using=None): while point: if added_using is None or added_using in point.added_using: - cloned_point = CallPoint(point.obj, point.requires, point.returns) + cloned_point = CallPoint(self, point.obj, point.requires, point.returns) cloned_point.labels = set(point.labels) for label in cloned_point.labels: self.labels[label] = cloned_point @@ -191,7 +201,7 @@ def replace(self, else: returns = point.returns - new_point = CallPoint(replacement, requires, returns) + new_point = CallPoint(self, replacement, requires, returns) if point.previous is None: self.start = new_point diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 45106d8..50a6e59 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -284,14 +284,18 @@ def foo(bar: Syncer('request')): @pytest.mark.asyncio async def test_default_custom_requirement(): - class FromRequest(Requirement): async def resolve(self, context): return (await context.get('request'))[self.key] + def default_requirement_type(requirement): + if requirement.__class__ is Requirement: + requirement.__class__ = FromRequest + return requirement + def foo(bar): return bar - context = Context(FromRequest) + context = Context(default_requirement_type) context.add({'bar': 'foo'}, provides='request') compare(await context.call(foo), expected='foo') diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index aab7ba3..d857d39 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -6,9 +6,8 @@ from mush.callpoints import CallPoint from mush.declarations import requires, returns, RequiresType -from .. import Value from mush.extraction import update_wrapper -from .helpers import r +from mush.requirements import Requirement class TestCallPoints(TestCase): @@ -18,7 +17,7 @@ def setUp(self): def test_passive_attributes(self): # these are managed by Modifiers - point = CallPoint(self.context) + point = CallPoint(self.context, Mock()) compare(point.previous, None) compare(point.next, None) compare(point.labels, set()) @@ -27,11 +26,11 @@ def test_supplied_explicitly(self): def foo(a1): pass rq = requires('foo') rt = returns('bar') - result = CallPoint(foo, rq, rt)(self.context) + result = CallPoint(self.context, foo, rq, rt)(self.context) compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([r(Value('foo'), name='a1')]), + RequiresType([Requirement('foo', name='a1')]), rt)) def test_extract_from_decorations(self): @@ -42,11 +41,11 @@ def test_extract_from_decorations(self): @rt def foo(a1): pass - result = CallPoint(foo)(self.context) + result = CallPoint(self.context, foo)(self.context) compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([r(Value('foo'), name='a1')]), + RequiresType([Requirement('foo', name='a1')]), returns('bar'))) def test_extract_from_decorated_class(self): @@ -70,9 +69,9 @@ def foo(prefix): return prefix+'answer' self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) - result = CallPoint(foo)(self.context) + result = CallPoint(self.context, foo)(self.context) compare(result, expected=('the answer', - RequiresType([r(Value('foo'), name='prefix')]), + RequiresType([Requirement('foo', name='prefix')]), rt)) def test_explicit_trumps_decorators(self): @@ -80,21 +79,22 @@ def test_explicit_trumps_decorators(self): @returns('bar') def foo(a1): pass - result = CallPoint(foo, requires('baz'), returns('bob'))(self.context) + point = CallPoint(self.context, foo, requires('baz'), returns('bob')) + result = point(self.context) compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([r(Value('baz'), name='a1')]), + RequiresType([Requirement('baz', name='a1')]), returns('bob'))) def test_repr_minimal(self): def foo(): pass - point = CallPoint(foo) + point = CallPoint(self.context, foo) compare(repr(foo)+" requires() returns_result_type()", repr(point)) def test_repr_maximal(self): def foo(a1): pass - point = CallPoint(foo, requires('foo'), returns('bar')) + point = CallPoint(self.context, foo, requires('foo'), returns('bar')) point.labels.add('baz') point.labels.add('bob') compare(repr(foo)+" requires('foo') returns('bar') <-- baz, bob", @@ -102,7 +102,7 @@ def foo(a1): pass def test_convert_to_requires_and_returns(self): def foo(baz): pass - point = CallPoint(foo, requires='foo', returns='bar') + point = CallPoint(self.context, foo, requires='foo', returns='bar') self.assertTrue(isinstance(point.requires, RequiresType)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires('foo') returns('bar')", @@ -110,7 +110,8 @@ def foo(baz): pass def test_convert_to_requires_and_returns_tuple(self): def foo(a1, a2): pass - point = CallPoint(foo, + point = CallPoint(self.context, + foo, requires=('foo', 'bar'), returns=('baz', 'bob')) self.assertTrue(isinstance(point.requires, RequiresType)) @@ -120,7 +121,8 @@ def foo(a1, a2): pass def test_convert_to_requires_and_returns_list(self): def foo(a1, a2): pass - point = CallPoint(foo, + point = CallPoint(self.context, + foo, requires=['foo', 'bar'], returns=['baz', 'bob']) self.assertTrue(isinstance(point.requires, RequiresType)) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index e92349a..51d00cf 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -408,17 +408,17 @@ def test_nest(self): compare(c1.get('c'), expected='c') def test_nest_with_overridden_default_requirement_type(self): - class FromRequest(Requirement): pass - c1 = Context(default_requirement_type=FromRequest) + def modifier(): pass + c1 = Context(modifier) c2 = c1.nest() - assert c2.default_requirement_type is FromRequest + assert c2._requirement_modifier is modifier def test_nest_with_explicit_default_requirement_type(self): - class Requirement1(Requirement): pass - class Requirement2(Requirement): pass - c1 = Context(default_requirement_type=Requirement1) - c2 = c1.nest(default_requirement_type=Requirement2) - assert c2.default_requirement_type is Requirement2 + def modifier1(): pass + def modifier2(): pass + c1 = Context(modifier1) + c2 = c1.nest(modifier2) + assert c2._requirement_modifier is modifier2 def test_nest_keeps_declarations_cache(self): c1 = Context() @@ -448,7 +448,7 @@ def resolve(self, context): def foo(bar: FromRequest('bar')): pass - context = Context(default_requirement_type=FromRequest) + context = Context() context.add({}, provides='request') with ShouldRaise(ContextError("No 'bar' in context")): compare(context.call(foo)) @@ -462,6 +462,11 @@ def resolve(self, context): def foo(bar): return bar - context = Context(default_requirement_type=FromRequest) + def modifier(requirement): + if requirement.__class__ is Requirement: + requirement.__class__ = FromRequest + return requirement + + context = Context(requirement_modifier=modifier) context.add({'bar': 'foo'}, provides='request') compare(context.call(foo), expected='foo') diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 2f1c6fd..d4d0f69 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -562,24 +562,19 @@ def foo(a = Value()): pass expected_rq=RequiresType((r(Value(key='a'), name='a'),)), expected_rt=result_type) - def test_default_requirement_type(self): + def test_requirement_modifier(self): def foo(x: str = None): pass class FromRequest(Requirement): pass - rq = extract_requires(foo, default_requirement_type=FromRequest) - compare(rq, strict=True, expected=RequiresType(( - FromRequest(key='x', name='x', type_=str, default=None), - ))) - - def test_default_requirement_not_used(self): - def foo(x: str = Value(default=None)): pass + def modifier(requirement): + if requirement.__class__ is Requirement: + requirement.__class__ = FromRequest + return requirement - class FromRequest(Requirement): pass - - rq = extract_requires(foo, default_requirement_type=FromRequest) + rq = extract_requires(foo, modifier=modifier) compare(rq, strict=True, expected=RequiresType(( - r(Value(key='x'), name='x', type=str, default=None), + FromRequest(key='x', name='x', type_=str, default=None), ))) diff --git a/mush/tests/test_resolver.py b/mush/tests/test_resolver.py index b30564a..7cadf98 100644 --- a/mush/tests/test_resolver.py +++ b/mush/tests/test_resolver.py @@ -1,20 +1,11 @@ from mush.context import ResolvableValue from testfixtures import compare -from mush import returns -from mush.resolvers import Lazy from mush.markers import Marker foo = Marker('foo') -class TestLazy: - - def test_repr(self): - f = Lazy(foo, None, returns('foo')) - compare(repr(f), expected='>') - - class TestResolvableValue: def test_repr_with_resolver(self): diff --git a/mush/types.py b/mush/types.py index c1cd756..01568c4 100644 --- a/mush/types.py +++ b/mush/types.py @@ -1,6 +1,11 @@ -from typing import NewType, Union, Hashable, Callable, Any +from typing import NewType, Union, Hashable, Callable, Any, TYPE_CHECKING + +if TYPE_CHECKING: + from .context import Context + from .requirements import Requirement ResourceKey = NewType('ResourceKey', Union[Hashable, Callable]) ResourceValue = NewType('ResourceValue', Any) ResourceResolver = Callable[['Context', Any], ResourceValue] RequirementResolver = Callable[['Context'], ResourceValue] +RequirementModifier = Callable[['Requirement'], 'Requirement'] From b349af64b9bafa2f73e29f5ef217ebcc78f3c22f Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 9 Mar 2020 18:27:53 +0000 Subject: [PATCH 063/159] Get rid of resolvable resources now that lazy is container within runners. --- mush/asyncio.py | 33 +++----------- mush/context.py | 53 ++++++---------------- mush/tests/test_async_context.py | 75 ++++++-------------------------- mush/tests/test_context.py | 61 +++++++------------------- mush/tests/test_resolver.py | 13 ------ mush/types.py | 2 - 6 files changed, 48 insertions(+), 189 deletions(-) delete mode 100644 mush/tests/test_resolver.py diff --git a/mush/asyncio.py b/mush/asyncio.py index b5db9e9..3312073 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -1,11 +1,11 @@ import asyncio from functools import partial -from typing import Type, Callable +from typing import Callable -from . import Context as SyncContext, Value as SyncValue +from . import Context as SyncContext from .declarations import RequiresType, ReturnsType -from .requirements import Requirement -from .types import ResourceKey +from .extraction import default_requirement_type +from .types import RequirementModifier async def ensure_async(func, *args, **kw): @@ -17,12 +17,6 @@ async def ensure_async(func, *args, **kw): return await loop.run_in_executor(None, func, *args) -class Value(SyncValue): - - async def resolve(self, context): - return await context.get(self.key, self.default) - - class SyncFromAsyncContext: def __init__(self, context, loop): @@ -30,11 +24,7 @@ def __init__(self, context, loop): self.loop = loop self.remove = context.remove self.add = context.add - - def get(self, key: ResourceKey, default=None): - coro = self.context.get(key, default) - future = asyncio.run_coroutine_threadsafe(coro, self.loop) - return future.result() + self.get = context.get def call(self, obj: Callable, requires: RequiresType = None): coro = self.context.call(obj, requires) @@ -47,12 +37,6 @@ def extract(self, obj: Callable, requires: RequiresType = None, returns: Returns return future.result() -def default_requirement_type(requirement): - if requirement.__class__ is Requirement: - requirement.__class__ = Value - return requirement - - class Context(SyncContext): def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): @@ -62,13 +46,6 @@ def __init__(self, requirement_modifier: RequirementModifier = default_requireme def _context_for(self, obj): return self if asyncio.iscoroutinefunction(obj) else self._sync_context - async def get(self, key: ResourceKey, default=None): - resolvable = self._get(key, default) - r = resolvable.resolver - if r is not None: - return await ensure_async(r, self._context_for(r), default) - return resolvable.value - async def call(self, obj: Callable, requires: RequiresType = None): args = [] kw = {} diff --git a/mush/context.py b/mush/context.py index b400747..7ce347a 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,10 +1,9 @@ -from typing import Optional, Type, Callable +from typing import Optional, Callable from .declarations import RequiresType, ReturnsType from .extraction import extract_requires, extract_returns, default_requirement_type from .markers import missing -from .requirements import Requirement, Value -from .types import ResourceKey, ResourceValue, ResourceResolver, RequirementModifier +from .types import ResourceKey, ResourceValue, RequirementModifier NONE_TYPE = type(None) @@ -60,19 +59,6 @@ def type_key(type_tuple): return type.__name__ -class ResolvableValue: - __slots__ = ('value', 'resolver') - - def __init__(self, value, resolver=None): - self.value = value - self.resolver = resolver - - def __repr__(self): - if self.resolver is None: - return repr(self.value) - return repr(self.resolver) - - class Context: "Stores resources for a particular run." @@ -86,24 +72,19 @@ def __init__(self, requirement_modifier: RequirementModifier = default_requireme def add(self, resource: Optional[ResourceValue] = None, - provides: Optional[ResourceKey] = None, - resolver: Optional[ResourceResolver] = None): + provides: Optional[ResourceKey] = None): """ Add a resource to the context. Optionally specify what the resource provides. """ - if resolver is not None and (provides is None or resource is not None): - if resource is not None: - raise TypeError('resource cannot be supplied when using a resolver') - raise TypeError('Both provides and resolver must be supplied') if provides is None: provides = type(resource) if provides is NONE_TYPE: raise ValueError('Cannot add None to context') if provides in self._store: raise ContextError(f'Context already contains {provides!r}') - self._store[provides] = ResolvableValue(resource, resolver) + self._store[provides] = resource def remove(self, key: ResourceKey, *, strict: bool = True): """ @@ -181,27 +162,19 @@ def call(self, obj: Callable, requires: RequiresType = None): resolving.send(requirement.resolve(self)) return obj(*args, **kw) - def _get(self, key, default): + def get(self, key: ResourceKey, default=None): context = self - resolvable = None - while resolvable is None and context is not None: - resolvable = context._store.get(key, None) - if resolvable is None: + while context is not None: + value = context._store.get(key, missing) + if value is missing: context = context._parent - elif context is not self: - self._store[key] = resolvable - - if resolvable is None: - return ResolvableValue(default) - - return resolvable + else: + if context is not self: + self._store[key] = value + return value - def get(self, key: ResourceKey, default=None): - resolvable = self._get(key, default) - if resolvable.resolver is not None: - return resolvable.resolver(self, default) - return resolvable.value + return default def nest(self, requirement_modifier: RequirementModifier = None): if requirement_modifier is None: diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 50a6e59..4bfe296 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -5,7 +5,6 @@ from mush import Context, Value, requires, returns from mush.asyncio import Context -from mush.context import ResolvableValue from mush.declarations import RequiresType from mush.requirements import Requirement from testfixtures import compare @@ -13,52 +12,6 @@ from mush.tests.test_context import TheType -@pytest.mark.asyncio -async def test_get_is_async(): - context = Context() - result = context.get('foo', default='bar') - assert asyncio.iscoroutine(result) - compare(await result, expected='bar') - - -@pytest.mark.asyncio -async def test_get_async_resolver(): - async def resolver(*args): - return 'bar' - context = Context() - context.add(provides='foo', resolver=resolver) - compare(await context.get('foo'), expected='bar') - - -@pytest.mark.asyncio -async def test_get_async_resolver_calls_back_into_async(): - async def resolver(context, default): - return await context.get('baz') - context = Context() - context.add('bar', provides='baz') - context.add(provides='foo', resolver=resolver) - compare(await context.get('foo'), expected='bar') - - -@pytest.mark.asyncio -async def test_get_sync_resolver(): - def resolver(*args): - return 'bar' - context = Context() - context.add(provides='foo', resolver=resolver) - compare(await context.get('foo'), expected='bar') - - -@pytest.mark.asyncio -async def test_get_sync_resolver_calls_back_into_async(): - def resolver(context, default): - return context.get('baz') - context = Context() - context.add('bar', provides='baz') - context.add(provides='foo', resolver=resolver) - compare(await context.get('foo'), expected='bar') - - @pytest.mark.asyncio async def test_call_is_async(): context = Context() @@ -83,7 +36,7 @@ async def test_call_async_requires_context(): context = Context() context.add('bar', provides='baz') async def it(context: Context): - return await context.get('baz') + return context.get('baz') compare(await context.call(it), expected='bar') @@ -92,7 +45,7 @@ async def test_call_async_requires_async_context(): context = Context() context.add('bar', provides='baz') async def it(context: Context): - return await context.get('baz') + return context.get('baz') compare(await context.call(it), expected='bar') @@ -139,7 +92,7 @@ def it(): result = context.extract(it, requires(), returns('baz')) assert asyncio.iscoroutine(result) compare(await result, expected='bar') - compare(await context.get('baz'), expected='bar') + compare(context.get('baz'), expected='bar') @pytest.mark.asyncio @@ -147,10 +100,10 @@ async def test_extract_async(): context = Context() context.add('foo', provides='bar') async def it(context): - return await context.get('bar')+'bar' + return context.get('bar')+'bar' result = context.extract(it, requires(Context), returns('baz')) compare(await result, expected='foobar') - compare(await context.get('baz'), expected='foobar') + compare(context.get('baz'), expected='foobar') @pytest.mark.asyncio @@ -161,7 +114,7 @@ def it(context): return context.get('bar')+'bar' result = context.extract(it, requires(Context), returns('baz')) compare(await result, expected='foobar') - compare(await context.get('baz'), expected='foobar') + compare(context.get('baz'), expected='foobar') @pytest.mark.asyncio @@ -172,7 +125,7 @@ def foo() -> TheType: context = Context() result = await context.extract(foo) assert result is o - compare({TheType: ResolvableValue(o)}, actual=context._store) + compare({TheType: o}, actual=context._store) compare(context._requires_cache[foo], expected=RequiresType()) compare(context._returns_cache[foo], expected=returns(TheType)) @@ -186,8 +139,8 @@ def foo(*args): result = await context.extract(foo, requires(str), returns(Tuple[str])) compare(result, expected=('a',)) compare({ - str: ResolvableValue('a'), - Tuple[str]: ResolvableValue(('a',)), + str: 'a', + Tuple[str]: ('a',), }, actual=context._store) compare(context._requires_cache, expected={}) compare(context._returns_cache, expected={}) @@ -198,7 +151,7 @@ async def test_custom_requirement_async_resolve(): class FromRequest(Requirement): async def resolve(self, context): - return (await context.get('request'))[self.key] + return (context.get('request'))[self.key] def foo(bar: FromRequest('bar')): return bar @@ -258,7 +211,7 @@ def foo(bar: Syncer(baz)): context = Context() context.add({'bar': 'foo'}, provides='request') compare(await context.call(foo), expected='foo') - compare(await context.get('response'), expected='foo') + compare(context.get('response'), expected='foo') @pytest.mark.asyncio @@ -277,8 +230,8 @@ def foo(bar: Syncer('request')): context = Context() context.add({'bar': 'foo'}, provides='request') compare(await context.call(foo), expected='foo') - compare(await context.get('request'), expected=None) - compare(await context.get('response'), expected='foo') + compare(context.get('request'), expected=None) + compare(context.get('response'), expected='foo') @pytest.mark.asyncio @@ -286,7 +239,7 @@ async def test_default_custom_requirement(): class FromRequest(Requirement): async def resolve(self, context): - return (await context.get('request'))[self.key] + return (context.get('request'))[self.key] def default_requirement_type(requirement): if requirement.__class__ is Requirement: diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 51d00cf..bb3894f 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -2,7 +2,6 @@ from unittest import TestCase from mock import Mock -from mush.context import ResolvableValue from testfixtures import ShouldRaise, compare from mush import ( @@ -24,7 +23,7 @@ def test_simple(self): context = Context() context.add(obj) - compare(context._store, expected={TheType: ResolvableValue(obj)}) + compare(context._store, expected={TheType: obj}) expected = ( ": \n" @@ -41,7 +40,7 @@ def test_type_as_string(self): expected = ("\n" "}>") - compare(context._store, expected={'my label': ResolvableValue(obj)}) + compare(context._store, expected={'my label': obj}) self.assertEqual(repr(context), expected) self.assertEqual(str(context), expected) @@ -50,7 +49,7 @@ class T2(object): pass obj = TheType() context = Context() context.add(obj, provides=T2) - compare(context._store, expected={T2: ResolvableValue(obj)}) + compare(context._store, expected={T2: obj}) expected = ("\n" "}>") @@ -63,34 +62,6 @@ def test_no_resolver_or_provides(self): context.add() compare(context._store, expected={}) - def test_resolver_but_no_provides(self): - context = Context() - with ShouldRaise(TypeError('Both provides and resolver must be supplied')): - context.add(resolver=lambda: None) - compare(context._store, expected={}) - - def test_resolver(self): - m = Mock() - context = Context() - context.add(provides='foo', resolver=m) - m.assert_not_called() - assert context.get('foo') is m.return_value - m.assert_called_with(context, None) - - def test_resolver_and_resource(self): - m = Mock() - context = Context() - with ShouldRaise(TypeError('resource cannot be supplied when using a resolver')): - context.add('bar', provides='foo', resolver=m) - compare(context._store, expected={}) - - def test_resolver_with_default(self): - m = Mock() - context = Context() - context.add(provides='foo', - resolver=lambda context, default=None: context.get('foo-bar', default)) - assert context.get('foo', default=m) is m - def test_clash(self): obj1 = TheType() obj2 = TheType() @@ -115,7 +86,7 @@ def test_add_none(self): def test_add_none_with_type(self): context = Context() context.add(None, TheType) - compare(context._store, expected={TheType: ResolvableValue(None)}) + compare(context._store, expected={TheType: None}) def test_call_basic(self): def foo(): @@ -131,7 +102,7 @@ def foo(obj): context.add('bar', 'baz') result = context.call(foo, requires('baz')) compare(result, 'bar') - compare({'baz': ResolvableValue('bar')}, actual=context._store) + compare({'baz': 'bar'}, actual=context._store) def test_call_requires_type(self): def foo(obj): @@ -140,7 +111,7 @@ def foo(obj): context.add('bar', TheType) result = context.call(foo, requires(TheType)) compare(result, 'bar') - compare({TheType: ResolvableValue('bar')}, actual=context._store) + compare({TheType: 'bar'}, actual=context._store) def test_call_requires_missing(self): def foo(obj): return obj @@ -177,8 +148,8 @@ def foo(x, y): context.add('bar', 'baz') result = context.call(foo, requires(y='baz', x=TheType)) compare(result, ('foo', 'bar')) - compare({TheType: ResolvableValue('foo'), - 'baz': ResolvableValue('bar')}, + compare({TheType: 'foo', + 'baz': 'bar'}, actual=context._store) def test_call_requires_optional_present(self): @@ -188,7 +159,7 @@ def foo(x=1): context.add(2, TheType) result = context.call(foo, requires(TheType)) compare(result, 2) - compare({TheType: ResolvableValue(2)}, actual=context._store) + compare({TheType: 2}, actual=context._store) def test_call_requires_optional_missing(self): def foo(x: TheType = 1): @@ -212,7 +183,7 @@ def foo(x:'foo'=1): context.add(2, 'foo') result = context.call(foo) compare(result, 2) - compare({'foo': ResolvableValue(2)}, actual=context._store) + compare({'foo': 2}, actual=context._store) def test_call_requires_item(self): def foo(x): @@ -293,7 +264,7 @@ def foo() -> TheType: context = Context() result = context.extract(foo) assert result is o - compare({TheType: ResolvableValue(o)}, actual=context._store) + compare({TheType: o}, actual=context._store) compare(context._requires_cache[foo], expected=RequiresType()) compare(context._returns_cache[foo], expected=returns(TheType)) @@ -305,8 +276,8 @@ def foo(*args): result = context.extract(foo, requires(str), returns(Tuple[str])) compare(result, expected=('a',)) compare({ - str: ResolvableValue('a'), - Tuple[str]: ResolvableValue(('a',)), + str: 'a', + Tuple[str]: ('a',), }, actual=context._store) compare(context._requires_cache, expected={}) compare(context._returns_cache, expected={}) @@ -317,7 +288,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns(TheType)) compare(result, 'bar') - compare({TheType: ResolvableValue('bar')}, actual=context._store) + compare({TheType: 'bar'}, actual=context._store) def test_returns_sequence(self): def foo(): @@ -325,7 +296,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns('foo', 'bar')) compare(result, (1, 2)) - compare({'foo': ResolvableValue(1), 'bar': ResolvableValue(2)}, + compare({'foo': 1, 'bar': 2}, actual=context._store) def test_returns_mapping(self): @@ -334,7 +305,7 @@ def foo(): context = Context() result = context.extract(foo, nothing, returns_mapping()) compare(result, {'foo': 1, 'bar': 2}) - compare({'foo': ResolvableValue(1), 'bar': ResolvableValue(2)}, + compare({'foo': 1, 'bar': 2}, actual=context._store) def test_ignore_return(self): diff --git a/mush/tests/test_resolver.py b/mush/tests/test_resolver.py deleted file mode 100644 index 7cadf98..0000000 --- a/mush/tests/test_resolver.py +++ /dev/null @@ -1,13 +0,0 @@ -from mush.context import ResolvableValue -from testfixtures import compare - -from mush.markers import Marker - -foo = Marker('foo') - - -class TestResolvableValue: - - def test_repr_with_resolver(self): - compare(repr(ResolvableValue(None, foo)), - expected='') diff --git a/mush/types.py b/mush/types.py index 01568c4..0df1ee9 100644 --- a/mush/types.py +++ b/mush/types.py @@ -6,6 +6,4 @@ ResourceKey = NewType('ResourceKey', Union[Hashable, Callable]) ResourceValue = NewType('ResourceValue', Any) -ResourceResolver = Callable[['Context', Any], ResourceValue] -RequirementResolver = Callable[['Context'], ResourceValue] RequirementModifier = Callable[['Requirement'], 'Requirement'] From 8efc0afc74b3d1d1d74669c960048f20a92dadd5 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 9 Mar 2020 18:57:23 +0000 Subject: [PATCH 064/159] factor out a get_sync --- mush/declarations.py | 8 ++++++++ mush/extraction.py | 14 ++++++-------- mush/plug.py | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index e04fa47..887889c 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -2,6 +2,7 @@ from itertools import chain from typing import _type_check +from .markers import missing from .requirements import Requirement, Value, name_or_repr @@ -11,6 +12,13 @@ def set_mush(obj, key, value): obj.__mush__[key] = value +def get_mush(obj, key, default): + __mush__ = getattr(obj, '__mush__', missing) + if __mush__ is missing: + return default + return __mush__.get(key, default) + + class RequiresType(list): def __repr__(self): diff --git a/mush/extraction.py b/mush/extraction.py index d518812..e8f7e81 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -10,7 +10,8 @@ from .declarations import ( requires, RequiresType, ReturnsType, returns, result_type, - nothing + nothing, + get_mush ) from .requirements import Requirement, Value from .markers import missing @@ -112,11 +113,9 @@ def extract_requires(obj: Callable, by_index = list(by_name) # from declarations - mush_declarations = getattr(obj, '__mush__', None) - if mush_declarations is not None: - requires_ = mush_declarations.get('requires') - if requires_ is not None: - _apply_requires(by_name, by_index, requires_) + mush_requires = get_mush(obj, 'requires', None) + if mush_requires is not None: + _apply_requires(by_name, by_index, mush_requires) # explicit if explicit is not None: @@ -143,8 +142,7 @@ def extract_requires(obj: Callable, def extract_returns(obj: Callable, explicit: ReturnsType = None): if explicit is None: - mush_declarations = getattr(obj, '__mush__', {}) - returns_ = mush_declarations.get('returns', None) + returns_ = get_mush(obj, 'returns', None) if returns_ is None: annotations = getattr(obj, '__annotations__', {}) returns_ = annotations.get('return') diff --git a/mush/plug.py b/mush/plug.py index 64b6010..1680d4a 100644 --- a/mush/plug.py +++ b/mush/plug.py @@ -1,4 +1,4 @@ -from .declarations import set_mush +from .declarations import set_mush, get_mush class ignore(object): @@ -67,5 +67,5 @@ def add_to(self, runner): if not name.startswith('_'): obj = getattr(self, name) if callable(obj): - action = getattr(obj, '__mush__', {}).get('plug', default_action) + action = get_mush(obj, 'plug', default_action) action.apply(runner, obj) From 552a955c286cb7bd42cfe6be45d6c5286f3d2693 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 06:58:54 +0000 Subject: [PATCH 065/159] tuple generics' have a __name__ on Py3.6 --- mush/tests/helpers.py | 8 ++++++++ mush/tests/test_declarations.py | 8 +++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py index d0c0921..f5a547c 100644 --- a/mush/tests/helpers.py +++ b/mush/tests/helpers.py @@ -1,6 +1,14 @@ +import sys + + def r(base, **attrs): """ helper for returning Requirement subclasses with extra attributes """ base.__dict__.update(attrs) return base + + +PY_VERSION = sys.version_info[:2] + +PY_36 = PY_VERSION == (3, 6) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index d4d0f69..4d7cfa6 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -16,7 +16,7 @@ from mush.extraction import extract_requires, extract_returns, update_wrapper from mush.markers import missing from mush.requirements import Requirement, AttrOp, ItemOp -from .helpers import r +from .helpers import r, PY_36 def check_extract(obj, expected_rq, expected_rt): @@ -61,7 +61,8 @@ def test_strings(self): def test_typing(self): r_ = requires(Tuple[str]) - compare(repr(r_), "requires(typing.Tuple[str])") + text = 'Tuple' if PY_36 else 'typing.Tuple[str]' + compare(repr(r_), f"requires({text})") compare(r_, expected=[r(Value(Tuple[str]), type=Tuple[str])]) def test_tuple_arg(self): @@ -213,7 +214,8 @@ def test_string(self): def test_typing(self): r = returns(Tuple[str]) - compare(repr(r), 'returns(typing.Tuple[str])') + text = 'Tuple' if PY_36 else 'typing.Tuple[str]' + compare(repr(r), f'returns({text})') compare(dict(r.process('foo')), {Tuple[str]: 'foo'}) def test_sequence(self): From 0e0b9f91bba471ba82671142e434689ebabd16d8 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 07:43:25 +0000 Subject: [PATCH 066/159] make the name for this class more sensible --- mush/asyncio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 3312073..1503ae0 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -17,7 +17,7 @@ async def ensure_async(func, *args, **kw): return await loop.run_in_executor(None, func, *args) -class SyncFromAsyncContext: +class AsyncFromSyncContext: def __init__(self, context, loop): self.context = context @@ -41,7 +41,7 @@ class Context(SyncContext): def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): super().__init__(requirement_modifier) - self._sync_context = SyncFromAsyncContext(self, asyncio.get_event_loop()) + self._sync_context = AsyncFromSyncContext(self, asyncio.get_event_loop()) def _context_for(self, obj): return self if asyncio.iscoroutinefunction(obj) else self._sync_context From 43b6185bcec20a807a059c0598ea60ed627a7568 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 07:43:51 +0000 Subject: [PATCH 067/159] Add a marker indicating that even if a callable is synchronous, it doesn't need running in a thread. Use it to mark Value.resolve as non-blocking. --- mush/asyncio.py | 4 +++- mush/markers.py | 11 +++++++++++ mush/requirements.py | 5 +++-- mush/tests/test_async_context.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 1503ae0..180bcae 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -9,7 +9,9 @@ async def ensure_async(func, *args, **kw): - if asyncio.iscoroutinefunction(func): + if getattr(func, '__nonblocking__', False): + return func(*args, **kw) + elif asyncio.iscoroutinefunction(func): return await func(*args, **kw) if kw: func = partial(func, **kw) diff --git a/mush/markers.py b/mush/markers.py index 0b87c8d..16dba48 100644 --- a/mush/markers.py +++ b/mush/markers.py @@ -11,3 +11,14 @@ def __repr__(self): #: A sentinel object to indicate that a value is missing. missing = Marker('missing') + + +def nonblocking(obj): + """ + A decorator to mark a method as not requiring running + in a thread, even though it's not async. + """ + # Not using set_mush / get_mush to try and keep this as + # quick as possible + obj.__nonblocking__ = True + return obj diff --git a/mush/requirements.py b/mush/requirements.py index 0940bec..1d2b98b 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -2,7 +2,7 @@ from typing import Any, Optional, List, TYPE_CHECKING from .types import ResourceKey -from .markers import missing +from .markers import missing, nonblocking if TYPE_CHECKING: from .context import Context @@ -133,7 +133,8 @@ def __init__(self, key: ResourceKey=None, *, type_: type = None, default: Any = type_ = key super().__init__(key, type_=type_, default=default) - def resolve(self, context): + @nonblocking + def resolve(self, context: 'Context'): return context.get(self.key, self.default) diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 4bfe296..26cdcd1 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -1,7 +1,9 @@ import asyncio +from contextlib import contextmanager from typing import Tuple import pytest +from mock import Mock from mush import Context, Value, requires, returns from mush.asyncio import Context @@ -12,6 +14,21 @@ from mush.tests.test_context import TheType +@pytest.fixture() +def no_threads(): + # pytest-asyncio does things so we need to do this mock *in* the test: + @contextmanager + def raise_on_threads(): + loop = asyncio.get_event_loop() + original = loop.run_in_executor + loop.run_in_executor = Mock(side_effect=Exception('bad')) + try: + yield + finally: + loop.run_in_executor = original + return raise_on_threads() + + @pytest.mark.asyncio async def test_call_is_async(): context = Context() @@ -146,6 +163,18 @@ def foo(*args): compare(context._returns_cache, expected={}) +@pytest.mark.asyncio +async def test_value_resolve_does_not_run_in_thread(no_threads): + with no_threads: + context = Context() + context.add('foo', provides='baz') + + async def it(baz): + return baz+'bar' + + compare(await context.call(it), expected='foobar') + + @pytest.mark.asyncio async def test_custom_requirement_async_resolve(): From 680f8dc1fed867d8b6fbae848d6ec5b01c86262e Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 08:22:06 +0000 Subject: [PATCH 068/159] Split out a simpler ResourceError and move ContextError to be with the Runner class. --- mush/__init__.py | 5 +-- mush/context.py | 68 +++++++++----------------------------- mush/requirements.py | 7 +++- mush/runner.py | 46 ++++++++++++++++++++++++-- mush/tests/test_context.py | 31 +++++++++++------ mush/tests/test_runner.py | 9 +++-- 6 files changed, 94 insertions(+), 72 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index 4d9b313..e63be3b 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -1,4 +1,4 @@ -from .context import Context, ContextError +from .context import Context, ResourceError from .declarations import ( requires, returns_result_type, returns_mapping, returns_sequence, returns, nothing ) @@ -6,12 +6,13 @@ from .markers import missing from .plug import Plug from .requirements import Value -from .runner import Runner +from .runner import Runner, ContextError __all__ = [ 'Context', 'ContextError', 'Plug', + 'ResourceError', 'Runner', 'Value', 'missing', diff --git a/mush/context.py b/mush/context.py index 7ce347a..bb42751 100644 --- a/mush/context.py +++ b/mush/context.py @@ -3,60 +3,23 @@ from .declarations import RequiresType, ReturnsType from .extraction import extract_requires, extract_returns, default_requirement_type from .markers import missing +from .requirements import Requirement from .types import ResourceKey, ResourceValue, RequirementModifier NONE_TYPE = type(None) -class ContextError(Exception): +class ResourceError(Exception): """ - Errors likely caused by incorrect building of a runner. + An exception raised when there is a problem with a `ResourceKey`. """ - def __init__(self, text, point=None, context=None): - self.text = text - self.point = point - self.context = context - - def __str__(self): - rows = [] - if self.point: - point = self.point.previous - while point: - rows.append(repr(point)) - point = point.previous - if rows: - rows.append('Already called:') - rows.append('') - rows.append('') - rows.reverse() - rows.append('') - - rows.append('While calling: '+repr(self.point)) - if self.context is not None: - rows.append('with '+repr(self.context)+':') - rows.append('') - - rows.append(self.text) - - if self.point: - point = self.point.next - if point: - rows.append('') - rows.append('Still to call:') - while point: - rows.append(repr(point)) - point = point.next - - return '\n'.join(rows) - - __repr__ = __str__ - - -def type_key(type_tuple): - type, _ = type_tuple - if isinstance(type, str): - return type - return type.__name__ + + def __init__(self, message: str, key: ResourceKey, requirement: Requirement = None): + super().__init__(message) + #: The key for the problematic resource. + self.key: ResourceKey = key + #: The requirement that caused this exception. + self.requirement: Requirement = requirement class Context: @@ -83,23 +46,23 @@ def add(self, if provides is NONE_TYPE: raise ValueError('Cannot add None to context') if provides in self._store: - raise ContextError(f'Context already contains {provides!r}') + raise ResourceError(f'Context already contains {provides!r}', provides) self._store[provides] = resource def remove(self, key: ResourceKey, *, strict: bool = True): """ Remove the specified resource key from the context. - If ``strict``, then a :class:`ContextError` will be raised if the + If ``strict``, then a :class:`ResourceError` will be raised if the specified resource is not present in the context. """ if strict and key not in self._store: - raise ContextError(f'Context does not contain {key!r}') + raise ResourceError(f'Context does not contain {key!r}', key) self._store.pop(key, None) def __repr__(self): bits = [] - for type, value in sorted(self._store.items(), key=type_key): + for type, value in sorted(self._store.items(), key=lambda o: repr(o)): bits.append('\n %r: %r' % (type, value)) if bits: bits.append('\n') @@ -145,7 +108,8 @@ def _resolve(self, obj, requires, args, kw, context): if isinstance(key, type) and issubclass(key, Context): o = context else: - raise ContextError('No %s in context' % requirement.value_repr()) + raise ResourceError(f'No {requirement!r} in context', + key, requirement) if requirement.target is None: args.append(o) diff --git a/mush/requirements.py b/mush/requirements.py index 1d2b98b..5eaa836 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -91,7 +91,12 @@ def __repr__(self): value = getattr(self, a.rstrip('_')) if value is not None and value != self.key: attrs.append(f", {a}={value!r}") - return self.value_repr(''.join(attrs), from_repr=True) + + key = name_or_repr(self.key) + default = '' if self.default is missing else f', default={self.default!r}' + ops = ''.join(repr(o) for o in self.ops) + + return f"{type(self).__name__}({key}{default}{''.join(attrs)}){ops}" def attr(self, name): """ diff --git a/mush/runner.py b/mush/runner.py index 9701688..3a62085 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -1,7 +1,7 @@ from typing import Callable from .callpoints import CallPoint -from .context import Context, ContextError +from .context import Context, ResourceError from .declarations import DeclarationsFrom from .extraction import extract_requires, extract_returns, default_requirement_type from .markers import not_specified @@ -265,7 +265,7 @@ def __call__(self, context=None): try: result = point(context) - except ContextError as e: + except ResourceError as e: raise ContextError(str(e), point, context) if getattr(result, '__enter__', None): @@ -288,3 +288,45 @@ def __repr__(self): return '%s' % ''.join(bits) +class ContextError(Exception): + """ + Errors likely caused by incorrect building of a runner. + """ + def __init__(self, text: str, point: CallPoint=None, context: Context = None): + self.text: str = text + self.point: CallPoint = point + self.context: Context = context + + def __str__(self): + rows = [] + if self.point: + point = self.point.previous + while point: + rows.append(repr(point)) + point = point.previous + if rows: + rows.append('Already called:') + rows.append('') + rows.append('') + rows.reverse() + rows.append('') + + rows.append('While calling: '+repr(self.point)) + if self.context is not None: + rows.append('with '+repr(self.context)+':') + rows.append('') + + rows.append(self.text) + + if self.point: + point = self.point.next + if point: + rows.append('') + rows.append('Still to call:') + while point: + rows.append(repr(point)) + point = point.next + + return '\n'.join(rows) + + __repr__ = __str__ diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index bb3894f..70bcbd5 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -5,11 +5,13 @@ from testfixtures import ShouldRaise, compare from mush import ( - Context, ContextError, requires, returns, nothing, returns_mapping, - Value, missing + Context, requires, returns, nothing, returns_mapping, Value, missing ) +from mush.context import ResourceError from mush.declarations import RequiresType from mush.requirements import Requirement +from .helpers import r + class TheType(object): def __repr__(self): @@ -67,7 +69,8 @@ def test_clash(self): obj2 = TheType() context = Context() context.add(obj1, TheType) - with ShouldRaise(ContextError('Context already contains '+repr(TheType))): + with ShouldRaise(ResourceError('Context already contains '+repr(TheType), + key=TheType)): context.add(obj2, TheType) def test_clash_string_type(self): @@ -75,7 +78,8 @@ def test_clash_string_type(self): obj2 = TheType() context = Context() context.add(obj1, provides='my label') - with ShouldRaise(ContextError("Context already contains 'my label'")): + with ShouldRaise(ResourceError("Context already contains 'my label'", + key='my label')): context.add(obj2, provides='my label') def test_add_none(self): @@ -116,8 +120,10 @@ def foo(obj): def test_call_requires_missing(self): def foo(obj): return obj context = Context() - with ShouldRaise(ContextError( - "No TheType in context" + with ShouldRaise(ResourceError( + "No Value(TheType) in context", + key=TheType, + requirement=Value(TheType), )): context.call(foo, requires(TheType)) @@ -125,8 +131,10 @@ def test_call_requires_item_missing(self): def foo(obj): return obj context = Context() context.add({}, TheType) - with ShouldRaise(ContextError( - "No Value(TheType)['foo'] in context" + with ShouldRaise(ResourceError( + "No Value(TheType)['foo'] in context", + key=TheType, + requirement=Value(TheType)['foo'], )): context.call(foo, requires(Value(TheType)['foo'])) @@ -337,7 +345,8 @@ def test_remove(self): def test_remove_not_there_strict(self): context = Context() - with ShouldRaise(ContextError("Context does not contain 'foo'")): + with ShouldRaise(ResourceError("Context does not contain 'foo'", + key='foo')): context.remove('foo') compare(context._store, expected={}) @@ -421,7 +430,9 @@ def foo(bar: FromRequest('bar')): context = Context() context.add({}, provides='request') - with ShouldRaise(ContextError("No 'bar' in context")): + with ShouldRaise(ResourceError("No FromRequest('bar') in context", + key='bar', + requirement=r(FromRequest('bar'), name='bar'))): compare(context.call(foo)) def test_default_custom_requirement(self): diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 207f01f..9b76a11 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -1,11 +1,10 @@ from unittest import TestCase from mock import Mock, call -from mush.context import ContextError from mush.declarations import ( requires, returns, returns_mapping, replacement, original) -from mush import Value +from mush import Value, ContextError from mush.runner import Runner from testfixtures import ( ShouldRaise, @@ -514,7 +513,7 @@ def job(arg): 'While calling: '+repr(job)+' requires(T) returns_result_type()', 'with :', '', - 'No T in context', + "No Value(T, name='arg') in context", )) compare(text, actual=repr(s.raised)) compare(text, actual=str(s.raised)) @@ -552,7 +551,7 @@ def job5(foo, bar): pass 'While calling: '+repr(job3)+' requires(T) returns_result_type()', 'with :', '', - 'No T in context', + "No Value(T, name='arg') in context", '', 'Still to call:', repr(job4)+' requires() returns_result_type() <-- 4', @@ -567,7 +566,7 @@ def job(arg): runner = Runner(job) with ShouldRaise(ContextError) as s: runner() - compare(s.raised.text, expected="No 'arg' in context") + compare(s.raised.text, expected="No Value('arg') in context") def test_already_in_context(self): class T(object): pass From 137c03d7bc4d10a03400825ea8746340b14d750e Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 08:33:01 +0000 Subject: [PATCH 069/159] round out code coverage --- mush/declarations.py | 3 +-- mush/tests/test_declarations.py | 5 +++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index 887889c..a625959 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -174,10 +174,9 @@ def valid_decoration_types(*objs): continue try: _type_check(obj, '') + continue except TypeError: pass - else: - continue raise TypeError( repr(obj)+" is not a valid decoration type" ) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 4d7cfa6..127d897 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -129,6 +129,11 @@ def test_no_special_name_via_getattr(self): assert v.__len__ compare(v.ops, []) + def test_resolve(self): + r = Requirement() + with ShouldRaise(NotImplementedError): + r.resolve(None) + class TestValue: From fb66793a730b22e245182dc1c681d958fcf8f711 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 08:45:58 +0000 Subject: [PATCH 070/159] Add a Call resolver for both sync and async use. --- mush/__init__.py | 3 +- mush/asyncio.py | 16 +++++++- mush/requirements.py | 23 ++++++++++- mush/tests/test_async_requirements.py | 59 +++++++++++++++++++++++++++ mush/tests/test_requirements.py | 55 +++++++++++++++++++++++++ 5 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 mush/tests/test_async_requirements.py create mode 100644 mush/tests/test_requirements.py diff --git a/mush/__init__.py b/mush/__init__.py index e63be3b..26b3bd1 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -5,10 +5,11 @@ from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing from .plug import Plug -from .requirements import Value +from .requirements import Value, Call from .runner import Runner, ContextError __all__ = [ + 'Call', 'Context', 'ContextError', 'Plug', diff --git a/mush/asyncio.py b/mush/asyncio.py index 180bcae..3d08ac8 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -2,7 +2,7 @@ from functools import partial from typing import Callable -from . import Context as SyncContext +from . import Context as SyncContext, Call as SyncCall, missing from .declarations import RequiresType, ReturnsType from .extraction import default_requirement_type from .types import RequirementModifier @@ -65,3 +65,17 @@ async def extract(self, result = await self.call(obj, requires) self._process(obj, result, returns) return result + + +class Call(SyncCall): + + async def resolve(self, context): + result = context.get(self.key, missing) + if result is missing: + result = await context.call(self.key) + if self.cache: + context.add(result, provides=self.key) + return result + + +__all__ = ['Context', 'Call'] diff --git a/mush/requirements.py b/mush/requirements.py index 5eaa836..810fc6b 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -1,5 +1,5 @@ from copy import copy -from typing import Any, Optional, List, TYPE_CHECKING +from typing import Any, Optional, List, TYPE_CHECKING, Callable from .types import ResourceKey from .markers import missing, nonblocking @@ -154,3 +154,24 @@ def resolve(self, context): result = context.call(obj, requires) context.add(result, provides=self.key) return result + + +class Call(Requirement): + """ + A requirement that is resolved by calling something. + + If ``cache`` is ``True``, then the result of that call will be cached + for the duration of the context in which this requirement is resolved. + """ + + def __init__(self, obj: Callable, *, cache: bool = True): + super().__init__(obj) + self.cache: bool = cache + + def resolve(self, context): + result = context.get(self.key, missing) + if result is missing: + result = context.call(self.key) + if self.cache: + context.add(result, provides=self.key) + return result diff --git a/mush/tests/test_async_requirements.py b/mush/tests/test_async_requirements.py new file mode 100644 index 0000000..e399ae5 --- /dev/null +++ b/mush/tests/test_async_requirements.py @@ -0,0 +1,59 @@ +import pytest +from testfixtures import compare + +from mush.asyncio import Context, Call + + +class TestCall: + + @pytest.mark.asyncio + async def test_resolve(self): + context = Context() + + called = [] + + async def foo(bar: str): + called.append(1) + return bar+'b' + + async def bob(x: str = Call(foo)): + return x+'c' + + context.add('a', provides='bar') + + compare(await context.call(bob), expected='abc') + compare(await context.call(bob), expected='abc') + compare(called, expected=[1]) + compare(context.get(foo), expected='ab') + + @pytest.mark.asyncio + async def test_resolve_without_caching(self): + context = Context() + + called = [] + + def foo(bar: str): + called.append(1) + return bar+'b' + + def bob(x: str = Call(foo, cache=False)): + return x+'c' + + context.add('a', provides='bar') + + compare(await context.call(bob), expected='abc') + compare(await context.call(bob), expected='abc') + compare(called, expected=[1, 1]) + compare(context.get(foo), expected=None) + + @pytest.mark.asyncio + async def test_parts_of_a_call(self): + context = Context() + + async def foo(): + return {'a': 'b'} + + async def bob(x: str = Call(foo)['a']): + return x+'c' + + compare(await context.call(bob), expected='bc') diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py new file mode 100644 index 0000000..4a4dcef --- /dev/null +++ b/mush/tests/test_requirements.py @@ -0,0 +1,55 @@ +from testfixtures import compare + +from mush import Context, Call + + +class TestCall: + + def test_resolve(self): + context = Context() + + called = [] + + def foo(bar: str): + called.append(1) + return bar+'b' + + def bob(x: str = Call(foo)): + return x+'c' + + context.add('a', provides='bar') + + compare(context.call(bob), expected='abc') + compare(context.call(bob), expected='abc') + compare(called, expected=[1]) + compare(context.get(foo), expected='ab') + + def test_resolve_without_caching(self): + context = Context() + + called = [] + + def foo(bar: str): + called.append(1) + return bar+'b' + + def bob(x: str = Call(foo, cache=False)): + return x+'c' + + context.add('a', provides='bar') + + compare(context.call(bob), expected='abc') + compare(context.call(bob), expected='abc') + compare(called, expected=[1, 1]) + compare(context.get(foo), expected=None) + + def test_parts_of_a_call(self): + context = Context() + + def foo(): + return {'a': 'b'} + + def bob(x: str = Call(foo)['a']): + return x+'c' + + compare(context.call(bob), expected='bc') From 3fac6e24acf4d03aa9b0fb2412e878a57418e689 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 08:49:24 +0000 Subject: [PATCH 071/159] move requirement tests to test_requirements --- mush/tests/helpers.py | 6 ++ mush/tests/test_declarations.py | 137 +------------------------------- mush/tests/test_requirements.py | 133 ++++++++++++++++++++++++++++++- 3 files changed, 140 insertions(+), 136 deletions(-) diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py index f5a547c..63183a7 100644 --- a/mush/tests/helpers.py +++ b/mush/tests/helpers.py @@ -12,3 +12,9 @@ def r(base, **attrs): PY_VERSION = sys.version_info[:2] PY_36 = PY_VERSION == (3, 6) + + +class Type1(object): pass +class Type2(object): pass +class Type3(object): pass +class Type4(object): pass diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 127d897..ebec96c 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -3,10 +3,9 @@ from unittest import TestCase import pytest -from mock import Mock from testfixtures import compare, ShouldRaise -from mush import Context, Value +from mush import Value from mush.declarations import ( requires, returns, returns_mapping, returns_sequence, returns_result_type, @@ -14,9 +13,8 @@ result_type, RequiresType ) from mush.extraction import extract_requires, extract_returns, update_wrapper -from mush.markers import missing -from mush.requirements import Requirement, AttrOp, ItemOp -from .helpers import r, PY_36 +from mush.requirements import Requirement, ItemOp +from .helpers import r, PY_36, Type1, Type2, Type3, Type4 def check_extract(obj, expected_rq, expected_rt): @@ -26,12 +24,6 @@ def check_extract(obj, expected_rq, expected_rt): compare(rt, expected=expected_rt, strict=True) -class Type1(object): pass -class Type2(object): pass -class Type3(object): pass -class Type4(object): pass - - class TestRequires(TestCase): def test_empty(self): @@ -82,129 +74,6 @@ def foo(): compare(foo(), 'bar') -def check_ops(value, data, *, expected): - for op in value.ops: - data = op(data) - compare(expected, actual=data) - - -class TestRequirement: - - def test_repr_minimal(self): - compare(repr(Requirement('foo')), - expected="Requirement('foo')") - - def test_repr_maximal(self): - r = Requirement('foo', name='n', type_='ty', default=None, target='ta') - r.ops.append(AttrOp('bar')) - compare(repr(r), - expected="Requirement('foo', default=None, " - "name='n', type_='ty', target='ta').bar") - - def test_clone(self): - r = Value('foo').bar.requirement - r_ = r.clone() - assert r_ is not r - assert r_.ops is not r.ops - compare(r_, expected=r) - - special_names = ['attr', 'ops', 'target'] - - @pytest.mark.parametrize("name", special_names) - def test_attr_special_name(self, name): - v = Requirement('foo') - assert getattr(v, name) is not self - assert v.attr(name) is v - compare(v.ops, expected=[AttrOp(name)]) - - @pytest.mark.parametrize("name", special_names) - def test_item_special_name(self, name): - v = Requirement('foo') - assert v[name] is v - compare(v.ops, expected=[ItemOp(name)]) - - def test_no_special_name_via_getattr(self): - v = Requirement('foo') - with ShouldRaise(AttributeError): - assert v.__len__ - compare(v.ops, []) - - def test_resolve(self): - r = Requirement() - with ShouldRaise(NotImplementedError): - r.resolve(None) - - -class TestValue: - - def test_type_from_key(self): - v = Value(str) - compare(v.requirement.type, expected=str) - - def test_key_and_type_cannot_disagree(self): - with ShouldRaise(TypeError('type_ cannot be specified if key is a type')): - Value(key=str, type_=int) - - -class TestItem: - - def test_single(self): - h = Value(Type1)['foo'] - compare(repr(h), expected="Value(Type1)['foo']") - check_ops(h, {'foo': 1}, expected=1) - - def test_multiple(self): - h = Value(Type1)['foo']['bar'] - compare(repr(h), expected="Value(Type1)['foo']['bar']") - check_ops(h, {'foo': {'bar': 1}}, expected=1) - - def test_missing_obj(self): - h = Value(Type1)['foo']['bar'] - with ShouldRaise(TypeError): - check_ops(h, object(), expected=None) - - def test_missing_key(self): - h = Value(Type1)['foo'] - check_ops(h, {}, expected=missing) - - def test_passed_missing(self): - c = Context() - c.add({}, provides='key') - compare(c.call(lambda x: x, requires(Value('key', default=1)['foo']['bar'])), - expected=1) - - def test_bad_type(self): - h = Value(Type1)['foo']['bar'] - with ShouldRaise(TypeError): - check_ops(h, [], expected=None) - - -class TestAttr(TestCase): - - def test_single(self): - h = Value(Type1).foo - compare(repr(h), "Value(Type1).foo") - m = Mock() - check_ops(h, m, expected=m.foo) - - def test_multiple(self): - h = Value(Type1).foo.bar - compare(repr(h), "Value(Type1).foo.bar") - m = Mock() - check_ops(h, m, expected=m.foo.bar) - - def test_missing(self): - h = Value(Type1).foo - compare(repr(h), "Value(Type1).foo") - check_ops(h, object(), expected=missing) - - def test_passed_missing(self): - c = Context() - c.add(object(), provides='key') - compare(c.call(lambda x: x, requires(Value('key', default=1).foo.bar)), - expected=1) - - class TestReturns(TestCase): def test_type(self): diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 4a4dcef..2f60268 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -1,6 +1,135 @@ -from testfixtures import compare +from unittest.case import TestCase -from mush import Context, Call +import pytest +from mock import Mock +from testfixtures import compare, ShouldRaise + +from mush import Context, Call, Value, missing, requires +from mush.requirements import Requirement, AttrOp, ItemOp +from .helpers import Type1 + + +def check_ops(value, data, *, expected): + for op in value.ops: + data = op(data) + compare(expected, actual=data) + + +class TestRequirement: + + def test_repr_minimal(self): + compare(repr(Requirement('foo')), + expected="Requirement('foo')") + + def test_repr_maximal(self): + r = Requirement('foo', name='n', type_='ty', default=None, target='ta') + r.ops.append(AttrOp('bar')) + compare(repr(r), + expected="Requirement('foo', default=None, " + "name='n', type_='ty', target='ta').bar") + + def test_clone(self): + r = Value('foo').bar.requirement + r_ = r.clone() + assert r_ is not r + assert r_.ops is not r.ops + compare(r_, expected=r) + + special_names = ['attr', 'ops', 'target'] + + @pytest.mark.parametrize("name", special_names) + def test_attr_special_name(self, name): + v = Requirement('foo') + assert getattr(v, name) is not self + assert v.attr(name) is v + compare(v.ops, expected=[AttrOp(name)]) + + @pytest.mark.parametrize("name", special_names) + def test_item_special_name(self, name): + v = Requirement('foo') + assert v[name] is v + compare(v.ops, expected=[ItemOp(name)]) + + def test_no_special_name_via_getattr(self): + v = Requirement('foo') + with ShouldRaise(AttributeError): + assert v.__len__ + compare(v.ops, []) + + def test_resolve(self): + r = Requirement() + with ShouldRaise(NotImplementedError): + r.resolve(None) + + +class TestValue: + + def test_type_from_key(self): + v = Value(str) + compare(v.requirement.type, expected=str) + + def test_key_and_type_cannot_disagree(self): + with ShouldRaise(TypeError('type_ cannot be specified if key is a type')): + Value(key=str, type_=int) + + +class TestItem: + + def test_single(self): + h = Value(Type1)['foo'] + compare(repr(h), expected="Value(Type1)['foo']") + check_ops(h, {'foo': 1}, expected=1) + + def test_multiple(self): + h = Value(Type1)['foo']['bar'] + compare(repr(h), expected="Value(Type1)['foo']['bar']") + check_ops(h, {'foo': {'bar': 1}}, expected=1) + + def test_missing_obj(self): + h = Value(Type1)['foo']['bar'] + with ShouldRaise(TypeError): + check_ops(h, object(), expected=None) + + def test_missing_key(self): + h = Value(Type1)['foo'] + check_ops(h, {}, expected=missing) + + def test_passed_missing(self): + c = Context() + c.add({}, provides='key') + compare(c.call(lambda x: x, requires(Value('key', default=1)['foo']['bar'])), + expected=1) + + def test_bad_type(self): + h = Value(Type1)['foo']['bar'] + with ShouldRaise(TypeError): + check_ops(h, [], expected=None) + + +class TestAttr(TestCase): + + def test_single(self): + h = Value(Type1).foo + compare(repr(h), "Value(Type1).foo") + m = Mock() + check_ops(h, m, expected=m.foo) + + def test_multiple(self): + h = Value(Type1).foo.bar + compare(repr(h), "Value(Type1).foo.bar") + m = Mock() + check_ops(h, m, expected=m.foo.bar) + + def test_missing(self): + h = Value(Type1).foo + compare(repr(h), "Value(Type1).foo") + check_ops(h, object(), expected=missing) + + def test_passed_missing(self): + c = Context() + c.add(object(), provides='key') + compare(c.call(lambda x: x, requires(Value('key', default=1).foo.bar)), + expected=1) class TestCall: From dd5d37c460e2d8463f15b064e86babed5e09dfe7 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 08:50:13 +0000 Subject: [PATCH 072/159] remove py3 suffix. --- ...ple_with_mush_clone_py3.py => test_example_with_mush_clone.py} | 0 ...with_mush_factory_py3.py => test_example_with_mush_factory.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename mush/tests/{test_example_with_mush_clone_py3.py => test_example_with_mush_clone.py} (100%) rename mush/tests/{test_example_with_mush_factory_py3.py => test_example_with_mush_factory.py} (100%) diff --git a/mush/tests/test_example_with_mush_clone_py3.py b/mush/tests/test_example_with_mush_clone.py similarity index 100% rename from mush/tests/test_example_with_mush_clone_py3.py rename to mush/tests/test_example_with_mush_clone.py diff --git a/mush/tests/test_example_with_mush_factory_py3.py b/mush/tests/test_example_with_mush_factory.py similarity index 100% rename from mush/tests/test_example_with_mush_factory_py3.py rename to mush/tests/test_example_with_mush_factory.py From b6a256e2f1a467628bf2ae83b06044e87b575ee2 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 11:54:16 +0000 Subject: [PATCH 073/159] Add AnyOf requirement. --- mush/__init__.py | 3 ++- mush/requirements.py | 17 ++++++++++++ mush/tests/test_async_context.py | 14 +++++++++- mush/tests/test_requirements.py | 44 ++++++++++++++++++++++++++++++-- 4 files changed, 74 insertions(+), 4 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index 26b3bd1..2dfe2ec 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -5,10 +5,11 @@ from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing from .plug import Plug -from .requirements import Value, Call +from .requirements import Value, Call, AnyOf from .runner import Runner, ContextError __all__ = [ + 'AnyOf', 'Call', 'Context', 'ContextError', diff --git a/mush/requirements.py b/mush/requirements.py index 810fc6b..6e85b86 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -175,3 +175,20 @@ def resolve(self, context): if self.cache: context.add(result, provides=self.key) return result + + +class AnyOf(Requirement): + """ + A requirement that is resolved by any of the specified keys. + """ + + def __init__(self, *keys, default=missing): + super().__init__(keys, default=default) + + @nonblocking + def resolve(self, context: 'Context'): + for key in self.key: + value = context.get(key, missing) + if value is not missing: + return value + return self.default diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 26cdcd1..0206a92 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -8,7 +8,7 @@ from mush import Context, Value, requires, returns from mush.asyncio import Context from mush.declarations import RequiresType -from mush.requirements import Requirement +from mush.requirements import Requirement, AnyOf from testfixtures import compare from mush.tests.test_context import TheType @@ -175,6 +175,18 @@ async def it(baz): compare(await context.call(it), expected='foobar') +@pytest.mark.asyncio +async def test_anyof_resolve_does_not_run_in_thread(no_threads): + with no_threads: + context = Context() + context.add(('foo', )) + + async def bob(x: str = AnyOf(tuple, Tuple[str])): + return x[0] + + compare(await context.call(bob), expected='foo') + + @pytest.mark.asyncio async def test_custom_requirement_async_resolve(): diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 2f60268..2493521 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -1,11 +1,12 @@ +from typing import Tuple from unittest.case import TestCase import pytest from mock import Mock from testfixtures import compare, ShouldRaise -from mush import Context, Call, Value, missing, requires -from mush.requirements import Requirement, AttrOp, ItemOp +from mush import Context, Call, Value, missing, requires, ResourceError +from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf from .helpers import Type1 @@ -182,3 +183,42 @@ def bob(x: str = Call(foo)['a']): return x+'c' compare(context.call(bob), expected='bc') + + +class TestAnyOf: + + def test_first(self): + context = Context() + context.add(('foo', )) + context.add(('bar', ), provides=Tuple[str]) + + def bob(x: str = AnyOf(tuple, Tuple[str])): + return x[0] + + compare(context.call(bob), expected='foo') + + def test_second(self): + context = Context() + context.add(('bar', ), provides=Tuple[str]) + + def bob(x: str = AnyOf(tuple, Tuple[str])): + return x[0] + + compare(context.call(bob), expected='bar') + + def test_none(self): + context = Context() + + def bob(x: str = AnyOf(tuple, Tuple[str])): + pass + + with ShouldRaise(ResourceError): + context.call(bob) + + def test_default(self): + context = Context() + + def bob(x: str = AnyOf(tuple, Tuple[str], default=(42,))): + return x[0] + + compare(context.call(bob), expected=42) From 0e973d8dce900bad403667ec76c150ce263b3657 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 18:18:12 +0000 Subject: [PATCH 074/159] Move this type to helpers.py --- mush/tests/helpers.py | 5 +++++ mush/tests/test_async_context.py | 2 +- mush/tests/test_context.py | 13 ++++--------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py index 63183a7..c717a22 100644 --- a/mush/tests/helpers.py +++ b/mush/tests/helpers.py @@ -18,3 +18,8 @@ class Type1(object): pass class Type2(object): pass class Type3(object): pass class Type4(object): pass + + +class TheType(object): + def __repr__(self): + return '' diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 0206a92..852f233 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -11,7 +11,7 @@ from mush.requirements import Requirement, AnyOf from testfixtures import compare -from mush.tests.test_context import TheType +from .helpers import TheType @pytest.fixture() diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 70bcbd5..d0ada10 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -10,12 +10,7 @@ from mush.context import ResourceError from mush.declarations import RequiresType from mush.requirements import Requirement -from .helpers import r - - -class TheType(object): - def __repr__(self): - return '' +from .helpers import r, TheType class TestContext(TestCase): @@ -28,7 +23,7 @@ def test_simple(self): compare(context._store, expected={TheType: obj}) expected = ( ": \n" + " : \n" "}>" ) self.assertEqual(repr(context), expected) @@ -142,8 +137,8 @@ def test_call_requires_accidental_tuple(self): def foo(obj): return obj context = Context() with ShouldRaise(TypeError( - "(, " - ") " + "(, " + ") " "is not a valid decoration type" )): context.call(foo, requires((TheType, TheType))) From c2348f8afb8aa1c3b78b586583439afe66a826d9 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 18:18:30 +0000 Subject: [PATCH 075/159] Add a Like requirement that searches up the requested type's mro. --- mush/__init__.py | 3 +- mush/requirements.py | 17 ++++++++++ mush/tests/test_async_context.py | 15 ++++++++- mush/tests/test_requirements.py | 54 +++++++++++++++++++++++++++++++- 4 files changed, 86 insertions(+), 3 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index 2dfe2ec..9daaa43 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -5,7 +5,7 @@ from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing from .plug import Plug -from .requirements import Value, Call, AnyOf +from .requirements import Value, Call, AnyOf, Like from .runner import Runner, ContextError __all__ = [ @@ -13,6 +13,7 @@ 'Call', 'Context', 'ContextError', + 'Like', 'Plug', 'ResourceError', 'Runner', diff --git a/mush/requirements.py b/mush/requirements.py index 6e85b86..e7fe107 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -192,3 +192,20 @@ def resolve(self, context: 'Context'): if value is not missing: return value return self.default + + +class Like(Requirement): + """ + A requirements that is resolved by the specified class or + any of its base classes. + """ + + @nonblocking + def resolve(self, context: 'Context'): + for key in self.key.__mro__: + if key is object: + break + value = context.get(key, missing) + if value is not missing: + return value + return self.default diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 852f233..4794c96 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -8,7 +8,7 @@ from mush import Context, Value, requires, returns from mush.asyncio import Context from mush.declarations import RequiresType -from mush.requirements import Requirement, AnyOf +from mush.requirements import Requirement, AnyOf, Like from testfixtures import compare from .helpers import TheType @@ -187,6 +187,19 @@ async def bob(x: str = AnyOf(tuple, Tuple[str])): compare(await context.call(bob), expected='foo') +@pytest.mark.asyncio +async def test_like_resolve_does_not_run_in_thread(no_threads): + with no_threads: + o = TheType() + context = Context() + context.add(o) + + async def bob(x: str = Like(TheType)): + return x + + assert await context.call(bob) is o + + @pytest.mark.asyncio async def test_custom_requirement_async_resolve(): diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 2493521..daf9fae 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -6,7 +6,7 @@ from testfixtures import compare, ShouldRaise from mush import Context, Call, Value, missing, requires, ResourceError -from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf +from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf, Like from .helpers import Type1 @@ -222,3 +222,55 @@ def bob(x: str = AnyOf(tuple, Tuple[str], default=(42,))): return x[0] compare(context.call(bob), expected=42) + + +class Parent(object): + pass + + +class Child(Parent): + pass + + +class TestLike: + + def test_actual(self): + context = Context() + p = Parent() + c = Child() + context.add(p) + context.add(c) + + def bob(x: str = Like(Child)): + return x + + assert context.call(bob) is c + + def test_base(self): + context = Context() + p = Parent() + context.add(p) + + def bob(x: str = Like(Child)): + return x + + assert context.call(bob) is p + + def test_none(self): + context = Context() + # make sure we don't pick up object! + context.add(object()) + + def bob(x: str = Like(Child)): + pass + + with ShouldRaise(ResourceError): + context.call(bob) + + def test_default(self): + context = Context() + + def bob(x: str = Like(Child, default=42)): + return x + + compare(context.call(bob), expected=42) From a8d24ffadb92f23d0eafcdbf1a818cebd47ef44c Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 18:21:33 +0000 Subject: [PATCH 076/159] expose Requirement as a top-level name --- mush/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mush/__init__.py b/mush/__init__.py index 9daaa43..99f8561 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -5,7 +5,7 @@ from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing from .plug import Plug -from .requirements import Value, Call, AnyOf, Like +from .requirements import Requirement, Value, Call, AnyOf, Like from .runner import Runner, ContextError __all__ = [ @@ -15,6 +15,7 @@ 'ContextError', 'Like', 'Plug', + 'Requirement', 'ResourceError', 'Runner', 'Value', From 151454be2105ad1053dc533a80f7796dc1d75bb3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 26 Mar 2020 07:34:56 +0000 Subject: [PATCH 077/159] make no_threads a helper not a fixture. --- mush/tests/helpers.py | 19 +++++++++++++++++++ mush/tests/test_async_context.py | 31 +++++++------------------------ 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py index c717a22..d8b8b39 100644 --- a/mush/tests/helpers.py +++ b/mush/tests/helpers.py @@ -1,4 +1,8 @@ +import asyncio import sys +from contextlib import contextmanager + +from mock import Mock def r(base, **attrs): @@ -23,3 +27,18 @@ class Type4(object): pass class TheType(object): def __repr__(self): return '' + + +@contextmanager +def no_threads(): + loop = asyncio.get_event_loop() + original = loop.run_in_executor + loop.run_in_executor = Mock(side_effect=Exception('threads used when they should not be')) + try: + yield + finally: + loop.run_in_executor = original + + +def must_run_in_thread(func): + pass diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 4794c96..a9255e8 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -1,9 +1,7 @@ import asyncio -from contextlib import contextmanager from typing import Tuple import pytest -from mock import Mock from mush import Context, Value, requires, returns from mush.asyncio import Context @@ -11,22 +9,7 @@ from mush.requirements import Requirement, AnyOf, Like from testfixtures import compare -from .helpers import TheType - - -@pytest.fixture() -def no_threads(): - # pytest-asyncio does things so we need to do this mock *in* the test: - @contextmanager - def raise_on_threads(): - loop = asyncio.get_event_loop() - original = loop.run_in_executor - loop.run_in_executor = Mock(side_effect=Exception('bad')) - try: - yield - finally: - loop.run_in_executor = original - return raise_on_threads() +from .helpers import TheType, no_threads @pytest.mark.asyncio @@ -164,8 +147,8 @@ def foo(*args): @pytest.mark.asyncio -async def test_value_resolve_does_not_run_in_thread(no_threads): - with no_threads: +async def test_value_resolve_does_not_run_in_thread(): + with no_threads(): context = Context() context.add('foo', provides='baz') @@ -176,8 +159,8 @@ async def it(baz): @pytest.mark.asyncio -async def test_anyof_resolve_does_not_run_in_thread(no_threads): - with no_threads: +async def test_anyof_resolve_does_not_run_in_thread(): + with no_threads(): context = Context() context.add(('foo', )) @@ -188,8 +171,8 @@ async def bob(x: str = AnyOf(tuple, Tuple[str])): @pytest.mark.asyncio -async def test_like_resolve_does_not_run_in_thread(no_threads): - with no_threads: +async def test_like_resolve_does_not_run_in_thread(): + with no_threads(): o = TheType() context = Context() context.add(o) From 9eea87f3bf67b8b8763ea0fd0dadddac58be2867 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 26 Mar 2020 07:37:03 +0000 Subject: [PATCH 078/159] Fix bug when context passed to runner but has no start point. --- mush/context.py | 4 +++- mush/runner.py | 5 ++++- mush/tests/test_runner.py | 9 ++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mush/context.py b/mush/context.py index bb42751..76057d4 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,5 +1,6 @@ from typing import Optional, Callable +from .callpoints import CallPoint from .declarations import RequiresType, ReturnsType from .extraction import extract_requires, extract_returns, default_requirement_type from .markers import missing @@ -25,7 +26,8 @@ def __init__(self, message: str, key: ResourceKey, requirement: Requirement = No class Context: "Stores resources for a particular run." - _parent = None + _parent: 'Context' = None + point: CallPoint = None def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): self._requirement_modifier = requirement_modifier diff --git a/mush/runner.py b/mush/runner.py index 3a62085..a0ab130 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -237,7 +237,7 @@ def __add__(self, other): runner._copy_from(r.start, r.end) return runner - def __call__(self, context=None): + def __call__(self, context: Context = None): """ Execute the callables in this runner in the required order storing objects that are returned and providing them as @@ -254,6 +254,7 @@ def __call__(self, context=None): """ if context is None: context = Context() + if context.point is None: context.point = self.start result = None @@ -272,6 +273,8 @@ def __call__(self, context=None): with result as manager: if manager not in (None, result): context.add(manager, manager.__class__) + # If the context manager swallows an exception, + # None should be returned, not the context manager: result = None result = self(context) diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 9b76a11..85fcbc5 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -4,7 +4,7 @@ from mush.declarations import ( requires, returns, returns_mapping, replacement, original) -from mush import Value, ContextError +from mush import Value, ContextError, Context from mush.runner import Runner from testfixtures import ( ShouldRaise, @@ -1312,3 +1312,10 @@ class T2: pass def test_repr_empty(self): compare('', repr(Runner())) + + def test_passed_in_context_with_no_point(self): + context = Context() + def foo(): + return 42 + runner = Runner(foo) + compare(runner(context), expected=42) From c8ebceafb4e325b0f8f13e8aba96b3dc9da49ffa Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 26 Mar 2020 07:38:04 +0000 Subject: [PATCH 079/159] move set_mush/get_mush to markers.py so they can be used there. --- mush/declarations.py | 15 +-------------- mush/extraction.py | 5 ++--- mush/markers.py | 11 +++++++++++ mush/plug.py | 2 +- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index a625959..7f29956 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -2,23 +2,10 @@ from itertools import chain from typing import _type_check -from .markers import missing +from .markers import set_mush from .requirements import Requirement, Value, name_or_repr -def set_mush(obj, key, value): - if not hasattr(obj, '__mush__'): - obj.__mush__ = {} - obj.__mush__[key] = value - - -def get_mush(obj, key, default): - __mush__ = getattr(obj, '__mush__', missing) - if __mush__ is missing: - return default - return __mush__.get(key, default) - - class RequiresType(list): def __repr__(self): diff --git a/mush/extraction.py b/mush/extraction.py index e8f7e81..38cd4f2 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -10,11 +10,10 @@ from .declarations import ( requires, RequiresType, ReturnsType, returns, result_type, - nothing, - get_mush + nothing ) from .requirements import Requirement, Value -from .markers import missing +from .markers import missing, get_mush from .types import RequirementModifier EMPTY = Parameter.empty diff --git a/mush/markers.py b/mush/markers.py index 16dba48..214bc5d 100644 --- a/mush/markers.py +++ b/mush/markers.py @@ -13,6 +13,17 @@ def __repr__(self): missing = Marker('missing') +def set_mush(obj, key, value): + if not hasattr(obj, '__mush__'): + obj.__mush__ = {} + obj.__mush__[key] = value + + +def get_mush(obj, key, default): + __mush__ = getattr(obj, '__mush__', missing) + if __mush__ is missing: + return default + return __mush__.get(key, default) def nonblocking(obj): """ A decorator to mark a method as not requiring running diff --git a/mush/plug.py b/mush/plug.py index 1680d4a..0f06ebb 100644 --- a/mush/plug.py +++ b/mush/plug.py @@ -1,4 +1,4 @@ -from .declarations import set_mush, get_mush +from .markers import set_mush, get_mush class ignore(object): From e67368f5cb72a29284750aa9002f2d25670a590c Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 26 Mar 2020 08:12:07 +0000 Subject: [PATCH 080/159] Add explicit support for marking callables as blocking or non-blocking. Class instantiation now also defaults to be treated as non-blocking. --- mush/__init__.py | 4 +- mush/asyncio.py | 41 +++++++++++----- mush/markers.py | 29 +++++++++-- mush/tests/helpers.py | 24 ++++++++- mush/tests/test_async_context.py | 83 +++++++++++++++++++++++++++++--- 5 files changed, 153 insertions(+), 28 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index 99f8561..2768ca1 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -3,7 +3,7 @@ requires, returns_result_type, returns_mapping, returns_sequence, returns, nothing ) from .extraction import extract_requires, extract_returns, update_wrapper -from .markers import missing +from .markers import missing, nonblocking, blocking from .plug import Plug from .requirements import Requirement, Value, Call, AnyOf, Like from .runner import Runner, ContextError @@ -19,7 +19,9 @@ 'ResourceError', 'Runner', 'Value', + 'blocking', 'missing', + 'nonblocking', 'nothing', 'requires', 'returns', diff --git a/mush/asyncio.py b/mush/asyncio.py index 3d08ac8..dc6f83c 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -5,20 +5,10 @@ from . import Context as SyncContext, Call as SyncCall, missing from .declarations import RequiresType, ReturnsType from .extraction import default_requirement_type +from .markers import get_mush, AsyncType from .types import RequirementModifier -async def ensure_async(func, *args, **kw): - if getattr(func, '__nonblocking__', False): - return func(*args, **kw) - elif asyncio.iscoroutinefunction(func): - return await func(*args, **kw) - if kw: - func = partial(func, **kw) - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, func, *args) - - class AsyncFromSyncContext: def __init__(self, context, loop): @@ -44,6 +34,31 @@ class Context(SyncContext): def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): super().__init__(requirement_modifier) self._sync_context = AsyncFromSyncContext(self, asyncio.get_event_loop()) + self._async_cache = {} + + async def _ensure_async(self, func, *args, **kw): + async_type = self._async_cache.get(func) + if async_type is None: + if asyncio.iscoroutinefunction(func): + async_type = AsyncType.async_ + else: + async_type = get_mush(func, 'async', default=None) + if async_type is None: + if isinstance(func, type): + async_type = AsyncType.nonblocking + else: + async_type = AsyncType.blocking + self._async_cache[func] = async_type + + if async_type is AsyncType.nonblocking: + return func(*args, **kw) + elif async_type is AsyncType.blocking: + if kw: + func = partial(func, **kw) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, func, *args) + else: + return await func(*args, **kw) def _context_for(self, obj): return self if asyncio.iscoroutinefunction(obj) else self._sync_context @@ -54,9 +69,9 @@ async def call(self, obj: Callable, requires: RequiresType = None): resolving = self._resolve(obj, requires, args, kw, self._context_for(obj)) for requirement in resolving: r = requirement.resolve - o = await ensure_async(r, self._context_for(r)) + o = await self._ensure_async(r, self._context_for(r)) resolving.send(o) - return await ensure_async(obj, *args, **kw) + return await self._ensure_async(obj, *args, **kw) async def extract(self, obj: Callable, diff --git a/mush/markers.py b/mush/markers.py index 214bc5d..7738d48 100644 --- a/mush/markers.py +++ b/mush/markers.py @@ -1,3 +1,7 @@ +import asyncio +from enum import Enum, auto + + class Marker(object): def __init__(self, name): @@ -24,12 +28,29 @@ def get_mush(obj, key, default): if __mush__ is missing: return default return __mush__.get(key, default) + + +class AsyncType(Enum): + blocking = auto() + nonblocking = auto() + async_ = auto() + + def nonblocking(obj): """ - A decorator to mark a method as not requiring running + A decorator to mark a callable as not requiring running in a thread, even though it's not async. """ - # Not using set_mush / get_mush to try and keep this as - # quick as possible - obj.__nonblocking__ = True + set_mush(obj, 'async', AsyncType.nonblocking) + return obj + + +def blocking(obj): + """ + A decorator to explicitly mark a callable as requiring running + in a thread. + """ + if asyncio.iscoroutinefunction(obj): + raise TypeError('cannot mark an async function as blocking') + set_mush(obj, 'async', AsyncType.blocking) return obj diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py index d8b8b39..614b0bf 100644 --- a/mush/tests/helpers.py +++ b/mush/tests/helpers.py @@ -1,6 +1,7 @@ import asyncio import sys from contextlib import contextmanager +from functools import partial from mock import Mock @@ -39,6 +40,25 @@ def no_threads(): finally: loop.run_in_executor = original - +@contextmanager def must_run_in_thread(func): - pass + seen = set() + loop = asyncio.get_event_loop() + original = loop.run_in_executor + + def recording_run_in_executor(executor, func, *args): + if isinstance(func, partial): + to_record = func.func + else: + # get the underlying method for bound methods: + to_record = getattr(func, '__func__', func) + seen.add(to_record) + return original(executor, func, *args) + + loop.run_in_executor = recording_run_in_executor + try: + yield + finally: + loop.run_in_executor = original + + assert func in seen, f'{func} was not run in a thread' diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index a9255e8..8a712b1 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -2,14 +2,14 @@ from typing import Tuple import pytest +from testfixtures import compare, ShouldRaise -from mush import Context, Value, requires, returns +from mush import Value, requires, returns, Context as SyncContext, blocking, nonblocking from mush.asyncio import Context from mush.declarations import RequiresType from mush.requirements import Requirement, AnyOf, Like -from testfixtures import compare - -from .helpers import TheType, no_threads +from .helpers import TheType, no_threads, must_run_in_thread +from ..markers import AsyncType @pytest.mark.asyncio @@ -19,7 +19,8 @@ def it(): return 'bar' result = context.call(it) assert asyncio.iscoroutine(result) - compare(await result, expected='bar') + with must_run_in_thread(it): + compare(await result, expected='bar') @pytest.mark.asyncio @@ -28,14 +29,15 @@ async def test_call_async(): context.add('1', provides='a') async def it(a, b='2'): return a+b - compare(await context.call(it), expected='12') + with no_threads(): + compare(await context.call(it), expected='12') @pytest.mark.asyncio async def test_call_async_requires_context(): context = Context() context.add('bar', provides='baz') - async def it(context: Context): + async def it(context: SyncContext): return context.get('baz') compare(await context.call(it), expected='bar') @@ -55,7 +57,8 @@ async def test_call_sync(): context.add('foo', provides='baz') def it(*, baz): return baz+'bar' - compare(await context.call(it), expected='foobar') + with must_run_in_thread(it): + compare(await context.call(it), expected='foobar') @pytest.mark.asyncio @@ -76,6 +79,61 @@ def it(context: Context): compare(await context.call(it), expected='bar') +@pytest.mark.asyncio +async def test_call_class_defaults_to_non_blocking(): + context = Context() + with no_threads(): + obj = await context.call(TheType) + assert isinstance(obj, TheType) + + +@pytest.mark.asyncio +async def test_call_class_explicitly_marked_as_blocking(): + @blocking + class BlockingType: pass + context = Context() + with must_run_in_thread(BlockingType): + obj = await context.call(BlockingType) + assert isinstance(obj, BlockingType) + + +@pytest.mark.asyncio +async def test_call_function_defaults_to_blocking(): + def foo(): + return 42 + context = Context() + with must_run_in_thread(foo): + compare(await context.call(foo), expected=42) + + +@pytest.mark.asyncio +async def test_call_function_explicitly_marked_as_non_blocking(): + @nonblocking + def foo(): + return 42 + context = Context() + with no_threads(): + compare(await context.call(foo), expected=42) + + +@pytest.mark.asyncio +async def test_call_async_function_explicitly_marked_as_non_blocking(): + # sure, I mean, whatever... + @nonblocking + async def foo(): + return 42 + context = Context() + with no_threads(): + compare(await context.call(foo), expected=42) + + +@pytest.mark.asyncio +async def test_call_async_function_explicitly_marked_as_blocking(): + with ShouldRaise(TypeError('cannot mark an async function as blocking')): + @blocking + async def foo(): pass + + @pytest.mark.asyncio async def test_call_cache_requires(): context = Context() @@ -84,6 +142,15 @@ def foo(): pass compare(context._requires_cache[foo], expected=RequiresType()) +@pytest.mark.asyncio +async def test_call_caches_asyncness(): + async def foo(): + return 42 + context = Context() + await context.call(foo) + compare(context._async_cache[foo], expected=AsyncType.async_) + + @pytest.mark.asyncio async def test_extract_is_async(): context = Context() From 886fa9e3b4900d2e2f1009ad3a399acd31d9bd85 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 2 Apr 2020 13:46:25 +0100 Subject: [PATCH 081/159] Split `nothing` into `requires_nothing` and `returns_nothing`. It's also no longer part of the public API. This makes it clearer what's being used for what, and fixes some unfortunate repr issues. --- docs/use.txt | 7 ++- mush/__init__.py | 3 +- mush/callpoints.py | 7 ++- mush/declarations.py | 105 +++++++++++++++----------------- mush/extraction.py | 6 +- mush/tests/test_context.py | 16 ++--- mush/tests/test_declarations.py | 16 ++--- 7 files changed, 76 insertions(+), 84 deletions(-) diff --git a/docs/use.txt b/docs/use.txt index 06dea5f..82de49b 100755 --- a/docs/use.txt +++ b/docs/use.txt @@ -375,17 +375,18 @@ I sold vegetables as fruit I made juice out of a tomato and a cucumber Finally, if you have a callable that returns results that you wish to ignore, -you can do so using :attr:`~mush.declarations.nothing`: +you can do so as follows: .. code-block:: python - from mush import Runner, nothing + from mush import Runner + @returns() def spam(): return 'spam' runner = Runner() - runner.add(spam, returns=nothing) + runner.add(spam) .. _named-resources: diff --git a/mush/__init__.py b/mush/__init__.py index 2768ca1..c475076 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -1,6 +1,6 @@ from .context import Context, ResourceError from .declarations import ( - requires, returns_result_type, returns_mapping, returns_sequence, returns, nothing + requires, returns, returns_result_type, returns_mapping, returns_sequence, ) from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing, nonblocking, blocking @@ -22,7 +22,6 @@ 'blocking', 'missing', 'nonblocking', - 'nothing', 'requires', 'returns', 'returns_mapping', diff --git a/mush/callpoints.py b/mush/callpoints.py index e0f9d05..a501d5d 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,7 +1,7 @@ from .declarations import ( - nothing, returns as returns_declaration + requires_nothing, returns as returns_declaration, -) + returns_nothing) from .extraction import extract_requires, extract_returns @@ -22,7 +22,8 @@ def __init__(self, runner, obj, requires=None, returns=None, lazy=False): raise TypeError('a single return type must be explicitly specified') runner.lazy[returns.args[0]] = obj, requires obj = do_nothing - requires = returns = nothing + requires = requires_nothing + returns = returns_nothing self.obj = obj self.requires = requires self.returns = returns diff --git a/mush/declarations.py b/mush/declarations.py index 7f29956..9d2f492 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -6,6 +6,23 @@ from .requirements import Requirement, Value, name_or_repr +VALID_DECORATION_TYPES = (type, str, Requirement) + + +def valid_decoration_types(*objs): + for obj in objs: + if isinstance(obj, VALID_DECORATION_TYPES): + continue + try: + _type_check(obj, '') + continue + except TypeError: + pass + raise TypeError( + repr(obj)+" is not a valid decoration type" + ) + + class RequiresType(list): def __repr__(self): @@ -49,6 +66,9 @@ def requires(*args, **kw): return requires_ +requires_nothing = RequiresType() + + class ReturnsType(object): def __call__(self, obj): @@ -59,6 +79,33 @@ def __repr__(self): return self.__class__.__name__ + '()' +class returns(ReturnsType): + """ + Declaration that specifies names for returned resources or overrides + the type of a returned resource. + + This declaration can be used to indicate the type or name of a single + returned resource or, if multiple arguments are passed, that the callable + will return a sequence of values where each one should be named or have its + type overridden. + """ + + def __init__(self, *args): + valid_decoration_types(*args) + self.args = args + + def process(self, obj): + if len(self.args) == 1: + yield self.args[0], obj + elif self.args: + for t, o in zip(self.args, obj): + yield t, o + + def __repr__(self): + args_repr = ', '.join(name_or_repr(arg) for arg in self.args) + return self.__class__.__name__ + '(' + args_repr + ')' + + class returns_result_type(ReturnsType): """ Default declaration that indicates a callable's return value @@ -97,50 +144,11 @@ def process(self, sequence): yield pair -class returns(returns_result_type): - """ - Declaration that specifies names for returned resources or overrides - the type of a returned resource. - - This declaration can be used to indicate the type or name of a single - returned resource or, if multiple arguments are passed, that the callable - will return a sequence of values where each one should be named or have its - type overridden. - """ - - def __init__(self, *args): - valid_decoration_types(*args) - self.args = args - - def process(self, obj): - if len(self.args) == 1: - yield self.args[0], obj - else: - for t, o in zip(self.args, obj): - yield t, o - - def __repr__(self): - args_repr = ', '.join(name_or_repr(arg) for arg in self.args) - return self.__class__.__name__ + '(' + args_repr + ')' - +returns_nothing = returns() -#: A singleton indicating that a callable's return value should be -#: stored based on the type of that return value. result_type = returns_result_type() -class Nothing(RequiresType, returns): - - def process(self, result): - return () - - -#: A singleton that be used as a :class:`~mush.requires` to indicate that a -#: callable has no required arguments or as a :class:`~mush.returns` to indicate -#: that anything returned from a callable should be ignored. -nothing = Nothing() - - class DeclarationsFrom(Enum): original = auto() replacement = auto() @@ -150,20 +158,3 @@ class DeclarationsFrom(Enum): original = DeclarationsFrom.original #: Use declarations from the replacement callable. replacement = DeclarationsFrom.replacement - - -VALID_DECORATION_TYPES = (type, str, Requirement) - - -def valid_decoration_types(*objs): - for obj in objs: - if isinstance(obj, VALID_DECORATION_TYPES): - continue - try: - _type_check(obj, '') - continue - except TypeError: - pass - raise TypeError( - repr(obj)+" is not a valid decoration type" - ) diff --git a/mush/extraction.py b/mush/extraction.py index 38cd4f2..785cf6b 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -5,12 +5,12 @@ partial ) from inspect import signature, Parameter -from typing import Callable, Type +from typing import Callable from .declarations import ( requires, RequiresType, ReturnsType, returns, result_type, - nothing + requires_nothing ) from .requirements import Requirement, Value from .markers import missing, get_mush @@ -124,7 +124,7 @@ def extract_requires(obj: Callable, _apply_requires(by_name, by_index, requires_) if not by_name: - return nothing + return requires_nothing needs_target = False for name, requirement in by_name.items(): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index d0ada10..7dfa6f2 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -5,10 +5,10 @@ from testfixtures import ShouldRaise, compare from mush import ( - Context, requires, returns, nothing, returns_mapping, Value, missing + Context, requires, returns, returns_mapping, Value, missing ) from mush.context import ResourceError -from mush.declarations import RequiresType +from mush.declarations import RequiresType, requires_nothing, returns_nothing from mush.requirements import Requirement from .helpers import r, TheType @@ -91,7 +91,7 @@ def test_call_basic(self): def foo(): return 'bar' context = Context() - result = context.call(foo, nothing) + result = context.call(foo, requires_nothing) compare(result, 'bar') def test_call_requires_string(self): @@ -289,7 +289,7 @@ def test_returns_single(self): def foo(): return 'bar' context = Context() - result = context.extract(foo, nothing, returns(TheType)) + result = context.extract(foo, requires_nothing, returns(TheType)) compare(result, 'bar') compare({TheType: 'bar'}, actual=context._store) @@ -297,7 +297,7 @@ def test_returns_sequence(self): def foo(): return 1, 2 context = Context() - result = context.extract(foo, nothing, returns('foo', 'bar')) + result = context.extract(foo, requires_nothing, returns('foo', 'bar')) compare(result, (1, 2)) compare({'foo': 1, 'bar': 2}, actual=context._store) @@ -306,7 +306,7 @@ def test_returns_mapping(self): def foo(): return {'foo': 1, 'bar': 2} context = Context() - result = context.extract(foo, nothing, returns_mapping()) + result = context.extract(foo, requires_nothing, returns_mapping()) compare(result, {'foo': 1, 'bar': 2}) compare({'foo': 1, 'bar': 2}, actual=context._store) @@ -315,14 +315,14 @@ def test_ignore_return(self): def foo(): return 'bar' context = Context() - result = context.extract(foo, nothing, nothing) + result = context.extract(foo, requires_nothing, returns_nothing) compare(result, 'bar') compare({}, context._store) def test_ignore_non_iterable_return(self): def foo(): pass context = Context() - result = context.extract(foo, nothing, nothing) + result = context.extract(foo) compare(result, expected=None) compare(context._store, expected={}) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index ebec96c..1ee4dfb 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -9,7 +9,7 @@ from mush.declarations import ( requires, returns, returns_mapping, returns_sequence, returns_result_type, - nothing, + requires_nothing, result_type, RequiresType ) from mush.extraction import extract_requires, extract_returns, update_wrapper @@ -213,7 +213,7 @@ def foo(a=None): pass check_extract( p, # since a is already bound by the partial: - expected_rq=nothing, + expected_rq=requires_nothing, expected_rt=result_type ) @@ -222,7 +222,7 @@ def foo(a=None): pass p = partial(foo, a=1) check_extract( p, - expected_rq=nothing, + expected_rq=requires_nothing, expected_rt=result_type ) @@ -232,7 +232,7 @@ def foo(a): pass check_extract( p, # since a is already bound by the partial: - expected_rq=nothing, + expected_rq=requires_nothing, expected_rt=result_type ) @@ -241,7 +241,7 @@ def foo(a): pass p = partial(foo, a=1) check_extract( p, - expected_rq=nothing, + expected_rq=requires_nothing, expected_rt=result_type ) @@ -303,7 +303,7 @@ def foo(a: 'foo'): pass def test_returns_only(self): def foo() -> 'bar': pass check_extract(foo, - expected_rq=nothing, + expected_rq=requires_nothing, expected_rt=returns('bar')) def test_extract_from_decorated_class(self): @@ -338,14 +338,14 @@ def test_returns_mapping(self): rt = returns_mapping() def foo() -> rt: pass check_extract(foo, - expected_rq=nothing, + expected_rq=requires_nothing, expected_rt=rt) def test_returns_sequence(self): rt = returns_sequence() def foo() -> rt: pass check_extract(foo, - expected_rq=nothing, + expected_rq=requires_nothing, expected_rt=rt) def test_how_instance_in_annotations(self): From 45380db07d685c05bb837719edb5003e94866213 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 29 Mar 2020 09:15:23 +0100 Subject: [PATCH 082/159] Async Runner implementation. --- mush/asyncio.py | 56 ++- mush/runner.py | 4 +- mush/tests/helpers.py | 6 +- mush/tests/test_async_runner.py | 586 ++++++++++++++++++++++++++++++++ 4 files changed, 646 insertions(+), 6 deletions(-) create mode 100644 mush/tests/test_async_runner.py diff --git a/mush/asyncio.py b/mush/asyncio.py index dc6f83c..7a0be6c 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -2,7 +2,10 @@ from functools import partial from typing import Callable -from . import Context as SyncContext, Call as SyncCall, missing +from . import ( + Context as SyncContext, Runner as SyncRunner, Call as SyncCall, + missing, ResourceError, ContextError +) from .declarations import RequiresType, ReturnsType from .extraction import default_requirement_type from .markers import get_mush, AsyncType @@ -82,6 +85,55 @@ async def extract(self, return result +class SyncContextManagerWrapper: + + def __init__(self, sync_manager): + self.sync_manager = sync_manager + self.loop = asyncio.get_event_loop() + + async def __aenter__(self): + return await self.loop.run_in_executor(None, self.sync_manager.__enter__) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self.loop.run_in_executor(None, self.sync_manager.__exit__, + exc_type, exc_val, exc_tb) + + +class Runner(SyncRunner): + + async def __call__(self, context: Context = None): + if context is None: + context = Context() + if context.point is None: + context.point = self.start + + result = None + + while context.point: + + point = context.point + context.point = point.next + + try: + result = manager = await point(context) + except ResourceError as e: + raise ContextError(str(e), point, context) + + if getattr(result, '__enter__', None): + manager = SyncContextManagerWrapper(result) + + if getattr(manager, '__aenter__', None): + async with manager as managed: + if managed not in (None, result): + context.add(managed) + # If the context manager swallows an exception, + # None should be returned, not the context manager: + result = None + result = await self(context) + + return result + + class Call(SyncCall): async def resolve(self, context): @@ -93,4 +145,4 @@ async def resolve(self, context): return result -__all__ = ['Context', 'Call'] +__all__ = ['Context', 'Runner', 'Call'] diff --git a/mush/runner.py b/mush/runner.py index a0ab130..177059a 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -135,7 +135,7 @@ def clone(self, label specified in this option should be cloned. This filtering is applied in addition to the above options. """ - runner = Runner() + runner = self.__class__() if start_label: start = self.labels[start_label] @@ -232,7 +232,7 @@ def __add__(self, other): Return a new :class:`Runner` containing the contents of the two :class:`Runner` instances being added together. """ - runner = Runner() + runner = self.__class__() for r in self, other: runner._copy_from(r.start, r.end) return runner diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py index 614b0bf..5868260 100644 --- a/mush/tests/helpers.py +++ b/mush/tests/helpers.py @@ -40,8 +40,9 @@ def no_threads(): finally: loop.run_in_executor = original + @contextmanager -def must_run_in_thread(func): +def must_run_in_thread(*expected): seen = set() loop = asyncio.get_event_loop() original = loop.run_in_executor @@ -61,4 +62,5 @@ def recording_run_in_executor(executor, func, *args): finally: loop.run_in_executor = original - assert func in seen, f'{func} was not run in a thread' + not_seen = set(expected) - seen + assert not not_seen, f'{not_seen} not run in a thread, seen: {seen}' diff --git a/mush/tests/test_async_runner.py b/mush/tests/test_async_runner.py new file mode 100644 index 0000000..fd773d0 --- /dev/null +++ b/mush/tests/test_async_runner.py @@ -0,0 +1,586 @@ +import asyncio +from unittest.mock import Mock, call + +import pytest +from testfixtures import compare, ShouldRaise, Comparison as C + +from mush import ContextError, requires, returns +from mush.asyncio import Runner, Context +from .helpers import no_threads, must_run_in_thread + + +@pytest.mark.asyncio +async def test_call_is_async(): + def it(): + return 'bar' + runner = Runner(it) + result = runner() + assert asyncio.iscoroutine(result) + with must_run_in_thread(it): + compare(await result, expected='bar') + + +@pytest.mark.asyncio +async def test_resource_missing(): + def it(foo): + pass + runner = Runner(it) + context = Context() + with ShouldRaise(ContextError(C(str), runner.start, context)): + await runner(context) + + +@pytest.mark.asyncio +async def test_cloned_still_async(): + def it(): + return 'bar' + runner = Runner(it) + runner_ = runner.clone() + result = runner_() + assert asyncio.iscoroutine(result) + compare(await result, expected='bar') + + +@pytest.mark.asyncio +async def test_addition_still_async(): + async def foo(): + return 'foo' + @requires(str) + @returns() + async def bar(foo): + return foo+'bar' + r1 = Runner(foo) + r2 = Runner(bar) + runner = r1 + r2 + result = runner() + assert asyncio.iscoroutine(result) + compare(await result, expected='foobar') + + +class CommonCM: + m = None + context = None + swallow_exceptions = None + + +class AsyncCM(CommonCM): + + async def __aenter__(self): + self.m.enter() + if self.context is 'self': + return self + return self.context + + async def __aexit__(self, type, obj, tb): + self.m.exit(obj) + return self.swallow_exceptions + + +class SyncCM(CommonCM): + + def __enter__(self): + self.m.enter() + if self.context is 'self': + return self + return self.context + + def __exit__(self, type, obj, tb): + self.m.exit(obj) + return self.swallow_exceptions + + +def make_cm(name, type_, m, context=None, swallow_exceptions=None): + return type(name, + (type_,), + {'m': getattr(m, name.lower()), + 'context': context, + 'swallow_exceptions': swallow_exceptions}) + + +@pytest.mark.asyncio +async def test_async_context_manager(): + m = Mock() + CM = make_cm('CM', AsyncCM, m) + + async def func(): + m.func() + + runner = Runner(CM, func) + + with no_threads(): + await runner() + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.func(), + call.cm.exit(None) + ]) + + +@pytest.mark.asyncio +async def test_async_context_manager_inner_requires_cm(): + m = Mock() + CM = make_cm('CM', AsyncCM, m, context='self') + + @requires(CM) + async def func(obj): + m.func(type(obj)) + + runner = Runner(CM, func) + + with no_threads(): + await runner() + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.func(CM), + call.cm.exit(None) + ]) + + +@pytest.mark.asyncio +async def test_async_context_manager_inner_requires_context(): + m = Mock() + class CMContext: pass + cm_context = CMContext() + CM = make_cm('CM', AsyncCM, m, context=cm_context) + + @requires(CMContext) + async def func(obj): + m.func(obj) + + runner = Runner(CM, func) + + with no_threads(): + await runner() + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.func(cm_context), + call.cm.exit(None) + ]) + + +@pytest.mark.asyncio +async def test_async_context_manager_nested(): + m = Mock() + CM1 = make_cm('CM1', AsyncCM, m) + CM2 = make_cm('CM2', AsyncCM, m) + + async def func(): + m.func() + + runner = Runner(CM1, CM2, func) + + with no_threads(): + await runner() + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.func(), + call.cm2.exit(None), + call.cm1.exit(None), + ]) + + +@pytest.mark.asyncio +async def test_async_context_manager_nested_exception_inner_handles(): + m = Mock() + CM1 = make_cm('CM1', AsyncCM, m) + CM2 = make_cm('CM2', AsyncCM, m, swallow_exceptions=True) + + e = Exception() + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + with no_threads(): + await runner() + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(None), + ]) + + +@pytest.mark.asyncio +async def test_async_context_manager_nested_exception_outer_handles(): + m = Mock() + CM1 = make_cm('CM1', AsyncCM, m, swallow_exceptions=True) + CM2 = make_cm('CM2', AsyncCM, m) + + e = Exception() + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + with no_threads(): + await runner() + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(e), + ]) + + +@pytest.mark.asyncio +async def test_async_context_manager_exception_not_handled(): + m = Mock() + CM = make_cm('CM', AsyncCM, m) + + e = Exception('foo') + + async def func(): + raise e + + runner = Runner(CM, func) + + with no_threads(), ShouldRaise(e): + await runner() + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.cm.exit(e) + ]) + + +@pytest.mark.asyncio +async def test_sync_context_manager(): + m = Mock() + CM = make_cm('CM', SyncCM, m) + + async def func(): + m.func() + + runner = Runner(CM, func) + + with must_run_in_thread(CM.__enter__, CM.__exit__): + await runner() + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.func(), + call.cm.exit(None) + ]) + + +@pytest.mark.asyncio +async def test_sync_context_manager_inner_requires_cm(): + m = Mock() + CM = make_cm('CM', SyncCM, m, context='self') + + @requires(CM) + async def func(obj): + m.func(type(obj)) + + runner = Runner(CM, func) + + with must_run_in_thread(CM.__enter__, CM.__exit__): + await runner() + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.func(CM), + call.cm.exit(None) + ]) + + +@pytest.mark.asyncio +async def test_sync_context_manager_inner_requires_context(): + m = Mock() + class CMContext: pass + cm_context = CMContext() + CM = make_cm('CM', SyncCM, m, context=cm_context) + + @requires(CMContext) + async def func(obj): + m.func(obj) + + runner = Runner(CM, func) + + with must_run_in_thread(CM.__enter__, CM.__exit__): + await runner() + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.func(cm_context), + call.cm.exit(None) + ]) + + +@pytest.mark.asyncio +async def test_sync_context_manager_nested(): + m = Mock() + CM1 = make_cm('CM1', SyncCM, m) + CM2 = make_cm('CM2', SyncCM, m) + + async def func(): + m.func() + + runner = Runner(CM1, CM2, func) + + with must_run_in_thread(CM1.__enter__, CM1.__exit__, CM2.__enter__, CM2.__exit__): + await runner() + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.func(), + call.cm2.exit(None), + call.cm1.exit(None), + ]) + + +@pytest.mark.asyncio +async def test_sync_context_manager_nested_exception_inner_handles(): + m = Mock() + CM1 = make_cm('CM1', SyncCM, m) + CM2 = make_cm('CM2', SyncCM, m, swallow_exceptions=True) + + e = Exception() + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + with must_run_in_thread(CM1.__enter__, CM1.__exit__, CM2.__enter__, CM2.__exit__): + await runner() + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(None), + ]) + + +@pytest.mark.asyncio +async def test_sync_context_manager_nested_exception_outer_handles(): + m = Mock() + CM1 = make_cm('CM1', SyncCM, m, swallow_exceptions=True) + CM2 = make_cm('CM2', SyncCM, m) + + e = Exception() + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + with must_run_in_thread(CM1.__enter__, CM1.__exit__, CM2.__enter__, CM2.__exit__): + await runner() + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(e), + ]) + + +@pytest.mark.asyncio +async def test_sync_context_manager_exception_not_handled(): + m = Mock() + CM = make_cm('CM', SyncCM, m) + + e = Exception('foo') + + async def func(): + raise e + + runner = Runner(CM, func) + + with must_run_in_thread(CM.__enter__, CM.__exit__), ShouldRaise(e): + await runner() + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.cm.exit(e) + ]) + +@pytest.mark.asyncio +async def test_sync_context_then_async_context(): + m = Mock() + CM1 = make_cm('CM1', SyncCM, m) + CM2 = make_cm('CM2', AsyncCM, m) + + async def func(): + return 42 + + runner = Runner(CM1, CM2, func) + + compare(await runner(), expected=42) + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(None), + call.cm1.exit(None), + ]) + + +@pytest.mark.asyncio +async def test_async_context_then_sync_context(): + m = Mock() + CM1 = make_cm('CM1', AsyncCM, m) + CM2 = make_cm('CM2', SyncCM, m) + + async def func(): + return 42 + + runner = Runner(CM1, CM2, func) + + compare(await runner(), expected=42) + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(None), + call.cm1.exit(None), + ]) + + +@pytest.mark.asyncio +async def test_sync_context_then_async_context_exception_handled_inner(): + m = Mock() + CM1 = make_cm('CM1', SyncCM, m) + CM2 = make_cm('CM2', AsyncCM, m, swallow_exceptions=True) + + e = Exception() + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + # if something goes wrong *and handled by a CM*, you get None + compare(await runner(), expected=None) + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(None), + ]) + + +@pytest.mark.asyncio +async def test_sync_context_then_async_context_exception_handled_outer(): + m = Mock() + CM1 = make_cm('CM1', SyncCM, m, swallow_exceptions=True) + CM2 = make_cm('CM2', AsyncCM, m) + + e = Exception() + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + # if something goes wrong *and handled by a CM*, you get None + compare(await runner(), expected=None) + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(e), + ]) + + +@pytest.mark.asyncio +async def test_sync_context_then_async_context_exception_not_handled(): + m = Mock() + CM1 = make_cm('CM1', SyncCM, m) + CM2 = make_cm('CM2', AsyncCM, m) + + e = Exception('foo') + + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + with ShouldRaise(e): + await runner() + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(e), + ]) + + +@pytest.mark.asyncio +async def test_async_context_then_sync_context_exception_handled_inner(): + m = Mock() + CM1 = make_cm('CM1', AsyncCM, m) + CM2 = make_cm('CM2', SyncCM, m, swallow_exceptions=True) + + e = Exception() + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + # if something goes wrong *and handled by a CM*, you get None + compare(await runner(), expected=None) + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(None), + ]) + + +@pytest.mark.asyncio +async def test_async_context_then_sync_context_exception_handled_outer(): + m = Mock() + CM1 = make_cm('CM1', AsyncCM, m, swallow_exceptions=True) + CM2 = make_cm('CM2', SyncCM, m) + + e = Exception() + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + # if something goes wrong *and handled by a CM*, you get None + compare(await runner(), expected=None) + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(e), + ]) + + +@pytest.mark.asyncio +async def test_async_context_then_sync_context_exception_not_handled(): + m = Mock() + CM1 = make_cm('CM1', AsyncCM, m) + CM2 = make_cm('CM2', SyncCM, m) + + e = Exception('foo') + + async def func(): + raise e + + runner = Runner(CM1, CM2, func) + + with ShouldRaise(e): + await runner() + + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.cm2.exit(e), + call.cm1.exit(e), + ]) From 8ae38def8dea699228b567bb379be23f39cc1fee Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 29 Mar 2020 09:29:30 +0100 Subject: [PATCH 083/159] better local variable names and leave guessing the type to the context. --- mush/runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mush/runner.py b/mush/runner.py index 177059a..f8b4f83 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -270,9 +270,9 @@ def __call__(self, context: Context = None): raise ContextError(str(e), point, context) if getattr(result, '__enter__', None): - with result as manager: - if manager not in (None, result): - context.add(manager, manager.__class__) + with result as managed: + if managed not in (None, result): + context.add(managed) # If the context manager swallows an exception, # None should be returned, not the context manager: result = None From 26a0317a377e4549a0386ffd2286e99f913655ed Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 2 Apr 2020 07:35:48 +0100 Subject: [PATCH 084/159] By default, no longer add context manager to the mush context when they are instantiated. You probably want to just have the context added to the mush context when it's returned from __enter__ or __aenter__. --- mush/asyncio.py | 2 +- mush/declarations.py | 4 ++-- mush/runner.py | 2 +- mush/tests/test_runner.py | 39 +++++++++++++++++++-------------------- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 7a0be6c..f2d8693 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -124,7 +124,7 @@ async def __call__(self, context: Context = None): if getattr(manager, '__aenter__', None): async with manager as managed: - if managed not in (None, result): + if managed is not None: context.add(managed) # If the context manager swallows an exception, # None should be returned, not the context manager: diff --git a/mush/declarations.py b/mush/declarations.py index 9d2f492..3f78c68 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -111,11 +111,11 @@ class returns_result_type(ReturnsType): Default declaration that indicates a callable's return value should be used as a resource based on the type of the object returned. - ``None`` is ignored as a return value. + ``None`` is ignored as a return value, as are context managers """ def process(self, obj): - if obj is not None: + if not (obj is None or hasattr(obj, '__enter__') or hasattr(obj, '__aenter__')): yield obj.__class__, obj diff --git a/mush/runner.py b/mush/runner.py index f8b4f83..9491f9a 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -271,7 +271,7 @@ def __call__(self, context: Context = None): if getattr(result, '__enter__', None): with result as managed: - if managed not in (None, result): + if managed is not None: context.add(managed) # If the context manager swallows an exception, # None should be returned, not the context manager: diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 85fcbc5..0b71ae2 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -723,11 +723,10 @@ def __exit__(self, type, obj, tb): def func1(obj): m.func1(type(obj)) - @requires(CM1, CM2, CM2Context) - def func2(obj1, obj2, obj3): + @requires(CM1, CM2Context) + def func2(obj1, obj2): m.func2(type(obj1), - type(obj2), - type(obj3)) + type(obj2)) return '2' runner = Runner( @@ -740,14 +739,14 @@ def func2(obj1, obj2, obj3): result = runner() compare(result, '2') - compare([ - call.cm1.enter(), - call.cm2.enter(), - call.func1(CM1), - call.func2(CM1, CM2, CM2Context), - call.cm2.exit(None, None), - call.cm1.exit(None, None) - ], m.mock_calls) + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.func1(CM1), + call.func2(CM1, CM2Context), + call.cm2.exit(None, None), + call.cm1.exit(None, None) + ]) # now check with an exception m.reset_mock() @@ -757,14 +756,14 @@ def func2(obj1, obj2, obj3): # if something goes wrong, you get None compare(None, result) - compare([ - call.cm1.enter(), - call.cm2.enter(), - call.func1(CM1), - call.func2(CM1, CM2, CM2Context), - call.cm2.exit(Exception, e), - call.cm1.exit(Exception, e) - ], m.mock_calls) + compare(m.mock_calls, expected=[ + call.cm1.enter(), + call.cm2.enter(), + call.func1(CM1), + call.func2(CM1, CM2Context), + call.cm2.exit(Exception, e), + call.cm1.exit(Exception, e) + ]) def test_clone(self): m = Mock() From c1cac5c75e89164f14dfd670abd2e1ff7ae61e10 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 3 Apr 2020 07:48:36 +0100 Subject: [PATCH 085/159] work around deficiencies in iscoroutinefunction's ability to spot async stuff --- mush/asyncio.py | 8 +++++++- mush/tests/test_async_context.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index f2d8693..0d59359 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -1,5 +1,6 @@ import asyncio from functools import partial +from types import FunctionType from typing import Callable from . import ( @@ -42,7 +43,12 @@ def __init__(self, requirement_modifier: RequirementModifier = default_requireme async def _ensure_async(self, func, *args, **kw): async_type = self._async_cache.get(func) if async_type is None: - if asyncio.iscoroutinefunction(func): + to_check = func + if isinstance(func, partial): + to_check = func.func + if asyncio.iscoroutinefunction(to_check): + async_type = AsyncType.async_ + elif asyncio.iscoroutinefunction(to_check.__call__): async_type = AsyncType.async_ else: async_type = get_mush(func, 'async', default=None) diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 8a712b1..2f15f21 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -1,4 +1,5 @@ import asyncio +from functools import partial from typing import Tuple import pytest @@ -33,6 +34,29 @@ async def it(a, b='2'): compare(await context.call(it), expected='12') +@pytest.mark.asyncio +async def test_call_async_callable_object(): + context = Context() + + class AsyncCallable: + async def __call__(self): + return 42 + + with no_threads(): + compare(await context.call(AsyncCallable()), expected=42) + + +@pytest.mark.asyncio +async def test_call_partial_around_async(): + context = Context() + + async def it(): + return 42 + + with no_threads(): + compare(await context.call(partial(it)), expected=42) + + @pytest.mark.asyncio async def test_call_async_requires_context(): context = Context() From ce0be56372b81a6136f8fd35a2751ad7e048770b Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 3 Apr 2020 08:43:01 +0100 Subject: [PATCH 086/159] don't blow up when cloning and empty runner. --- mush/runner.py | 6 +++++- mush/tests/test_runner.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mush/runner.py b/mush/runner.py index 9491f9a..1c6c154 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -152,7 +152,11 @@ def clone(self, end = self.end # check start point is before end_point - point = start.previous + if start is not None: + point = start.previous + else: + point = None + while point: if point is end: return runner diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 0b71ae2..a13571c 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -942,6 +942,13 @@ def test_clone_added_using(self): (m.f7, {'the_label'}), ) + def test_clone_empty(self): + runner1 = Runner() + runner2 = runner1.clone() + # this gets set by the clone on runner 2, it's a class variable on runner1: + runner1.end = None + compare(expected=runner1, actual=runner2) + def test_extend(self): m = Mock() class T1(object): pass From 2216c06458ed8e86237ea212db8f454702b9efe8 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 7 Apr 2020 07:14:58 +0100 Subject: [PATCH 087/159] fix bug where a requirement of a non-default type that was explicitly specified would lose its type --- mush/extraction.py | 20 +++++++++++++------- mush/tests/test_callpoints.py | 11 ++++++----- mush/tests/test_declarations.py | 14 ++++++++++++++ 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/mush/extraction.py b/mush/extraction.py index 785cf6b..7b1099c 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -35,11 +35,17 @@ def _apply_requires(by_name, by_index, requires_): name = r.target existing = by_name[name] - existing.key = existing.key if r.key is None else r.key - existing.type = existing.type if r.type is None else r.type - existing.default = existing.default if r.default is missing else r.default - existing.ops = existing.ops if not r.ops else r.ops - existing.target = existing.target if r.target is None else r.target + if type(existing) is not type(r): + r_ = r.clone() + r_.name = existing.name + by_name[name] = r_ + else: + r_ = existing + r_.key = existing.key if r.key is None else r.key + r_.type = existing.type if r.type is None else r.type + r_.default = existing.default if r.default is missing else r.default + r_.ops = existing.ops if not r.ops else r.ops + r_.target = existing.target if r.target is None else r.target def default_requirement_type(requirement): @@ -52,14 +58,13 @@ def extract_requires(obj: Callable, explicit: RequiresType = None, modifier: RequirementModifier = default_requirement_type): # from annotations - is_partial = isinstance(obj, partial) by_name = {} for name, p in signature(obj).parameters.items(): if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): continue # https://bugs.python.org/issue39753: - if is_partial and p.name in obj.keywords: + if isinstance(obj, partial) and p.name in obj.keywords: continue name = p.name @@ -126,6 +131,7 @@ def extract_requires(obj: Callable, if not by_name: return requires_nothing + # sort out target and apply modifier: needs_target = False for name, requirement in by_name.items(): requirement_ = modifier(requirement) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index d857d39..4a368f0 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -7,7 +7,8 @@ from mush.callpoints import CallPoint from mush.declarations import requires, returns, RequiresType from mush.extraction import update_wrapper -from mush.requirements import Requirement +from mush.requirements import Requirement, Value +from .helpers import r class TestCallPoints(TestCase): @@ -30,7 +31,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([Requirement('foo', name='a1')]), + RequiresType([r(Value('foo'), name='a1')]), rt)) def test_extract_from_decorations(self): @@ -45,7 +46,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([Requirement('foo', name='a1')]), + RequiresType([r(Value('foo'), name='a1')]), returns('bar'))) def test_extract_from_decorated_class(self): @@ -71,7 +72,7 @@ def foo(prefix): self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) result = CallPoint(self.context, foo)(self.context) compare(result, expected=('the answer', - RequiresType([Requirement('foo', name='prefix')]), + RequiresType([r(Value('foo'), name='prefix')]), rt)) def test_explicit_trumps_decorators(self): @@ -84,7 +85,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([Requirement('baz', name='a1')]), + RequiresType([r(Value('baz'), name='a1')]), returns('bob'))) def test_repr_minimal(self): diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 1ee4dfb..c51d5e3 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -489,3 +489,17 @@ def foo(a: r2 = r3, b: str = r2, c = r3): r(Value('c'), name='c', target='c'), )), expected_rt=result_type) + + def test_explicit_requirement_type_trumps_default_requirement_type(self): + + class FromRequest(Requirement): pass + + @requires(a=Requirement('a')) + def foo(a): + pass + + compare(actual=extract_requires(foo, requires(a=FromRequest('b'))), + strict=True, + expected=RequiresType(( + r(FromRequest('b'), name='a', target='a'), + ))) From 57f7988f387332eca069fe510290024fd9df8e46 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 7 Apr 2020 08:56:18 +0100 Subject: [PATCH 088/159] keep markers singleton, even when copied --- mush/markers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mush/markers.py b/mush/markers.py index 7738d48..c03a04d 100644 --- a/mush/markers.py +++ b/mush/markers.py @@ -10,6 +10,9 @@ def __init__(self, name): def __repr__(self): return '' % self.name + def __copy__(self): + return self + not_specified = Marker('not_specified') From 9d903118be53936457a86f5851d2550f75520506 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 8 Apr 2020 13:04:36 +0100 Subject: [PATCH 089/159] Use a real runner here. This feels like Callpoint knows too much about runner, but they are quite intertwined anyway... --- mush/tests/test_callpoints.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 4a368f0..c07e963 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -7,7 +7,8 @@ from mush.callpoints import CallPoint from mush.declarations import requires, returns, RequiresType from mush.extraction import update_wrapper -from mush.requirements import Requirement, Value +from mush.requirements import Value +from mush.runner import Runner from .helpers import r @@ -15,10 +16,11 @@ class TestCallPoints(TestCase): def setUp(self): self.context = Mock() + self.runner = Runner() def test_passive_attributes(self): # these are managed by Modifiers - point = CallPoint(self.context, Mock()) + point = CallPoint(self.runner, Mock()) compare(point.previous, None) compare(point.next, None) compare(point.labels, set()) @@ -27,7 +29,7 @@ def test_supplied_explicitly(self): def foo(a1): pass rq = requires('foo') rt = returns('bar') - result = CallPoint(self.context, foo, rq, rt)(self.context) + result = CallPoint(self.runner, foo, rq, rt)(self.context) compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, @@ -42,7 +44,7 @@ def test_extract_from_decorations(self): @rt def foo(a1): pass - result = CallPoint(self.context, foo)(self.context) + result = CallPoint(self.runner, foo)(self.context) compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, @@ -70,7 +72,7 @@ def foo(prefix): return prefix+'answer' self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) - result = CallPoint(self.context, foo)(self.context) + result = CallPoint(self.runner, foo)(self.context) compare(result, expected=('the answer', RequiresType([r(Value('foo'), name='prefix')]), rt)) @@ -80,7 +82,7 @@ def test_explicit_trumps_decorators(self): @returns('bar') def foo(a1): pass - point = CallPoint(self.context, foo, requires('baz'), returns('bob')) + point = CallPoint(self.runner, foo, requires('baz'), returns('bob')) result = point(self.context) compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), @@ -90,20 +92,20 @@ def foo(a1): pass def test_repr_minimal(self): def foo(): pass - point = CallPoint(self.context, foo) + point = CallPoint(self.runner, foo) compare(repr(foo)+" requires() returns_result_type()", repr(point)) def test_repr_maximal(self): def foo(a1): pass - point = CallPoint(self.context, foo, requires('foo'), returns('bar')) + point = CallPoint(self.runner, foo, requires('foo'), returns('bar')) point.labels.add('baz') point.labels.add('bob') - compare(repr(foo)+" requires('foo') returns('bar') <-- baz, bob", - repr(point)) + compare(expected=repr(foo)+" requires('foo') returns('bar') <-- baz, bob", + actual=repr(point)) def test_convert_to_requires_and_returns(self): def foo(baz): pass - point = CallPoint(self.context, foo, requires='foo', returns='bar') + point = CallPoint(self.runner, foo, requires='foo', returns='bar') self.assertTrue(isinstance(point.requires, RequiresType)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires('foo') returns('bar')", @@ -111,7 +113,7 @@ def foo(baz): pass def test_convert_to_requires_and_returns_tuple(self): def foo(a1, a2): pass - point = CallPoint(self.context, + point = CallPoint(self.runner, foo, requires=('foo', 'bar'), returns=('baz', 'bob')) @@ -122,7 +124,7 @@ def foo(a1, a2): pass def test_convert_to_requires_and_returns_list(self): def foo(a1, a2): pass - point = CallPoint(self.context, + point = CallPoint(self.runner, foo, requires=['foo', 'bar'], returns=['baz', 'bob']) From 8c4d976e65a48410208402d6d6c4e4436d300b88 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 8 Apr 2020 13:05:12 +0100 Subject: [PATCH 090/159] fix nasty bug in requirement extraction --- mush/extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mush/extraction.py b/mush/extraction.py index 7b1099c..b72ed3a 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -136,7 +136,7 @@ def extract_requires(obj: Callable, for name, requirement in by_name.items(): requirement_ = modifier(requirement) if requirement_ is not requirement: - by_name[name] = requirement + by_name[name] = requirement = requirement_ if requirement.target is not None: needs_target = True elif needs_target: From 1ee555c48799ef739bf1c7e73339db773457d0d1 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 8 Apr 2020 13:26:50 +0100 Subject: [PATCH 091/159] simplify down repr of requirements and use consistently --- docs/use.txt | 22 +++++++++++----------- mush/declarations.py | 2 +- mush/requirements.py | 18 +++--------------- mush/tests/test_callpoints.py | 8 ++++---- mush/tests/test_declarations.py | 6 +++--- mush/tests/test_runner.py | 12 ++++++------ 6 files changed, 28 insertions(+), 40 deletions(-) diff --git a/docs/use.txt b/docs/use.txt index 82de49b..008a1a3 100755 --- a/docs/use.txt +++ b/docs/use.txt @@ -650,10 +650,10 @@ the representation of a runner gives all this information: >>> runner requires() returns_result_type() - requires(Ring) returns_result_type() <-- forged - requires(Ring) returns_result_type() - requires(Ring) returns_result_type() <-- engraved - requires(Ring) returns_result_type() + requires(Value(Ring)) returns_result_type() <-- forged + requires(Value(Ring)) returns_result_type() + requires(Value(Ring)) returns_result_type() <-- engraved + requires(Value(Ring)) returns_result_type() As you can see above, when a callable is inserted at a label, the label @@ -676,7 +676,7 @@ Now, when you add to a specific label, only that label is moved: >>> runner requires() returns_result_type() <-- before_polish - requires('ring') returns_result_type() <-- after_polish + requires(Value('ring')) returns_result_type() <-- after_polish Of course, you can still add to the end of the runner: @@ -686,8 +686,8 @@ Of course, you can still add to the end of the runner: >>> runner requires() returns_result_type() <-- before_polish - requires('ring') returns_result_type() <-- after_polish - requires('ring') returns_result_type() + requires(Value('ring')) returns_result_type() <-- after_polish + requires(Value('ring')) returns_result_type() However, the point modifier returned by getting a label from a runner will @@ -697,9 +697,9 @@ keep on moving the label as more callables are added using it: >>> runner requires() returns_result_type() <-- before_polish - requires('ring') returns_result_type() - requires('ring') returns_result_type() <-- after_polish - requires('ring') returns_result_type() + requires(Value('ring')) returns_result_type() + requires(Value('ring')) returns_result_type() <-- after_polish + requires(Value('ring')) returns_result_type() .. _plugs: @@ -1041,7 +1041,7 @@ To see how the configuration panned out, we would look at the :func:`repr`: requires() returns('config') requires(Value('config')['foo']) returns_result_type() <-- config - requires('connection') returns_result_type() + requires(Value('connection')) returns_result_type() As you can see, there is a problem with this configuration that will be exposed diff --git a/mush/declarations.py b/mush/declarations.py index 3f78c68..f449180 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -26,7 +26,7 @@ def valid_decoration_types(*objs): class RequiresType(list): def __repr__(self): - parts = (r.value_repr() if r.target is None else f'{r.target}={r.value_repr()}' + parts = (repr(r) if r.target is None else f'{r.target}={r!r}' for r in self) return f"requires({', '.join(parts)})" diff --git a/mush/requirements.py b/mush/requirements.py index e7fe107..f8dab19 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -77,26 +77,14 @@ def clone(self): obj.ops = list(self.ops) return obj - def value_repr(self, params='', *, from_repr=False): - key = name_or_repr(self.key) - if self.ops or self.default is not missing or from_repr: - default = '' if self.default is missing else f', default={self.default!r}' - ops = ''.join(repr(o) for o in self.ops) - return f"{type(self).__name__}({key}{default}{params}){ops}" - return key + def resolve(self, context: 'Context'): + raise NotImplementedError() def __repr__(self): - attrs = [] - for a in 'name', 'type_', 'target': - value = getattr(self, a.rstrip('_')) - if value is not None and value != self.key: - attrs.append(f", {a}={value!r}") - key = name_or_repr(self.key) default = '' if self.default is missing else f', default={self.default!r}' ops = ''.join(repr(o) for o in self.ops) - - return f"{type(self).__name__}({key}{default}{''.join(attrs)}){ops}" + return f"{type(self).__name__}({key}{default}){ops}" def attr(self, name): """ diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index c07e963..e418767 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -100,7 +100,7 @@ def foo(a1): pass point = CallPoint(self.runner, foo, requires('foo'), returns('bar')) point.labels.add('baz') point.labels.add('bob') - compare(expected=repr(foo)+" requires('foo') returns('bar') <-- baz, bob", + compare(expected=repr(foo)+" requires(Value('foo')) returns('bar') <-- baz, bob", actual=repr(point)) def test_convert_to_requires_and_returns(self): @@ -108,7 +108,7 @@ def foo(baz): pass point = CallPoint(self.runner, foo, requires='foo', returns='bar') self.assertTrue(isinstance(point.requires, RequiresType)) self.assertTrue(isinstance(point.returns, returns)) - compare(repr(foo)+" requires('foo') returns('bar')", + compare(repr(foo)+" requires(Value('foo')) returns('bar')", repr(point)) def test_convert_to_requires_and_returns_tuple(self): @@ -119,7 +119,7 @@ def foo(a1, a2): pass returns=('baz', 'bob')) self.assertTrue(isinstance(point.requires, RequiresType)) self.assertTrue(isinstance(point.returns, returns)) - compare(repr(foo)+" requires('foo', 'bar') returns('baz', 'bob')", + compare(repr(foo)+" requires(Value('foo'), Value('bar')) returns('baz', 'bob')", repr(point)) def test_convert_to_requires_and_returns_list(self): @@ -130,5 +130,5 @@ def foo(a1, a2): pass returns=['baz', 'bob']) self.assertTrue(isinstance(point.requires, RequiresType)) self.assertTrue(isinstance(point.returns, returns)) - compare(repr(foo)+" requires('foo', 'bar') returns('baz', 'bob')", + compare(repr(foo)+" requires(Value('foo'), Value('bar')) returns('baz', 'bob')", repr(point)) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index c51d5e3..e1aad9a 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -33,7 +33,7 @@ def test_empty(self): def test_types(self): r_ = requires(Type1, Type2, x=Type3, y=Type4) - compare(repr(r_), 'requires(Type1, Type2, x=Type3, y=Type4)') + compare(repr(r_), 'requires(Value(Type1), Value(Type2), x=Value(Type3), y=Value(Type4))') compare(r_, expected=[ Value(Type1), Value(Type2), @@ -43,7 +43,7 @@ def test_types(self): def test_strings(self): r_ = requires('1', '2', x='3', y='4') - compare(repr(r_), "requires('1', '2', x='3', y='4')") + compare(repr(r_), "requires(Value('1'), Value('2'), x=Value('3'), y=Value('4'))") compare(r_, expected=[ Value('1'), Value('2'), @@ -54,7 +54,7 @@ def test_strings(self): def test_typing(self): r_ = requires(Tuple[str]) text = 'Tuple' if PY_36 else 'typing.Tuple[str]' - compare(repr(r_), f"requires({text})") + compare(repr(r_),expected=f"requires(Value({text}))") compare(r_, expected=[r(Value(Tuple[str]), type=Tuple[str])]) def test_tuple_arg(self): diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index a13571c..b1adcb0 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -510,10 +510,10 @@ def job(arg): runner() text = '\n'.join(( - 'While calling: '+repr(job)+' requires(T) returns_result_type()', + 'While calling: '+repr(job)+' requires(Value(T)) returns_result_type()', 'with :', '', - "No Value(T, name='arg') in context", + "No Value(T) in context", )) compare(text, actual=repr(s.raised)) compare(text, actual=str(s.raised)) @@ -548,14 +548,14 @@ def job5(foo, bar): pass repr(job1)+' requires() returns_result_type() <-- 1', repr(job2)+' requires() returns_result_type()', '', - 'While calling: '+repr(job3)+' requires(T) returns_result_type()', + 'While calling: '+repr(job3)+' requires(Value(T)) returns_result_type()', 'with :', '', - "No Value(T, name='arg') in context", + "No Value(T) in context", '', 'Still to call:', repr(job4)+' requires() returns_result_type() <-- 4', - repr(job5)+" requires('foo', 'baz') returns('bob')", + repr(job5)+" requires(Value('foo'), Value('baz')) returns('bob')", )) compare(text, actual=repr(s.raised)) compare(text, actual=str(s.raised)) @@ -1310,7 +1310,7 @@ class T2: pass compare('\n'.join(( '', ' '+repr(m.job1)+' requires() returns_result_type() <-- label1', - ' '+repr(m.job2)+" requires('foo', T1) returns(T2) <-- label2", + ' '+repr(m.job2)+" requires(Value('foo'), Value(T1)) returns(T2) <-- label2", ' '+repr(m.job3)+' requires() returns_result_type()', '' From 2300910212e89b4832895455be3cc4eb03dfec01 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 8 Apr 2020 13:41:01 +0100 Subject: [PATCH 092/159] Replace requirement.clone with .make and .make_from. Also replace any cases of modifying __class__ on objects with creating a new instance! --- mush/declarations.py | 3 +- mush/extraction.py | 29 +++++----- mush/requirements.py | 39 ++++++++++--- mush/runner.py | 3 +- mush/tests/helpers.py | 8 --- mush/tests/test_async_context.py | 4 +- mush/tests/test_callpoints.py | 9 ++- mush/tests/test_context.py | 8 +-- mush/tests/test_declarations.py | 96 ++++++++++++++++---------------- mush/tests/test_requirements.py | 66 ++++++++++++++++++++-- 10 files changed, 167 insertions(+), 98 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index f449180..1b44bb3 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -54,8 +54,7 @@ def requires(*args, **kw): kw.items(), ): if isinstance(possible, Requirement): - possible = possible.clone() - possible.target = target + possible = possible.make_from(possible, target=target) requirement = possible else: requirement = Value(possible) diff --git a/mush/extraction.py b/mush/extraction.py index b72ed3a..e9e35a4 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -24,33 +24,32 @@ def _apply_requires(by_name, by_index, requires_): for i, r in enumerate(requires_): + if r.target is None: try: name = by_index[i] except IndexError: # case where something takes *args - by_name[i] = r.clone() + by_name[i] = r.make_from(r) continue else: name = r.target existing = by_name[name] - if type(existing) is not type(r): - r_ = r.clone() - r_.name = existing.name - by_name[name] = r_ - else: - r_ = existing - r_.key = existing.key if r.key is None else r.key - r_.type = existing.type if r.type is None else r.type - r_.default = existing.default if r.default is missing else r.default - r_.ops = existing.ops if not r.ops else r.ops - r_.target = existing.target if r.target is None else r.target + by_name[name] = r.make_from( + r, + name=existing.name, + key=existing.key if r.key is None else r.key, + type=existing.type if r.type is None else r.type, + default=existing.default if r.default is missing else r.default, + ops=existing.ops if not r.ops else r.ops, + target=existing.target if r.target is None else r.target, + ) def default_requirement_type(requirement): - if requirement.__class__ is Requirement: - requirement.__class__ = Value + if type(requirement) is Requirement: + requirement = Value.make_from(requirement) return requirement @@ -92,7 +91,7 @@ def extract_requires(obj: Callable, else: key = type_ else: - requirement = requirement.clone() + requirement = requirement.make_from(requirement) type_ = type_ if requirement.type is None else requirement.type if requirement.key is not None: key = requirement.key diff --git a/mush/requirements.py b/mush/requirements.py index f8dab19..c440be9 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -48,11 +48,11 @@ class Requirement: """ def __init__(self, - key: ResourceKey = None, + key: ResourceKey = None, # XXX should not be default? name: str = None, type_: type = None, default: Any = missing, - target: str =None): + target: str = None): #: The resource key needed for this parameter. self.key: Optional[ResourceKey] = key #: The name of this parameter in the callable's signature. @@ -66,15 +66,38 @@ def __init__(self, self.ops: List['Op'] = [] self.target: Optional[str] = target - def resolve(self, context: 'Context'): - raise NotImplementedError() + @classmethod + def make(cls, **attrs): + """ + Make a :class:`Requirement` instance with all attributes provided. + + .. note:: + + This method allows instances to be created with missing or invalid attributes. + While this is necessary for use cases such as testing :class:`Requirement` + instantiation or otherwise setting attributes that are not accessible from + a custom requirement's :meth:`__init__`, it should be used with care. + + :param attrs: + :return: + """ + obj = Requirement(attrs.pop('key')) + obj.__class__ = cls + obj.__dict__.update(attrs) + return obj - def clone(self): + @classmethod + def make_from(cls, source: 'Requirement', **attrs): """ - Create a copy of this requirement, so it can be mutated + Make a new instance of this requirement class, using attributes + from a source requirement overlaid with any additional + ``attrs`` that have been supplied. """ - obj = copy(self) - obj.ops = list(self.ops) + attrs_ = source.__dict__.copy() + attrs_.update(attrs) + obj = cls.make(**attrs_) + obj.ops = list(source.ops) + obj.default = copy(source.default) return obj def resolve(self, context: 'Context'): diff --git a/mush/runner.py b/mush/runner.py index 1c6c154..3c5f1a6 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -26,8 +26,7 @@ def __init__(self, *objects): def modify_requirement(self, requirement): if requirement.key in self.lazy: - requirement.__class__ = Lazy - requirement.runner = self + requirement = Lazy.make_from(requirement, runner=self) else: requirement = default_requirement_type(requirement) return requirement diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py index 5868260..3abc429 100644 --- a/mush/tests/helpers.py +++ b/mush/tests/helpers.py @@ -6,14 +6,6 @@ from mock import Mock -def r(base, **attrs): - """ - helper for returning Requirement subclasses with extra attributes - """ - base.__dict__.update(attrs) - return base - - PY_VERSION = sys.version_info[:2] PY_36 = PY_VERSION == (3, 6) diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 2f15f21..12fd9bc 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -370,8 +370,8 @@ async def resolve(self, context): return (context.get('request'))[self.key] def default_requirement_type(requirement): - if requirement.__class__ is Requirement: - requirement.__class__ = FromRequest + if type(requirement) is Requirement: + requirement = FromRequest.make_from(requirement) return requirement def foo(bar): diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index e418767..6fe8830 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -9,7 +9,6 @@ from mush.extraction import update_wrapper from mush.requirements import Value from mush.runner import Runner -from .helpers import r class TestCallPoints(TestCase): @@ -33,7 +32,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([r(Value('foo'), name='a1')]), + RequiresType([Value.make(key='foo', name='a1')]), rt)) def test_extract_from_decorations(self): @@ -48,7 +47,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([r(Value('foo'), name='a1')]), + RequiresType([Value.make(key='foo', name='a1')]), returns('bar'))) def test_extract_from_decorated_class(self): @@ -74,7 +73,7 @@ def foo(prefix): self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) result = CallPoint(self.runner, foo)(self.context) compare(result, expected=('the answer', - RequiresType([r(Value('foo'), name='prefix')]), + RequiresType([Value.make(key='foo', name='prefix')]), rt)) def test_explicit_trumps_decorators(self): @@ -87,7 +86,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(tuple(self.context.extract.mock_calls[0].args), expected=(foo, - RequiresType([r(Value('baz'), name='a1')]), + RequiresType([Value.make(key='baz', name='a1')]), returns('bob'))) def test_repr_minimal(self): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 7dfa6f2..2f89f85 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -10,7 +10,7 @@ from mush.context import ResourceError from mush.declarations import RequiresType, requires_nothing, returns_nothing from mush.requirements import Requirement -from .helpers import r, TheType +from .helpers import TheType class TestContext(TestCase): @@ -427,7 +427,7 @@ def foo(bar: FromRequest('bar')): context.add({}, provides='request') with ShouldRaise(ResourceError("No FromRequest('bar') in context", key='bar', - requirement=r(FromRequest('bar'), name='bar'))): + requirement=FromRequest.make(key='bar', name='bar'))): compare(context.call(foo)) def test_default_custom_requirement(self): @@ -440,8 +440,8 @@ def foo(bar): return bar def modifier(requirement): - if requirement.__class__ is Requirement: - requirement.__class__ = FromRequest + if type(requirement) is Requirement: + requirement = FromRequest.make_from(requirement) return requirement context = Context(requirement_modifier=modifier) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index e1aad9a..a190323 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -14,7 +14,7 @@ ) from mush.extraction import extract_requires, extract_returns, update_wrapper from mush.requirements import Requirement, ItemOp -from .helpers import r, PY_36, Type1, Type2, Type3, Type4 +from .helpers import PY_36, Type1, Type2, Type3, Type4 def check_extract(obj, expected_rq, expected_rt): @@ -37,8 +37,8 @@ def test_types(self): compare(r_, expected=[ Value(Type1), Value(Type2), - r(Value(Type3), name='x', target='x'), - r(Value(Type4), name='y', target='y'), + Value.make(key=Type3, type=Type3, name='x', target='x'), + Value.make(key=Type4, type=Type4, name='y', target='y'), ]) def test_strings(self): @@ -47,15 +47,15 @@ def test_strings(self): compare(r_, expected=[ Value('1'), Value('2'), - r(Value('3'), name='x', target='x'), - r(Value('4'), name='y', target='y'), + Value.make(key='3', name='x', target='x'), + Value.make(key='4', name='y', target='y'), ]) def test_typing(self): r_ = requires(Tuple[str]) text = 'Tuple' if PY_36 else 'typing.Tuple[str]' compare(repr(r_),expected=f"requires(Value({text}))") - compare(r_, expected=[r(Value(Tuple[str]), type=Tuple[str])]) + compare(r_, expected=[Value.make(key=Tuple[str], type=Tuple[str])]) def test_tuple_arg(self): with ShouldRaise(TypeError("('1', '2') is not a valid decoration type")): @@ -169,8 +169,8 @@ def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, expected_rq=RequiresType(( - r(Value('a'), name='a'), - r(Value('b'), default=None, name='b'), + Value.make(key='a', name='a'), + Value.make(key='b', default=None, name='b'), )), expected_rt=result_type) @@ -179,8 +179,8 @@ class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, expected_rq=RequiresType(( - r(Value('a'), name='a'), - r(Value('b'), name='b', default=None), + Value.make(key='a', name='a'), + Value.make(key='b', name='b', default=None), )), expected_rt=result_type) @@ -190,8 +190,8 @@ def foo(x, y, z, a=None): pass check_extract( p, expected_rq=RequiresType(( - r(Value('z'), name='z', target='z'), - r(Value('a'), name='a', target='a', default=None), + Value.make(key='z', name='z', target='z'), + Value.make(key='a', name='a', target='a', default=None), )), expected_rt=result_type ) @@ -202,7 +202,7 @@ def foo(a=None): pass check_extract( p, expected_rq=RequiresType(( - r(Value('a'), name='a', default=None), + Value.make(key='a', name='a', default=None), )), expected_rt=result_type ) @@ -251,8 +251,8 @@ def foo(b, a=None): pass check_extract( p, expected_rq=RequiresType(( - r(Value('b'), name='b'), - r(Value('a'), name='a', default=None), + Value.make(key='b', name='b'), + Value.make(key='a', name='a', default=None), )), expected_rt=result_type ) @@ -264,7 +264,7 @@ def foo(b, a): pass p, # since b is already bound: expected_rq=RequiresType(( - r(Value('a'), name='a'), + Value.make(key='a', name='a'), )), expected_rt=result_type ) @@ -275,7 +275,7 @@ def foo(b, a): pass check_extract( p, expected_rq=RequiresType(( - r(Value('b'), name='b'), + Value.make(key='b', name='b'), )), expected_rt=result_type ) @@ -287,17 +287,17 @@ def test_extract_from_annotations(self): def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass check_extract(foo, expected_rq=RequiresType(( - r(Value('foo'), name='a'), - r(Value('b'), name='b'), - r(Value('bar'), name='c', default=1), - r(Value('d'), name='d', default=2) + Value.make(key='foo', name='a'), + Value.make(key='b', name='b'), + Value.make(key='bar', name='c', default=1), + Value.make(key='d', name='d', default=2) )), expected_rt=returns('bar')) def test_requires_only(self): def foo(a: 'foo'): pass check_extract(foo, - expected_rq=RequiresType((r(Value('foo'), name='a'),)), + expected_rq=RequiresType((Value.make(key='foo', name='a'),)), expected_rt=result_type) def test_returns_only(self): @@ -323,7 +323,7 @@ def foo(a: 'foo' = None) -> 'bar': compare(foo(), expected='the answer') check_extract(foo, - expected_rq=RequiresType((r(Value('foo'), name='a', default=None),)), + expected_rq=RequiresType((Value.make(key='foo', name='a', default=None),)), expected_rt=returns('bar')) def test_decorator_trumps_annotations(self): @@ -331,7 +331,7 @@ def test_decorator_trumps_annotations(self): @returns('bar') def foo(a: 'x') -> 'y': pass check_extract(foo, - expected_rq=RequiresType((r(Value('foo'), name='a'),)), + expected_rq=RequiresType((Value.make(key='foo', name='a'),)), expected_rt=returns('bar')) def test_returns_mapping(self): @@ -352,7 +352,7 @@ def test_how_instance_in_annotations(self): def foo(a: Value('config')['db_url']): pass check_extract(foo, expected_rq=RequiresType(( - r(Value('config'), name='a', ops=[ItemOp('db_url')]), + Value.make(key='config', name='a', ops=[ItemOp('db_url')]), )), expected_rt=result_type) @@ -360,10 +360,10 @@ def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, expected_rq=RequiresType(( - r(Value('a'), name='a'), - r(Value('b'), name='b', default=1), - r(Value('c'), name='c', target='c'), - r(Value('d'), name='d', target='d', default=None) + Value.make(key='a', name='a'), + Value.make(key='b', name='b', default=1), + Value.make(key='c', name='c', target='c'), + Value.make(key='d', name='d', target='d', default=None) )), expected_rt=result_type) @@ -371,27 +371,27 @@ def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=RequiresType((r(Value(T), name='a', type=T),)), + expected_rq=RequiresType((Value.make(key=T, name='a', type=T),)), expected_rt=result_type) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=RequiresType((r(Value('a'), name='a', type=type_),)), + expected_rq=RequiresType((Value.make(key='a', name='a', type=type_),)), expected_rt=result_type) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=RequiresType((r(Value('b'), name='a', type=str),)), + expected_rq=RequiresType((Value.make(key='b', name='a', type=str),)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, expected_rq=RequiresType(( - r(Value('b'), name='a', type=str, default=1), + Value.make(key='b', name='a', type=str, default=1), )), expected_rt=result_type) @@ -399,7 +399,7 @@ def test_value_annotation_plus_default(self): def foo(a: Value('b', type_=str) = 1): pass check_extract(foo, expected_rq=RequiresType(( - r(Value('b'), name='a', type=str, default=1), + Value.make(key='b', name='a', type=str, default=1), )), expected_rt=result_type) @@ -407,7 +407,7 @@ def test_value_annotation_just_type_in_value_key_plus_default(self): def foo(a: Value(str) = 1): pass check_extract(foo, expected_rq=RequiresType(( - r(Value(str), name='a', type=str, default=1), + Value.make(key=str, name='a', type=str, default=1), )), expected_rt=result_type) @@ -415,7 +415,7 @@ def test_value_annotation_just_type_plus_default(self): def foo(a: Value(type_=str) = 1): pass check_extract(foo, expected_rq=RequiresType(( - r(Value(key='a'), name='a', type=str, default=1), + Value.make(key='a', name='a', type=str, default=1), )), expected_rt=result_type) @@ -423,19 +423,19 @@ def test_value_unspecified_with_type(self): class T1: pass def foo(a: T1 = Value()): pass check_extract(foo, - expected_rq=RequiresType((r(Value(key=T1), name='a', type=T1),)), + expected_rq=RequiresType((Value.make(key=T1, name='a', type=T1),)), expected_rt=result_type) def test_value_unspecified_with_simple_type(self): def foo(a: str = Value()): pass check_extract(foo, - expected_rq=RequiresType((r(Value(key='a'), name='a', type=str),)), + expected_rq=RequiresType((Value.make(key='a', name='a', type=str),)), expected_rt=result_type) def test_value_unspecified(self): def foo(a = Value()): pass check_extract(foo, - expected_rq=RequiresType((r(Value(key='a'), name='a'),)), + expected_rq=RequiresType((Value.make(key='a', name='a'),)), expected_rt=result_type) def test_requirement_modifier(self): @@ -444,8 +444,8 @@ def foo(x: str = None): pass class FromRequest(Requirement): pass def modifier(requirement): - if requirement.__class__ is Requirement: - requirement.__class__ = FromRequest + if type(requirement) is Requirement: + requirement = FromRequest.make_from(requirement) return requirement rq = extract_requires(foo, modifier=modifier) @@ -467,9 +467,9 @@ def foo(a: r1, b, c=r3): check_extract(foo, expected_rq=RequiresType(( - r(Value('a'), name='a'), - r(Value('b'), name='b', target='b'), - r(Value('c'), name='c', target='c'), + Value.make(key='a', name='a'), + Value.make(key='b', name='b', target='b'), + Value.make(key='c', name='c', target='c'), )), expected_rt=result_type) @@ -484,9 +484,9 @@ def foo(a: r2 = r3, b: str = r2, c = r3): check_extract(foo, expected_rq=RequiresType(( - r(Value('a'), name='a', target='a'), - r(Value('b'), name='b', target='b', type=str), - r(Value('c'), name='c', target='c'), + Value.make(key='a', name='a', target='a'), + Value.make(key='b', name='b', target='b', type=str), + Value.make(key='c', name='c', target='c'), )), expected_rt=result_type) @@ -501,5 +501,5 @@ def foo(a): compare(actual=extract_requires(foo, requires(a=FromRequest('b'))), strict=True, expected=RequiresType(( - r(FromRequest('b'), name='a', target='a'), + FromRequest.make(key='b', name='a', target='a'), ))) diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index daf9fae..a4b55c5 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -26,16 +26,74 @@ def test_repr_maximal(self): r = Requirement('foo', name='n', type_='ty', default=None, target='ta') r.ops.append(AttrOp('bar')) compare(repr(r), - expected="Requirement('foo', default=None, " - "name='n', type_='ty', target='ta').bar") + expected="Requirement('foo', default=None).bar") - def test_clone(self): + def test_make_allows_params_not_passed_to_constructor(self): + r = Value.make(key='x', target='a') + assert type(r) is Value + compare(r.key, expected='x') + compare(r.target, expected='a') + + def test_make_can_create_invalid_objects(self): + # So be careful! + + class SampleRequirement(Requirement): + def __init__(self, foo): + super().__init__(key='y') + self.foo = foo + + r = SampleRequirement('it') + compare(r.foo, expected='it') + + r = SampleRequirement.make(key='x') + assert 'foo' not in r.__dict__ + # ...when it really should be! + + def test_clone_using_make_from(self): r = Value('foo').bar.requirement - r_ = r.clone() + r_ = r.make_from(r) assert r_ is not r assert r_.ops is not r.ops compare(r_, expected=r) + def test_make_from_with_mutable_default(self): + r = Requirement('foo', default=[]) + r_ = r.make_from(r) + assert r_ is not r + assert r_.default is not r.default + compare(r_, expected=r) + + def test_make_from_into_new_type(self): + r = Requirement('foo').bar.requirement + r_ = Value.make_from(r) + compare(r_, expected=Value('foo').bar.requirement) + + def test_make_from_with_required_constructor_parameter(self): + + class SampleRequirement(Requirement): + def __init__(self, foo): + super().__init__('foo') + self.foo = foo + + r = Requirement('foo') + r_ = SampleRequirement.make_from(r, foo='it') + assert r_ is not r + compare(r_, expected=SampleRequirement(foo='it')) + + def test_make_from_source_has_more_attributes(self): + + class SampleRequirement(Requirement): + def __init__(self, foo): + super().__init__() + self.foo = foo + + r = SampleRequirement('it') + r_ = Requirement.make_from(r) + assert r_ is not r + + # while this is a bit ugly, it will hopefully do no harm: + assert r_.foo == 'it' + special_names = ['attr', 'ops', 'target'] @pytest.mark.parametrize("name", special_names) From c5f75639749aaea9ae1cf0487b5b99ca4ba5733a Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 9 Apr 2020 09:07:51 +0100 Subject: [PATCH 093/159] Raise exceptions for problems around lazy resources Also simplify to use Call, not need for a custom requirement. --- mush/callpoints.py | 16 ++++++--- mush/requirements.py | 20 +++-------- mush/runner.py | 20 +++++++---- mush/tests/test_runner.py | 72 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 102 insertions(+), 26 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index a501d5d..d0723ff 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,8 +1,8 @@ from .declarations import ( - requires_nothing, returns as returns_declaration, - - returns_nothing) + requires_nothing, returns as returns_declaration, returns_nothing +) from .extraction import extract_requires, extract_returns +from .requirements import Call, name_or_repr def do_nothing(): @@ -20,7 +20,15 @@ def __init__(self, runner, obj, requires=None, returns=None, lazy=False): if lazy: if not (type(returns) is returns_declaration and len(returns.args) == 1): raise TypeError('a single return type must be explicitly specified') - runner.lazy[returns.args[0]] = obj, requires + key = returns.args[0] + requirement = Call(obj, requires) + if key in runner.lazy: + raise TypeError( + f'{name_or_repr(key)} has more than one lazy definition:\n' + f'{runner.lazy[key]}\n' + f'{requirement}' + ) + runner.lazy[key] = requirement obj = do_nothing requires = requires_nothing returns = returns_nothing diff --git a/mush/requirements.py b/mush/requirements.py index c440be9..1795645 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -154,35 +154,25 @@ def resolve(self, context: 'Context'): return context.get(self.key, self.default) -class Lazy(Requirement): - - runner = None - - def resolve(self, context): - result = context.get(self.key, missing) - if result is missing: - obj, requires = self.runner.lazy[self.key] - result = context.call(obj, requires) - context.add(result, provides=self.key) - return result - - class Call(Requirement): """ A requirement that is resolved by calling something. If ``cache`` is ``True``, then the result of that call will be cached for the duration of the context in which this requirement is resolved. + + Explicit ``requires`` can also be passed in. """ - def __init__(self, obj: Callable, *, cache: bool = True): + def __init__(self, obj: Callable, requires=None, *, cache: bool = True): super().__init__(obj) + self.requires = requires self.cache: bool = cache def resolve(self, context): result = context.get(self.key, missing) if result is missing: - result = context.call(self.key) + result = context.call(self.key, self.requires) if self.cache: context.add(result, provides=self.key) return result diff --git a/mush/runner.py b/mush/runner.py index 3c5f1a6..7f06fcd 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -7,7 +7,7 @@ from .markers import not_specified from .modifier import Modifier from .plug import Plug -from .requirements import Lazy +from .requirements import name_or_repr class Runner(object): @@ -26,7 +26,7 @@ def __init__(self, *objects): def modify_requirement(self, requirement): if requirement.key in self.lazy: - requirement = Lazy.make_from(requirement, runner=self) + requirement = self.lazy[requirement.key] else: requirement = default_requirement_type(requirement) return requirement @@ -69,7 +69,15 @@ def add_label(self, label): m.add_label(label) return m - def _copy_from(self, start_point, end_point, added_using=None): + def _copy_from(self, runner, start_point, end_point, added_using=None): + lazy_clash = set(self.lazy) & set(runner.lazy) + if lazy_clash: + raise TypeError( + 'both runners have lazy definitions for these resources:\n' + + '\n'.join(name_or_repr(key) for key in lazy_clash) + ) + self.lazy.update(runner.lazy) + previous_cloned_point = self.end point = start_point @@ -104,7 +112,7 @@ def extend(self, *objs): """ for obj in objs: if isinstance(obj, Runner): - self._copy_from(obj.start, obj.end) + self._copy_from(obj, obj.start, obj.end) else: self.add(obj) @@ -161,7 +169,7 @@ def clone(self, return runner point = point.previous - runner._copy_from(start, end, added_using) + runner._copy_from(self, start, end, added_using) return runner def replace(self, @@ -237,7 +245,7 @@ def __add__(self, other): """ runner = self.__class__() for r in self, other: - runner._copy_from(r.start, r.end) + runner._copy_from(r, r.start, r.end) return runner def __call__(self, context: Context = None): diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index b1adcb0..9bd62a3 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -4,7 +4,7 @@ from mush.declarations import ( requires, returns, returns_mapping, replacement, original) -from mush import Value, ContextError, Context +from mush import Value, ContextError, Context, Requirement from mush.runner import Runner from testfixtures import ( ShouldRaise, @@ -445,6 +445,17 @@ class T2(object): pass ): runner.add(lambda: None, returns=returns(T1, T2), lazy=True) + def test_lazy_two_callable_provide_same_type(self): + class T1(object): pass + runner = Runner() + runner.add(lambda: None, returns=returns(T1), lazy=True) + with ShouldRaise(TypeError( + 'T1 has more than one lazy definition:\n' + 'Call()\n' + 'Call()' + )): + runner.add(lambda: None, returns=returns(T1), lazy=True) + def test_lazy_per_context(self): m = Mock() class T1(object): pass @@ -470,6 +481,65 @@ def job(obj): call.job(t), ], ) + def test_lazy_after_clone(self): + m = Mock() + class T1(object): pass + t = T1() + + def lazy(): + m.lazy_used() + return t + + def job(obj): + m.job(obj) + + runner = Runner() + runner.add(lazy, returns=returns(T1), lazy=True) + runner_ = runner.clone() + runner_.add(job, requires(T1)) + runner_() + + compare(m.mock_calls, expected=[ + call.lazy_used(), + call.job(t), + ], ) + + def test_lazy_after_add(self): + m = Mock() + class T1(object): pass + t = T1() + + def lazy(): + m.lazy_used() + return t + + def job(obj): + m.job(obj) + + runner1 = Runner() + runner1.add(lazy, returns=returns(T1), lazy=True) + runner2 = Runner() + runner2.add(job, requires(T1)) + runner = runner1 + runner2 + runner() + + compare(m.mock_calls, expected=[ + call.lazy_used(), + call.job(t), + ], ) + + def test_lazy_add_clash(self): + class T1(object): pass + runner1 = Runner() + runner1.add(lambda: None, returns=returns(T1), lazy=True) + runner2 = Runner() + runner2.add(lambda: None, returns=returns(T1), lazy=True) + with ShouldRaise(TypeError( + 'both runners have lazy definitions for these resources:\n' + 'T1' + )): + runner1 + runner2 + def test_lazy_only_resolved_once(self): m = Mock() class T1(object): pass From 810ecc4277e42334cea6c7343be56bae95be9d8f Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 9 Apr 2020 10:58:15 +0100 Subject: [PATCH 094/159] Add requirement_modifier support to Runners. This required further changes to lazy resources/providers and further highlights their weaknesses. --- mush/callpoints.py | 14 ++++--- mush/requirements.py | 22 +++++++--- mush/runner.py | 21 ++++++---- mush/tests/test_runner.py | 88 +++++++++++++++++++++++++++++++++++---- 4 files changed, 118 insertions(+), 27 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index d0723ff..fb8627a 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,3 +1,5 @@ +from collections import namedtuple + from .declarations import ( requires_nothing, returns as returns_declaration, returns_nothing ) @@ -9,6 +11,9 @@ def do_nothing(): pass +LazyProvider = namedtuple('LazyProvider', ['obj', 'requires', 'returns']) + + class CallPoint(object): next = None @@ -21,14 +26,13 @@ def __init__(self, runner, obj, requires=None, returns=None, lazy=False): if not (type(returns) is returns_declaration and len(returns.args) == 1): raise TypeError('a single return type must be explicitly specified') key = returns.args[0] - requirement = Call(obj, requires) if key in runner.lazy: raise TypeError( - f'{name_or_repr(key)} has more than one lazy definition:\n' - f'{runner.lazy[key]}\n' - f'{requirement}' + f'{name_or_repr(key)} has more than one lazy provider:\n' + f'{runner.lazy[key].obj}\n' + f'{obj}' ) - runner.lazy[key] = requirement + runner.lazy[key] = LazyProvider(obj, requires, returns) obj = do_nothing requires = requires_nothing returns = returns_nothing diff --git a/mush/requirements.py b/mush/requirements.py index 1795645..f17eeb0 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -160,19 +160,16 @@ class Call(Requirement): If ``cache`` is ``True``, then the result of that call will be cached for the duration of the context in which this requirement is resolved. - - Explicit ``requires`` can also be passed in. """ - def __init__(self, obj: Callable, requires=None, *, cache: bool = True): + def __init__(self, obj: Callable, *, cache: bool = True): super().__init__(obj) - self.requires = requires self.cache: bool = cache def resolve(self, context): result = context.get(self.key, missing) if result is missing: - result = context.call(self.key, self.requires) + result = context.call(self.key) if self.cache: context.add(result, provides=self.key) return result @@ -210,3 +207,18 @@ def resolve(self, context: 'Context'): if value is not missing: return value return self.default + + +class Lazy(Requirement): + + def __init__(self, original, provider): + super().__init__(original.key) + self.original = original + self.provider = provider + self.ops = original.ops + + def resolve(self, context): + resource = context.get(self.key, missing) + if resource is missing: + context.extract(self.provider.obj, self.provider.requires, self.provider.returns) + return self.original.resolve(context) diff --git a/mush/runner.py b/mush/runner.py index 7f06fcd..a5e67cc 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -7,7 +7,7 @@ from .markers import not_specified from .modifier import Modifier from .plug import Plug -from .requirements import name_or_repr +from .requirements import name_or_repr, Lazy class Runner(object): @@ -19,16 +19,16 @@ class Runner(object): start = end = None - def __init__(self, *objects): + def __init__(self, *objects, requirement_modifier=default_requirement_type): + self.requirement_modifier = requirement_modifier self.labels = {} self.lazy = {} self.extend(*objects) def modify_requirement(self, requirement): + requirement = self.requirement_modifier(requirement) if requirement.key in self.lazy: - requirement = self.lazy[requirement.key] - else: - requirement = default_requirement_type(requirement) + requirement = Lazy(requirement, provider=self.lazy[requirement.key]) return requirement def add(self, obj, requires=None, returns=None, label=None, lazy=False): @@ -70,11 +70,16 @@ def add_label(self, label): return m def _copy_from(self, runner, start_point, end_point, added_using=None): + if self.requirement_modifier is not runner.requirement_modifier: + raise TypeError('requirement_modifier must be identical') + lazy_clash = set(self.lazy) & set(runner.lazy) if lazy_clash: raise TypeError( - 'both runners have lazy definitions for these resources:\n' + - '\n'.join(name_or_repr(key) for key in lazy_clash) + 'both runners have lazy providers for these resources:\n' + + '\n'.join(f'{name_or_repr(key)}: \n' + f' {self.lazy[key].obj}\n' + f' {runner.lazy[key].obj}' for key in lazy_clash) ) self.lazy.update(runner.lazy) @@ -142,7 +147,7 @@ def clone(self, label specified in this option should be cloned. This filtering is applied in addition to the above options. """ - runner = self.__class__() + runner = self.__class__(requirement_modifier=self.requirement_modifier) if start_label: start = self.labels[start_label] diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 9bd62a3..b158301 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -447,14 +447,16 @@ class T2(object): pass def test_lazy_two_callable_provide_same_type(self): class T1(object): pass + def foo(): pass + def bar(): pass runner = Runner() - runner.add(lambda: None, returns=returns(T1), lazy=True) + runner.add(foo, returns=returns(T1), lazy=True) with ShouldRaise(TypeError( - 'T1 has more than one lazy definition:\n' - 'Call()\n' - 'Call()' + 'T1 has more than one lazy provider:\n' + f'{foo!r}\n' + f'{bar!r}' )): - runner.add(lambda: None, returns=returns(T1), lazy=True) + runner.add(bar, returns=returns(T1), lazy=True) def test_lazy_per_context(self): m = Mock() @@ -530,13 +532,17 @@ def job(obj): def test_lazy_add_clash(self): class T1(object): pass + def foo(): pass + def bar(): pass runner1 = Runner() - runner1.add(lambda: None, returns=returns(T1), lazy=True) + runner1.add(foo, returns=returns(T1), lazy=True) runner2 = Runner() - runner2.add(lambda: None, returns=returns(T1), lazy=True) + runner2.add(bar, returns=returns(T1), lazy=True) with ShouldRaise(TypeError( - 'both runners have lazy definitions for these resources:\n' - 'T1' + 'both runners have lazy providers for these resources:\n' + 'T1: \n' + f' {foo!r}\n' + f' {bar!r}' )): runner1 + runner2 @@ -567,6 +573,29 @@ def job2(obj): call.job2(t), ], ) + def test_lazy_with_requirement_modifier(self): + def make_data(): + return {'foo': 'bar'} + + class FromKey(Requirement): + def resolve(self, context): + return context.get('data')[self.data_key] + + def modifier(requirement): + if type(requirement) is Requirement: + # another limitation of lazy: + requirement = FromKey.make_from(requirement, + key='data', + data_key=requirement.key) + return requirement + + runner = Runner(requirement_modifier=modifier) + runner.add(make_data, returns='data', lazy=True) + runner.add(lambda foo: foo+'baz', returns='processed') + runner.add(lambda *args: args, requires(Value('data')['foo'], 'processed')) + + compare(runner(), expected=('bar', 'barbaz')) + def test_missing_from_context_no_chain(self): class T(object): pass @@ -1395,3 +1424,44 @@ def foo(): return 42 runner = Runner(foo) compare(runner(context), expected=42) + + def test_requirement_modifier(self): + + class FromRequest(Requirement): + def resolve(self, context): + return context.get('request')[self.key] + + def foo(bar): + return bar + + def modifier(requirement): + if type(requirement) is Requirement: + requirement = FromRequest.make_from(requirement) + return requirement + + runner = Runner(requirement_modifier=modifier) + runner.add(foo) + context = Context() + context.add({'bar': 'foo'}, provides='request') + compare(runner(context), expected='foo') + + def test_clone_requirement_modifier(self): + def modifier(requirement): pass + runner = Runner(requirement_modifier=modifier) + assert runner.clone().requirement_modifier is runner.requirement_modifier + + def test_add_clashing_requirement_modifier(self): + def modifier1(requirement): pass + runner1 = Runner(requirement_modifier=modifier1) + def modifier2(requirement): pass + runner2 = Runner(requirement_modifier=modifier2) + with ShouldRaise(TypeError('requirement_modifier must be identical')): + runner1 + runner2 + + def test_extend_other_runner_clashing_requirement_modifier(self): + def modifier1(requirement): pass + runner1 = Runner(requirement_modifier=modifier1) + def modifier2(requirement): pass + runner2 = Runner(requirement_modifier=modifier2) + with ShouldRaise(TypeError('requirement_modifier must be identical')): + runner1.extend(runner2) From 7f71c84ea58f11d746df0165dc35c70f4696237b Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 13 Apr 2020 10:43:42 +0100 Subject: [PATCH 095/159] Requirement.key no longer needs to be optional. --- mush/requirements.py | 2 +- mush/tests/test_requirements.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index f17eeb0..3f36531 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -48,7 +48,7 @@ class Requirement: """ def __init__(self, - key: ResourceKey = None, # XXX should not be default? + key: ResourceKey, name: str = None, type_: type = None, default: Any = missing, diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index a4b55c5..29c7062 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -84,13 +84,14 @@ def test_make_from_source_has_more_attributes(self): class SampleRequirement(Requirement): def __init__(self, foo): - super().__init__() + super().__init__('bar') self.foo = foo r = SampleRequirement('it') r_ = Requirement.make_from(r) assert r_ is not r + assert r_.key == 'bar' # while this is a bit ugly, it will hopefully do no harm: assert r_.foo == 'it' @@ -116,7 +117,7 @@ def test_no_special_name_via_getattr(self): compare(v.ops, []) def test_resolve(self): - r = Requirement() + r = Requirement('foo') with ShouldRaise(NotImplementedError): r.resolve(None) From 330b0c885de31525e47c6405640307e54c73cd4c Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 13 Apr 2020 11:32:04 +0100 Subject: [PATCH 096/159] add more typing information --- mush/asyncio.py | 4 ++-- mush/callpoints.py | 9 ++++++++- mush/declarations.py | 6 +++--- mush/extraction.py | 6 +++--- mush/modifier.py | 5 ++++- mush/runner.py | 25 ++++++++++++++----------- mush/types.py | 14 ++++++++++++-- 7 files changed, 46 insertions(+), 23 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 0d59359..de1d3b2 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -16,7 +16,7 @@ class AsyncFromSyncContext: def __init__(self, context, loop): - self.context = context + self.context: Context = context self.loop = loop self.remove = context.remove self.add = context.add @@ -142,7 +142,7 @@ async def __call__(self, context: Context = None): class Call(SyncCall): - async def resolve(self, context): + async def resolve(self, context: Context): result = context.get(self.key, missing) if result is missing: result = await context.call(self.key) diff --git a/mush/callpoints.py b/mush/callpoints.py index fb8627a..618f4ad 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,10 +1,15 @@ from collections import namedtuple +from typing import TYPE_CHECKING, Callable from .declarations import ( requires_nothing, returns as returns_declaration, returns_nothing ) from .extraction import extract_requires, extract_returns from .requirements import Call, name_or_repr +from .types import Requires, Returns + +if TYPE_CHECKING: + from .runner import Runner def do_nothing(): @@ -19,7 +24,9 @@ class CallPoint(object): next = None previous = None - def __init__(self, runner, obj, requires=None, returns=None, lazy=False): + def __init__(self, runner: 'Runner', obj: Callable, + requires: Requires = None, returns: Returns = None, + lazy: bool = False): requires = extract_requires(obj, requires, runner.modify_requirement) returns = extract_returns(obj, returns) if lazy: diff --git a/mush/declarations.py b/mush/declarations.py index 1b44bb3..5ba8b7d 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -4,7 +4,7 @@ from .markers import set_mush from .requirements import Requirement, Value, name_or_repr - +from .types import RequirementType, ReturnType VALID_DECORATION_TYPES = (type, str, Requirement) @@ -35,7 +35,7 @@ def __call__(self, obj): return obj -def requires(*args, **kw): +def requires(*args: RequirementType, **kw: RequirementType): """ Represents requirements for a particular callable. @@ -89,7 +89,7 @@ class returns(ReturnsType): type overridden. """ - def __init__(self, *args): + def __init__(self, *args: ReturnType): valid_decoration_types(*args) self.args = args diff --git a/mush/extraction.py b/mush/extraction.py index e9e35a4..3b0334a 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -14,7 +14,7 @@ ) from .requirements import Requirement, Value from .markers import missing, get_mush -from .types import RequirementModifier +from .types import RequirementModifier, Requires, Returns EMPTY = Parameter.empty #: For these types, prefer the name instead of the type. @@ -54,7 +54,7 @@ def default_requirement_type(requirement): def extract_requires(obj: Callable, - explicit: RequiresType = None, + explicit: Requires = None, modifier: RequirementModifier = default_requirement_type): # from annotations by_name = {} @@ -144,7 +144,7 @@ def extract_requires(obj: Callable, return RequiresType(by_name.values()) -def extract_returns(obj: Callable, explicit: ReturnsType = None): +def extract_returns(obj: Callable, explicit: Returns = None): if explicit is None: returns_ = get_mush(obj, 'returns', None) if returns_ is None: diff --git a/mush/modifier.py b/mush/modifier.py index 691f2fa..5217812 100644 --- a/mush/modifier.py +++ b/mush/modifier.py @@ -1,9 +1,11 @@ """ .. currentmodule:: mush """ +from typing import Callable from .callpoints import CallPoint from .markers import not_specified +from .types import Requires, Returns class Modifier(object): @@ -19,7 +21,8 @@ def __init__(self, runner, callpoint, label): else: self.labels = {label} - def add(self, obj, requires=None, returns=None, label=None, lazy=False): + def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, + label: str = None, lazy: bool = False): """ :param obj: The callable to be added. diff --git a/mush/runner.py b/mush/runner.py index a5e67cc..b18d498 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Optional from .callpoints import CallPoint from .context import Context, ResourceError @@ -8,6 +8,7 @@ from .modifier import Modifier from .plug import Plug from .requirements import name_or_repr, Lazy +from .types import Requires, Returns class Runner(object): @@ -17,9 +18,10 @@ class Runner(object): will be called. """ - start = end = None + start: Optional[CallPoint] = None + end: Optional[CallPoint] = None - def __init__(self, *objects, requirement_modifier=default_requirement_type): + def __init__(self, *objects: Callable, requirement_modifier=default_requirement_type): self.requirement_modifier = requirement_modifier self.labels = {} self.lazy = {} @@ -31,7 +33,8 @@ def modify_requirement(self, requirement): requirement = Lazy(requirement, provider=self.lazy[requirement.key]) return requirement - def add(self, obj, requires=None, returns=None, label=None, lazy=False): + def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, + label: str = None, lazy: bool = False): """ Add a callable to the runner. @@ -61,7 +64,7 @@ def add(self, obj, requires=None, returns=None, label=None, lazy=False): m.add(obj, requires, returns, label, lazy) return m - def add_label(self, label): + def add_label(self, label: str): """ Add a label to the the point currently at the end of the runner. """ @@ -108,7 +111,7 @@ def _copy_from(self, runner, start_point, end_point, added_using=None): self.end = previous_cloned_point - def extend(self, *objs): + def extend(self, *objs: Callable): """ Add the specified callables to this runner. @@ -122,9 +125,9 @@ def extend(self, *objs): self.add(obj) def clone(self, - start_label=None, end_label=None, - include_start=False, include_end=False, - added_using=None): + start_label: str = None, end_label: str = None, + include_start: bool = False, include_end: bool = False, + added_using: str = None): """ Return a copy of this :class:`Runner`. @@ -236,14 +239,14 @@ def replace(self, point = point.next - def __getitem__(self, label): + def __getitem__(self, label: str): """ Retrieve a :class:`~.modifier.Modifier` for a previous labelled point in the runner. """ return Modifier(self, self.labels[label], label) - def __add__(self, other): + def __add__(self, other: 'Runner'): """ Return a new :class:`Runner` containing the contents of the two :class:`Runner` instances being added together. diff --git a/mush/types.py b/mush/types.py index 0df1ee9..8592235 100644 --- a/mush/types.py +++ b/mush/types.py @@ -1,9 +1,19 @@ -from typing import NewType, Union, Hashable, Callable, Any, TYPE_CHECKING +from typing import NewType, Union, Hashable, Callable, Any, TYPE_CHECKING, List, Tuple if TYPE_CHECKING: from .context import Context + from .declarations import RequiresType, ReturnsType from .requirements import Requirement -ResourceKey = NewType('ResourceKey', Union[Hashable, Callable]) +RequirementType = Union['Requirement', type, str] +Requires = Union['RequiresType', + RequirementType, + List[RequirementType], + Tuple[RequirementType]] + +ReturnType = Union[type, str] +Returns = Union['ReturnsType', ReturnType, List[ReturnType], Tuple[ReturnType]] + +ResourceKey = Union[Hashable, Callable] ResourceValue = NewType('ResourceValue', Any) RequirementModifier = Callable[['Requirement'], 'Requirement'] From c18a572ba95273981ed5b3c177b500885e7a7132 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 13 Apr 2020 11:32:20 +0100 Subject: [PATCH 097/159] Don't unnecessarily recreate requires() objects --- mush/extraction.py | 9 ++++++--- mush/tests/test_runner.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mush/extraction.py b/mush/extraction.py index 3b0334a..7aa10a0 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -122,9 +122,12 @@ def extract_requires(obj: Callable, # explicit if explicit is not None: - if not isinstance(explicit, (list, tuple)): - explicit = (explicit,) - requires_ = requires(*explicit) + if isinstance(explicit, RequiresType): + requires_ = explicit + else: + if not isinstance(explicit, (list, tuple)): + explicit = (explicit,) + requires_ = requires(*explicit) _apply_requires(by_name, by_index, requires_) if not by_name: diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index b158301..0881fac 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -654,7 +654,7 @@ def job5(foo, bar): pass '', 'Still to call:', repr(job4)+' requires() returns_result_type() <-- 4', - repr(job5)+" requires(Value('foo'), Value('baz')) returns('bob')", + repr(job5)+" requires(Value('foo'), bar=Value('baz')) returns('bob')", )) compare(text, actual=repr(s.raised)) compare(text, actual=str(s.raised)) From 50f97b1ed03078382632637c922cdd36718f63fb Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 13 Apr 2020 11:35:06 +0100 Subject: [PATCH 098/159] rename module to match stdlib one --- mush/asyncio.py | 2 +- mush/callpoints.py | 2 +- mush/context.py | 2 +- mush/declarations.py | 2 +- mush/extraction.py | 2 +- mush/modifier.py | 2 +- mush/requirements.py | 2 +- mush/runner.py | 2 +- mush/{types.py => typing.py} | 0 9 files changed, 8 insertions(+), 8 deletions(-) rename mush/{types.py => typing.py} (100%) diff --git a/mush/asyncio.py b/mush/asyncio.py index de1d3b2..8ff6be9 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -10,7 +10,7 @@ from .declarations import RequiresType, ReturnsType from .extraction import default_requirement_type from .markers import get_mush, AsyncType -from .types import RequirementModifier +from .typing import RequirementModifier class AsyncFromSyncContext: diff --git a/mush/callpoints.py b/mush/callpoints.py index 618f4ad..269f16a 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -6,7 +6,7 @@ ) from .extraction import extract_requires, extract_returns from .requirements import Call, name_or_repr -from .types import Requires, Returns +from .typing import Requires, Returns if TYPE_CHECKING: from .runner import Runner diff --git a/mush/context.py b/mush/context.py index 76057d4..f1fe604 100644 --- a/mush/context.py +++ b/mush/context.py @@ -5,7 +5,7 @@ from .extraction import extract_requires, extract_returns, default_requirement_type from .markers import missing from .requirements import Requirement -from .types import ResourceKey, ResourceValue, RequirementModifier +from .typing import ResourceKey, ResourceValue, RequirementModifier NONE_TYPE = type(None) diff --git a/mush/declarations.py b/mush/declarations.py index 5ba8b7d..ed2fa02 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -4,7 +4,7 @@ from .markers import set_mush from .requirements import Requirement, Value, name_or_repr -from .types import RequirementType, ReturnType +from .typing import RequirementType, ReturnType VALID_DECORATION_TYPES = (type, str, Requirement) diff --git a/mush/extraction.py b/mush/extraction.py index 7aa10a0..0638358 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -14,7 +14,7 @@ ) from .requirements import Requirement, Value from .markers import missing, get_mush -from .types import RequirementModifier, Requires, Returns +from .typing import RequirementModifier, Requires, Returns EMPTY = Parameter.empty #: For these types, prefer the name instead of the type. diff --git a/mush/modifier.py b/mush/modifier.py index 5217812..16fa223 100644 --- a/mush/modifier.py +++ b/mush/modifier.py @@ -5,7 +5,7 @@ from .callpoints import CallPoint from .markers import not_specified -from .types import Requires, Returns +from .typing import Requires, Returns class Modifier(object): diff --git a/mush/requirements.py b/mush/requirements.py index 3f36531..6bd49f2 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -1,7 +1,7 @@ from copy import copy from typing import Any, Optional, List, TYPE_CHECKING, Callable -from .types import ResourceKey +from .typing import ResourceKey from .markers import missing, nonblocking if TYPE_CHECKING: diff --git a/mush/runner.py b/mush/runner.py index b18d498..adffc34 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -8,7 +8,7 @@ from .modifier import Modifier from .plug import Plug from .requirements import name_or_repr, Lazy -from .types import Requires, Returns +from .typing import Requires, Returns class Runner(object): diff --git a/mush/types.py b/mush/typing.py similarity index 100% rename from mush/types.py rename to mush/typing.py From 1058df5e37a35fb313d1bb8fc8a9f11b7a5785d3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 17 Apr 2020 08:24:38 +0100 Subject: [PATCH 099/159] fix infinite recursion when last callpoint is a context manager --- mush/asyncio.py | 3 ++- mush/runner.py | 3 ++- mush/tests/test_async_runner.py | 14 ++++++++++++++ mush/tests/test_runner.py | 19 +++++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 8ff6be9..f751832 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -135,7 +135,8 @@ async def __call__(self, context: Context = None): # If the context manager swallows an exception, # None should be returned, not the context manager: result = None - result = await self(context) + if context.point is not None: + result = await self(context) return result diff --git a/mush/runner.py b/mush/runner.py index adffc34..8b940b6 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -295,7 +295,8 @@ def __call__(self, context: Context = None): # If the context manager swallows an exception, # None should be returned, not the context manager: result = None - result = self(context) + if context.point is not None: + result = self(context) return result diff --git a/mush/tests/test_async_runner.py b/mush/tests/test_async_runner.py index fd773d0..ac65ee3 100644 --- a/mush/tests/test_async_runner.py +++ b/mush/tests/test_async_runner.py @@ -584,3 +584,17 @@ async def func(): call.cm2.exit(e), call.cm1.exit(e), ]) + + +@pytest.mark.asyncio +async def test_context_manager_is_last_callpoint(): + m = Mock() + CM = make_cm('CM', AsyncCM, m) + + runner = Runner(CM) + + compare(await runner(), expected=None) + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.cm.exit(None), + ]) diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 0881fac..49f0d82 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -864,6 +864,25 @@ def func2(obj1, obj2): call.cm1.exit(Exception, e) ]) + def test_context_manager_is_last_callpoint(self): + m = Mock() + + class CM(object): + def __enter__(self): + m.cm.enter() + def __exit__(self, type, obj, tb): + m.cm.exit() + + runner = Runner(CM) + result = runner() + compare(result, expected=None) + + compare(m.mock_calls, expected=[ + call.cm.enter(), + call.cm.exit(), + ]) + + def test_clone(self): m = Mock() class T1(object): pass From 0e36009cb236f49864d808257c597782e5dc6492 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sat, 18 Apr 2020 10:15:04 +0100 Subject: [PATCH 100/159] fix type definition for tuple of requirements --- mush/typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mush/typing.py b/mush/typing.py index 8592235..fc950f2 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -9,10 +9,10 @@ Requires = Union['RequiresType', RequirementType, List[RequirementType], - Tuple[RequirementType]] + Tuple[RequirementType, ...]] ReturnType = Union[type, str] -Returns = Union['ReturnsType', ReturnType, List[ReturnType], Tuple[ReturnType]] +Returns = Union['ReturnsType', ReturnType, List[ReturnType], Tuple[ReturnType, ...]] ResourceKey = Union[Hashable, Callable] ResourceValue = NewType('ResourceValue', Any) From 9c39177787a98c8872c64f6e3848d9fdeac4bcae Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 20 Apr 2020 13:32:04 +0100 Subject: [PATCH 101/159] Move to testfixtures mock facade This means you don't have to install the mock backport if you're on Python 3 and want to use mush's testing tools. --- mush/tests/helpers.py | 2 +- mush/tests/test_async_runner.py | 2 +- mush/tests/test_callpoints.py | 26 +++++++++++++------------- mush/tests/test_context.py | 2 +- mush/tests/test_plug.py | 2 +- mush/tests/test_requirements.py | 2 +- mush/tests/test_runner.py | 2 +- setup.py | 2 +- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py index 3abc429..b89189d 100644 --- a/mush/tests/helpers.py +++ b/mush/tests/helpers.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from functools import partial -from mock import Mock +from testfixtures.mock import Mock PY_VERSION = sys.version_info[:2] diff --git a/mush/tests/test_async_runner.py b/mush/tests/test_async_runner.py index ac65ee3..b5d75b1 100644 --- a/mush/tests/test_async_runner.py +++ b/mush/tests/test_async_runner.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import Mock, call +from testfixtures.mock import Mock, call import pytest from testfixtures import compare, ShouldRaise, Comparison as C diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 6fe8830..789af2b 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -1,8 +1,8 @@ from functools import update_wrapper from unittest import TestCase -from mock import Mock from testfixtures import compare +from testfixtures.mock import Mock, call from mush.callpoints import CallPoint from mush.declarations import requires, returns, RequiresType @@ -30,10 +30,10 @@ def foo(a1): pass rt = returns('bar') result = CallPoint(self.runner, foo, rq, rt)(self.context) compare(result, self.context.extract.return_value) - compare(tuple(self.context.extract.mock_calls[0].args), - expected=(foo, - RequiresType([Value.make(key='foo', name='a1')]), - rt)) + compare(self.context.extract.mock_calls, + expected=[call(foo, + RequiresType([Value.make(key='foo', name='a1')]), + rt)]) def test_extract_from_decorations(self): rq = requires('foo') @@ -45,10 +45,10 @@ def foo(a1): pass result = CallPoint(self.runner, foo)(self.context) compare(result, self.context.extract.return_value) - compare(tuple(self.context.extract.mock_calls[0].args), - expected=(foo, - RequiresType([Value.make(key='foo', name='a1')]), - returns('bar'))) + compare(self.context.extract.mock_calls, + expected=[call(foo, + RequiresType([Value.make(key='foo', name='a1')]), + returns('bar'))]) def test_extract_from_decorated_class(self): @@ -84,10 +84,10 @@ def foo(a1): pass point = CallPoint(self.runner, foo, requires('baz'), returns('bob')) result = point(self.context) compare(result, self.context.extract.return_value) - compare(tuple(self.context.extract.mock_calls[0].args), - expected=(foo, - RequiresType([Value.make(key='baz', name='a1')]), - returns('bob'))) + compare(self.context.extract.mock_calls, + expected=[call(foo, + RequiresType([Value.make(key='baz', name='a1')]), + returns('bob'))]) def test_repr_minimal(self): def foo(): pass diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 2f89f85..ae59807 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,8 +1,8 @@ from typing import Tuple, List from unittest import TestCase -from mock import Mock from testfixtures import ShouldRaise, compare +from testfixtures.mock import Mock from mush import ( Context, requires, returns, returns_mapping, Value, missing diff --git a/mush/tests/test_plug.py b/mush/tests/test_plug.py index 6179679..2826339 100644 --- a/mush/tests/test_plug.py +++ b/mush/tests/test_plug.py @@ -1,7 +1,7 @@ from unittest import TestCase -from mock import Mock, call from testfixtures import compare, ShouldRaise +from testfixtures.mock import Mock, call from mush import Plug, Runner, returns, requires from mush.plug import insert, ignore, append diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 29c7062..327c3b5 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -2,8 +2,8 @@ from unittest.case import TestCase import pytest -from mock import Mock from testfixtures import compare, ShouldRaise +from testfixtures.mock import Mock from mush import Context, Call, Value, missing, requires, ResourceError from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf, Like diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 49f0d82..d0005cb 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -1,6 +1,5 @@ from unittest import TestCase -from mock import Mock, call from mush.declarations import ( requires, returns, returns_mapping, replacement, original) @@ -10,6 +9,7 @@ ShouldRaise, compare ) +from testfixtures.mock import Mock, call def verify(runner, *expected): diff --git a/setup.py b/setup.py index a44bb0c..2b8a0ef 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ 'pytest-asyncio', 'pytest-cov', 'sybil', - 'testfixtures>=6.13' + 'testfixtures>=6.14.1' ], build=['sphinx', 'setuptools-git', 'wheel', 'twine'] )) From 01984b93857b89907fdc6f0408c7f41da11439b0 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 23 Apr 2020 12:55:58 +0100 Subject: [PATCH 102/159] move to circleci matrix stuff --- .circleci/config.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 339a0d8..33425cc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,17 +7,17 @@ common: &common jobs: - python/pip-run-tests: - name: python36 - image: circleci/python:3.6 - - python/pip-run-tests: - name: python38 - image: circleci/python:3.8 + matrix: + parameters: + image: + - circleci/python:3.6 + - circleci/python:3.7 + - circleci/python:3.8 - python/coverage: name: coverage requires: - - python36 - - python38 + - python/pip-run-tests - python/release: name: release From 92429a43a2336009754e8dfd32e85676ce1ab734 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 28 Apr 2020 07:57:18 +0100 Subject: [PATCH 103/159] No need for requirement_modifier to be private. --- mush/context.py | 6 +++--- mush/tests/test_context.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mush/context.py b/mush/context.py index f1fe604..ddb3a3b 100644 --- a/mush/context.py +++ b/mush/context.py @@ -30,7 +30,7 @@ class Context: point: CallPoint = None def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): - self._requirement_modifier = requirement_modifier + self.requirement_modifier = requirement_modifier self._store = {} self._requires_cache = {} self._returns_cache = {} @@ -92,7 +92,7 @@ def _resolve(self, obj, requires, args, kw, context): if requires is None: requires = extract_requires(obj, explicit=None, - modifier=self._requirement_modifier) + modifier=self.requirement_modifier) self._requires_cache[obj] = requires for requirement in requires: @@ -144,7 +144,7 @@ def get(self, key: ResourceKey, default=None): def nest(self, requirement_modifier: RequirementModifier = None): if requirement_modifier is None: - requirement_modifier = self._requirement_modifier + requirement_modifier = self.requirement_modifier nested = self.__class__(requirement_modifier) nested._parent = self nested._requires_cache = self._requires_cache diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index ae59807..a8321df 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -386,14 +386,14 @@ def test_nest_with_overridden_default_requirement_type(self): def modifier(): pass c1 = Context(modifier) c2 = c1.nest() - assert c2._requirement_modifier is modifier + assert c2.requirement_modifier is modifier def test_nest_with_explicit_default_requirement_type(self): def modifier1(): pass def modifier2(): pass c1 = Context(modifier1) c2 = c1.nest(modifier2) - assert c2._requirement_modifier is modifier2 + assert c2.requirement_modifier is modifier2 def test_nest_keeps_declarations_cache(self): c1 = Context() From b698f28b221cd5e23424e4919ddb1f2591e6a062 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 28 Apr 2020 13:06:59 +0100 Subject: [PATCH 104/159] None as a default turns out to be worse than raising an exception in practice --- mush/context.py | 8 ++++++-- mush/tests/test_async_context.py | 2 +- mush/tests/test_async_requirements.py | 2 +- mush/tests/test_context.py | 11 ++++++----- mush/tests/test_requirements.py | 2 +- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/mush/context.py b/mush/context.py index ddb3a3b..ebd038c 100644 --- a/mush/context.py +++ b/mush/context.py @@ -3,11 +3,12 @@ from .callpoints import CallPoint from .declarations import RequiresType, ReturnsType from .extraction import extract_requires, extract_returns, default_requirement_type -from .markers import missing +from .markers import missing, Marker from .requirements import Requirement from .typing import ResourceKey, ResourceValue, RequirementModifier NONE_TYPE = type(None) +unspecified = Marker('unspecified') class ResourceError(Exception): @@ -128,7 +129,7 @@ def call(self, obj: Callable, requires: RequiresType = None): resolving.send(requirement.resolve(self)) return obj(*args, **kw) - def get(self, key: ResourceKey, default=None): + def get(self, key: ResourceKey, default=unspecified): context = self while context is not None: @@ -140,6 +141,9 @@ def get(self, key: ResourceKey, default=None): self._store[key] = value return value + if default is unspecified: + raise ResourceError(f'No {key!r} in context', key) + return default def nest(self, requirement_modifier: RequirementModifier = None): diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 12fd9bc..22051ef 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -358,7 +358,7 @@ def foo(bar: Syncer('request')): context = Context() context.add({'bar': 'foo'}, provides='request') compare(await context.call(foo), expected='foo') - compare(context.get('request'), expected=None) + compare(context.get('request', default=None), expected=None) compare(context.get('response'), expected='foo') diff --git a/mush/tests/test_async_requirements.py b/mush/tests/test_async_requirements.py index e399ae5..68bd3ad 100644 --- a/mush/tests/test_async_requirements.py +++ b/mush/tests/test_async_requirements.py @@ -44,7 +44,7 @@ def bob(x: str = Call(foo, cache=False)): compare(await context.call(bob), expected='abc') compare(await context.call(bob), expected='abc') compare(called, expected=[1, 1]) - compare(context.get(foo), expected=None) + compare(context.get(foo, default=None), expected=None) @pytest.mark.asyncio async def test_parts_of_a_call(self): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index a8321df..b3cf3ef 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -359,14 +359,15 @@ def test_get_type(self): context = Context() context.add(['bar'], provides=List[str]) compare(context.get(List[str]), expected=['bar']) - compare(context.get(List[int]), expected=None) - compare(context.get(List), expected=None) + compare(context.get(List[int], default=None), expected=None) + compare(context.get(List, default=None), expected=None) # nb: this might be surprising: - compare(context.get(list), expected=None) + compare(context.get(list, default=None), expected=None) def test_get_missing(self): context = Context() - compare(context.get('foo'), expected=None) + with ShouldRaise(ResourceError("No 'foo' in context", 'foo')): + context.get('foo') def test_nest(self): c1 = Context() @@ -379,7 +380,7 @@ def test_nest(self): compare(c2.get('b'), expected='b') compare(c2.get('c'), expected='d') compare(c1.get('a'), expected='a') - compare(c1.get('b'), expected=None) + compare(c1.get('b', default=None), expected=None) compare(c1.get('c'), expected='c') def test_nest_with_overridden_default_requirement_type(self): diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 327c3b5..c0bee99 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -230,7 +230,7 @@ def bob(x: str = Call(foo, cache=False)): compare(context.call(bob), expected='abc') compare(context.call(bob), expected='abc') compare(called, expected=[1, 1]) - compare(context.get(foo), expected=None) + compare(context.get(foo, default=None), expected=None) def test_parts_of_a_call(self): context = Context() From 9d74b149909b4a8f151db6311516401dd4806fe1 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 28 Apr 2020 13:07:05 +0100 Subject: [PATCH 105/159] whitespace --- mush/requirements.py | 2 +- mush/tests/test_runner.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index 6bd49f2..e5a3467 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -142,7 +142,7 @@ class Value(Requirement): ever use this. """ - def __init__(self, key: ResourceKey=None, *, type_: type = None, default: Any = missing): + def __init__(self, key: ResourceKey = None, *, type_: type = None, default: Any = missing): if isinstance(key, type): if type_ is not None: raise TypeError('type_ cannot be specified if key is a type') diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index d0005cb..757d991 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -882,7 +882,6 @@ def __exit__(self, type, obj, tb): call.cm.exit(), ]) - def test_clone(self): m = Mock() class T1(object): pass From 1c3efd7e8eb1f33c1d8933eae1554b579e6938f1 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 12 May 2020 08:16:26 +0100 Subject: [PATCH 106/159] Remove Call requirements. Not sure tight coupling this entails is to be encouraged. --- mush/__init__.py | 3 +- mush/asyncio.py | 17 +------- mush/callpoints.py | 2 +- mush/requirements.py | 23 +---------- mush/tests/test_async_requirements.py | 59 --------------------------- mush/tests/test_requirements.py | 54 +----------------------- 6 files changed, 6 insertions(+), 152 deletions(-) delete mode 100644 mush/tests/test_async_requirements.py diff --git a/mush/__init__.py b/mush/__init__.py index c475076..87f9d00 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -5,12 +5,11 @@ from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing, nonblocking, blocking from .plug import Plug -from .requirements import Requirement, Value, Call, AnyOf, Like +from .requirements import Requirement, Value, AnyOf, Like from .runner import Runner, ContextError __all__ = [ 'AnyOf', - 'Call', 'Context', 'ContextError', 'Like', diff --git a/mush/asyncio.py b/mush/asyncio.py index f751832..bfee086 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -1,11 +1,9 @@ import asyncio from functools import partial -from types import FunctionType from typing import Callable from . import ( - Context as SyncContext, Runner as SyncRunner, Call as SyncCall, - missing, ResourceError, ContextError + Context as SyncContext, Runner as SyncRunner, ResourceError, ContextError ) from .declarations import RequiresType, ReturnsType from .extraction import default_requirement_type @@ -141,15 +139,4 @@ async def __call__(self, context: Context = None): return result -class Call(SyncCall): - - async def resolve(self, context: Context): - result = context.get(self.key, missing) - if result is missing: - result = await context.call(self.key) - if self.cache: - context.add(result, provides=self.key) - return result - - -__all__ = ['Context', 'Runner', 'Call'] +__all__ = ['Context', 'Runner'] diff --git a/mush/callpoints.py b/mush/callpoints.py index 269f16a..9ef55d7 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -5,7 +5,7 @@ requires_nothing, returns as returns_declaration, returns_nothing ) from .extraction import extract_requires, extract_returns -from .requirements import Call, name_or_repr +from .requirements import name_or_repr from .typing import Requires, Returns if TYPE_CHECKING: diff --git a/mush/requirements.py b/mush/requirements.py index e5a3467..3fcd934 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -1,5 +1,5 @@ from copy import copy -from typing import Any, Optional, List, TYPE_CHECKING, Callable +from typing import Any, Optional, List, TYPE_CHECKING from .typing import ResourceKey from .markers import missing, nonblocking @@ -154,27 +154,6 @@ def resolve(self, context: 'Context'): return context.get(self.key, self.default) -class Call(Requirement): - """ - A requirement that is resolved by calling something. - - If ``cache`` is ``True``, then the result of that call will be cached - for the duration of the context in which this requirement is resolved. - """ - - def __init__(self, obj: Callable, *, cache: bool = True): - super().__init__(obj) - self.cache: bool = cache - - def resolve(self, context): - result = context.get(self.key, missing) - if result is missing: - result = context.call(self.key) - if self.cache: - context.add(result, provides=self.key) - return result - - class AnyOf(Requirement): """ A requirement that is resolved by any of the specified keys. diff --git a/mush/tests/test_async_requirements.py b/mush/tests/test_async_requirements.py deleted file mode 100644 index 68bd3ad..0000000 --- a/mush/tests/test_async_requirements.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest -from testfixtures import compare - -from mush.asyncio import Context, Call - - -class TestCall: - - @pytest.mark.asyncio - async def test_resolve(self): - context = Context() - - called = [] - - async def foo(bar: str): - called.append(1) - return bar+'b' - - async def bob(x: str = Call(foo)): - return x+'c' - - context.add('a', provides='bar') - - compare(await context.call(bob), expected='abc') - compare(await context.call(bob), expected='abc') - compare(called, expected=[1]) - compare(context.get(foo), expected='ab') - - @pytest.mark.asyncio - async def test_resolve_without_caching(self): - context = Context() - - called = [] - - def foo(bar: str): - called.append(1) - return bar+'b' - - def bob(x: str = Call(foo, cache=False)): - return x+'c' - - context.add('a', provides='bar') - - compare(await context.call(bob), expected='abc') - compare(await context.call(bob), expected='abc') - compare(called, expected=[1, 1]) - compare(context.get(foo, default=None), expected=None) - - @pytest.mark.asyncio - async def test_parts_of_a_call(self): - context = Context() - - async def foo(): - return {'a': 'b'} - - async def bob(x: str = Call(foo)['a']): - return x+'c' - - compare(await context.call(bob), expected='bc') diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index c0bee99..4f84041 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -5,7 +5,7 @@ from testfixtures import compare, ShouldRaise from testfixtures.mock import Mock -from mush import Context, Call, Value, missing, requires, ResourceError +from mush import Context, Value, missing, requires, ResourceError from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf, Like from .helpers import Type1 @@ -192,58 +192,6 @@ def test_passed_missing(self): expected=1) -class TestCall: - - def test_resolve(self): - context = Context() - - called = [] - - def foo(bar: str): - called.append(1) - return bar+'b' - - def bob(x: str = Call(foo)): - return x+'c' - - context.add('a', provides='bar') - - compare(context.call(bob), expected='abc') - compare(context.call(bob), expected='abc') - compare(called, expected=[1]) - compare(context.get(foo), expected='ab') - - def test_resolve_without_caching(self): - context = Context() - - called = [] - - def foo(bar: str): - called.append(1) - return bar+'b' - - def bob(x: str = Call(foo, cache=False)): - return x+'c' - - context.add('a', provides='bar') - - compare(context.call(bob), expected='abc') - compare(context.call(bob), expected='abc') - compare(called, expected=[1, 1]) - compare(context.get(foo, default=None), expected=None) - - def test_parts_of_a_call(self): - context = Context() - - def foo(): - return {'a': 'b'} - - def bob(x: str = Call(foo)['a']): - return x+'c' - - compare(context.call(bob), expected='bc') - - class TestAnyOf: def test_first(self): From a20374f74b25c00e48a49688bbee0ca25ef87394 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 5 Jun 2020 07:12:06 +0100 Subject: [PATCH 107/159] ignore ellipsis examples too --- .coveragerc | 1 + 1 file changed, 1 insertion(+) diff --git a/.coveragerc b/.coveragerc index 1afc40e..e7f3765 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,6 +9,7 @@ exclude_lines = # stuff that we don't worry about pass + ... __name__ == '__main__' # circular references needed for type checking: From 964e2088fbda9401c9e062dca50839f0fc6fe58d Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 5 Jun 2020 07:12:57 +0100 Subject: [PATCH 108/159] checkpoint: just starting requirement extraction --- mush/context.py | 255 ++++++------ mush/requirements.py | 50 +-- mush/tests/test_context.py | 818 +++++++++++++++++++------------------ 3 files changed, 596 insertions(+), 527 deletions(-) diff --git a/mush/context.py b/mush/context.py index ebd038c..7061055 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,11 +1,11 @@ -from typing import Optional, Callable +from typing import Optional, Callable, Hashable, Type from .callpoints import CallPoint from .declarations import RequiresType, ReturnsType from .extraction import extract_requires, extract_returns, default_requirement_type from .markers import missing, Marker from .requirements import Requirement -from .typing import ResourceKey, ResourceValue, RequirementModifier +from .typing import ResourceValue, RequirementModifier NONE_TYPE = type(None) unspecified = Marker('unspecified') @@ -16,141 +16,166 @@ class ResourceError(Exception): An exception raised when there is a problem with a `ResourceKey`. """ - def __init__(self, message: str, key: ResourceKey, requirement: Requirement = None): + def __init__(self, message: str, type_: Type = None, identifier: Hashable = None): super().__init__(message) - #: The key for the problematic resource. - self.key: ResourceKey = key - #: The requirement that caused this exception. - self.requirement: Requirement = requirement + #: The type for the problematic resource. + self.type: Type = type_ + #: The identifier for the problematic resource. + self.identifier: Hashable = identifier + # #: The requirement that caused this exception. + # self.requirement: Requirement = requirement + + +class ResourceKey(tuple): + + @property + def type(self): + return self[0] + + @property + def identifier(self): + return self[1] + + def __repr__(self): + if self.type is None: + return repr(self.identifier) + elif self.identifier is None: + return repr(self.type) + return f'{self.type!r}, {self.identifier!r}' class Context: "Stores resources for a particular run." - _parent: 'Context' = None - point: CallPoint = None + # _parent: 'Context' = None + # point: CallPoint = None - def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): - self.requirement_modifier = requirement_modifier + def __init__(self): self._store = {} - self._requires_cache = {} - self._returns_cache = {} + self._seen_types = set() + self._seen_identifiers = set() + # self._requires_cache = {} + # self._returns_cache = {} def add(self, - resource: Optional[ResourceValue] = None, - provides: Optional[ResourceKey] = None): + resource: ResourceValue, + provides: Optional[Type] = missing, + identifier: Hashable = None): """ Add a resource to the context. Optionally specify what the resource provides. """ - if provides is None: + if provides is missing: provides = type(resource) - if provides is NONE_TYPE: - raise ValueError('Cannot add None to context') - if provides in self._store: - raise ResourceError(f'Context already contains {provides!r}', provides) - self._store[provides] = resource - - def remove(self, key: ResourceKey, *, strict: bool = True): - """ - Remove the specified resource key from the context. - - If ``strict``, then a :class:`ResourceError` will be raised if the - specified resource is not present in the context. - """ - if strict and key not in self._store: - raise ResourceError(f'Context does not contain {key!r}', key) - self._store.pop(key, None) - + to_add = [ResourceKey((provides, identifier))] + if identifier and provides: + to_add.append(ResourceKey((None, identifier))) + for key in to_add: + if key in self._store: + raise ResourceError(f'Context already contains {key!r}', *key) + self._store[key] = resource + + # def remove(self, key: ResourceKey, *, strict: bool = True): + # """ + # Remove the specified resource key from the context. + # + # If ``strict``, then a :class:`ResourceError` will be raised if the + # specified resource is not present in the context. + # """ + # if strict and key not in self._store: + # raise ResourceError(f'Context does not contain {key!r}', key) + # self._store.pop(key, None) + # def __repr__(self): bits = [] - for type, value in sorted(self._store.items(), key=lambda o: repr(o)): - bits.append('\n %r: %r' % (type, value)) + for key, value in sorted(self._store.items(), key=lambda o: repr(o)): + bits.append(f'\n {key!r}: {value!r}') if bits: bits.append('\n') - return '' % ''.join(bits) - - def _process(self, obj, result, returns): - if returns is None: - returns = self._returns_cache.get(obj) - if returns is None: - returns = extract_returns(obj, explicit=None) - self._returns_cache[obj] = returns - - for type, obj in returns.process(result): - self.add(obj, type) - - def extract(self, obj: Callable, requires: RequiresType = None, returns: ReturnsType = None): - result = self.call(obj, requires) - self._process(obj, result, returns) - return result - - def _resolve(self, obj, requires, args, kw, context): - - if requires is None: - requires = self._requires_cache.get(obj) - if requires is None: - requires = extract_requires(obj, - explicit=None, - modifier=self.requirement_modifier) - self._requires_cache[obj] = requires - - for requirement in requires: - o = yield requirement - - if o is not requirement.default: - for op in requirement.ops: - o = op(o) - if o is missing: - o = requirement.default - break - - if o is missing: - key = requirement.key - if isinstance(key, type) and issubclass(key, Context): - o = context - else: - raise ResourceError(f'No {requirement!r} in context', - key, requirement) - - if requirement.target is None: - args.append(o) - else: - kw[requirement.target] = o - - yield + return f"" + # + # def _process(self, obj, result, returns): + # if returns is None: + # returns = self._returns_cache.get(obj) + # if returns is None: + # returns = extract_returns(obj, explicit=None) + # self._returns_cache[obj] = returns + # + # for type, obj in returns.process(result): + # self.add(obj, type) + # + # def extract(self, obj: Callable, requires: RequiresType = None, returns: ReturnsType = None): + # result = self.call(obj, requires) + # self._process(obj, result, returns) + # return result + # + # def _resolve(self, obj, requires, args, kw, context): + # + # if requires is None: + # requires = self._requires_cache.get(obj) + # if requires is None: + # requires = extract_requires(obj, + # explicit=None, + # modifier=self.requirement_modifier) + # self._requires_cache[obj] = requires + # + # for requirement in requires: + # o = yield requirement + # + # if o is not requirement.default: + # for op in requirement.ops: + # o = op(o) + # if o is missing: + # o = requirement.default + # break + # + # if o is missing: + # key = requirement.key + # if isinstance(key, type) and issubclass(key, Context): + # o = context + # else: + # raise ResourceError(f'No {requirement!r} in context', + # key, requirement) + # + # if requirement.target is None: + # args.append(o) + # else: + # kw[requirement.target] = o + # + # yield def call(self, obj: Callable, requires: RequiresType = None): args = [] kw = {} - resolving = self._resolve(obj, requires, args, kw, self) - for requirement in resolving: - resolving.send(requirement.resolve(self)) - return obj(*args, **kw) + # resolving = self._resolve(obj, requires, args, kw, self) + # for requirement in resolving: + # resolving.send(requirement.resolve(self)) - def get(self, key: ResourceKey, default=unspecified): - context = self - - while context is not None: - value = context._store.get(key, missing) - if value is missing: - context = context._parent - else: - if context is not self: - self._store[key] = value - return value - - if default is unspecified: - raise ResourceError(f'No {key!r} in context', key) - - return default - - def nest(self, requirement_modifier: RequirementModifier = None): - if requirement_modifier is None: - requirement_modifier = self.requirement_modifier - nested = self.__class__(requirement_modifier) - nested._parent = self - nested._requires_cache = self._requires_cache - nested._returns_cache = self._returns_cache - return nested + return obj(*args, **kw) + # + # def get(self, key: ResourceKey, default=unspecified): + # context = self + # + # while context is not None: + # value = context._store.get(key, missing) + # if value is missing: + # context = context._parent + # else: + # if context is not self: + # self._store[key] = value + # return value + # + # if default is unspecified: + # raise ResourceError(f'No {key!r} in context', key) + # + # return default + # + # def nest(self, requirement_modifier: RequirementModifier = None): + # if requirement_modifier is None: + # requirement_modifier = self.requirement_modifier + # nested = self.__class__(requirement_modifier) + # nested._parent = self + # nested._requires_cache = self._requires_cache + # nested._returns_cache = self._returns_cache + # return nested diff --git a/mush/requirements.py b/mush/requirements.py index 3fcd934..cf20631 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -1,7 +1,7 @@ from copy import copy from typing import Any, Optional, List, TYPE_CHECKING -from .typing import ResourceKey +# from .typing import ResourceKey from .markers import missing, nonblocking if TYPE_CHECKING: @@ -47,24 +47,24 @@ class Requirement: The requirement for an individual parameter of a callable. """ - def __init__(self, - key: ResourceKey, - name: str = None, - type_: type = None, - default: Any = missing, - target: str = None): - #: The resource key needed for this parameter. - self.key: Optional[ResourceKey] = key - #: The name of this parameter in the callable's signature. - self.name: Optional[str] = name - #: The type required for this parameter. - self.type: Optional[type] = type_ - #: The default for this parameter, should the required resource be unavailable. - self.default: Optional[Any] = default - #: Any operations to be performed on the resource after it - #: has been obtained. - self.ops: List['Op'] = [] - self.target: Optional[str] = target + # def __init__(self, + # key: ResourceKey, + # name: str = None, + # type_: type = None, + # default: Any = missing, + # target: str = None): + # #: The resource key needed for this parameter. + # self.key: Optional[ResourceKey] = key + # #: The name of this parameter in the callable's signature. + # self.name: Optional[str] = name + # #: The type required for this parameter. + # self.type: Optional[type] = type_ + # #: The default for this parameter, should the required resource be unavailable. + # self.default: Optional[Any] = default + # #: Any operations to be performed on the resource after it + # #: has been obtained. + # self.ops: List['Op'] = [] + # self.target: Optional[str] = target @classmethod def make(cls, **attrs): @@ -142,12 +142,12 @@ class Value(Requirement): ever use this. """ - def __init__(self, key: ResourceKey = None, *, type_: type = None, default: Any = missing): - if isinstance(key, type): - if type_ is not None: - raise TypeError('type_ cannot be specified if key is a type') - type_ = key - super().__init__(key, type_=type_, default=default) + # def __init__(self, key: ResourceKey = None, *, type_: type = None, default: Any = missing): + # if isinstance(key, type): + # if type_ is not None: + # raise TypeError('type_ cannot be specified if key is a type') + # type_ = key + # super().__init__(key, type_=type_, default=default) @nonblocking def resolve(self, context: 'Context'): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index b3cf3ef..88c7d8a 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,450 +1,494 @@ -from typing import Tuple, List -from unittest import TestCase - +# from typing import Tuple, List +# from testfixtures import ShouldRaise, compare -from testfixtures.mock import Mock - +# from testfixtures.mock import Mock +# from mush import ( Context, requires, returns, returns_mapping, Value, missing ) from mush.context import ResourceError -from mush.declarations import RequiresType, requires_nothing, returns_nothing -from mush.requirements import Requirement +# from mush.declarations import RequiresType, requires_nothing, returns_nothing +# from mush.requirements import Requirement from .helpers import TheType -class TestContext(TestCase): +class TestContext(object): - def test_simple(self): + def test_add_by_inferred_type(self): obj = TheType() context = Context() context.add(obj) - compare(context._store, expected={TheType: obj}) + compare(context._store, expected={(TheType, None): obj}) expected = ( ": \n" "}>" ) - self.assertEqual(repr(context), expected) - self.assertEqual(str(context), expected) + compare(expected, actual=repr(context)) + compare(expected, actual=str(context)) + + def test_add_by_identifier(self): + obj = TheType() + context = Context() + context.add(obj, identifier='my label') + + compare(context._store, expected={ + (TheType, 'my label'): obj, + (None, 'my label'): obj, + }) + expected = ("\n" + " , 'my label': \n" + "}>") + compare(expected, actual=repr(context)) + compare(expected, actual=str(context)) - def test_type_as_string(self): + def test_add_by_identifier_only(self): obj = TheType() context = Context() - context.add(obj, provides='my label') + context.add(obj, provides=None, identifier='my label') + compare(context._store, expected={(None, 'my label'): obj}) expected = ("\n" "}>") - compare(context._store, expected={'my label': obj}) - self.assertEqual(repr(context), expected) - self.assertEqual(str(context), expected) + compare(expected, actual=repr(context)) + compare(expected, actual=str(context)) def test_explicit_type(self): class T2(object): pass obj = TheType() context = Context() context.add(obj, provides=T2) - compare(context._store, expected={T2: obj}) + compare(context._store, expected={(T2, None): obj}) expected = ("\n" "}>") - compare(repr(context), expected) - compare(str(context), expected) - - def test_no_resolver_or_provides(self): - context = Context() - with ShouldRaise(ValueError('Cannot add None to context')): - context.add() - compare(context._store, expected={}) + compare(expected, actual=repr(context)) + compare(expected, actual=str(context)) - def test_clash(self): + def test_clash_just_type(self): obj1 = TheType() obj2 = TheType() context = Context() context.add(obj1, TheType) - with ShouldRaise(ResourceError('Context already contains '+repr(TheType), - key=TheType)): + with ShouldRaise(ResourceError(f'Context already contains {TheType!r}', + type_=TheType)): context.add(obj2, TheType) - def test_clash_string_type(self): + def test_clash_just_identifier(self): obj1 = TheType() obj2 = TheType() context = Context() - context.add(obj1, provides='my label') + context.add(obj1, provides=None, identifier='my label') with ShouldRaise(ResourceError("Context already contains 'my label'", - key='my label')): - context.add(obj2, provides='my label') + identifier='my label')): + context.add(obj2, provides=None, identifier='my label') - def test_add_none(self): + def test_clash_identifier_only_with_identifier_plus_type(self): + obj1 = TheType() + obj2 = TheType() context = Context() - with ShouldRaise(ValueError('Cannot add None to context')): - context.add(None, type(None)) + context.add(obj1, provides=None, identifier='my label') + with ShouldRaise(ResourceError("Context already contains 'my label'", + identifier='my label')): + context.add(obj2, identifier='my label') - def test_add_none_with_type(self): + def test_clash_identifier_plus_type_with_identifier_only(self): + obj1 = TheType() + obj2 = TheType() context = Context() - context.add(None, TheType) - compare(context._store, expected={TheType: None}) + context.add(obj1, identifier='my label') + with ShouldRaise(ResourceError("Context already contains 'my label'", + identifier='my label')): + context.add(obj2, provides=None, identifier='my label') - def test_call_basic(self): + def test_call_no_params(self): def foo(): return 'bar' context = Context() - result = context.call(foo, requires_nothing) - compare(result, 'bar') - - def test_call_requires_string(self): - def foo(obj): - return obj - context = Context() - context.add('bar', 'baz') - result = context.call(foo, requires('baz')) - compare(result, 'bar') - compare({'baz': 'bar'}, actual=context._store) - - def test_call_requires_type(self): - def foo(obj): - return obj - context = Context() - context.add('bar', TheType) - result = context.call(foo, requires(TheType)) - compare(result, 'bar') - compare({TheType: 'bar'}, actual=context._store) - - def test_call_requires_missing(self): - def foo(obj): return obj - context = Context() - with ShouldRaise(ResourceError( - "No Value(TheType) in context", - key=TheType, - requirement=Value(TheType), - )): - context.call(foo, requires(TheType)) - - def test_call_requires_item_missing(self): - def foo(obj): return obj - context = Context() - context.add({}, TheType) - with ShouldRaise(ResourceError( - "No Value(TheType)['foo'] in context", - key=TheType, - requirement=Value(TheType)['foo'], - )): - context.call(foo, requires(Value(TheType)['foo'])) - - def test_call_requires_accidental_tuple(self): - def foo(obj): return obj - context = Context() - with ShouldRaise(TypeError( - "(, " - ") " - "is not a valid decoration type" - )): - context.call(foo, requires((TheType, TheType))) - - def test_call_requires_named_parameter(self): - def foo(x, y): - return x, y - context = Context() - context.add('foo', TheType) - context.add('bar', 'baz') - result = context.call(foo, requires(y='baz', x=TheType)) - compare(result, ('foo', 'bar')) - compare({TheType: 'foo', - 'baz': 'bar'}, - actual=context._store) - - def test_call_requires_optional_present(self): - def foo(x=1): - return x - context = Context() - context.add(2, TheType) - result = context.call(foo, requires(TheType)) - compare(result, 2) - compare({TheType: 2}, actual=context._store) - - def test_call_requires_optional_missing(self): - def foo(x: TheType = 1): - return x - context = Context() - result = context.call(foo) - compare(result, 1) - - def test_call_requires_optional_override_source_and_default(self): - def foo(x=1): - return x - context = Context() - context.add(2, provides='x') - result = context.call(foo, requires(x=Value('y', default=3))) - compare(result, expected=3) - - def test_call_requires_optional_string(self): - def foo(x:'foo'=1): - return x - context = Context() - context.add(2, 'foo') - result = context.call(foo) - compare(result, 2) - compare({'foo': 2}, actual=context._store) - - def test_call_requires_item(self): - def foo(x): - return x - context = Context() - context.add(dict(bar='baz'), 'foo') - result = context.call(foo, requires(Value('foo')['bar'])) - compare(result, 'baz') - - def test_call_requires_attr(self): - def foo(x): - return x - m = Mock() - context = Context() - context.add(m, 'foo') - result = context.call(foo, requires(Value('foo').bar)) - compare(result, m.bar) - - def test_call_requires_item_attr(self): - def foo(x): - return x - m = Mock() - m.bar= dict(baz='bob') - context = Context() - context.add(m, provides='foo') - result = context.call(foo, requires(Value('foo').bar['baz'])) - compare(result, 'bob') - - def test_call_requires_optional_item_missing(self): - def foo(x: str = Value('foo', default=1)['bar']): - return x - context = Context() result = context.call(foo) - compare(result, 1) + compare(result, 'bar') - def test_call_requires_optional_item_present(self): - def foo(x: str = Value('foo', default=1)['bar']): - return x + def test_call_type_from_annotation(self): + def foo(baz: str): + return baz context = Context() - context.add(dict(bar='baz'), provides='foo') + context.add('bar') result = context.call(foo) - compare(result, 'baz') + compare(result, expected='bar') - def test_call_extract_requirements(self): - def foo(param): - return param + def test_call_identifier_from_annotation(self): + def foo(baz: str): + return baz context = Context() - context.add('bar', 'param') + context.add('bar', provides=str) + context.add('bob', identifier='baz') result = context.call(foo) - compare(result, 'bar') + compare(result, expected='bob') - def test_call_extract_no_requirements(self): - def foo(): - pass + def test_call_identifier_and_type_from_annotation(self): + def foo(baz: str): + return baz context = Context() + context.add('bar', provides=str) + context.add('bob', identifier='baz') + context.add('foo', provides=str, identifier='baz') result = context.call(foo) - compare(result, expected=None) - - def test_call_caches_requires(self): - context = Context() - def foo(): pass - context.call(foo) - compare(context._requires_cache[foo], expected=RequiresType()) - - def test_call_explict_explicit_requires_no_cache(self): - context = Context() - context.add('a') - def foo(*args): - return args - result = context.call(foo, requires(str)) - compare(result, ('a',)) - compare(context._requires_cache, expected={}) - - def test_extract_minimal(self): - o = TheType() - def foo() -> TheType: - return o - context = Context() - result = context.extract(foo) - assert result is o - compare({TheType: o}, actual=context._store) - compare(context._requires_cache[foo], expected=RequiresType()) - compare(context._returns_cache[foo], expected=returns(TheType)) - - def test_extract_maximal(self): - def foo(*args): - return args - context = Context() - context.add('a') - result = context.extract(foo, requires(str), returns(Tuple[str])) - compare(result, expected=('a',)) - compare({ - str: 'a', - Tuple[str]: ('a',), - }, actual=context._store) - compare(context._requires_cache, expected={}) - compare(context._returns_cache, expected={}) - - def test_returns_single(self): - def foo(): - return 'bar' - context = Context() - result = context.extract(foo, requires_nothing, returns(TheType)) - compare(result, 'bar') - compare({TheType: 'bar'}, actual=context._store) - - def test_returns_sequence(self): - def foo(): - return 1, 2 - context = Context() - result = context.extract(foo, requires_nothing, returns('foo', 'bar')) - compare(result, (1, 2)) - compare({'foo': 1, 'bar': 2}, - actual=context._store) - - def test_returns_mapping(self): - def foo(): - return {'foo': 1, 'bar': 2} - context = Context() - result = context.extract(foo, requires_nothing, returns_mapping()) - compare(result, {'foo': 1, 'bar': 2}) - compare({'foo': 1, 'bar': 2}, - actual=context._store) - - def test_ignore_return(self): - def foo(): - return 'bar' - context = Context() - result = context.extract(foo, requires_nothing, returns_nothing) - compare(result, 'bar') - compare({}, context._store) - - def test_ignore_non_iterable_return(self): - def foo(): pass - context = Context() - result = context.extract(foo) - compare(result, expected=None) - compare(context._store, expected={}) - - def test_context_contains_itself(self): - context = Context() - def return_context(context: Context): - return context - assert context.call(return_context) is context - - def test_remove(self): - context = Context() - context.add('foo') - context.remove(str) - compare(context._store, expected={}) - - def test_remove_not_there_strict(self): - context = Context() - with ShouldRaise(ResourceError("Context does not contain 'foo'", - key='foo')): - context.remove('foo') - compare(context._store, expected={}) - - def test_remove_not_there_not_strict(self): - context = Context() - context.remove('foo', strict=False) - compare(context._store, expected={}) - - def test_get_present(self): - context = Context() - context.add('bar', provides='foo') - compare(context.get('foo'), expected='bar') - - def test_get_type(self): - context = Context() - context.add(['bar'], provides=List[str]) - compare(context.get(List[str]), expected=['bar']) - compare(context.get(List[int], default=None), expected=None) - compare(context.get(List, default=None), expected=None) - # nb: this might be surprising: - compare(context.get(list, default=None), expected=None) - - def test_get_missing(self): - context = Context() - with ShouldRaise(ResourceError("No 'foo' in context", 'foo')): - context.get('foo') - - def test_nest(self): - c1 = Context() - c1.add('a', provides='a') - c1.add('c', provides='c') - c2 = c1.nest() - c2.add('b', provides='b') - c2.add('d', provides='c') - compare(c2.get('a'), expected='a') - compare(c2.get('b'), expected='b') - compare(c2.get('c'), expected='d') - compare(c1.get('a'), expected='a') - compare(c1.get('b', default=None), expected=None) - compare(c1.get('c'), expected='c') - - def test_nest_with_overridden_default_requirement_type(self): - def modifier(): pass - c1 = Context(modifier) - c2 = c1.nest() - assert c2.requirement_modifier is modifier - - def test_nest_with_explicit_default_requirement_type(self): - def modifier1(): pass - def modifier2(): pass - c1 = Context(modifier1) - c2 = c1.nest(modifier2) - assert c2.requirement_modifier is modifier2 - - def test_nest_keeps_declarations_cache(self): - c1 = Context() - c2 = c1.nest() - assert c2._requires_cache is c1._requires_cache - assert c2._returns_cache is c1._returns_cache - - def test_custom_requirement(self): - - class FromRequest(Requirement): - def resolve(self, context): - return context.get('request')[self.key] - - def foo(bar: FromRequest('bar')): - return bar - - context = Context() - context.add({'bar': 'foo'}, provides='request') - compare(context.call(foo), expected='foo') - - def test_custom_requirement_returns_missing(self): - - class FromRequest(Requirement): - def resolve(self, context): - return context.get('request').get(self.key, missing) - - def foo(bar: FromRequest('bar')): - pass - - context = Context() - context.add({}, provides='request') - with ShouldRaise(ResourceError("No FromRequest('bar') in context", - key='bar', - requirement=FromRequest.make(key='bar', name='bar'))): - compare(context.call(foo)) - - def test_default_custom_requirement(self): - - class FromRequest(Requirement): - def resolve(self, context): - return context.get('request')[self.key] - - def foo(bar): - return bar - - def modifier(requirement): - if type(requirement) is Requirement: - requirement = FromRequest.make_from(requirement) - return requirement - - context = Context(requirement_modifier=modifier) - context.add({'bar': 'foo'}, provides='request') - compare(context.call(foo), expected='foo') + compare(result, expected='foo') + + # def test_call_requires_string(self): + # def foo(obj): + # return obj + # context = Context() + # context.add('bar', identifier='baz') + # result = context.call(foo, requires('baz')) + # compare(result, expected='bar') + # compare({'baz': 'bar'}, actual=context._store) + +# def test_call_requires_type(self): +# def foo(obj): +# return obj +# context = Context() +# context.add('bar', TheType) +# result = context.call(foo, requires(TheType)) +# compare(result, 'bar') +# compare({TheType: 'bar'}, actual=context._store) +# +# def test_call_requires_missing(self): +# def foo(obj): return obj +# context = Context() +# with ShouldRaise(ResourceError( +# "No Value(TheType) in context", +# key=TheType, +# requirement=Value(TheType), +# )): +# context.call(foo, requires(TheType)) +# +# def test_call_requires_item_missing(self): +# def foo(obj): return obj +# context = Context() +# context.add({}, TheType) +# with ShouldRaise(ResourceError( +# "No Value(TheType)['foo'] in context", +# key=TheType, +# requirement=Value(TheType)['foo'], +# )): +# context.call(foo, requires(Value(TheType)['foo'])) +# +# def test_call_requires_accidental_tuple(self): +# def foo(obj): return obj +# context = Context() +# with ShouldRaise(TypeError( +# "(, " +# ") " +# "is not a valid decoration type" +# )): +# context.call(foo, requires((TheType, TheType))) +# +# def test_call_requires_named_parameter(self): +# def foo(x, y): +# return x, y +# context = Context() +# context.add('foo', TheType) +# context.add('bar', 'baz') +# result = context.call(foo, requires(y='baz', x=TheType)) +# compare(result, ('foo', 'bar')) +# compare({TheType: 'foo', +# 'baz': 'bar'}, +# actual=context._store) +# +# def test_call_requires_optional_present(self): +# def foo(x=1): +# return x +# context = Context() +# context.add(2, TheType) +# result = context.call(foo, requires(TheType)) +# compare(result, 2) +# compare({TheType: 2}, actual=context._store) +# +# def test_call_requires_optional_missing(self): +# def foo(x: TheType = 1): +# return x +# context = Context() +# result = context.call(foo) +# compare(result, 1) +# +# def test_call_requires_optional_override_source_and_default(self): +# def foo(x=1): +# return x +# context = Context() +# context.add(2, provides='x') +# result = context.call(foo, requires(x=Value('y', default=3))) +# compare(result, expected=3) +# +# def test_call_requires_optional_string(self): +# def foo(x:'foo'=1): +# return x +# context = Context() +# context.add(2, 'foo') +# result = context.call(foo) +# compare(result, 2) +# compare({'foo': 2}, actual=context._store) +# +# def test_call_requires_item(self): +# def foo(x): +# return x +# context = Context() +# context.add(dict(bar='baz'), 'foo') +# result = context.call(foo, requires(Value('foo')['bar'])) +# compare(result, 'baz') +# +# def test_call_requires_attr(self): +# def foo(x): +# return x +# m = Mock() +# context = Context() +# context.add(m, 'foo') +# result = context.call(foo, requires(Value('foo').bar)) +# compare(result, m.bar) +# +# def test_call_requires_item_attr(self): +# def foo(x): +# return x +# m = Mock() +# m.bar= dict(baz='bob') +# context = Context() +# context.add(m, provides='foo') +# result = context.call(foo, requires(Value('foo').bar['baz'])) +# compare(result, 'bob') +# +# def test_call_requires_optional_item_missing(self): +# def foo(x: str = Value('foo', default=1)['bar']): +# return x +# context = Context() +# result = context.call(foo) +# compare(result, 1) +# +# def test_call_requires_optional_item_present(self): +# def foo(x: str = Value('foo', default=1)['bar']): +# return x +# context = Context() +# context.add(dict(bar='baz'), provides='foo') +# result = context.call(foo) +# compare(result, 'baz') +# +# def test_call_extract_requirements(self): +# def foo(param): +# return param +# context = Context() +# context.add('bar', 'param') +# result = context.call(foo) +# compare(result, 'bar') +# +# def test_call_extract_no_requirements(self): +# def foo(): +# pass +# context = Context() +# result = context.call(foo) +# compare(result, expected=None) +# +# def test_call_caches_requires(self): +# context = Context() +# def foo(): pass +# context.call(foo) +# compare(context._requires_cache[foo], expected=RequiresType()) +# +# def test_call_explict_explicit_requires_no_cache(self): +# context = Context() +# context.add('a') +# def foo(*args): +# return args +# result = context.call(foo, requires(str)) +# compare(result, ('a',)) +# compare(context._requires_cache, expected={}) +# +# def test_extract_minimal(self): +# o = TheType() +# def foo() -> TheType: +# return o +# context = Context() +# result = context.extract(foo) +# assert result is o +# compare({TheType: o}, actual=context._store) +# compare(context._requires_cache[foo], expected=RequiresType()) +# compare(context._returns_cache[foo], expected=returns(TheType)) +# +# def test_extract_maximal(self): +# def foo(*args): +# return args +# context = Context() +# context.add('a') +# result = context.extract(foo, requires(str), returns(Tuple[str])) +# compare(result, expected=('a',)) +# compare({ +# str: 'a', +# Tuple[str]: ('a',), +# }, actual=context._store) +# compare(context._requires_cache, expected={}) +# compare(context._returns_cache, expected={}) +# +# def test_returns_single(self): +# def foo(): +# return 'bar' +# context = Context() +# result = context.extract(foo, requires_nothing, returns(TheType)) +# compare(result, 'bar') +# compare({TheType: 'bar'}, actual=context._store) +# +# def test_returns_sequence(self): +# def foo(): +# return 1, 2 +# context = Context() +# result = context.extract(foo, requires_nothing, returns('foo', 'bar')) +# compare(result, (1, 2)) +# compare({'foo': 1, 'bar': 2}, +# actual=context._store) +# +# def test_returns_mapping(self): +# def foo(): +# return {'foo': 1, 'bar': 2} +# context = Context() +# result = context.extract(foo, requires_nothing, returns_mapping()) +# compare(result, {'foo': 1, 'bar': 2}) +# compare({'foo': 1, 'bar': 2}, +# actual=context._store) +# +# def test_ignore_return(self): +# def foo(): +# return 'bar' +# context = Context() +# result = context.extract(foo, requires_nothing, returns_nothing) +# compare(result, 'bar') +# compare({}, context._store) +# +# def test_ignore_non_iterable_return(self): +# def foo(): pass +# context = Context() +# result = context.extract(foo) +# compare(result, expected=None) +# compare(context._store, expected={}) +# +# def test_context_contains_itself(self): +# context = Context() +# def return_context(context: Context): +# return context +# assert context.call(return_context) is context +# +# def test_remove(self): +# context = Context() +# context.add('foo') +# context.remove(str) +# compare(context._store, expected={}) +# +# def test_remove_not_there_strict(self): +# context = Context() +# with ShouldRaise(ResourceError("Context does not contain 'foo'", +# key='foo')): +# context.remove('foo') +# compare(context._store, expected={}) +# +# def test_remove_not_there_not_strict(self): +# context = Context() +# context.remove('foo', strict=False) +# compare(context._store, expected={}) +# +# def test_get_present(self): +# context = Context() +# context.add('bar', provides='foo') +# compare(context.get('foo'), expected='bar') +# +# def test_get_type(self): +# context = Context() +# context.add(['bar'], provides=List[str]) +# compare(context.get(List[str]), expected=['bar']) +# compare(context.get(List[int], default=None), expected=None) +# compare(context.get(List, default=None), expected=None) +# # nb: this might be surprising: +# compare(context.get(list, default=None), expected=None) +# +# def test_get_missing(self): +# context = Context() +# with ShouldRaise(ResourceError("No 'foo' in context", 'foo')): +# context.get('foo') +# +# def test_nest(self): +# c1 = Context() +# c1.add('a', provides='a') +# c1.add('c', provides='c') +# c2 = c1.nest() +# c2.add('b', provides='b') +# c2.add('d', provides='c') +# compare(c2.get('a'), expected='a') +# compare(c2.get('b'), expected='b') +# compare(c2.get('c'), expected='d') +# compare(c1.get('a'), expected='a') +# compare(c1.get('b', default=None), expected=None) +# compare(c1.get('c'), expected='c') +# +# def test_nest_with_overridden_default_requirement_type(self): +# def modifier(): pass +# c1 = Context(modifier) +# c2 = c1.nest() +# assert c2.requirement_modifier is modifier +# +# def test_nest_with_explicit_default_requirement_type(self): +# def modifier1(): pass +# def modifier2(): pass +# c1 = Context(modifier1) +# c2 = c1.nest(modifier2) +# assert c2.requirement_modifier is modifier2 +# +# def test_nest_keeps_declarations_cache(self): +# c1 = Context() +# c2 = c1.nest() +# assert c2._requires_cache is c1._requires_cache +# assert c2._returns_cache is c1._returns_cache +# +# def test_custom_requirement(self): +# +# class FromRequest(Requirement): +# def resolve(self, context): +# return context.get('request')[self.key] +# +# def foo(bar: FromRequest('bar')): +# return bar +# +# context = Context() +# context.add({'bar': 'foo'}, provides='request') +# compare(context.call(foo), expected='foo') +# +# def test_custom_requirement_returns_missing(self): +# +# class FromRequest(Requirement): +# def resolve(self, context): +# return context.get('request').get(self.key, missing) +# +# def foo(bar: FromRequest('bar')): +# pass +# +# context = Context() +# context.add({}, provides='request') +# with ShouldRaise(ResourceError("No FromRequest('bar') in context", +# key='bar', +# requirement=FromRequest.make(key='bar', name='bar'))): +# compare(context.call(foo)) +# +# def test_default_custom_requirement(self): +# +# class FromRequest(Requirement): +# def resolve(self, context): +# return context.get('request')[self.key] +# +# def foo(bar): +# return bar +# +# def modifier(requirement): +# if type(requirement) is Requirement: +# requirement = FromRequest.make_from(requirement) +# return requirement +# +# context = Context(requirement_modifier=modifier) +# context.add({'bar': 'foo'}, provides='request') +# compare(context.call(foo), expected='foo') From be4ca8fef609642a482c7791e541883ea1e53762 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 17 Sep 2020 08:39:23 +0100 Subject: [PATCH 109/159] simple requirements and resources --- mush/__init__.py | 2 +- mush/context.py | 133 ++++++-------- mush/extraction.py | 170 +++++++++--------- mush/requirements.py | 222 +++++++++-------------- mush/resources.py | 32 ++++ mush/runner.py | 6 +- mush/tests/test_context.py | 199 +++++++++++---------- mush/tests/test_requirements.py | 302 ++++++++++++++++---------------- mush/typing.py | 3 +- 9 files changed, 521 insertions(+), 548 deletions(-) create mode 100644 mush/resources.py diff --git a/mush/__init__.py b/mush/__init__.py index 87f9d00..985ded1 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -5,7 +5,7 @@ from .extraction import extract_requires, extract_returns, update_wrapper from .markers import missing, nonblocking, blocking from .plug import Plug -from .requirements import Requirement, Value, AnyOf, Like +from .requirements import Requirement, Value#, AnyOf, Like from .runner import Runner, ContextError __all__ = [ diff --git a/mush/context.py b/mush/context.py index 7061055..baf90b0 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,11 +1,11 @@ -from typing import Optional, Callable, Hashable, Type +from typing import Optional, Callable, Hashable, Type, Sequence -from .callpoints import CallPoint -from .declarations import RequiresType, ReturnsType -from .extraction import extract_requires, extract_returns, default_requirement_type +from .declarations import RequiresType +from .extraction import extract_requires from .markers import missing, Marker from .requirements import Requirement -from .typing import ResourceValue, RequirementModifier +from .resources import ResourceKey, Resource +from .typing import ResourceValue NONE_TYPE = type(None) unspecified = Marker('unspecified') @@ -13,36 +13,9 @@ class ResourceError(Exception): """ - An exception raised when there is a problem with a `ResourceKey`. + An exception raised when there is a problem with a resource. """ - def __init__(self, message: str, type_: Type = None, identifier: Hashable = None): - super().__init__(message) - #: The type for the problematic resource. - self.type: Type = type_ - #: The identifier for the problematic resource. - self.identifier: Hashable = identifier - # #: The requirement that caused this exception. - # self.requirement: Requirement = requirement - - -class ResourceKey(tuple): - - @property - def type(self): - return self[0] - - @property - def identifier(self): - return self[1] - - def __repr__(self): - if self.type is None: - return repr(self.identifier) - elif self.identifier is None: - return repr(self.type) - return f'{self.type!r}, {self.identifier!r}' - class Context: "Stores resources for a particular run." @@ -54,7 +27,7 @@ def __init__(self): self._store = {} self._seen_types = set() self._seen_identifiers = set() - # self._requires_cache = {} + self._requires_cache = {} # self._returns_cache = {} def add(self, @@ -68,13 +41,13 @@ def add(self, """ if provides is missing: provides = type(resource) - to_add = [ResourceKey((provides, identifier))] + to_add = [ResourceKey(provides, identifier)] if identifier and provides: - to_add.append(ResourceKey((None, identifier))) + to_add.append(ResourceKey(None, identifier)) for key in to_add: if key in self._store: - raise ResourceError(f'Context already contains {key!r}', *key) - self._store[key] = resource + raise ResourceError(f'Context already contains {key}') + self._store[key] = Resource(resource) # def remove(self, key: ResourceKey, *, strict: bool = True): # """ @@ -90,7 +63,7 @@ def add(self, def __repr__(self): bits = [] for key, value in sorted(self._store.items(), key=lambda o: repr(o)): - bits.append(f'\n {key!r}: {value!r}') + bits.append(f'\n {key}: {value!r}') if bits: bits.append('\n') return f"" @@ -109,46 +82,56 @@ def __repr__(self): # result = self.call(obj, requires) # self._process(obj, result, returns) # return result - # - # def _resolve(self, obj, requires, args, kw, context): - # - # if requires is None: - # requires = self._requires_cache.get(obj) - # if requires is None: - # requires = extract_requires(obj, - # explicit=None, - # modifier=self.requirement_modifier) - # self._requires_cache[obj] = requires - # - # for requirement in requires: - # o = yield requirement - # - # if o is not requirement.default: - # for op in requirement.ops: - # o = op(o) - # if o is missing: - # o = requirement.default - # break - # - # if o is missing: - # key = requirement.key - # if isinstance(key, type) and issubclass(key, Context): - # o = context - # else: - # raise ResourceError(f'No {requirement!r} in context', - # key, requirement) - # - # if requirement.target is None: - # args.append(o) - # else: - # kw[requirement.target] = o - # - # yield + + def _resolve(self, obj, requires, args, kw, context): + + if requires is None: + requires = self._requires_cache.get(obj) + if requires is None: + requires = extract_requires(obj) + self._requires_cache[obj] = requires + + specials = {Context: self} + + for requirement in requires: + + o = missing + + for key in requirement.keys: + # how to handle context and requirement here?! + resource = self._store.get(key) + if resource is None: + o = specials.get(key[0], missing) + else: + o = resource.obj + if o is not missing: + break + + if o is missing: + o = requirement.default + + # if o is not requirement.default: + # for op in requirement.ops: + # o = op(o) + # if o is missing: + # o = requirement.default + # break + + if o is missing: + raise ResourceError(f'{requirement!r} could not be satisfied') + + # if requirement.target is None: + args.append(o) + # else: + # kw[requirement.target] = o + # + # yield def call(self, obj: Callable, requires: RequiresType = None): args = [] kw = {} - # resolving = self._resolve(obj, requires, args, kw, self) + + self._resolve(obj, requires, args, kw, self) # for requirement in resolving: # resolving.send(requirement.resolve(self)) diff --git a/mush/extraction.py b/mush/extraction.py index 0638358..336ffb8 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -5,16 +5,16 @@ partial ) from inspect import signature, Parameter -from typing import Callable +from typing import Callable, Iterable from .declarations import ( requires, RequiresType, ReturnsType, returns, result_type, requires_nothing ) -from .requirements import Requirement, Value +from .requirements import Value, Requirement from .markers import missing, get_mush -from .typing import RequirementModifier, Requires, Returns +from .typing import Requires, Returns EMPTY = Parameter.empty #: For these types, prefer the name instead of the type. @@ -47,104 +47,100 @@ def _apply_requires(by_name, by_index, requires_): ) -def default_requirement_type(requirement): - if type(requirement) is Requirement: - requirement = Value.make_from(requirement) - return requirement - - -def extract_requires(obj: Callable, - explicit: Requires = None, - modifier: RequirementModifier = default_requirement_type): +def extract_requires(obj: Callable) -> Iterable[Requirement]: + # explicit: Requires = None): # from annotations by_name = {} for name, p in signature(obj).parameters.items(): if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): continue - # https://bugs.python.org/issue39753: - if isinstance(obj, partial) and p.name in obj.keywords: - continue - + # # https://bugs.python.org/issue39753: + # if isinstance(obj, partial) and p.name in obj.keywords: + # continue + # name = p.name - if isinstance(p.annotation, type) and not p.annotation is EMPTY: + if isinstance(p.annotation, type) and p.annotation is not EMPTY: type_ = p.annotation else: type_ = None - key = None + default = missing if p.default is EMPTY else p.default ops = [] - requirement = None - if isinstance(default, Requirement): - requirement = default - default = missing - elif isinstance(p.annotation, Requirement): - requirement = p.annotation - - if requirement is None: - requirement = Requirement(key) - if isinstance(p.annotation, str): - key = p.annotation - elif type_ is None or issubclass(type_, SIMPLE_TYPES): - key = name - else: - key = type_ - else: - requirement = requirement.make_from(requirement) - type_ = type_ if requirement.type is None else requirement.type - if requirement.key is not None: - key = requirement.key - elif type_ is None or issubclass(type_, SIMPLE_TYPES): - key = name - else: - key = type_ - default = requirement.default if requirement.default is not missing else default - ops = requirement.ops - - requirement.key = key - requirement.name = name - requirement.type = type_ - requirement.default = default - requirement.ops = ops - - if p.kind is p.KEYWORD_ONLY: - requirement.target = p.name - + requirement = Value(type_, p.name, default) + # + # requirement = None + # if isinstance(default, Requirement): + # requirement = default + # default = missing + # elif isinstance(p.annotation, Requirement): + # requirement = p.annotation + # + # if requirement is None: + # requirement = Requirement(key) + # if isinstance(p.annotation, str): + # key = p.annotation + # elif type_ is None or issubclass(type_, SIMPLE_TYPES): + # key = name + # else: + # key = type_ + # else: + # requirement = requirement.make_from(requirement) + # type_ = type_ if requirement.type is None else requirement.type + # if requirement.key is not None: + # key = requirement.key + # elif type_ is None or issubclass(type_, SIMPLE_TYPES): + # key = name + # else: + # key = type_ + # default = requirement.default if requirement.default is not missing else default + # ops = requirement.ops + # + # requirement.key = key + # requirement.name = name + # requirement.type = type_ + # requirement.default = default + # requirement.ops = ops + # + # if p.kind is p.KEYWORD_ONLY: + # requirement.target = p.name + # by_name[name] = requirement - - by_index = list(by_name) - - # from declarations - mush_requires = get_mush(obj, 'requires', None) - if mush_requires is not None: - _apply_requires(by_name, by_index, mush_requires) - - # explicit - if explicit is not None: - if isinstance(explicit, RequiresType): - requires_ = explicit - else: - if not isinstance(explicit, (list, tuple)): - explicit = (explicit,) - requires_ = requires(*explicit) - _apply_requires(by_name, by_index, requires_) - - if not by_name: - return requires_nothing - - # sort out target and apply modifier: - needs_target = False - for name, requirement in by_name.items(): - requirement_ = modifier(requirement) - if requirement_ is not requirement: - by_name[name] = requirement = requirement_ - if requirement.target is not None: - needs_target = True - elif needs_target: - requirement.target = requirement.name - - return RequiresType(by_name.values()) + # + # by_index = list(by_name) + # + # # from declarations + # mush_requires = get_mush(obj, 'requires', None) + # if mush_requires is not None: + # _apply_requires(by_name, by_index, mush_requires) + # + # # explicit + # if explicit is not None: + # if isinstance(explicit, RequiresType): + # requires_ = explicit + # else: + # if not isinstance(explicit, (list, tuple)): + # explicit = (explicit,) + # requires_ = requires(*explicit) + # _apply_requires(by_name, by_index, requires_) + # + # if not by_name: + # return requires_nothing + # + # # sort out target and apply modifier: + # needs_target = False + # for name, requirement in by_name.items(): + # requirement_ = modifier(requirement) + # if requirement_ is not requirement: + # by_name[name] = requirement = requirement_ + # if requirement.target is not None: + # needs_target = True + # elif needs_target: + # requirement.target = requirement.name + # + return by_name.values() + # return RequiresType(by_name.values()) def extract_returns(obj: Callable, explicit: Returns = None): diff --git a/mush/requirements.py b/mush/requirements.py index cf20631..fc73451 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -1,8 +1,7 @@ -from copy import copy -from typing import Any, Optional, List, TYPE_CHECKING +from typing import Any, List, TYPE_CHECKING, Hashable, Sequence -# from .typing import ResourceKey from .markers import missing, nonblocking +from .resources import ResourceKey if TYPE_CHECKING: from .context import Context @@ -47,85 +46,37 @@ class Requirement: The requirement for an individual parameter of a callable. """ - # def __init__(self, - # key: ResourceKey, - # name: str = None, - # type_: type = None, - # default: Any = missing, - # target: str = None): - # #: The resource key needed for this parameter. - # self.key: Optional[ResourceKey] = key - # #: The name of this parameter in the callable's signature. - # self.name: Optional[str] = name - # #: The type required for this parameter. - # self.type: Optional[type] = type_ - # #: The default for this parameter, should the required resource be unavailable. - # self.default: Optional[Any] = default - # #: Any operations to be performed on the resource after it - # #: has been obtained. - # self.ops: List['Op'] = [] + def __init__(self, default: Any, *keys: Sequence[ResourceKey]): + self.keys = keys + self.default = default + self.ops: List['Op'] = [] # self.target: Optional[str] = target - @classmethod - def make(cls, **attrs): - """ - Make a :class:`Requirement` instance with all attributes provided. - - .. note:: - - This method allows instances to be created with missing or invalid attributes. - While this is necessary for use cases such as testing :class:`Requirement` - instantiation or otherwise setting attributes that are not accessible from - a custom requirement's :meth:`__init__`, it should be used with care. - - :param attrs: - :return: - """ - obj = Requirement(attrs.pop('key')) - obj.__class__ = cls - obj.__dict__.update(attrs) - return obj - - @classmethod - def make_from(cls, source: 'Requirement', **attrs): - """ - Make a new instance of this requirement class, using attributes - from a source requirement overlaid with any additional - ``attrs`` that have been supplied. - """ - attrs_ = source.__dict__.copy() - attrs_.update(attrs) - obj = cls.make(**attrs_) - obj.ops = list(source.ops) - obj.default = copy(source.default) - return obj - - def resolve(self, context: 'Context'): - raise NotImplementedError() + def _keys_repr(self): + return ', '.join(repr(key) for key in self.keys) def __repr__(self): - key = name_or_repr(self.key) default = '' if self.default is missing else f', default={self.default!r}' ops = ''.join(repr(o) for o in self.ops) - return f"{type(self).__name__}({key}{default}){ops}" - - def attr(self, name): - """ - If you need to get an attribute called either ``attr`` or ``item`` - then you will need to call this method instead of using the - generating behaviour. - """ - self.ops.append(AttrOp(name)) - return self - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - return self.attr(name) - - def __getitem__(self, name): - self.ops.append(ItemOp(name)) - return self + return f"{type(self).__name__}({self._keys_repr()}{default}){ops}" + # + # def attr(self, name): + # """ + # If you need to get an attribute called either ``attr`` or ``item`` + # then you will need to call this method instead of using the + # generating behaviour. + # """ + # self.ops.append(AttrOp(name)) + # return self + # + # def __getattr__(self, name): + # if name.startswith('__'): + # raise AttributeError(name) + # return self.attr(name) + # + # def __getitem__(self, name): + # self.ops.append(ItemOp(name)) + # return self class Value(Requirement): @@ -142,62 +93,63 @@ class Value(Requirement): ever use this. """ - # def __init__(self, key: ResourceKey = None, *, type_: type = None, default: Any = missing): - # if isinstance(key, type): - # if type_ is not None: - # raise TypeError('type_ cannot be specified if key is a type') - # type_ = key - # super().__init__(key, type_=type_, default=default) - - @nonblocking - def resolve(self, context: 'Context'): - return context.get(self.key, self.default) - - -class AnyOf(Requirement): - """ - A requirement that is resolved by any of the specified keys. - """ - - def __init__(self, *keys, default=missing): - super().__init__(keys, default=default) - - @nonblocking - def resolve(self, context: 'Context'): - for key in self.key: - value = context.get(key, missing) - if value is not missing: - return value - return self.default - - -class Like(Requirement): - """ - A requirements that is resolved by the specified class or - any of its base classes. - """ - - @nonblocking - def resolve(self, context: 'Context'): - for key in self.key.__mro__: - if key is object: - break - value = context.get(key, missing) - if value is not missing: - return value - return self.default - - -class Lazy(Requirement): - - def __init__(self, original, provider): - super().__init__(original.key) - self.original = original - self.provider = provider - self.ops = original.ops - - def resolve(self, context): - resource = context.get(self.key, missing) - if resource is missing: - context.extract(self.provider.obj, self.provider.requires, self.provider.returns) - return self.original.resolve(context) + def __init__(self, type_: type = None, identifier: Hashable = None, default: Any = missing): + super().__init__( + default, + ResourceKey(type_, identifier), + ResourceKey(None, identifier), + ResourceKey(type_, None), + ) + + def _keys_repr(self): + return str(self.keys[0]) + +# +# +# class AnyOf(Requirement): +# """ +# A requirement that is resolved by any of the specified keys. +# """ +# +# def __init__(self, *keys, default=missing): +# super().__init__(keys, default=default) +# +# @nonblocking +# def resolve(self, context: 'Context'): +# for key in self.key: +# value = context.get(key, missing) +# if value is not missing: +# return value +# return self.default +# +# +# class Like(Requirement): +# """ +# A requirements that is resolved by the specified class or +# any of its base classes. +# """ +# +# @nonblocking +# def resolve(self, context: 'Context'): +# for key in self.key.__mro__: +# if key is object: +# break +# value = context.get(key, missing) +# if value is not missing: +# return value +# return self.default +# +# +# class Lazy(Requirement): +# +# def __init__(self, original, provider): +# super().__init__(original.key) +# self.original = original +# self.provider = provider +# self.ops = original.ops +# +# def resolve(self, context): +# resource = context.get(self.key, missing) +# if resource is missing: +# context.extract(self.provider.obj, self.provider.requires, self.provider.returns) +# return self.original.resolve(context) diff --git a/mush/resources.py b/mush/resources.py new file mode 100644 index 0000000..4ec9f95 --- /dev/null +++ b/mush/resources.py @@ -0,0 +1,32 @@ +class ResourceKey(tuple): + + def __new__(cls, type_, identifier): + return tuple.__new__(cls, (type_, identifier)) + + @property + def type(self): + return self[0] + + @property + def identifier(self): + return self[1] + + def __str__(self): + if self.type is None: + return repr(self.identifier) + elif self.identifier is None: + return repr(self.type) + return f'{self.type!r}, {self.identifier!r}' + + +class Provider: + pass + + +class Resource: + + def __init__(self, obj): + self.obj = obj + + def __repr__(self): + return repr(self.obj) diff --git a/mush/runner.py b/mush/runner.py index 8b940b6..03652d5 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -3,11 +3,11 @@ from .callpoints import CallPoint from .context import Context, ResourceError from .declarations import DeclarationsFrom -from .extraction import extract_requires, extract_returns, default_requirement_type +from .extraction import extract_requires, extract_returns from .markers import not_specified from .modifier import Modifier from .plug import Plug -from .requirements import name_or_repr, Lazy +from .requirements import name_or_repr#, Lazy from .typing import Requires, Returns @@ -21,7 +21,7 @@ class Runner(object): start: Optional[CallPoint] = None end: Optional[CallPoint] = None - def __init__(self, *objects: Callable, requirement_modifier=default_requirement_type): + def __init__(self, *objects: Callable): self.requirement_modifier = requirement_modifier self.labels = {} self.lazy = {} diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 88c7d8a..ef4c411 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -4,12 +4,13 @@ # from testfixtures.mock import Mock # from mush import ( - Context, requires, returns, returns_mapping, Value, missing + Context#, requires, returns, returns_mapping, Value, missing ) from mush.context import ResourceError # from mush.declarations import RequiresType, requires_nothing, returns_nothing # from mush.requirements import Requirement from .helpers import TheType +from ..resources import Resource class TestContext(object): @@ -19,7 +20,7 @@ def test_add_by_inferred_type(self): context = Context() context.add(obj) - compare(context._store, expected={(TheType, None): obj}) + compare(context._store, expected={(TheType, None): Resource(obj)}) expected = ( ": \n" @@ -34,12 +35,12 @@ def test_add_by_identifier(self): context.add(obj, identifier='my label') compare(context._store, expected={ - (TheType, 'my label'): obj, - (None, 'my label'): obj, + (TheType, 'my label'): Resource(obj), + (None, 'my label'): Resource(obj), }) expected = ("\n" " , 'my label': \n" + " 'my label': \n" "}>") compare(expected, actual=repr(context)) compare(expected, actual=str(context)) @@ -49,7 +50,7 @@ def test_add_by_identifier_only(self): context = Context() context.add(obj, provides=None, identifier='my label') - compare(context._store, expected={(None, 'my label'): obj}) + compare(context._store, expected={(None, 'my label'): Resource(obj)}) expected = ("\n" "}>") @@ -61,7 +62,7 @@ class T2(object): pass obj = TheType() context = Context() context.add(obj, provides=T2) - compare(context._store, expected={(T2, None): obj}) + compare(context._store, expected={(T2, None): Resource(obj)}) expected = ("\n" "}>") @@ -73,8 +74,7 @@ def test_clash_just_type(self): obj2 = TheType() context = Context() context.add(obj1, TheType) - with ShouldRaise(ResourceError(f'Context already contains {TheType!r}', - type_=TheType)): + with ShouldRaise(ResourceError(f'Context already contains {TheType!r}')): context.add(obj2, TheType) def test_clash_just_identifier(self): @@ -82,8 +82,7 @@ def test_clash_just_identifier(self): obj2 = TheType() context = Context() context.add(obj1, provides=None, identifier='my label') - with ShouldRaise(ResourceError("Context already contains 'my label'", - identifier='my label')): + with ShouldRaise(ResourceError("Context already contains 'my label'")): context.add(obj2, provides=None, identifier='my label') def test_clash_identifier_only_with_identifier_plus_type(self): @@ -91,8 +90,7 @@ def test_clash_identifier_only_with_identifier_plus_type(self): obj2 = TheType() context = Context() context.add(obj1, provides=None, identifier='my label') - with ShouldRaise(ResourceError("Context already contains 'my label'", - identifier='my label')): + with ShouldRaise(ResourceError("Context already contains 'my label'")): context.add(obj2, identifier='my label') def test_clash_identifier_plus_type_with_identifier_only(self): @@ -100,8 +98,7 @@ def test_clash_identifier_plus_type_with_identifier_only(self): obj2 = TheType() context = Context() context.add(obj1, identifier='my label') - with ShouldRaise(ResourceError("Context already contains 'my label'", - identifier='my label')): + with ShouldRaise(ResourceError("Context already contains 'my label'")): context.add(obj2, provides=None, identifier='my label') def test_call_no_params(self): @@ -128,16 +125,69 @@ def foo(baz: str): result = context.call(foo) compare(result, expected='bob') - def test_call_identifier_and_type_from_annotation(self): - def foo(baz: str): - return baz + def test_call_by_identifier_only(self): + def foo(param): + return param + context = Context() - context.add('bar', provides=str) - context.add('bob', identifier='baz') - context.add('foo', provides=str, identifier='baz') + context.add('bar', identifier='param') + result = context.call(foo) + compare(result, 'bar') + + def test_call_requires_missing(self): + def foo(obj: TheType): return obj + context = Context() + with ShouldRaise(ResourceError( + "Value(, 'obj') could not be satisfied" + )): + context.call(foo) + + def test_call_optional_type_present(self): + def foo(x: TheType = 1): + return x + context = Context() + context.add(2, TheType) + result = context.call(foo) + compare(result, 2) + + def test_call_optional_type_missing(self): + def foo(x: TheType = 1): + return x + context = Context() + result = context.call(foo) + compare(result, 1) + + def test_call_optional_identifier_present(self): + def foo(x=1): + return x + + context = Context() + context.add(2, identifier='x') + result = context.call(foo) + compare(result, 2) + + def test_call_optional_identifier_missing(self): + def foo(x=1): + return x + + context = Context() + context.add(2) result = context.call(foo) - compare(result, expected='foo') + compare(result, 1) + + def test_call_requires_context(self): + context = Context() + + def return_context(context_: Context): + return context_ + + assert context.call(return_context) is context + def test_call_requires_requirement(self): + # this should blow up unless we're in a provider? + pass + +# XXX - these are for explicit requires() objects: # def test_call_requires_string(self): # def foo(obj): # return obj @@ -156,16 +206,28 @@ def foo(baz: str): # compare(result, 'bar') # compare({TheType: 'bar'}, actual=context._store) # -# def test_call_requires_missing(self): -# def foo(obj): return obj + # + # def test_call_requires_accidental_tuple(self): + # def foo(obj): return obj + # context = Context() + # with ShouldRaise(TypeError( + # "(, " + # ") " + # "is not a valid decoration type" + # )): + # context.call(foo, requires((TheType, TheType))) +# +# def test_call_requires_optional_override_source_and_default(self): +# def foo(x=1): +# return x # context = Context() -# with ShouldRaise(ResourceError( -# "No Value(TheType) in context", -# key=TheType, -# requirement=Value(TheType), -# )): -# context.call(foo, requires(TheType)) +# context.add(2, provides='x') +# result = context.call(foo, requires(x=Value('y', default=3))) +# compare(result, expected=3) # + + +# XXX - these are for ops # def test_call_requires_item_missing(self): # def foo(obj): return obj # context = Context() @@ -177,16 +239,6 @@ def foo(baz: str): # )): # context.call(foo, requires(Value(TheType)['foo'])) # -# def test_call_requires_accidental_tuple(self): -# def foo(obj): return obj -# context = Context() -# with ShouldRaise(TypeError( -# "(, " -# ") " -# "is not a valid decoration type" -# )): -# context.call(foo, requires((TheType, TheType))) -# # def test_call_requires_named_parameter(self): # def foo(x, y): # return x, y @@ -199,39 +251,6 @@ def foo(baz: str): # 'baz': 'bar'}, # actual=context._store) # -# def test_call_requires_optional_present(self): -# def foo(x=1): -# return x -# context = Context() -# context.add(2, TheType) -# result = context.call(foo, requires(TheType)) -# compare(result, 2) -# compare({TheType: 2}, actual=context._store) -# -# def test_call_requires_optional_missing(self): -# def foo(x: TheType = 1): -# return x -# context = Context() -# result = context.call(foo) -# compare(result, 1) -# -# def test_call_requires_optional_override_source_and_default(self): -# def foo(x=1): -# return x -# context = Context() -# context.add(2, provides='x') -# result = context.call(foo, requires(x=Value('y', default=3))) -# compare(result, expected=3) -# -# def test_call_requires_optional_string(self): -# def foo(x:'foo'=1): -# return x -# context = Context() -# context.add(2, 'foo') -# result = context.call(foo) -# compare(result, 2) -# compare({'foo': 2}, actual=context._store) -# # def test_call_requires_item(self): # def foo(x): # return x @@ -273,21 +292,9 @@ def foo(baz: str): # context.add(dict(bar='baz'), provides='foo') # result = context.call(foo) # compare(result, 'baz') -# -# def test_call_extract_requirements(self): -# def foo(param): -# return param -# context = Context() -# context.add('bar', 'param') -# result = context.call(foo) -# compare(result, 'bar') -# -# def test_call_extract_no_requirements(self): -# def foo(): -# pass -# context = Context() -# result = context.call(foo) -# compare(result, expected=None) + + +# XXX requirements caching: # # def test_call_caches_requires(self): # context = Context() @@ -370,12 +377,6 @@ def foo(baz: str): # compare(result, expected=None) # compare(context._store, expected={}) # -# def test_context_contains_itself(self): -# context = Context() -# def return_context(context: Context): -# return context -# assert context.call(return_context) is context -# # def test_remove(self): # context = Context() # context.add('foo') @@ -445,6 +446,10 @@ def foo(baz: str): # c2 = c1.nest() # assert c2._requires_cache is c1._requires_cache # assert c2._returns_cache is c1._returns_cache + + + +# XXX "custom requirement" stuff # # def test_custom_requirement(self): # @@ -492,3 +497,9 @@ def foo(baz: str): # context = Context(requirement_modifier=modifier) # context.add({'bar': 'foo'}, provides='request') # compare(context.call(foo), expected='foo') + + def test_provider(self): + pass + + def test_provider_needs_requirement(self): + pass diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 4f84041..d00a080 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -6,7 +6,7 @@ from testfixtures.mock import Mock from mush import Context, Value, missing, requires, ResourceError -from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf, Like +from mush.requirements import Requirement, AttrOp, ItemOp#, AnyOf, Like from .helpers import Type1 @@ -131,153 +131,153 @@ def test_type_from_key(self): def test_key_and_type_cannot_disagree(self): with ShouldRaise(TypeError('type_ cannot be specified if key is a type')): Value(key=str, type_=int) - - -class TestItem: - - def test_single(self): - h = Value(Type1)['foo'] - compare(repr(h), expected="Value(Type1)['foo']") - check_ops(h, {'foo': 1}, expected=1) - - def test_multiple(self): - h = Value(Type1)['foo']['bar'] - compare(repr(h), expected="Value(Type1)['foo']['bar']") - check_ops(h, {'foo': {'bar': 1}}, expected=1) - - def test_missing_obj(self): - h = Value(Type1)['foo']['bar'] - with ShouldRaise(TypeError): - check_ops(h, object(), expected=None) - - def test_missing_key(self): - h = Value(Type1)['foo'] - check_ops(h, {}, expected=missing) - - def test_passed_missing(self): - c = Context() - c.add({}, provides='key') - compare(c.call(lambda x: x, requires(Value('key', default=1)['foo']['bar'])), - expected=1) - - def test_bad_type(self): - h = Value(Type1)['foo']['bar'] - with ShouldRaise(TypeError): - check_ops(h, [], expected=None) - - -class TestAttr(TestCase): - - def test_single(self): - h = Value(Type1).foo - compare(repr(h), "Value(Type1).foo") - m = Mock() - check_ops(h, m, expected=m.foo) - - def test_multiple(self): - h = Value(Type1).foo.bar - compare(repr(h), "Value(Type1).foo.bar") - m = Mock() - check_ops(h, m, expected=m.foo.bar) - - def test_missing(self): - h = Value(Type1).foo - compare(repr(h), "Value(Type1).foo") - check_ops(h, object(), expected=missing) - - def test_passed_missing(self): - c = Context() - c.add(object(), provides='key') - compare(c.call(lambda x: x, requires(Value('key', default=1).foo.bar)), - expected=1) - - -class TestAnyOf: - - def test_first(self): - context = Context() - context.add(('foo', )) - context.add(('bar', ), provides=Tuple[str]) - - def bob(x: str = AnyOf(tuple, Tuple[str])): - return x[0] - - compare(context.call(bob), expected='foo') - - def test_second(self): - context = Context() - context.add(('bar', ), provides=Tuple[str]) - - def bob(x: str = AnyOf(tuple, Tuple[str])): - return x[0] - - compare(context.call(bob), expected='bar') - - def test_none(self): - context = Context() - - def bob(x: str = AnyOf(tuple, Tuple[str])): - pass - - with ShouldRaise(ResourceError): - context.call(bob) - - def test_default(self): - context = Context() - - def bob(x: str = AnyOf(tuple, Tuple[str], default=(42,))): - return x[0] - - compare(context.call(bob), expected=42) - - -class Parent(object): - pass - - -class Child(Parent): - pass - - -class TestLike: - - def test_actual(self): - context = Context() - p = Parent() - c = Child() - context.add(p) - context.add(c) - - def bob(x: str = Like(Child)): - return x - - assert context.call(bob) is c - - def test_base(self): - context = Context() - p = Parent() - context.add(p) - - def bob(x: str = Like(Child)): - return x - - assert context.call(bob) is p - - def test_none(self): - context = Context() - # make sure we don't pick up object! - context.add(object()) - - def bob(x: str = Like(Child)): - pass - - with ShouldRaise(ResourceError): - context.call(bob) - - def test_default(self): - context = Context() - - def bob(x: str = Like(Child, default=42)): - return x - - compare(context.call(bob), expected=42) +# +# +# class TestItem: +# +# def test_single(self): +# h = Value(Type1)['foo'] +# compare(repr(h), expected="Value(Type1)['foo']") +# check_ops(h, {'foo': 1}, expected=1) +# +# def test_multiple(self): +# h = Value(Type1)['foo']['bar'] +# compare(repr(h), expected="Value(Type1)['foo']['bar']") +# check_ops(h, {'foo': {'bar': 1}}, expected=1) +# +# def test_missing_obj(self): +# h = Value(Type1)['foo']['bar'] +# with ShouldRaise(TypeError): +# check_ops(h, object(), expected=None) +# +# def test_missing_key(self): +# h = Value(Type1)['foo'] +# check_ops(h, {}, expected=missing) +# +# def test_passed_missing(self): +# c = Context() +# c.add({}, provides='key') +# compare(c.call(lambda x: x, requires(Value('key', default=1)['foo']['bar'])), +# expected=1) +# +# def test_bad_type(self): +# h = Value(Type1)['foo']['bar'] +# with ShouldRaise(TypeError): +# check_ops(h, [], expected=None) +# +# +# class TestAttr(TestCase): +# +# def test_single(self): +# h = Value(Type1).foo +# compare(repr(h), "Value(Type1).foo") +# m = Mock() +# check_ops(h, m, expected=m.foo) +# +# def test_multiple(self): +# h = Value(Type1).foo.bar +# compare(repr(h), "Value(Type1).foo.bar") +# m = Mock() +# check_ops(h, m, expected=m.foo.bar) +# +# def test_missing(self): +# h = Value(Type1).foo +# compare(repr(h), "Value(Type1).foo") +# check_ops(h, object(), expected=missing) +# +# def test_passed_missing(self): +# c = Context() +# c.add(object(), provides='key') +# compare(c.call(lambda x: x, requires(Value('key', default=1).foo.bar)), +# expected=1) +# +# +# class TestAnyOf: +# +# def test_first(self): +# context = Context() +# context.add(('foo', )) +# context.add(('bar', ), provides=Tuple[str]) +# +# def bob(x: str = AnyOf(tuple, Tuple[str])): +# return x[0] +# +# compare(context.call(bob), expected='foo') +# +# def test_second(self): +# context = Context() +# context.add(('bar', ), provides=Tuple[str]) +# +# def bob(x: str = AnyOf(tuple, Tuple[str])): +# return x[0] +# +# compare(context.call(bob), expected='bar') +# +# def test_none(self): +# context = Context() +# +# def bob(x: str = AnyOf(tuple, Tuple[str])): +# pass +# +# with ShouldRaise(ResourceError): +# context.call(bob) +# +# def test_default(self): +# context = Context() +# +# def bob(x: str = AnyOf(tuple, Tuple[str], default=(42,))): +# return x[0] +# +# compare(context.call(bob), expected=42) +# +# +# class Parent(object): +# pass +# +# +# class Child(Parent): +# pass +# +# +# class TestLike: +# +# def test_actual(self): +# context = Context() +# p = Parent() +# c = Child() +# context.add(p) +# context.add(c) +# +# def bob(x: str = Like(Child)): +# return x +# +# assert context.call(bob) is c +# +# def test_base(self): +# context = Context() +# p = Parent() +# context.add(p) +# +# def bob(x: str = Like(Child)): +# return x +# +# assert context.call(bob) is p +# +# def test_none(self): +# context = Context() +# # make sure we don't pick up object! +# context.add(object()) +# +# def bob(x: str = Like(Child)): +# pass +# +# with ShouldRaise(ResourceError): +# context.call(bob) +# +# def test_default(self): +# context = Context() +# +# def bob(x: str = Like(Child, default=42)): +# return x +# +# compare(context.call(bob), expected=42) diff --git a/mush/typing.py b/mush/typing.py index fc950f2..e15be7f 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -14,6 +14,5 @@ ReturnType = Union[type, str] Returns = Union['ReturnsType', ReturnType, List[ReturnType], Tuple[ReturnType, ...]] -ResourceKey = Union[Hashable, Callable] ResourceValue = NewType('ResourceValue', Any) -RequirementModifier = Callable[['Requirement'], 'Requirement'] + From d72f3ee7c67e6722ec27614718ab9bde73441efc Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 21 Sep 2020 08:49:40 +0100 Subject: [PATCH 110/159] split extractions tests from declaration test --- mush/tests/test_declarations.py | 357 +------------------- mush/tests/test_requirements_extraction.py | 366 +++++++++++++++++++++ 2 files changed, 367 insertions(+), 356 deletions(-) create mode 100644 mush/tests/test_requirements_extraction.py diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index a190323..d4d34f9 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -1,29 +1,16 @@ -from functools import partial from typing import Tuple from unittest import TestCase -import pytest from testfixtures import compare, ShouldRaise from mush import Value from mush.declarations import ( requires, returns, - returns_mapping, returns_sequence, returns_result_type, - requires_nothing, - result_type, RequiresType + returns_mapping, returns_sequence, returns_result_type ) -from mush.extraction import extract_requires, extract_returns, update_wrapper -from mush.requirements import Requirement, ItemOp from .helpers import PY_36, Type1, Type2, Type3, Type4 -def check_extract(obj, expected_rq, expected_rt): - rq = extract_requires(obj, None) - rt = extract_returns(obj, None) - compare(rq, expected=expected_rq, strict=True) - compare(rt, expected=expected_rt, strict=True) - - class TestRequires(TestCase): def test_empty(self): @@ -161,345 +148,3 @@ def foo(): pass r = returns_result_type() compare(dict(r.process(foo())), {}) - - -class TestExtractDeclarations(object): - - def test_default_requirements_for_function(self): - def foo(a, b=None): pass - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='a', name='a'), - Value.make(key='b', default=None, name='b'), - )), - expected_rt=result_type) - - def test_default_requirements_for_class(self): - class MyClass(object): - def __init__(self, a, b=None): pass - check_extract(MyClass, - expected_rq=RequiresType(( - Value.make(key='a', name='a'), - Value.make(key='b', name='b', default=None), - )), - expected_rt=result_type) - - def test_extract_from_partial(self): - def foo(x, y, z, a=None): pass - p = partial(foo, 1, y=2) - check_extract( - p, - expected_rq=RequiresType(( - Value.make(key='z', name='z', target='z'), - Value.make(key='a', name='a', target='a', default=None), - )), - expected_rt=result_type - ) - - def test_extract_from_partial_default_not_in_partial(self): - def foo(a=None): pass - p = partial(foo) - check_extract( - p, - expected_rq=RequiresType(( - Value.make(key='a', name='a', default=None), - )), - expected_rt=result_type - ) - - def test_extract_from_partial_default_in_partial_arg(self): - def foo(a=None): pass - p = partial(foo, 1) - check_extract( - p, - # since a is already bound by the partial: - expected_rq=requires_nothing, - expected_rt=result_type - ) - - def test_extract_from_partial_default_in_partial_kw(self): - def foo(a=None): pass - p = partial(foo, a=1) - check_extract( - p, - expected_rq=requires_nothing, - expected_rt=result_type - ) - - def test_extract_from_partial_required_in_partial_arg(self): - def foo(a): pass - p = partial(foo, 1) - check_extract( - p, - # since a is already bound by the partial: - expected_rq=requires_nothing, - expected_rt=result_type - ) - - def test_extract_from_partial_required_in_partial_kw(self): - def foo(a): pass - p = partial(foo, a=1) - check_extract( - p, - expected_rq=requires_nothing, - expected_rt=result_type - ) - - def test_extract_from_partial_plus_one_default_not_in_partial(self): - def foo(b, a=None): pass - p = partial(foo) - check_extract( - p, - expected_rq=RequiresType(( - Value.make(key='b', name='b'), - Value.make(key='a', name='a', default=None), - )), - expected_rt=result_type - ) - - def test_extract_from_partial_plus_one_required_in_partial_arg(self): - def foo(b, a): pass - p = partial(foo, 1) - check_extract( - p, - # since b is already bound: - expected_rq=RequiresType(( - Value.make(key='a', name='a'), - )), - expected_rt=result_type - ) - - def test_extract_from_partial_plus_one_required_in_partial_kw(self): - def foo(b, a): pass - p = partial(foo, a=1) - check_extract( - p, - expected_rq=RequiresType(( - Value.make(key='b', name='b'), - )), - expected_rt=result_type - ) - - -class TestExtractDeclarationsFromTypeAnnotations(object): - - def test_extract_from_annotations(self): - def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='foo', name='a'), - Value.make(key='b', name='b'), - Value.make(key='bar', name='c', default=1), - Value.make(key='d', name='d', default=2) - )), - expected_rt=returns('bar')) - - def test_requires_only(self): - def foo(a: 'foo'): pass - check_extract(foo, - expected_rq=RequiresType((Value.make(key='foo', name='a'),)), - expected_rt=result_type) - - def test_returns_only(self): - def foo() -> 'bar': pass - check_extract(foo, - expected_rq=requires_nothing, - expected_rt=returns('bar')) - - def test_extract_from_decorated_class(self): - - class Wrapper(object): - def __init__(self, func): - self.func = func - def __call__(self): - return 'the '+self.func() - - def my_dec(func): - return update_wrapper(Wrapper(func), func) - - @my_dec - def foo(a: 'foo' = None) -> 'bar': - return 'answer' - - compare(foo(), expected='the answer') - check_extract(foo, - expected_rq=RequiresType((Value.make(key='foo', name='a', default=None),)), - expected_rt=returns('bar')) - - def test_decorator_trumps_annotations(self): - @requires('foo') - @returns('bar') - def foo(a: 'x') -> 'y': pass - check_extract(foo, - expected_rq=RequiresType((Value.make(key='foo', name='a'),)), - expected_rt=returns('bar')) - - def test_returns_mapping(self): - rt = returns_mapping() - def foo() -> rt: pass - check_extract(foo, - expected_rq=requires_nothing, - expected_rt=rt) - - def test_returns_sequence(self): - rt = returns_sequence() - def foo() -> rt: pass - check_extract(foo, - expected_rq=requires_nothing, - expected_rt=rt) - - def test_how_instance_in_annotations(self): - def foo(a: Value('config')['db_url']): pass - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='config', name='a', ops=[ItemOp('db_url')]), - )), - expected_rt=result_type) - - def test_default_requirements(self): - def foo(a, b=1, *, c, d=None): pass - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='a', name='a'), - Value.make(key='b', name='b', default=1), - Value.make(key='c', name='c', target='c'), - Value.make(key='d', name='d', target='d', default=None) - )), - expected_rt=result_type) - - def test_type_only(self): - class T: pass - def foo(a: T): pass - check_extract(foo, - expected_rq=RequiresType((Value.make(key=T, name='a', type=T),)), - expected_rt=result_type) - - @pytest.mark.parametrize("type_", [str, int, dict, list]) - def test_simple_type_only(self, type_): - def foo(a: type_): pass - check_extract(foo, - expected_rq=RequiresType((Value.make(key='a', name='a', type=type_),)), - expected_rt=result_type) - - def test_type_plus_value(self): - def foo(a: str = Value('b')): pass - check_extract(foo, - expected_rq=RequiresType((Value.make(key='b', name='a', type=str),)), - expected_rt=result_type) - - def test_type_plus_value_with_default(self): - def foo(a: str = Value('b', default=1)): pass - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='b', name='a', type=str, default=1), - )), - expected_rt=result_type) - - def test_value_annotation_plus_default(self): - def foo(a: Value('b', type_=str) = 1): pass - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='b', name='a', type=str, default=1), - )), - expected_rt=result_type) - - def test_value_annotation_just_type_in_value_key_plus_default(self): - def foo(a: Value(str) = 1): pass - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key=str, name='a', type=str, default=1), - )), - expected_rt=result_type) - - def test_value_annotation_just_type_plus_default(self): - def foo(a: Value(type_=str) = 1): pass - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='a', name='a', type=str, default=1), - )), - expected_rt=result_type) - - def test_value_unspecified_with_type(self): - class T1: pass - def foo(a: T1 = Value()): pass - check_extract(foo, - expected_rq=RequiresType((Value.make(key=T1, name='a', type=T1),)), - expected_rt=result_type) - - def test_value_unspecified_with_simple_type(self): - def foo(a: str = Value()): pass - check_extract(foo, - expected_rq=RequiresType((Value.make(key='a', name='a', type=str),)), - expected_rt=result_type) - - def test_value_unspecified(self): - def foo(a = Value()): pass - check_extract(foo, - expected_rq=RequiresType((Value.make(key='a', name='a'),)), - expected_rt=result_type) - - def test_requirement_modifier(self): - def foo(x: str = None): pass - - class FromRequest(Requirement): pass - - def modifier(requirement): - if type(requirement) is Requirement: - requirement = FromRequest.make_from(requirement) - return requirement - - rq = extract_requires(foo, modifier=modifier) - compare(rq, strict=True, expected=RequiresType(( - FromRequest(key='x', name='x', type_=str, default=None), - ))) - - -class TestDeclarationsFromMultipleSources: - - def test_declarations_from_different_sources(self): - r1 = Requirement('a') - r2 = Requirement('b') - r3 = Requirement('c') - - @requires(b=r2) - def foo(a: r1, b, c=r3): - pass - - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='a', name='a'), - Value.make(key='b', name='b', target='b'), - Value.make(key='c', name='c', target='c'), - )), - expected_rt=result_type) - - def test_declaration_priorities(self): - r1 = Requirement('a') - r2 = Requirement('b') - r3 = Requirement('c') - - @requires(a=r1) - def foo(a: r2 = r3, b: str = r2, c = r3): - pass - - check_extract(foo, - expected_rq=RequiresType(( - Value.make(key='a', name='a', target='a'), - Value.make(key='b', name='b', target='b', type=str), - Value.make(key='c', name='c', target='c'), - )), - expected_rt=result_type) - - def test_explicit_requirement_type_trumps_default_requirement_type(self): - - class FromRequest(Requirement): pass - - @requires(a=Requirement('a')) - def foo(a): - pass - - compare(actual=extract_requires(foo, requires(a=FromRequest('b'))), - strict=True, - expected=RequiresType(( - FromRequest.make(key='b', name='a', target='a'), - ))) diff --git a/mush/tests/test_requirements_extraction.py b/mush/tests/test_requirements_extraction.py new file mode 100644 index 0000000..a40f312 --- /dev/null +++ b/mush/tests/test_requirements_extraction.py @@ -0,0 +1,366 @@ +from functools import partial +from typing import Tuple +from unittest import TestCase + +import pytest +from testfixtures import compare, ShouldRaise + +from mush import Value +from mush.declarations import ( + requires, returns, + returns_mapping, returns_sequence, returns_result_type, + requires_nothing, + result_type, RequiresType +) +from mush.extraction import extract_requires#, extract_returns, update_wrapper +from mush.requirements import Requirement, ItemOp +from .helpers import PY_36, Type1, Type2, Type3, Type4 + + +def check_extract(obj, expected_rq, expected_rt): + rq = extract_requires(obj, None) + rt = extract_returns(obj, None) + compare(rq, expected=expected_rq, strict=True) + compare(rt, expected=expected_rt, strict=True) + + +class TestRequirementsExtraction(object): + + def test_default_requirements_for_function(self): + def foo(a, b=None): pass + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='a', name='a'), + Value.make(key='b', default=None, name='b'), + )), + expected_rt=result_type) + + def test_default_requirements_for_class(self): + class MyClass(object): + def __init__(self, a, b=None): pass + check_extract(MyClass, + expected_rq=RequiresType(( + Value.make(key='a', name='a'), + Value.make(key='b', name='b', default=None), + )), + expected_rt=result_type) + + def test_extract_from_partial(self): + def foo(x, y, z, a=None): pass + p = partial(foo, 1, y=2) + check_extract( + p, + expected_rq=RequiresType(( + Value.make(key='z', name='z', target='z'), + Value.make(key='a', name='a', target='a', default=None), + )), + expected_rt=result_type + ) + + def test_extract_from_partial_default_not_in_partial(self): + def foo(a=None): pass + p = partial(foo) + check_extract( + p, + expected_rq=RequiresType(( + Value.make(key='a', name='a', default=None), + )), + expected_rt=result_type + ) + + def test_extract_from_partial_default_in_partial_arg(self): + def foo(a=None): pass + p = partial(foo, 1) + check_extract( + p, + # since a is already bound by the partial: + expected_rq=requires_nothing, + expected_rt=result_type + ) + + def test_extract_from_partial_default_in_partial_kw(self): + def foo(a=None): pass + p = partial(foo, a=1) + check_extract( + p, + expected_rq=requires_nothing, + expected_rt=result_type + ) + + def test_extract_from_partial_required_in_partial_arg(self): + def foo(a): pass + p = partial(foo, 1) + check_extract( + p, + # since a is already bound by the partial: + expected_rq=requires_nothing, + expected_rt=result_type + ) + + def test_extract_from_partial_required_in_partial_kw(self): + def foo(a): pass + p = partial(foo, a=1) + check_extract( + p, + expected_rq=requires_nothing, + expected_rt=result_type + ) + + def test_extract_from_partial_plus_one_default_not_in_partial(self): + def foo(b, a=None): pass + p = partial(foo) + check_extract( + p, + expected_rq=RequiresType(( + Value.make(key='b', name='b'), + Value.make(key='a', name='a', default=None), + )), + expected_rt=result_type + ) + + def test_extract_from_partial_plus_one_required_in_partial_arg(self): + def foo(b, a): pass + p = partial(foo, 1) + check_extract( + p, + # since b is already bound: + expected_rq=RequiresType(( + Value.make(key='a', name='a'), + )), + expected_rt=result_type + ) + + def test_extract_from_partial_plus_one_required_in_partial_kw(self): + def foo(b, a): pass + p = partial(foo, a=1) + check_extract( + p, + expected_rq=RequiresType(( + Value.make(key='b', name='b'), + )), + expected_rt=result_type + ) + + +class TestExtractDeclarationsFromTypeAnnotations(object): + + def test_extract_from_annotations(self): + def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='foo', name='a'), + Value.make(key='b', name='b'), + Value.make(key='bar', name='c', default=1), + Value.make(key='d', name='d', default=2) + )), + expected_rt=returns('bar')) + + def test_requires_only(self): + def foo(a: 'foo'): pass + check_extract(foo, + expected_rq=RequiresType((Value.make(key='foo', name='a'),)), + expected_rt=result_type) + + def test_returns_only(self): + def foo() -> 'bar': pass + check_extract(foo, + expected_rq=requires_nothing, + expected_rt=returns('bar')) + + def test_extract_from_decorated_class(self): + + class Wrapper(object): + def __init__(self, func): + self.func = func + def __call__(self): + return 'the '+self.func() + + def my_dec(func): + return update_wrapper(Wrapper(func), func) + + @my_dec + def foo(a: 'foo' = None) -> 'bar': + return 'answer' + + compare(foo(), expected='the answer') + check_extract(foo, + expected_rq=RequiresType((Value.make(key='foo', name='a', default=None),)), + expected_rt=returns('bar')) + + def test_decorator_trumps_annotations(self): + @requires('foo') + @returns('bar') + def foo(a: 'x') -> 'y': pass + check_extract(foo, + expected_rq=RequiresType((Value.make(key='foo', name='a'),)), + expected_rt=returns('bar')) + + def test_returns_mapping(self): + rt = returns_mapping() + def foo() -> rt: pass + check_extract(foo, + expected_rq=requires_nothing, + expected_rt=rt) + + def test_returns_sequence(self): + rt = returns_sequence() + def foo() -> rt: pass + check_extract(foo, + expected_rq=requires_nothing, + expected_rt=rt) + + def test_how_instance_in_annotations(self): + def foo(a: Value('config')['db_url']): pass + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='config', name='a', ops=[ItemOp('db_url')]), + )), + expected_rt=result_type) + + def test_default_requirements(self): + def foo(a, b=1, *, c, d=None): pass + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='a', name='a'), + Value.make(key='b', name='b', default=1), + Value.make(key='c', name='c', target='c'), + Value.make(key='d', name='d', target='d', default=None) + )), + expected_rt=result_type) + + def test_type_only(self): + class T: pass + def foo(a: T): pass + check_extract(foo, + expected_rq=RequiresType((Value.make(key=T, name='a', type=T),)), + expected_rt=result_type) + + @pytest.mark.parametrize("type_", [str, int, dict, list]) + def test_simple_type_only(self, type_): + def foo(a: type_): pass + check_extract(foo, + expected_rq=RequiresType((Value.make(key='a', name='a', type=type_),)), + expected_rt=result_type) + + def test_type_plus_value(self): + def foo(a: str = Value('b')): pass + check_extract(foo, + expected_rq=RequiresType((Value.make(key='b', name='a', type=str),)), + expected_rt=result_type) + + def test_type_plus_value_with_default(self): + def foo(a: str = Value('b', default=1)): pass + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='b', name='a', type=str, default=1), + )), + expected_rt=result_type) + + def test_value_annotation_plus_default(self): + def foo(a: Value('b', type_=str) = 1): pass + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='b', name='a', type=str, default=1), + )), + expected_rt=result_type) + + def test_value_annotation_just_type_in_value_key_plus_default(self): + def foo(a: Value(str) = 1): pass + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key=str, name='a', type=str, default=1), + )), + expected_rt=result_type) + + def test_value_annotation_just_type_plus_default(self): + def foo(a: Value(type_=str) = 1): pass + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='a', name='a', type=str, default=1), + )), + expected_rt=result_type) + + def test_value_unspecified_with_type(self): + class T1: pass + def foo(a: T1 = Value()): pass + check_extract(foo, + expected_rq=RequiresType((Value.make(key=T1, name='a', type=T1),)), + expected_rt=result_type) + + def test_value_unspecified_with_simple_type(self): + def foo(a: str = Value()): pass + check_extract(foo, + expected_rq=RequiresType((Value.make(key='a', name='a', type=str),)), + expected_rt=result_type) + + def test_value_unspecified(self): + def foo(a=Value()): pass + check_extract(foo, + expected_rq=RequiresType((Value.make(key='a', name='a'),)), + expected_rt=result_type) + + def test_requirement_modifier(self): + def foo(x: str = None): pass + + class FromRequest(Requirement): pass + + def modifier(requirement): + if type(requirement) is Requirement: + requirement = FromRequest.make_from(requirement) + return requirement + + rq = extract_requires(foo, modifier=modifier) + compare(rq, strict=True, expected=RequiresType(( + FromRequest(key='x', name='x', type_=str, default=None), + ))) + + +class TestDeclarationsFromMultipleSources: + + def test_declarations_from_different_sources(self): + r1 = Requirement('a') + r2 = Requirement('b') + r3 = Requirement('c') + + @requires(b=r2) + def foo(a: r1, b, c=r3): + pass + + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='a', name='a'), + Value.make(key='b', name='b', target='b'), + Value.make(key='c', name='c', target='c'), + )), + expected_rt=result_type) + + def test_declaration_priorities(self): + r1 = Requirement('a') + r2 = Requirement('b') + r3 = Requirement('c') + + @requires(a=r1) + def foo(a: r2 = r3, b: str = r2, c = r3): + pass + + check_extract(foo, + expected_rq=RequiresType(( + Value.make(key='a', name='a', target='a'), + Value.make(key='b', name='b', target='b', type=str), + Value.make(key='c', name='c', target='c'), + )), + expected_rt=result_type) + + def test_explicit_requirement_type_trumps_default_requirement_type(self): + + class FromRequest(Requirement): pass + + @requires(a=Requirement('a')) + def foo(a): + pass + + compare(actual=extract_requires(foo, requires(a=FromRequest('b'))), + strict=True, + expected=RequiresType(( + FromRequest.make(key='b', name='a', target='a'), + ))) From 5997a5f15dc1c215e6e2ad1d4250cf2ea4700025 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 21 Sep 2020 08:50:49 +0100 Subject: [PATCH 111/159] comment out more stuff that may not end up being needed --- mush/__init__.py | 2 +- mush/callpoints.py | 2 +- mush/extraction.py | 116 ++++++++++++++++++++++----------------------- mush/runner.py | 2 +- 4 files changed, 61 insertions(+), 61 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index 985ded1..c2813ef 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -2,7 +2,7 @@ from .declarations import ( requires, returns, returns_result_type, returns_mapping, returns_sequence, ) -from .extraction import extract_requires, extract_returns, update_wrapper +from .extraction import extract_requires#, extract_returns, update_wrapper from .markers import missing, nonblocking, blocking from .plug import Plug from .requirements import Requirement, Value#, AnyOf, Like diff --git a/mush/callpoints.py b/mush/callpoints.py index 9ef55d7..588b201 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -4,7 +4,7 @@ from .declarations import ( requires_nothing, returns as returns_declaration, returns_nothing ) -from .extraction import extract_requires, extract_returns +from .extraction import extract_requires#, extract_returns from .requirements import name_or_repr from .typing import Requires, Returns diff --git a/mush/extraction.py b/mush/extraction.py index 336ffb8..fe2a1a1 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -18,33 +18,33 @@ EMPTY = Parameter.empty #: For these types, prefer the name instead of the type. -SIMPLE_TYPES = (str, int, dict, list) - - -def _apply_requires(by_name, by_index, requires_): - - for i, r in enumerate(requires_): - - if r.target is None: - try: - name = by_index[i] - except IndexError: - # case where something takes *args - by_name[i] = r.make_from(r) - continue - else: - name = r.target - - existing = by_name[name] - by_name[name] = r.make_from( - r, - name=existing.name, - key=existing.key if r.key is None else r.key, - type=existing.type if r.type is None else r.type, - default=existing.default if r.default is missing else r.default, - ops=existing.ops if not r.ops else r.ops, - target=existing.target if r.target is None else r.target, - ) +# SIMPLE_TYPES = (str, int, dict, list) +# +# +# def _apply_requires(by_name, by_index, requires_): +# +# for i, r in enumerate(requires_): +# +# if r.target is None: +# try: +# name = by_index[i] +# except IndexError: +# # case where something takes *args +# by_name[i] = r.make_from(r) +# continue +# else: +# name = r.target +# +# existing = by_name[name] +# by_name[name] = r.make_from( +# r, +# name=existing.name, +# key=existing.key if r.key is None else r.key, +# type=existing.type if r.type is None else r.type, +# default=existing.default if r.default is missing else r.default, +# ops=existing.ops if not r.ops else r.ops, +# target=existing.target if r.target is None else r.target, +# ) def extract_requires(obj: Callable) -> Iterable[Requirement]: @@ -143,34 +143,34 @@ def extract_requires(obj: Callable) -> Iterable[Requirement]: # return RequiresType(by_name.values()) -def extract_returns(obj: Callable, explicit: Returns = None): - if explicit is None: - returns_ = get_mush(obj, 'returns', None) - if returns_ is None: - annotations = getattr(obj, '__annotations__', {}) - returns_ = annotations.get('return') - else: - returns_ = explicit - - if returns_ is None or isinstance(returns_, ReturnsType): - pass - elif isinstance(returns_, (list, tuple)): - returns_ = returns(*returns_) - else: - returns_ = returns(returns_) - - return returns_ or result_type - - -WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) - - -def update_wrapper(wrapper, - wrapped, - assigned=WRAPPER_ASSIGNMENTS, - updated=WRAPPER_UPDATES): - """ - An extended version of :func:`functools.update_wrapper` that - also preserves Mush's annotations. - """ - return functools_update_wrapper(wrapper, wrapped, assigned, updated) +# def extract_returns(obj: Callable, explicit: Returns = None): +# if explicit is None: +# returns_ = get_mush(obj, 'returns', None) +# if returns_ is None: +# annotations = getattr(obj, '__annotations__', {}) +# returns_ = annotations.get('return') +# else: +# returns_ = explicit +# +# if returns_ is None or isinstance(returns_, ReturnsType): +# pass +# elif isinstance(returns_, (list, tuple)): +# returns_ = returns(*returns_) +# else: +# returns_ = returns(returns_) +# +# return returns_ or result_type +# +# +# WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) +# +# +# def update_wrapper(wrapper, +# wrapped, +# assigned=WRAPPER_ASSIGNMENTS, +# updated=WRAPPER_UPDATES): +# """ +# An extended version of :func:`functools.update_wrapper` that +# also preserves Mush's annotations. +# """ +# return functools_update_wrapper(wrapper, wrapped, assigned, updated) diff --git a/mush/runner.py b/mush/runner.py index 03652d5..79dffc0 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -3,7 +3,7 @@ from .callpoints import CallPoint from .context import Context, ResourceError from .declarations import DeclarationsFrom -from .extraction import extract_requires, extract_returns +from .extraction import extract_requires#, extract_returns from .markers import not_specified from .modifier import Modifier from .plug import Plug From 2836500e409fd134c3bb7df1b7d4393733ce7dad Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 21 Sep 2020 08:51:41 +0100 Subject: [PATCH 112/159] bugfix: keys typing --- mush/requirements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index fc73451..c24a864 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -46,8 +46,8 @@ class Requirement: The requirement for an individual parameter of a callable. """ - def __init__(self, default: Any, *keys: Sequence[ResourceKey]): - self.keys = keys + def __init__(self, default: Any, *keys: ResourceKey): + self.keys: Sequence[ResourceKey] = keys self.default = default self.ops: List['Op'] = [] # self.target: Optional[str] = target From 8dee15f92b630329fe4c56ff6752034cad9a87c6 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 23 Sep 2020 08:50:05 +0100 Subject: [PATCH 113/159] kill context.get --- mush/context.py | 17 ----------------- mush/tests/test_context.py | 19 ------------------- 2 files changed, 36 deletions(-) diff --git a/mush/context.py b/mush/context.py index baf90b0..c482b80 100644 --- a/mush/context.py +++ b/mush/context.py @@ -137,23 +137,6 @@ def call(self, obj: Callable, requires: RequiresType = None): return obj(*args, **kw) # - # def get(self, key: ResourceKey, default=unspecified): - # context = self - # - # while context is not None: - # value = context._store.get(key, missing) - # if value is missing: - # context = context._parent - # else: - # if context is not self: - # self._store[key] = value - # return value - # - # if default is unspecified: - # raise ResourceError(f'No {key!r} in context', key) - # - # return default - # # def nest(self, requirement_modifier: RequirementModifier = None): # if requirement_modifier is None: # requirement_modifier = self.requirement_modifier diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index ef4c411..2aa21ad 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -395,25 +395,6 @@ def test_call_requires_requirement(self): # context.remove('foo', strict=False) # compare(context._store, expected={}) # -# def test_get_present(self): -# context = Context() -# context.add('bar', provides='foo') -# compare(context.get('foo'), expected='bar') -# -# def test_get_type(self): -# context = Context() -# context.add(['bar'], provides=List[str]) -# compare(context.get(List[str]), expected=['bar']) -# compare(context.get(List[int], default=None), expected=None) -# compare(context.get(List, default=None), expected=None) -# # nb: this might be surprising: -# compare(context.get(list, default=None), expected=None) -# -# def test_get_missing(self): -# context = Context() -# with ShouldRaise(ResourceError("No 'foo' in context", 'foo')): -# context.get('foo') -# # def test_nest(self): # c1 = Context() # c1.add('a', provides='a') From 7810b7b49cdba3bf8255b564c9569020c200782f Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 23 Sep 2020 09:04:04 +0100 Subject: [PATCH 114/159] Implement providers --- mush/context.py | 96 ++++++++----- mush/resources.py | 30 +++- mush/tests/test_context.py | 271 +++++++++++++++++++++++++++++++------ 3 files changed, 317 insertions(+), 80 deletions(-) diff --git a/mush/context.py b/mush/context.py index c482b80..ba31765 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,10 +1,11 @@ -from typing import Optional, Callable, Hashable, Type, Sequence +from inspect import signature +from typing import Optional, Callable, Hashable, Type, Union, Mapping, Any, Dict +from .requirements import Requirement from .declarations import RequiresType from .extraction import extract_requires from .markers import missing, Marker -from .requirements import Requirement -from .resources import ResourceKey, Resource +from .resources import ResourceKey, Resource, Provider from .typing import ResourceValue NONE_TYPE = type(None) @@ -25,29 +26,47 @@ class Context: def __init__(self): self._store = {} - self._seen_types = set() - self._seen_identifiers = set() - self._requires_cache = {} + # self._requires_cache = {} # self._returns_cache = {} def add(self, - resource: ResourceValue, + obj: Union[Provider, ResourceValue], provides: Optional[Type] = missing, identifier: Hashable = None): """ Add a resource to the context. Optionally specify what the resource provides. + + ``provides`` can be explicitly specified as ``None`` to only register against the identifier """ - if provides is missing: - provides = type(resource) - to_add = [ResourceKey(provides, identifier)] - if identifier and provides: + if isinstance(obj, Provider): + resource = obj + if provides is missing: + sig = signature(obj.provider) + annotation = sig.return_annotation + if annotation is sig.empty: + if identifier is None: + raise ResourceError( + f'Could not determine what is provided by {obj.provider}' + ) + else: + provides = annotation + + else: + resource = Resource(obj) + if provides is missing: + provides = type(obj) + + to_add = [] + if provides is not missing: + to_add.append(ResourceKey(provides, identifier)) + if not (identifier is None or provides is None): to_add.append(ResourceKey(None, identifier)) for key in to_add: if key in self._store: raise ResourceError(f'Context already contains {key}') - self._store[key] = Resource(resource) + self._store[key] = resource # def remove(self, key: ResourceKey, *, strict: bool = True): # """ @@ -83,27 +102,46 @@ def __repr__(self): # self._process(obj, result, returns) # return result - def _resolve(self, obj, requires, args, kw, context): + def _find_resource(self, key): + if not isinstance(key[0], type): + return self._store.get(key) + type_, identifier = key + exact = True + for type__ in type_.__mro__: + resource = self._store.get((type__, identifier)) + if resource is not None and (exact or resource.provides_subclasses): + return resource + exact = False + + def _resolve(self, obj, specials = None): + if specials is None: + specials: Dict[type, Any] = {Context: self} - if requires is None: - requires = self._requires_cache.get(obj) - if requires is None: - requires = extract_requires(obj) - self._requires_cache[obj] = requires + requires = extract_requires(obj) - specials = {Context: self} + args = [] + kw = {} for requirement in requires: o = missing for key in requirement.keys: - # how to handle context and requirement here?! - resource = self._store.get(key) + + resource = self._find_resource(key) + if resource is None: o = specials.get(key[0], missing) else: - o = resource.obj + if resource.obj is missing: + specials_ = specials.copy() + specials_[Requirement] = requirement + o = self._resolve(resource.provider, specials_) + if resource.cache: + resource.obj = o + else: + o = resource.obj + if o is not missing: break @@ -122,20 +160,12 @@ def _resolve(self, obj, requires, args, kw, context): # if requirement.target is None: args.append(o) - # else: - # kw[requirement.target] = o - # - # yield - def call(self, obj: Callable, requires: RequiresType = None): - args = [] - kw = {} + return obj(*args, **kw) - self._resolve(obj, requires, args, kw, self) - # for requirement in resolving: - # resolving.send(requirement.resolve(self)) + def call(self, obj: Callable, requires: RequiresType = None): + return self._resolve(obj) - return obj(*args, **kw) # # def nest(self, requirement_modifier: RequirementModifier = None): # if requirement_modifier is None: diff --git a/mush/resources.py b/mush/resources.py index 4ec9f95..66b541e 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -1,3 +1,9 @@ +from typing import Callable, Optional + +from .markers import missing +from .typing import ResourceValue + + class ResourceKey(tuple): def __new__(cls, type_, identifier): @@ -19,14 +25,28 @@ def __str__(self): return f'{self.type!r}, {self.identifier!r}' -class Provider: - pass - - class Resource: - def __init__(self, obj): + provider: Optional[Callable] = None + provides_subclasses: bool = False + + def __init__(self, obj: ResourceValue): self.obj = obj def __repr__(self): return repr(self.obj) + + +class Provider(Resource): + + def __init__(self, obj: Callable, *, cache: bool = True, provides_subclasses: bool = False): + super().__init__(missing) + self.provider = obj + self.cache = cache + self.provides_subclasses = provides_subclasses + + def __repr__(self): + obj_repr = '' if self.obj is missing else f'cached={self.obj!r}, ' + return (f'Provider({self.provider}, {obj_repr}' + f'cache={self.cache}, ' + f'provides_subclasses={self.provides_subclasses})') diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 2aa21ad..675e139 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,21 +1,24 @@ # from typing import Tuple, List # +from typing import NewType + from testfixtures import ShouldRaise, compare + # from testfixtures.mock import Mock # from mush import ( - Context#, requires, returns, returns_mapping, Value, missing + Context, Requirement # , requires, returns, returns_mapping, Value, missing ) from mush.context import ResourceError # from mush.declarations import RequiresType, requires_nothing, returns_nothing # from mush.requirements import Requirement -from .helpers import TheType -from ..resources import Resource +from .helpers import TheType, Type1, Type2 +from ..resources import Resource, Provider -class TestContext(object): +class TestAdd: - def test_add_by_inferred_type(self): + def test_by_inferred_type(self): obj = TheType() context = Context() context.add(obj) @@ -29,7 +32,7 @@ def test_add_by_inferred_type(self): compare(expected, actual=repr(context)) compare(expected, actual=str(context)) - def test_add_by_identifier(self): + def test_by_identifier(self): obj = TheType() context = Context() context.add(obj, identifier='my label') @@ -45,7 +48,7 @@ def test_add_by_identifier(self): compare(expected, actual=repr(context)) compare(expected, actual=str(context)) - def test_add_by_identifier_only(self): + def test_by_identifier_only(self): obj = TheType() context = Context() context.add(obj, provides=None, identifier='my label') @@ -101,14 +104,17 @@ def test_clash_identifier_plus_type_with_identifier_only(self): with ShouldRaise(ResourceError("Context already contains 'my label'")): context.add(obj2, provides=None, identifier='my label') - def test_call_no_params(self): + +class TestCall: + + def test_no_params(self): def foo(): return 'bar' context = Context() result = context.call(foo) compare(result, 'bar') - def test_call_type_from_annotation(self): + def test_type_from_annotation(self): def foo(baz: str): return baz context = Context() @@ -116,7 +122,7 @@ def foo(baz: str): result = context.call(foo) compare(result, expected='bar') - def test_call_identifier_from_annotation(self): + def test_identifier_from_annotation(self): def foo(baz: str): return baz context = Context() @@ -125,7 +131,7 @@ def foo(baz: str): result = context.call(foo) compare(result, expected='bob') - def test_call_by_identifier_only(self): + def test_by_identifier_only(self): def foo(param): return param @@ -134,7 +140,7 @@ def foo(param): result = context.call(foo) compare(result, 'bar') - def test_call_requires_missing(self): + def test_requires_missing(self): def foo(obj: TheType): return obj context = Context() with ShouldRaise(ResourceError( @@ -142,7 +148,7 @@ def foo(obj: TheType): return obj )): context.call(foo) - def test_call_optional_type_present(self): + def tes_optional_type_present(self): def foo(x: TheType = 1): return x context = Context() @@ -150,14 +156,14 @@ def foo(x: TheType = 1): result = context.call(foo) compare(result, 2) - def test_call_optional_type_missing(self): + def test_optional_type_missing(self): def foo(x: TheType = 1): return x context = Context() result = context.call(foo) compare(result, 1) - def test_call_optional_identifier_present(self): + def test_optional_identifier_present(self): def foo(x=1): return x @@ -166,7 +172,7 @@ def foo(x=1): result = context.call(foo) compare(result, 2) - def test_call_optional_identifier_missing(self): + def test_optional_identifier_missing(self): def foo(x=1): return x @@ -175,7 +181,7 @@ def foo(x=1): result = context.call(foo) compare(result, 1) - def test_call_requires_context(self): + def test_requires_context(self): context = Context() def return_context(context_: Context): @@ -183,7 +189,7 @@ def return_context(context_: Context): assert context.call(return_context) is context - def test_call_requires_requirement(self): + def test_base_class_should_not_match(self): # this should blow up unless we're in a provider? pass @@ -311,6 +317,9 @@ def test_call_requires_requirement(self): # compare(result, ('a',)) # compare(context._requires_cache, expected={}) # + + # XXX extract + # def test_extract_minimal(self): # o = TheType() # def foo() -> TheType: @@ -377,6 +386,9 @@ def test_call_requires_requirement(self): # compare(result, expected=None) # compare(context._store, expected={}) # + + # XXX - remove + # def test_remove(self): # context = Context() # context.add('foo') @@ -395,6 +407,8 @@ def test_call_requires_requirement(self): # context.remove('foo', strict=False) # compare(context._store, expected={}) # +# XXX - nest +# # def test_nest(self): # c1 = Context() # c1.add('a', provides='a') @@ -427,8 +441,205 @@ def test_call_requires_requirement(self): # c2 = c1.nest() # assert c2._requires_cache is c1._requires_cache # assert c2._returns_cache is c1._returns_cache +# - XXX nesting versus cached providers! + +class TestProviders: + + def test_cached(self): + items = [] + + def provider(): + items.append(1) + return sum(items) + + context = Context() + context.add(Provider(provider), provides=int) + + def returner(obj: int): + return obj + + compare(context.call(returner), expected=1) + compare(context.call(returner), expected=1) + + def test_not_cached(self): + items = [] + + def provider(): + items.append(1) + return sum(items) + + context = Context() + context.add(Provider(provider, cache=False), provides=int) + + def returner(obj: int): + return obj + + compare(context.call(returner), expected=1) + compare(context.call(returner), expected=2) + + def test_needs_resources(self): + def provider(start: int): + return start*2 + + context = Context() + context.add(Provider(provider), provides=int) + context.add(4, identifier='start') + + def returner(obj: int): + return obj + + compare(context.call(returner), expected=8) + + def test_needs_requirement(self): + def provider(requirement: Requirement): + return requirement.keys[0].identifier + + context = Context() + context.add(Provider(provider), provides=str) + + def returner(obj: str): + return obj + + compare(context.call(returner), expected='obj') + + def test_provides_subclasses(self): + class Base: pass + + class TheType(Base): pass + + def provider(requirement: Requirement): + return requirement.keys[0].type() + + def foo(bar: TheType): + return bar + + context = Context() + context.add(Provider(provider, provides_subclasses=True), provides=Base) + + assert isinstance(context.call(foo), TheType) + + def test_does_not_provide_subclasses(self): + def foo(obj: TheType): pass + + context = Context() + context.add(Provider(lambda: None), provides=object) + + with ShouldRaise(ResourceError( + "Value(, 'obj') could not be satisfied" + )): + context.call(foo) + + def test_multiple_providers_using_requirement(self): + def provider(requirement: Requirement): + return requirement.keys[0].type() + + def foo(t1: Type1, t2: Type2): + return t1, t2 + + context = Context() + context.add(Provider(provider), provides=Type1) + context.add(Provider(provider), provides=Type2) + + t1, t2 = context.call(foo) + assert isinstance(t1, Type1) + assert isinstance(t2, Type2) + + def test_nested_providers_using_requirement(self): + class Base1: pass + + class Type1(Base1): pass + def provider1(requirement: Requirement): + return requirement.keys[0].type() + class Base2: + def __init__(self, x): + self.x = x + + class Type2(Base2): pass + + # order here is important + def provider2(t1: Type1, requirement: Requirement): + return requirement.keys[0].type(t1) + + def foo(t2: Type2): + return t2 + + context = Context() + context.add(Provider(provider1, provides_subclasses=True), provides=Base1) + context.add(Provider(provider2, provides_subclasses=True), provides=Base2) + + t2 = context.call(foo) + assert isinstance(t2, Type2) + assert isinstance(t2.x, Type1) + + def test_from_return_type_annotation(self): + def provider() -> Type1: + return Type1() + + context = Context() + context.add(Provider(provider)) + + def returner(obj: Type1): + return obj + + assert isinstance(context.call(returner), Type1) + + def test_no_provides(self): + def provider(): pass + context = Context() + with ShouldRaise(ResourceError(f'Could not determine what is provided by {provider}')): + context.add(Provider(provider)) + + def test_identifier(self): + def provider() -> str: + return 'some foo' + + context = Context() + context.add(Provider(provider), identifier='param') + + def foo(param): + return param + + compare(context.call(foo), expected='some foo') + + def test_identifier_only(self): + def provider(): + return 'some foo' + + context = Context() + context.add(Provider(provider), identifier='param') + + def foo(param): + return param + + compare(context.call(foo), expected='some foo') + + def test_minimal_representation(self): + def provider(): pass + context = Context() + context.add(Provider(provider), provides=str) + expected = (": Provider({provider}, " + f"cache=True, provides_subclasses=False)\n" + "}>") + compare(expected, actual=repr(context)) + compare(expected, actual=str(context)) + + def test_maximal_representation(self): + def provider() -> str: pass + p = Provider(provider, cache=False, provides_subclasses=True) + p.obj = 'it' + context = Context() + context.add(p, provides=str, identifier='the id') + expected = (", 'the id': Provider({provider}, " + f"cached='it', cache=False, provides_subclasses=True)\n" + f" 'the id': Provider({provider}, " + f"cached='it', cache=False, provides_subclasses=True)\n" + "}>") + compare(expected, actual=repr(context)) + compare(expected, actual=str(context)) # XXX "custom requirement" stuff # @@ -460,27 +671,3 @@ def test_call_requires_requirement(self): # key='bar', # requirement=FromRequest.make(key='bar', name='bar'))): # compare(context.call(foo)) -# -# def test_default_custom_requirement(self): -# -# class FromRequest(Requirement): -# def resolve(self, context): -# return context.get('request')[self.key] -# -# def foo(bar): -# return bar -# -# def modifier(requirement): -# if type(requirement) is Requirement: -# requirement = FromRequest.make_from(requirement) -# return requirement -# -# context = Context(requirement_modifier=modifier) -# context.add({'bar': 'foo'}, provides='request') -# compare(context.call(foo), expected='foo') - - def test_provider(self): - pass - - def test_provider_needs_requirement(self): - pass From 891ac02ce4f63396adedf3c26b0d7abfbb65771a Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 23 Sep 2020 22:14:09 +0100 Subject: [PATCH 115/159] Better repr for NewTypes and further typing tests. --- mush/resources.py | 10 +++++-- mush/tests/test_context.py | 56 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/mush/resources.py b/mush/resources.py index 66b541e..94d0fc5 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -20,9 +20,13 @@ def identifier(self): def __str__(self): if self.type is None: return repr(self.identifier) - elif self.identifier is None: - return repr(self.type) - return f'{self.type!r}, {self.identifier!r}' + if hasattr(self.type, '__supertype__'): + type_repr = f'NewType({self.type.__name__}, {self.type.__supertype__})' + else: + type_repr = repr(self.type) + if self.identifier is None: + return type_repr + return f'{type_repr}, {self.identifier!r}' class Resource: diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 675e139..5a52c50 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,6 +1,6 @@ # from typing import Tuple, List # -from typing import NewType +from typing import NewType, Mapping, Any from testfixtures import ShouldRaise, compare @@ -190,8 +190,58 @@ def return_context(context_: Context): assert context.call(return_context) is context def test_base_class_should_not_match(self): - # this should blow up unless we're in a provider? - pass + def foo(obj: TheType): return obj + context = Context() + context.add(object()) + with ShouldRaise(ResourceError( + "Value(, 'obj') could not be satisfied" + )): + context.call(foo) + + def test_requires_typing(self): + Request = NewType('Request', dict) + context = Context() + request = {} + context.add(request, provides=Request) + + def returner(request_: Request): + return request_ + + assert context.call(returner) is request + + def test_requires_typing_missing_typing(self): + context = Context() + + def returner(request_: Mapping[str, Any]): + return request_ + + with ShouldRaise(ResourceError( + "Value(typing.Mapping[str, typing.Any], 'request_') could not be satisfied" + )): + context.call(returner) + + def test_requires_typing_missing_new_type(self): + Request = NewType('Request', dict) + context = Context() + + def returner(request_: Request): + return request_ + + with ShouldRaise(ResourceError( + "Value(NewType(Request, ), 'request_') could not be satisfied" + )): + context.call(returner) + + def test_requires_requirement(self): + context = Context() + + def foo(requirement: Requirement): pass + + with ShouldRaise(ResourceError( + "Value(, 'requirement') " + "could not be satisfied" + )): + context.call(foo) # XXX - these are for explicit requires() objects: # def test_call_requires_string(self): From 07b413bdc7deab98b4ff4de4e18c253a569bb5ce Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 24 Sep 2020 07:45:05 +0100 Subject: [PATCH 116/159] add skips so the whole test suite can be run --- docs/conftest.py | 4 ++-- mush/tests/test_async_context.py | 6 ++++-- mush/tests/test_async_runner.py | 8 +++++--- mush/tests/test_callpoints.py | 4 +++- mush/tests/test_declarations.py | 1 + mush/tests/test_example_with_mush_clone.py | 4 +++- mush/tests/test_example_with_mush_factory.py | 4 +++- mush/tests/test_plug.py | 2 ++ mush/tests/test_requirements.py | 1 + mush/tests/test_requirements_extraction.py | 1 + mush/tests/test_runner.py | 1 + 11 files changed, 26 insertions(+), 10 deletions(-) diff --git a/docs/conftest.py b/docs/conftest.py index 291ab9e..187b090 100644 --- a/docs/conftest.py +++ b/docs/conftest.py @@ -15,5 +15,5 @@ ).pytest() -def pytest_collect_file(parent, path): - return sybil_collector(parent, path) +# def pytest_collect_file(parent, path): +# return sybil_collector(parent, path) diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 22051ef..1b14c29 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -1,3 +1,5 @@ +import pytest; pytestmark = pytest.mark.skip("WIP") + import asyncio from functools import partial from typing import Tuple @@ -6,9 +8,9 @@ from testfixtures import compare, ShouldRaise from mush import Value, requires, returns, Context as SyncContext, blocking, nonblocking -from mush.asyncio import Context +# from mush.asyncio import Context from mush.declarations import RequiresType -from mush.requirements import Requirement, AnyOf, Like +# from mush.requirements import Requirement, AnyOf, Like from .helpers import TheType, no_threads, must_run_in_thread from ..markers import AsyncType diff --git a/mush/tests/test_async_runner.py b/mush/tests/test_async_runner.py index b5d75b1..b9a2799 100644 --- a/mush/tests/test_async_runner.py +++ b/mush/tests/test_async_runner.py @@ -1,3 +1,5 @@ +import pytest; pytestmark = pytest.mark.skip("WIP") + import asyncio from testfixtures.mock import Mock, call @@ -5,7 +7,7 @@ from testfixtures import compare, ShouldRaise, Comparison as C from mush import ContextError, requires, returns -from mush.asyncio import Runner, Context +# from mush.asyncio import Runner, Context from .helpers import no_threads, must_run_in_thread @@ -67,7 +69,7 @@ class AsyncCM(CommonCM): async def __aenter__(self): self.m.enter() - if self.context is 'self': + if self.context == 'self': return self return self.context @@ -80,7 +82,7 @@ class SyncCM(CommonCM): def __enter__(self): self.m.enter() - if self.context is 'self': + if self.context == 'self': return self return self.context diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 789af2b..8b00fa4 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -1,3 +1,5 @@ +import pytest; pytestmark = pytest.mark.skip("WIP") + from functools import update_wrapper from unittest import TestCase @@ -6,7 +8,7 @@ from mush.callpoints import CallPoint from mush.declarations import requires, returns, RequiresType -from mush.extraction import update_wrapper +# from mush.extraction import update_wrapper from mush.requirements import Value from mush.runner import Runner diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index d4d34f9..7dba6a7 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -1,3 +1,4 @@ +import pytest; pytestmark = pytest.mark.skip("WIP") from typing import Tuple from unittest import TestCase diff --git a/mush/tests/test_example_with_mush_clone.py b/mush/tests/test_example_with_mush_clone.py index eefcc22..0364779 100644 --- a/mush/tests/test_example_with_mush_clone.py +++ b/mush/tests/test_example_with_mush_clone.py @@ -1,4 +1,4 @@ -from .example_with_mush_clone import DatabaseHandler, main, do, setup_logging +# from .example_with_mush_clone import DatabaseHandler, main, do, setup_logging from unittest import TestCase from testfixtures import TempDirectory from testfixtures import Replacer @@ -6,6 +6,8 @@ from testfixtures import ShouldRaise import sqlite3 +import pytest; pytestmark = pytest.mark.skip("WIP") + class Tests(TestCase): def test_main(self): diff --git a/mush/tests/test_example_with_mush_factory.py b/mush/tests/test_example_with_mush_factory.py index 1687ac0..392207f 100644 --- a/mush/tests/test_example_with_mush_factory.py +++ b/mush/tests/test_example_with_mush_factory.py @@ -1,4 +1,6 @@ -from .example_with_mush_factory import main +# from .example_with_mush_factory import main +import pytest; pytestmark = pytest.mark.skip("WIP") + from unittest import TestCase from testfixtures import TempDirectory, Replacer import sqlite3 diff --git a/mush/tests/test_plug.py b/mush/tests/test_plug.py index 2826339..e45513e 100644 --- a/mush/tests/test_plug.py +++ b/mush/tests/test_plug.py @@ -1,3 +1,5 @@ +import pytest; pytestmark = pytest.mark.skip("WIP") + from unittest import TestCase from testfixtures import compare, ShouldRaise diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index d00a080..c34f0ed 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -1,3 +1,4 @@ +import pytest; pytestmark = pytest.mark.skip("WIP") from typing import Tuple from unittest.case import TestCase diff --git a/mush/tests/test_requirements_extraction.py b/mush/tests/test_requirements_extraction.py index a40f312..5a56c06 100644 --- a/mush/tests/test_requirements_extraction.py +++ b/mush/tests/test_requirements_extraction.py @@ -1,3 +1,4 @@ +import pytest; pytestmark = pytest.mark.skip("WIP") from functools import partial from typing import Tuple from unittest import TestCase diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 757d991..d2e5747 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -1,3 +1,4 @@ +import pytest; pytestmark = pytest.mark.skip("WIP") from unittest import TestCase from mush.declarations import ( From 1ad898e9464bcce24227ad6a7145a027a16ed307 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 24 Sep 2020 07:55:29 +0100 Subject: [PATCH 117/159] pattern for test suites that will only compile on a minimum version of python --- mush/tests/conftest.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 mush/tests/conftest.py diff --git a/mush/tests/conftest.py b/mush/tests/conftest.py new file mode 100644 index 0000000..62cd59b --- /dev/null +++ b/mush/tests/conftest.py @@ -0,0 +1,9 @@ +import sys +from re import search + + +def pytest_ignore_collect(path): + file_min_version_match = search(r'_py(\d)(\d)$', path.purebasename) + if file_min_version_match: + file_min_version = tuple(int(d) for d in file_min_version_match.groups()) + return sys.version_info < file_min_version From 98bcda00c6e80f05f99d969cdcb384b2d31620e3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 24 Sep 2020 08:01:42 +0100 Subject: [PATCH 118/159] keyword-only and positional-only arguments --- mush/context.py | 6 ++++-- mush/extraction.py | 10 +++++++--- mush/requirements.py | 2 +- mush/tests/test_context.py | 9 +++++++++ mush/tests/test_context_py38.py | 23 +++++++++++++++++++++++ 5 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 mush/tests/test_context_py38.py diff --git a/mush/context.py b/mush/context.py index ba31765..3a7a78d 100644 --- a/mush/context.py +++ b/mush/context.py @@ -158,8 +158,10 @@ def _resolve(self, obj, specials = None): if o is missing: raise ResourceError(f'{requirement!r} could not be satisfied') - # if requirement.target is None: - args.append(o) + if requirement.target is None: + args.append(o) + else: + kw[requirement.target] = o return obj(*args, **kw) diff --git a/mush/extraction.py b/mush/extraction.py index fe2a1a1..b789450 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -16,7 +16,6 @@ from .markers import missing, get_mush from .typing import Requires, Returns -EMPTY = Parameter.empty #: For these types, prefer the name instead of the type. # SIMPLE_TYPES = (str, int, dict, list) # @@ -60,12 +59,12 @@ def extract_requires(obj: Callable) -> Iterable[Requirement]: # continue # name = p.name - if isinstance(p.annotation, type) and p.annotation is not EMPTY: + if p.annotation is not p.empty: type_ = p.annotation else: type_ = None - default = missing if p.default is EMPTY else p.default + default = missing if p.default is p.empty else p.default ops = [] requirement = Value(type_, p.name, default) @@ -106,6 +105,10 @@ def extract_requires(obj: Callable) -> Iterable[Requirement]: # if p.kind is p.KEYWORD_ONLY: # requirement.target = p.name # + + if p.kind is p.KEYWORD_ONLY: + requirement.target = p.name + by_name[name] = requirement # # by_index = list(by_name) @@ -139,6 +142,7 @@ def extract_requires(obj: Callable) -> Iterable[Requirement]: # elif needs_target: # requirement.target = requirement.name # + return by_name.values() # return RequiresType(by_name.values()) diff --git a/mush/requirements.py b/mush/requirements.py index c24a864..6682a6a 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -50,7 +50,7 @@ def __init__(self, default: Any, *keys: ResourceKey): self.keys: Sequence[ResourceKey] = keys self.default = default self.ops: List['Op'] = [] - # self.target: Optional[str] = target + self.target: Optional[str] = None def _keys_repr(self): return ', '.join(repr(key) for key in self.keys) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 5a52c50..29c5dad 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -243,6 +243,15 @@ def foo(requirement: Requirement): pass )): context.call(foo) + def test_keyword_only(self): + def foo(*, x: int): + return x + + context = Context() + context.add(2) + result = context.call(foo) + compare(result, expected=2) + # XXX - these are for explicit requires() objects: # def test_call_requires_string(self): # def foo(obj): diff --git a/mush/tests/test_context_py38.py b/mush/tests/test_context_py38.py new file mode 100644 index 0000000..a3dc935 --- /dev/null +++ b/mush/tests/test_context_py38.py @@ -0,0 +1,23 @@ +from testfixtures import compare + +from mush import Context + + +class TestCall: + + def test_positional_only(self): + def foo(x:int, /): + return x + + context = Context() + context.add(2) + result = context.call(foo) + compare(result, expected=2) + + def test_positional_only_with_default(self): + def foo(x:int = 1, /): + return x + + context = Context() + result = context.call(foo) + compare(result, expected=1) From 50d5bc15811bed9174bb98e63e7653b4d33362e3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 25 Sep 2020 07:25:42 +0100 Subject: [PATCH 119/159] Lose the Type suffix --- mush/asyncio.py | 12 ++--- mush/context.py | 4 +- mush/declarations.py | 18 ++++--- mush/extraction.py | 2 +- mush/tests/test_async_context.py | 6 +-- mush/tests/test_callpoints.py | 16 +++---- mush/tests/test_requirements_extraction.py | 56 +++++++++++----------- mush/typing.py | 6 +-- 8 files changed, 62 insertions(+), 58 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index bfee086..2d7f994 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -5,7 +5,7 @@ from . import ( Context as SyncContext, Runner as SyncRunner, ResourceError, ContextError ) -from .declarations import RequiresType, ReturnsType +from .declarations import Requirements, Return from .extraction import default_requirement_type from .markers import get_mush, AsyncType from .typing import RequirementModifier @@ -20,12 +20,12 @@ def __init__(self, context, loop): self.add = context.add self.get = context.get - def call(self, obj: Callable, requires: RequiresType = None): + def call(self, obj: Callable, requires: Requirements = None): coro = self.context.call(obj, requires) future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() - def extract(self, obj: Callable, requires: RequiresType = None, returns: ReturnsType = None): + def extract(self, obj: Callable, requires: Requirements = None, returns: Return = None): coro = self.context.extract(obj, requires, returns) future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() @@ -70,7 +70,7 @@ async def _ensure_async(self, func, *args, **kw): def _context_for(self, obj): return self if asyncio.iscoroutinefunction(obj) else self._sync_context - async def call(self, obj: Callable, requires: RequiresType = None): + async def call(self, obj: Callable, requires: Requirements = None): args = [] kw = {} resolving = self._resolve(obj, requires, args, kw, self._context_for(obj)) @@ -82,8 +82,8 @@ async def call(self, obj: Callable, requires: RequiresType = None): async def extract(self, obj: Callable, - requires: RequiresType = None, - returns: ReturnsType = None): + requires: Requirements = None, + returns: Return = None): result = await self.call(obj, requires) self._process(obj, result, returns) return result diff --git a/mush/context.py b/mush/context.py index 3a7a78d..2dbacb5 100644 --- a/mush/context.py +++ b/mush/context.py @@ -2,7 +2,7 @@ from typing import Optional, Callable, Hashable, Type, Union, Mapping, Any, Dict from .requirements import Requirement -from .declarations import RequiresType +from .declarations import Requirements from .extraction import extract_requires from .markers import missing, Marker from .resources import ResourceKey, Resource, Provider @@ -165,7 +165,7 @@ def _resolve(self, obj, specials = None): return obj(*args, **kw) - def call(self, obj: Callable, requires: RequiresType = None): + def call(self, obj: Callable): return self._resolve(obj) # diff --git a/mush/declarations.py b/mush/declarations.py index ed2fa02..fa566f4 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -23,7 +23,11 @@ def valid_decoration_types(*objs): ) -class RequiresType(list): +class Parameter: + __slots__ = 'target', 'requirement' + + +class Requirements(list): def __repr__(self): parts = (repr(r) if r.target is None else f'{r.target}={r!r}' @@ -46,7 +50,7 @@ def requires(*args: RequirementType, **kw: RequirementType): String names for resources must be used instead of types where the callable returning those resources is configured to return the named resource. """ - requires_ = RequiresType() + requires_ = Requirements() valid_decoration_types(*args) valid_decoration_types(*kw.values()) for target, possible in chain( @@ -65,10 +69,10 @@ def requires(*args: RequirementType, **kw: RequirementType): return requires_ -requires_nothing = RequiresType() +requires_nothing = Requirements() -class ReturnsType(object): +class Return(object): def __call__(self, obj): set_mush(obj, 'returns', self) @@ -78,7 +82,7 @@ def __repr__(self): return self.__class__.__name__ + '()' -class returns(ReturnsType): +class returns(Return): """ Declaration that specifies names for returned resources or overrides the type of a returned resource. @@ -105,7 +109,7 @@ def __repr__(self): return self.__class__.__name__ + '(' + args_repr + ')' -class returns_result_type(ReturnsType): +class returns_result_type(Return): """ Default declaration that indicates a callable's return value should be used as a resource based on the type of the object returned. @@ -118,7 +122,7 @@ def process(self, obj): yield obj.__class__, obj -class returns_mapping(ReturnsType): +class returns_mapping(Return): """ Declaration that indicates a callable returns a mapping of type or name to resource. diff --git a/mush/extraction.py b/mush/extraction.py index b789450..3f2abd4 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -8,7 +8,7 @@ from typing import Callable, Iterable from .declarations import ( - requires, RequiresType, ReturnsType, + requires, Requires, Returns, returns, result_type, requires_nothing ) diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 1b14c29..4a43799 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -9,7 +9,7 @@ from mush import Value, requires, returns, Context as SyncContext, blocking, nonblocking # from mush.asyncio import Context -from mush.declarations import RequiresType +from mush.declarations import Requirements # from mush.requirements import Requirement, AnyOf, Like from .helpers import TheType, no_threads, must_run_in_thread from ..markers import AsyncType @@ -165,7 +165,7 @@ async def test_call_cache_requires(): context = Context() def foo(): pass await context.call(foo) - compare(context._requires_cache[foo], expected=RequiresType()) + compare(context._requires_cache[foo], expected=Requirements()) @pytest.mark.asyncio @@ -219,7 +219,7 @@ def foo() -> TheType: result = await context.extract(foo) assert result is o compare({TheType: o}, actual=context._store) - compare(context._requires_cache[foo], expected=RequiresType()) + compare(context._requires_cache[foo], expected=Requirements()) compare(context._returns_cache[foo], expected=returns(TheType)) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 8b00fa4..b1687ec 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -7,7 +7,7 @@ from testfixtures.mock import Mock, call from mush.callpoints import CallPoint -from mush.declarations import requires, returns, RequiresType +from mush.declarations import requires, returns, Requirements # from mush.extraction import update_wrapper from mush.requirements import Value from mush.runner import Runner @@ -34,7 +34,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(self.context.extract.mock_calls, expected=[call(foo, - RequiresType([Value.make(key='foo', name='a1')]), + Requirements([Value.make(key='foo', name='a1')]), rt)]) def test_extract_from_decorations(self): @@ -49,7 +49,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(self.context.extract.mock_calls, expected=[call(foo, - RequiresType([Value.make(key='foo', name='a1')]), + Requirements([Value.make(key='foo', name='a1')]), returns('bar'))]) def test_extract_from_decorated_class(self): @@ -75,7 +75,7 @@ def foo(prefix): self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) result = CallPoint(self.runner, foo)(self.context) compare(result, expected=('the answer', - RequiresType([Value.make(key='foo', name='prefix')]), + Requirements([Value.make(key='foo', name='prefix')]), rt)) def test_explicit_trumps_decorators(self): @@ -88,7 +88,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(self.context.extract.mock_calls, expected=[call(foo, - RequiresType([Value.make(key='baz', name='a1')]), + Requirements([Value.make(key='baz', name='a1')]), returns('bob'))]) def test_repr_minimal(self): @@ -107,7 +107,7 @@ def foo(a1): pass def test_convert_to_requires_and_returns(self): def foo(baz): pass point = CallPoint(self.runner, foo, requires='foo', returns='bar') - self.assertTrue(isinstance(point.requires, RequiresType)) + self.assertTrue(isinstance(point.requires, Requirements)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires(Value('foo')) returns('bar')", repr(point)) @@ -118,7 +118,7 @@ def foo(a1, a2): pass foo, requires=('foo', 'bar'), returns=('baz', 'bob')) - self.assertTrue(isinstance(point.requires, RequiresType)) + self.assertTrue(isinstance(point.requires, Requirements)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires(Value('foo'), Value('bar')) returns('baz', 'bob')", repr(point)) @@ -129,7 +129,7 @@ def foo(a1, a2): pass foo, requires=['foo', 'bar'], returns=['baz', 'bob']) - self.assertTrue(isinstance(point.requires, RequiresType)) + self.assertTrue(isinstance(point.requires, Requirements)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires(Value('foo'), Value('bar')) returns('baz', 'bob')", repr(point)) diff --git a/mush/tests/test_requirements_extraction.py b/mush/tests/test_requirements_extraction.py index 5a56c06..d85e31c 100644 --- a/mush/tests/test_requirements_extraction.py +++ b/mush/tests/test_requirements_extraction.py @@ -11,7 +11,7 @@ requires, returns, returns_mapping, returns_sequence, returns_result_type, requires_nothing, - result_type, RequiresType + result_type, Requirements ) from mush.extraction import extract_requires#, extract_returns, update_wrapper from mush.requirements import Requirement, ItemOp @@ -30,7 +30,7 @@ class TestRequirementsExtraction(object): def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='a', name='a'), Value.make(key='b', default=None, name='b'), )), @@ -40,7 +40,7 @@ def test_default_requirements_for_class(self): class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='a', name='a'), Value.make(key='b', name='b', default=None), )), @@ -51,7 +51,7 @@ def foo(x, y, z, a=None): pass p = partial(foo, 1, y=2) check_extract( p, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='z', name='z', target='z'), Value.make(key='a', name='a', target='a', default=None), )), @@ -63,7 +63,7 @@ def foo(a=None): pass p = partial(foo) check_extract( p, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='a', name='a', default=None), )), expected_rt=result_type @@ -112,7 +112,7 @@ def foo(b, a=None): pass p = partial(foo) check_extract( p, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='b', name='b'), Value.make(key='a', name='a', default=None), )), @@ -125,7 +125,7 @@ def foo(b, a): pass check_extract( p, # since b is already bound: - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='a', name='a'), )), expected_rt=result_type @@ -136,7 +136,7 @@ def foo(b, a): pass p = partial(foo, a=1) check_extract( p, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='b', name='b'), )), expected_rt=result_type @@ -148,7 +148,7 @@ class TestExtractDeclarationsFromTypeAnnotations(object): def test_extract_from_annotations(self): def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='foo', name='a'), Value.make(key='b', name='b'), Value.make(key='bar', name='c', default=1), @@ -159,7 +159,7 @@ def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass def test_requires_only(self): def foo(a: 'foo'): pass check_extract(foo, - expected_rq=RequiresType((Value.make(key='foo', name='a'),)), + expected_rq=Requirements((Value.make(key='foo', name='a'),)), expected_rt=result_type) def test_returns_only(self): @@ -185,7 +185,7 @@ def foo(a: 'foo' = None) -> 'bar': compare(foo(), expected='the answer') check_extract(foo, - expected_rq=RequiresType((Value.make(key='foo', name='a', default=None),)), + expected_rq=Requirements((Value.make(key='foo', name='a', default=None),)), expected_rt=returns('bar')) def test_decorator_trumps_annotations(self): @@ -193,7 +193,7 @@ def test_decorator_trumps_annotations(self): @returns('bar') def foo(a: 'x') -> 'y': pass check_extract(foo, - expected_rq=RequiresType((Value.make(key='foo', name='a'),)), + expected_rq=Requirements((Value.make(key='foo', name='a'),)), expected_rt=returns('bar')) def test_returns_mapping(self): @@ -213,7 +213,7 @@ def foo() -> rt: pass def test_how_instance_in_annotations(self): def foo(a: Value('config')['db_url']): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='config', name='a', ops=[ItemOp('db_url')]), )), expected_rt=result_type) @@ -221,7 +221,7 @@ def foo(a: Value('config')['db_url']): pass def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='a', name='a'), Value.make(key='b', name='b', default=1), Value.make(key='c', name='c', target='c'), @@ -233,26 +233,26 @@ def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=RequiresType((Value.make(key=T, name='a', type=T),)), + expected_rq=Requirements((Value.make(key=T, name='a', type=T),)), expected_rt=result_type) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=RequiresType((Value.make(key='a', name='a', type=type_),)), + expected_rq=Requirements((Value.make(key='a', name='a', type=type_),)), expected_rt=result_type) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=RequiresType((Value.make(key='b', name='a', type=str),)), + expected_rq=Requirements((Value.make(key='b', name='a', type=str),)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='b', name='a', type=str, default=1), )), expected_rt=result_type) @@ -260,7 +260,7 @@ def foo(a: str = Value('b', default=1)): pass def test_value_annotation_plus_default(self): def foo(a: Value('b', type_=str) = 1): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='b', name='a', type=str, default=1), )), expected_rt=result_type) @@ -268,7 +268,7 @@ def foo(a: Value('b', type_=str) = 1): pass def test_value_annotation_just_type_in_value_key_plus_default(self): def foo(a: Value(str) = 1): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key=str, name='a', type=str, default=1), )), expected_rt=result_type) @@ -276,7 +276,7 @@ def foo(a: Value(str) = 1): pass def test_value_annotation_just_type_plus_default(self): def foo(a: Value(type_=str) = 1): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='a', name='a', type=str, default=1), )), expected_rt=result_type) @@ -285,19 +285,19 @@ def test_value_unspecified_with_type(self): class T1: pass def foo(a: T1 = Value()): pass check_extract(foo, - expected_rq=RequiresType((Value.make(key=T1, name='a', type=T1),)), + expected_rq=Requirements((Value.make(key=T1, name='a', type=T1),)), expected_rt=result_type) def test_value_unspecified_with_simple_type(self): def foo(a: str = Value()): pass check_extract(foo, - expected_rq=RequiresType((Value.make(key='a', name='a', type=str),)), + expected_rq=Requirements((Value.make(key='a', name='a', type=str),)), expected_rt=result_type) def test_value_unspecified(self): def foo(a=Value()): pass check_extract(foo, - expected_rq=RequiresType((Value.make(key='a', name='a'),)), + expected_rq=Requirements((Value.make(key='a', name='a'),)), expected_rt=result_type) def test_requirement_modifier(self): @@ -311,7 +311,7 @@ def modifier(requirement): return requirement rq = extract_requires(foo, modifier=modifier) - compare(rq, strict=True, expected=RequiresType(( + compare(rq, strict=True, expected=Requirements(( FromRequest(key='x', name='x', type_=str, default=None), ))) @@ -328,7 +328,7 @@ def foo(a: r1, b, c=r3): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='a', name='a'), Value.make(key='b', name='b', target='b'), Value.make(key='c', name='c', target='c'), @@ -345,7 +345,7 @@ def foo(a: r2 = r3, b: str = r2, c = r3): pass check_extract(foo, - expected_rq=RequiresType(( + expected_rq=Requirements(( Value.make(key='a', name='a', target='a'), Value.make(key='b', name='b', target='b', type=str), Value.make(key='c', name='c', target='c'), @@ -362,6 +362,6 @@ def foo(a): compare(actual=extract_requires(foo, requires(a=FromRequest('b'))), strict=True, - expected=RequiresType(( + expected=Requirements(( FromRequest.make(key='b', name='a', target='a'), ))) diff --git a/mush/typing.py b/mush/typing.py index e15be7f..58298af 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -2,17 +2,17 @@ if TYPE_CHECKING: from .context import Context - from .declarations import RequiresType, ReturnsType + from .declarations import Requirements, Return from .requirements import Requirement RequirementType = Union['Requirement', type, str] -Requires = Union['RequiresType', +Requires = Union['Requirements', RequirementType, List[RequirementType], Tuple[RequirementType, ...]] ReturnType = Union[type, str] -Returns = Union['ReturnsType', ReturnType, List[ReturnType], Tuple[ReturnType, ...]] +Returns = Union['Return', ReturnType, List[ReturnType], Tuple[ReturnType, ...]] ResourceValue = NewType('ResourceValue', Any) From ee4c5617878dd63ff5bf22c343d292ea8a5fb775 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 25 Sep 2020 08:18:05 +0100 Subject: [PATCH 120/159] Better typing --- mush/context.py | 15 +++++++-------- mush/resources.py | 16 ++++++++-------- mush/tests/test_context.py | 12 ++++++------ mush/typing.py | 7 +++---- 4 files changed, 24 insertions(+), 26 deletions(-) diff --git a/mush/context.py b/mush/context.py index 2dbacb5..9bd27ff 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,12 +1,11 @@ from inspect import signature -from typing import Optional, Callable, Hashable, Type, Union, Mapping, Any, Dict +from typing import Optional, Callable, Type, Union, Any, Dict -from .requirements import Requirement -from .declarations import Requirements from .extraction import extract_requires from .markers import missing, Marker -from .resources import ResourceKey, Resource, Provider -from .typing import ResourceValue +from .requirements import Requirement +from .resources import ResourceKey, ResourceValue, Provider +from .typing import Resource, Identifier NONE_TYPE = type(None) unspecified = Marker('unspecified') @@ -30,9 +29,9 @@ def __init__(self): # self._returns_cache = {} def add(self, - obj: Union[Provider, ResourceValue], + obj: Union[Provider, Resource], provides: Optional[Type] = missing, - identifier: Hashable = None): + identifier: Identifier = None): """ Add a resource to the context. @@ -54,7 +53,7 @@ def add(self, provides = annotation else: - resource = Resource(obj) + resource = ResourceValue(obj) if provides is missing: provides = type(obj) diff --git a/mush/resources.py b/mush/resources.py index 94d0fc5..d762380 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -1,20 +1,20 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Type from .markers import missing -from .typing import ResourceValue +from .typing import Resource, Identifier class ResourceKey(tuple): - def __new__(cls, type_, identifier): + def __new__(cls, type_: Type, identifier: Identifier): return tuple.__new__(cls, (type_, identifier)) @property - def type(self): + def type(self) -> Type: return self[0] @property - def identifier(self): + def identifier(self) -> Identifier: return self[1] def __str__(self): @@ -29,19 +29,19 @@ def __str__(self): return f'{type_repr}, {self.identifier!r}' -class Resource: +class ResourceValue: provider: Optional[Callable] = None provides_subclasses: bool = False - def __init__(self, obj: ResourceValue): + def __init__(self, obj: Resource): self.obj = obj def __repr__(self): return repr(self.obj) -class Provider(Resource): +class Provider(ResourceValue): def __init__(self, obj: Callable, *, cache: bool = True, provides_subclasses: bool = False): super().__init__(missing) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 29c5dad..a061ba3 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -13,7 +13,7 @@ # from mush.declarations import RequiresType, requires_nothing, returns_nothing # from mush.requirements import Requirement from .helpers import TheType, Type1, Type2 -from ..resources import Resource, Provider +from ..resources import ResourceValue, Provider class TestAdd: @@ -23,7 +23,7 @@ def test_by_inferred_type(self): context = Context() context.add(obj) - compare(context._store, expected={(TheType, None): Resource(obj)}) + compare(context._store, expected={(TheType, None): ResourceValue(obj)}) expected = ( ": \n" @@ -38,8 +38,8 @@ def test_by_identifier(self): context.add(obj, identifier='my label') compare(context._store, expected={ - (TheType, 'my label'): Resource(obj), - (None, 'my label'): Resource(obj), + (TheType, 'my label'): ResourceValue(obj), + (None, 'my label'): ResourceValue(obj), }) expected = (", 'my label': \n" @@ -53,7 +53,7 @@ def test_by_identifier_only(self): context = Context() context.add(obj, provides=None, identifier='my label') - compare(context._store, expected={(None, 'my label'): Resource(obj)}) + compare(context._store, expected={(None, 'my label'): ResourceValue(obj)}) expected = ("\n" "}>") @@ -65,7 +65,7 @@ class T2(object): pass obj = TheType() context = Context() context.add(obj, provides=T2) - compare(context._store, expected={(T2, None): Resource(obj)}) + compare(context._store, expected={(T2, None): ResourceValue(obj)}) expected = ("\n" "}>") diff --git a/mush/typing.py b/mush/typing.py index 58298af..d2df70f 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -1,7 +1,6 @@ -from typing import NewType, Union, Hashable, Callable, Any, TYPE_CHECKING, List, Tuple +from typing import NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple if TYPE_CHECKING: - from .context import Context from .declarations import Requirements, Return from .requirements import Requirement @@ -14,5 +13,5 @@ ReturnType = Union[type, str] Returns = Union['Return', ReturnType, List[ReturnType], Tuple[ReturnType, ...]] -ResourceValue = NewType('ResourceValue', Any) - +Resource = NewType('Resource', Any) +Identifier = Hashable From e114e88e626fcbfd9b110691dc114f0450e5ef93 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 25 Sep 2020 08:31:43 +0100 Subject: [PATCH 121/159] Add another layer so we should never have to modify a Requirement. --- mush/context.py | 9 +++++---- mush/declarations.py | 7 +++++-- mush/extraction.py | 47 +++++++++++++++++--------------------------- 3 files changed, 28 insertions(+), 35 deletions(-) diff --git a/mush/context.py b/mush/context.py index 9bd27ff..243b1c7 100644 --- a/mush/context.py +++ b/mush/context.py @@ -121,7 +121,8 @@ def _resolve(self, obj, specials = None): args = [] kw = {} - for requirement in requires: + for parameter in requires: + requirement = parameter.requirement o = missing @@ -145,7 +146,7 @@ def _resolve(self, obj, specials = None): break if o is missing: - o = requirement.default + o = parameter.default # if o is not requirement.default: # for op in requirement.ops: @@ -157,10 +158,10 @@ def _resolve(self, obj, specials = None): if o is missing: raise ResourceError(f'{requirement!r} could not be satisfied') - if requirement.target is None: + if parameter.target is None: args.append(o) else: - kw[requirement.target] = o + kw[parameter.target] = o return obj(*args, **kw) diff --git a/mush/declarations.py b/mush/declarations.py index fa566f4..a6c5945 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,6 +1,6 @@ from enum import Enum, auto from itertools import chain -from typing import _type_check +from typing import _type_check, Any from .markers import set_mush from .requirements import Requirement, Value, name_or_repr @@ -24,7 +24,10 @@ def valid_decoration_types(*objs): class Parameter: - __slots__ = 'target', 'requirement' + def __init__(self, requirement: Requirement, target:str, default: Any): + self.requirement = requirement + self.target = target + self.default = default class Requirements(list): diff --git a/mush/extraction.py b/mush/extraction.py index 3f2abd4..d46d2b4 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -4,17 +4,16 @@ update_wrapper as functools_update_wrapper, partial ) -from inspect import signature, Parameter +from inspect import signature from typing import Callable, Iterable from .declarations import ( - requires, Requires, Returns, + requires, Parameter, Requirements, Return, returns, result_type, requires_nothing ) from .requirements import Value, Requirement from .markers import missing, get_mush -from .typing import Requires, Returns #: For these types, prefer the name instead of the type. # SIMPLE_TYPES = (str, int, dict, list) @@ -46,8 +45,7 @@ # ) -def extract_requires(obj: Callable) -> Iterable[Requirement]: - # explicit: Requires = None): +def extract_requires(obj: Callable) -> Requirements: # from annotations by_name = {} for name, p in signature(obj).parameters.items(): @@ -59,13 +57,13 @@ def extract_requires(obj: Callable) -> Iterable[Requirement]: # continue # name = p.name + if p.annotation is not p.empty: type_ = p.annotation else: type_ = None default = missing if p.default is p.empty else p.default - ops = [] requirement = Value(type_, p.name, default) # @@ -94,22 +92,18 @@ def extract_requires(obj: Callable) -> Iterable[Requirement]: # else: # key = type_ # default = requirement.default if requirement.default is not missing else default - # ops = requirement.ops # # requirement.key = key # requirement.name = name # requirement.type = type_ # requirement.default = default - # requirement.ops = ops - # - # if p.kind is p.KEYWORD_ONLY: - # requirement.target = p.name - # - if p.kind is p.KEYWORD_ONLY: - requirement.target = p.name + by_name[name] = Parameter( + requirement, + target=p.name if p.kind is p.KEYWORD_ONLY else None, + default=requirement.default + ) - by_name[name] = requirement # # by_index = list(by_name) # @@ -130,21 +124,16 @@ def extract_requires(obj: Callable) -> Iterable[Requirement]: # # if not by_name: # return requires_nothing - # - # # sort out target and apply modifier: - # needs_target = False - # for name, requirement in by_name.items(): - # requirement_ = modifier(requirement) - # if requirement_ is not requirement: - # by_name[name] = requirement = requirement_ - # if requirement.target is not None: - # needs_target = True - # elif needs_target: - # requirement.target = requirement.name - # - return by_name.values() - # return RequiresType(by_name.values()) + # sort out target: + needs_target = False + for name, parameter in by_name.items(): + if parameter.target is not None: + needs_target = True + elif needs_target: + parameter.target = name + + return Requirements(by_name.values()) # def extract_returns(obj: Callable, explicit: Returns = None): From f95bda5dc083e8a246e2d5fe44eb6e8b502a43c7 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 28 Sep 2020 07:33:50 +0100 Subject: [PATCH 122/159] requirements extraction using parameter objects --- mush/declarations.py | 14 +- mush/extraction.py | 175 ++++++----------- mush/requirements.py | 22 ++- mush/resources.py | 2 +- ...ments_extraction.py => test_extraction.py} | 182 ++++++++---------- 5 files changed, 161 insertions(+), 234 deletions(-) rename mush/tests/{test_requirements_extraction.py => test_extraction.py} (60%) diff --git a/mush/declarations.py b/mush/declarations.py index a6c5945..7453ed8 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -2,7 +2,7 @@ from itertools import chain from typing import _type_check, Any -from .markers import set_mush +from .markers import set_mush, missing from .requirements import Requirement, Value, name_or_repr from .typing import RequirementType, ReturnType @@ -24,7 +24,7 @@ def valid_decoration_types(*objs): class Parameter: - def __init__(self, requirement: Requirement, target:str, default: Any): + def __init__(self, requirement: Requirement, target: str = None, default: Any = missing): self.requirement = requirement self.target = target self.default = default @@ -61,14 +61,10 @@ def requires(*args: RequirementType, **kw: RequirementType): kw.items(), ): if isinstance(possible, Requirement): - possible = possible.make_from(possible, target=target) - requirement = possible + parameter = Parameter(possible, target, default=possible.default) else: - requirement = Value(possible) - requirement.type = None if isinstance(possible, str) else possible - requirement.name = target - requirement.target = target - requires_.append(requirement) + parameter = Parameter(Value(possible), target) + requires_.append(parameter) return requires_ diff --git a/mush/extraction.py b/mush/extraction.py index d46d2b4..4656601 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -5,125 +5,77 @@ partial ) from inspect import signature -from typing import Callable, Iterable +from typing import Callable, get_type_hints from .declarations import ( - requires, Parameter, Requirements, Return, - returns, result_type, + Parameter, Requirements, Return, requires_nothing ) -from .requirements import Value, Requirement from .markers import missing, get_mush +from .requirements import Value, Requirement -#: For these types, prefer the name instead of the type. -# SIMPLE_TYPES = (str, int, dict, list) -# -# -# def _apply_requires(by_name, by_index, requires_): -# -# for i, r in enumerate(requires_): -# -# if r.target is None: -# try: -# name = by_index[i] -# except IndexError: -# # case where something takes *args -# by_name[i] = r.make_from(r) -# continue -# else: -# name = r.target -# -# existing = by_name[name] -# by_name[name] = r.make_from( -# r, -# name=existing.name, -# key=existing.key if r.key is None else r.key, -# type=existing.type if r.type is None else r.type, -# default=existing.default if r.default is missing else r.default, -# ops=existing.ops if not r.ops else r.ops, -# target=existing.target if r.target is None else r.target, -# ) + +def _apply_requires(by_name, by_index, requires_): + + for i, p in enumerate(requires_): + + if p.target is None: + try: + name = by_index[i] + except IndexError: + # case where something takes *args + by_name[i] = p + continue + else: + name = p.target + + by_name[name] = p def extract_requires(obj: Callable) -> Requirements: - # from annotations by_name = {} + + # from annotations + try: + annotations = get_type_hints(obj) + except TypeError: + annotations = {} + for name, p in signature(obj).parameters.items(): if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): continue - # # https://bugs.python.org/issue39753: - # if isinstance(obj, partial) and p.name in obj.keywords: - # continue - # - name = p.name - - if p.annotation is not p.empty: - type_ = p.annotation - else: - type_ = None + # https://bugs.python.org/issue39753: + if isinstance(obj, partial) and p.name in obj.keywords: + continue + type_ = annotations.get(name) default = missing if p.default is p.empty else p.default - requirement = Value(type_, p.name, default) - # - # requirement = None - # if isinstance(default, Requirement): - # requirement = default - # default = missing - # elif isinstance(p.annotation, Requirement): - # requirement = p.annotation - # - # if requirement is None: - # requirement = Requirement(key) - # if isinstance(p.annotation, str): - # key = p.annotation - # elif type_ is None or issubclass(type_, SIMPLE_TYPES): - # key = name - # else: - # key = type_ - # else: - # requirement = requirement.make_from(requirement) - # type_ = type_ if requirement.type is None else requirement.type - # if requirement.key is not None: - # key = requirement.key - # elif type_ is None or issubclass(type_, SIMPLE_TYPES): - # key = name - # else: - # key = type_ - # default = requirement.default if requirement.default is not missing else default - # - # requirement.key = key - # requirement.name = name - # requirement.type = type_ - # requirement.default = default + if isinstance(default, Requirement): + requirement = default + default = requirement.default + elif isinstance(p.annotation, Requirement): + requirement = p.annotation + default = default if requirement.default is missing else requirement.default + else: + requirement = Value(type_, p.name, default) by_name[name] = Parameter( requirement, target=p.name if p.kind is p.KEYWORD_ONLY else None, - default=requirement.default + default=default ) - # - # by_index = list(by_name) - # - # # from declarations - # mush_requires = get_mush(obj, 'requires', None) - # if mush_requires is not None: - # _apply_requires(by_name, by_index, mush_requires) - # - # # explicit - # if explicit is not None: - # if isinstance(explicit, RequiresType): - # requires_ = explicit - # else: - # if not isinstance(explicit, (list, tuple)): - # explicit = (explicit,) - # requires_ = requires(*explicit) - # _apply_requires(by_name, by_index, requires_) - # - # if not by_name: - # return requires_nothing + by_index = list(by_name) + + # from declarations + mush_requires = get_mush(obj, 'requires', None) + if mush_requires is not None: + _apply_requires(by_name, by_index, mush_requires) + + if not by_name: + return requires_nothing # sort out target: needs_target = False @@ -136,7 +88,8 @@ def extract_requires(obj: Callable) -> Requirements: return Requirements(by_name.values()) -# def extract_returns(obj: Callable, explicit: Returns = None): +def extract_returns(obj: Callable, explicit: Return = None): + return None # if explicit is None: # returns_ = get_mush(obj, 'returns', None) # if returns_ is None: @@ -153,17 +106,17 @@ def extract_requires(obj: Callable) -> Requirements: # returns_ = returns(returns_) # # return returns_ or result_type -# -# -# WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) -# -# -# def update_wrapper(wrapper, -# wrapped, -# assigned=WRAPPER_ASSIGNMENTS, -# updated=WRAPPER_UPDATES): -# """ -# An extended version of :func:`functools.update_wrapper` that -# also preserves Mush's annotations. -# """ -# return functools_update_wrapper(wrapper, wrapped, assigned, updated) + + +WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) + + +def update_wrapper(wrapper, + wrapped, + assigned=WRAPPER_ASSIGNMENTS, + updated=WRAPPER_UPDATES): + """ + An extended version of :func:`functools.update_wrapper` that + also preserves Mush's annotations. + """ + return functools_update_wrapper(wrapper, wrapped, assigned, updated) diff --git a/mush/requirements.py b/mush/requirements.py index 6682a6a..f9ebd46 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -1,10 +1,8 @@ -from typing import Any, List, TYPE_CHECKING, Hashable, Sequence +from typing import Any, List, Hashable, Sequence, Optional, Union -from .markers import missing, nonblocking +from .markers import missing from .resources import ResourceKey - -if TYPE_CHECKING: - from .context import Context +from .typing import Identifier def name_or_repr(obj): @@ -50,7 +48,6 @@ def __init__(self, default: Any, *keys: ResourceKey): self.keys: Sequence[ResourceKey] = keys self.default = default self.ops: List['Op'] = [] - self.target: Optional[str] = None def _keys_repr(self): return ', '.join(repr(key) for key in self.keys) @@ -93,7 +90,18 @@ class Value(Requirement): ever use this. """ - def __init__(self, type_: type = None, identifier: Hashable = None, default: Any = missing): + def __init__(self, + type_or_identifier: Union[type, Identifier] = None, + identifier: Identifier = None, + default: Any = missing): + if identifier is None: + if isinstance(type_or_identifier, type): + type_ = type_or_identifier + else: + identifier = type_or_identifier + type_ = None + else: + type_ = type_or_identifier super().__init__( default, ResourceKey(type_, identifier), diff --git a/mush/resources.py b/mush/resources.py index d762380..13d7c43 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -6,7 +6,7 @@ class ResourceKey(tuple): - def __new__(cls, type_: Type, identifier: Identifier): + def __new__(cls, type_: Type = None, identifier: Identifier = None): return tuple.__new__(cls, (type_, identifier)) @property diff --git a/mush/tests/test_requirements_extraction.py b/mush/tests/test_extraction.py similarity index 60% rename from mush/tests/test_requirements_extraction.py rename to mush/tests/test_extraction.py index d85e31c..08be598 100644 --- a/mush/tests/test_requirements_extraction.py +++ b/mush/tests/test_extraction.py @@ -1,28 +1,29 @@ -import pytest; pytestmark = pytest.mark.skip("WIP") from functools import partial -from typing import Tuple +from typing import Tuple, get_type_hints from unittest import TestCase import pytest from testfixtures import compare, ShouldRaise -from mush import Value +from mush import Value, missing from mush.declarations import ( requires, returns, returns_mapping, returns_sequence, returns_result_type, requires_nothing, - result_type, Requirements + result_type, Requirements, Parameter ) -from mush.extraction import extract_requires#, extract_returns, update_wrapper +from mush.extraction import extract_requires, extract_returns, update_wrapper from mush.requirements import Requirement, ItemOp from .helpers import PY_36, Type1, Type2, Type3, Type4 +from ..resources import ResourceKey def check_extract(obj, expected_rq, expected_rt): - rq = extract_requires(obj, None) + rq = extract_requires(obj) rt = extract_returns(obj, None) compare(rq, expected=expected_rq, strict=True) - compare(rt, expected=expected_rt, strict=True) + assert rt is None + # compare(rt, expected=expected_rt, strict=True) class TestRequirementsExtraction(object): @@ -31,8 +32,8 @@ def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, expected_rq=Requirements(( - Value.make(key='a', name='a'), - Value.make(key='b', default=None, name='b'), + Parameter(Value(identifier='a')), + Parameter(Value(identifier='b', default=None), default=None), )), expected_rt=result_type) @@ -41,8 +42,8 @@ class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, expected_rq=Requirements(( - Value.make(key='a', name='a'), - Value.make(key='b', name='b', default=None), + Parameter(Value(identifier='a')), + Parameter(Value(identifier='b', default=None), default=None), )), expected_rt=result_type) @@ -52,8 +53,8 @@ def foo(x, y, z, a=None): pass check_extract( p, expected_rq=Requirements(( - Value.make(key='z', name='z', target='z'), - Value.make(key='a', name='a', target='a', default=None), + Parameter(Value(identifier='z'), target='z'), + Parameter(Value(identifier='a', default=None), target='a', default=None), )), expected_rt=result_type ) @@ -64,7 +65,7 @@ def foo(a=None): pass check_extract( p, expected_rq=Requirements(( - Value.make(key='a', name='a', default=None), + Parameter(Value(identifier='a', default=None), default=None), )), expected_rt=result_type ) @@ -113,8 +114,8 @@ def foo(b, a=None): pass check_extract( p, expected_rq=Requirements(( - Value.make(key='b', name='b'), - Value.make(key='a', name='a', default=None), + Parameter(Value(identifier='b')), + Parameter(Value(identifier='a', default=None), default=None), )), expected_rt=result_type ) @@ -126,7 +127,7 @@ def foo(b, a): pass p, # since b is already bound: expected_rq=Requirements(( - Value.make(key='a', name='a'), + Parameter(Value(identifier='a')), )), expected_rt=result_type ) @@ -137,36 +138,47 @@ def foo(b, a): pass check_extract( p, expected_rq=Requirements(( - Value.make(key='b', name='b'), + Parameter(Value(identifier='b')), )), expected_rt=result_type ) +# https://bugs.python.org/issue41872 +def foo(a: 'Foo') -> 'Bar': pass +class Foo: pass +class Bar: pass + + class TestExtractDeclarationsFromTypeAnnotations(object): def test_extract_from_annotations(self): - def foo(a: 'foo', b, c: 'bar' = 1, d=2) -> 'bar': pass + def foo(a: Type1, b, c: Type2 = 1, d=2) -> Type3: pass check_extract(foo, expected_rq=Requirements(( - Value.make(key='foo', name='a'), - Value.make(key='b', name='b'), - Value.make(key='bar', name='c', default=1), - Value.make(key='d', name='d', default=2) + Parameter(Value(Type1, identifier='a')), + Parameter(Value(identifier='b')), + Parameter(Value(Type2, identifier='c', default=1), default=1), + Parameter(Value(identifier='d', default=2), default=2), )), expected_rt=returns('bar')) + def test_forward_type_references(self): + check_extract(foo, + expected_rq=Requirements((Parameter(Value(Foo, identifier='a')),)), + expected_rt=returns(Bar)) + def test_requires_only(self): - def foo(a: 'foo'): pass + def foo(a: Type1): pass check_extract(foo, - expected_rq=Requirements((Value.make(key='foo', name='a'),)), + expected_rq=Requirements((Parameter(Value(Type1, identifier='a')),)), expected_rt=result_type) def test_returns_only(self): - def foo() -> 'bar': pass + def foo() -> Type1: pass check_extract(foo, expected_rq=requires_nothing, - expected_rt=returns('bar')) + expected_rt=returns(Type1)) def test_extract_from_decorated_class(self): @@ -180,20 +192,26 @@ def my_dec(func): return update_wrapper(Wrapper(func), func) @my_dec - def foo(a: 'foo' = None) -> 'bar': + @requires(a=Value('foo')) + @returns('bar') + def foo(a=None): return 'answer' compare(foo(), expected='the answer') check_extract(foo, - expected_rq=Requirements((Value.make(key='foo', name='a', default=None),)), + expected_rq=Requirements(( + Parameter(Value(identifier='foo'), target='a'), + )), expected_rt=returns('bar')) def test_decorator_trumps_annotations(self): @requires('foo') @returns('bar') - def foo(a: 'x') -> 'y': pass + def foo(a: Type1) -> Type2: pass check_extract(foo, - expected_rq=Requirements((Value.make(key='foo', name='a'),)), + expected_rq=Requirements(( + Parameter(Value(identifier='foo')),) + ), expected_rt=returns('bar')) def test_returns_mapping(self): @@ -214,7 +232,7 @@ def test_how_instance_in_annotations(self): def foo(a: Value('config')['db_url']): pass check_extract(foo, expected_rq=Requirements(( - Value.make(key='config', name='a', ops=[ItemOp('db_url')]), + Parameter(Value(identifier='config')['db_url']), )), expected_rt=result_type) @@ -222,10 +240,10 @@ def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, expected_rq=Requirements(( - Value.make(key='a', name='a'), - Value.make(key='b', name='b', default=1), - Value.make(key='c', name='c', target='c'), - Value.make(key='d', name='d', target='d', default=None) + Parameter(Value(identifier='a')), + Parameter(Value(identifier='b', default=1), default=1), + Parameter(Value(identifier='c'), target='c'), + Parameter(Value(identifier='d', default=None), target='d', default=None) )), expected_rt=result_type) @@ -233,88 +251,54 @@ def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=Requirements((Value.make(key=T, name='a', type=T),)), + expected_rq=Requirements((Parameter(Value(T, identifier='a')),)), expected_rt=result_type) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=Requirements((Value.make(key='a', name='a', type=type_),)), + expected_rq=Requirements((Parameter(Value(type_, identifier='a')),)), expected_rt=result_type) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=Requirements((Value.make(key='b', name='a', type=str),)), + expected_rq=Requirements((Parameter(Value(identifier='b')),)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, expected_rq=Requirements(( - Value.make(key='b', name='a', type=str, default=1), + Parameter(Value(identifier='b', default=1), default=1), )), expected_rt=result_type) def test_value_annotation_plus_default(self): - def foo(a: Value('b', type_=str) = 1): pass + def foo(a: Value(str, identifier='b') = 1): pass check_extract(foo, expected_rq=Requirements(( - Value.make(key='b', name='a', type=str, default=1), + Parameter(Value(str, identifier='b'), default=1), )), expected_rt=result_type) - def test_value_annotation_just_type_in_value_key_plus_default(self): - def foo(a: Value(str) = 1): pass + def test_requirement_default_preferred_to_annotation_default(self): + def foo(a: Value(str, identifier='b', default=2) = 1): pass check_extract(foo, expected_rq=Requirements(( - Value.make(key=str, name='a', type=str, default=1), + Parameter(Value(str, identifier='b', default=2), default=2), )), expected_rt=result_type) - def test_value_annotation_just_type_plus_default(self): - def foo(a: Value(type_=str) = 1): pass + def test_value_annotation_just_type_in_value_key_plus_default(self): + def foo(a: Value(str) = 1): pass check_extract(foo, expected_rq=Requirements(( - Value.make(key='a', name='a', type=str, default=1), + Parameter(Value(str), default=1), )), expected_rt=result_type) - def test_value_unspecified_with_type(self): - class T1: pass - def foo(a: T1 = Value()): pass - check_extract(foo, - expected_rq=Requirements((Value.make(key=T1, name='a', type=T1),)), - expected_rt=result_type) - - def test_value_unspecified_with_simple_type(self): - def foo(a: str = Value()): pass - check_extract(foo, - expected_rq=Requirements((Value.make(key='a', name='a', type=str),)), - expected_rt=result_type) - - def test_value_unspecified(self): - def foo(a=Value()): pass - check_extract(foo, - expected_rq=Requirements((Value.make(key='a', name='a'),)), - expected_rt=result_type) - - def test_requirement_modifier(self): - def foo(x: str = None): pass - - class FromRequest(Requirement): pass - - def modifier(requirement): - if type(requirement) is Requirement: - requirement = FromRequest.make_from(requirement) - return requirement - - rq = extract_requires(foo, modifier=modifier) - compare(rq, strict=True, expected=Requirements(( - FromRequest(key='x', name='x', type_=str, default=None), - ))) - class TestDeclarationsFromMultipleSources: @@ -329,39 +313,25 @@ def foo(a: r1, b, c=r3): check_extract(foo, expected_rq=Requirements(( - Value.make(key='a', name='a'), - Value.make(key='b', name='b', target='b'), - Value.make(key='c', name='c', target='c'), + Parameter(Requirement(default='a'), default='a'), + Parameter(Requirement(default='b'), default='b', target='b'), + Parameter(Requirement(default='c'), default='c', target='c'), )), expected_rt=result_type) def test_declaration_priorities(self): - r1 = Requirement('a') - r2 = Requirement('b') - r3 = Requirement('c') + r1 = Requirement(missing, ResourceKey(identifier='x')) + r2 = Requirement(missing, ResourceKey(identifier='y')) + r3 = Requirement(missing, ResourceKey(identifier='z')) @requires(a=r1) - def foo(a: r2 = r3, b: str = r2, c = r3): + def foo(a: r2 = r3, b: str = r2, c=r3): pass check_extract(foo, expected_rq=Requirements(( - Value.make(key='a', name='a', target='a'), - Value.make(key='b', name='b', target='b', type=str), - Value.make(key='c', name='c', target='c'), + Parameter(r1, target='a'), + Parameter(r2, target='b'), + Parameter(r3, target='c'), )), expected_rt=result_type) - - def test_explicit_requirement_type_trumps_default_requirement_type(self): - - class FromRequest(Requirement): pass - - @requires(a=Requirement('a')) - def foo(a): - pass - - compare(actual=extract_requires(foo, requires(a=FromRequest('b'))), - strict=True, - expected=Requirements(( - FromRequest.make(key='b', name='a', target='a'), - ))) From dad492246e314f029069b0ec2b08c5933b2b35d4 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 28 Sep 2020 07:46:36 +0100 Subject: [PATCH 123/159] re-instate ops --- mush/context.py | 14 ++-- mush/requirements.py | 36 ++++----- mush/tests/test_context.py | 119 +++++++++++++---------------- mush/tests/test_requirements.py | 130 +++++++------------------------- 4 files changed, 107 insertions(+), 192 deletions(-) diff --git a/mush/context.py b/mush/context.py index 243b1c7..a810c00 100644 --- a/mush/context.py +++ b/mush/context.py @@ -112,7 +112,7 @@ def _find_resource(self, key): return resource exact = False - def _resolve(self, obj, specials = None): + def _resolve(self, obj, specials=None): if specials is None: specials: Dict[type, Any] = {Context: self} @@ -148,12 +148,12 @@ def _resolve(self, obj, specials = None): if o is missing: o = parameter.default - # if o is not requirement.default: - # for op in requirement.ops: - # o = op(o) - # if o is missing: - # o = requirement.default - # break + if o is not requirement.default: + for op in requirement.ops: + o = op(o) + if o is missing: + o = requirement.default + break if o is missing: raise ResourceError(f'{requirement!r} could not be satisfied') diff --git a/mush/requirements.py b/mush/requirements.py index f9ebd46..cae1541 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -56,24 +56,24 @@ def __repr__(self): default = '' if self.default is missing else f', default={self.default!r}' ops = ''.join(repr(o) for o in self.ops) return f"{type(self).__name__}({self._keys_repr()}{default}){ops}" - # - # def attr(self, name): - # """ - # If you need to get an attribute called either ``attr`` or ``item`` - # then you will need to call this method instead of using the - # generating behaviour. - # """ - # self.ops.append(AttrOp(name)) - # return self - # - # def __getattr__(self, name): - # if name.startswith('__'): - # raise AttributeError(name) - # return self.attr(name) - # - # def __getitem__(self, name): - # self.ops.append(ItemOp(name)) - # return self + + def attr(self, name): + """ + If you need to get an attribute called either ``attr`` or ``item`` + then you will need to call this method instead of using the + generating behaviour. + """ + self.ops.append(AttrOp(name)) + return self + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + return self.attr(name) + + def __getitem__(self, name): + self.ops.append(ItemOp(name)) + return self class Value(Requirement): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index a061ba3..06ba7c6 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -292,71 +292,60 @@ def foo(*, x: int): # -# XXX - these are for ops -# def test_call_requires_item_missing(self): -# def foo(obj): return obj -# context = Context() -# context.add({}, TheType) -# with ShouldRaise(ResourceError( -# "No Value(TheType)['foo'] in context", -# key=TheType, -# requirement=Value(TheType)['foo'], -# )): -# context.call(foo, requires(Value(TheType)['foo'])) -# -# def test_call_requires_named_parameter(self): -# def foo(x, y): -# return x, y -# context = Context() -# context.add('foo', TheType) -# context.add('bar', 'baz') -# result = context.call(foo, requires(y='baz', x=TheType)) -# compare(result, ('foo', 'bar')) -# compare({TheType: 'foo', -# 'baz': 'bar'}, -# actual=context._store) -# -# def test_call_requires_item(self): -# def foo(x): -# return x -# context = Context() -# context.add(dict(bar='baz'), 'foo') -# result = context.call(foo, requires(Value('foo')['bar'])) -# compare(result, 'baz') -# -# def test_call_requires_attr(self): -# def foo(x): -# return x -# m = Mock() -# context = Context() -# context.add(m, 'foo') -# result = context.call(foo, requires(Value('foo').bar)) -# compare(result, m.bar) -# -# def test_call_requires_item_attr(self): -# def foo(x): -# return x -# m = Mock() -# m.bar= dict(baz='bob') -# context = Context() -# context.add(m, provides='foo') -# result = context.call(foo, requires(Value('foo').bar['baz'])) -# compare(result, 'bob') -# -# def test_call_requires_optional_item_missing(self): -# def foo(x: str = Value('foo', default=1)['bar']): -# return x -# context = Context() -# result = context.call(foo) -# compare(result, 1) -# -# def test_call_requires_optional_item_present(self): -# def foo(x: str = Value('foo', default=1)['bar']): -# return x -# context = Context() -# context.add(dict(bar='baz'), provides='foo') -# result = context.call(foo) -# compare(result, 'baz') +class TestOps: + + def test_call_requires_item(self): + def foo(x: str = Value(identifier='foo')['bar']): + return x + context = Context() + context.add(dict(bar='baz'), identifier='foo') + result = context.call(foo) + compare(result, 'baz') + + def test_call_requires_item_missing(self): + def foo(obj: str = Value(dict)['foo']): pass + context = Context() + context.add({}) + with ShouldRaise(ResourceError( + "Value()['foo'] could not be satisfied", + )): + context.call(foo) + + def test_call_requires_attr(self): + @requires(Value('foo').bar) + def foo(x): + return x + m = Mock() + context = Context() + context.add(m, identifier='foo') + result = context.call(foo) + compare(result, m.bar) + + def test_call_requires_item_attr(self): + @requires(Value('foo').bar['baz']) + def foo(x): + return x + m = Mock() + m.bar= dict(baz='bob') + context = Context() + context.add(m, identifier='foo') + result = context.call(foo) + compare(result, 'bob') + + def test_call_requires_optional_item_missing(self): + def foo(x: str = Value('foo', default=1)['bar']): + return x + context = Context() + result = context.call(foo) + compare(result, 1) + + def test_call_requires_optional_item_present(self): + def foo(x: str = Value('foo', default=1)['bar']): + return x + context = Context() + context.add(dict(bar='baz'), identifier='foo') + result = context.call(foo) + compare(result, 'baz') # XXX requirements caching: diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index c34f0ed..ac2ee40 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -1,14 +1,8 @@ -import pytest; pytestmark = pytest.mark.skip("WIP") -from typing import Tuple -from unittest.case import TestCase - import pytest from testfixtures import compare, ShouldRaise -from testfixtures.mock import Mock -from mush import Context, Value, missing, requires, ResourceError -from mush.requirements import Requirement, AttrOp, ItemOp#, AnyOf, Like -from .helpers import Type1 +from mush import Value +from mush.requirements import Requirement, AttrOp, ItemOp # , AnyOf, Like def check_ops(value, data, *, expected): @@ -19,84 +13,17 @@ def check_ops(value, data, *, expected): class TestRequirement: - def test_repr_minimal(self): - compare(repr(Requirement('foo')), - expected="Requirement('foo')") - - def test_repr_maximal(self): - r = Requirement('foo', name='n', type_='ty', default=None, target='ta') - r.ops.append(AttrOp('bar')) - compare(repr(r), - expected="Requirement('foo', default=None).bar") - - def test_make_allows_params_not_passed_to_constructor(self): - r = Value.make(key='x', target='a') - assert type(r) is Value - compare(r.key, expected='x') - compare(r.target, expected='a') - - def test_make_can_create_invalid_objects(self): - # So be careful! - - class SampleRequirement(Requirement): - def __init__(self, foo): - super().__init__(key='y') - self.foo = foo - - r = SampleRequirement('it') - compare(r.foo, expected='it') - - r = SampleRequirement.make(key='x') - assert 'foo' not in r.__dict__ - # ...when it really should be! - - def test_clone_using_make_from(self): - r = Value('foo').bar.requirement - r_ = r.make_from(r) - assert r_ is not r - assert r_.ops is not r.ops - compare(r_, expected=r) - - def test_make_from_with_mutable_default(self): - r = Requirement('foo', default=[]) - r_ = r.make_from(r) - assert r_ is not r - assert r_.default is not r.default - compare(r_, expected=r) - - def test_make_from_into_new_type(self): - r = Requirement('foo').bar.requirement - r_ = Value.make_from(r) - compare(r_, expected=Value('foo').bar.requirement) - - def test_make_from_with_required_constructor_parameter(self): - - class SampleRequirement(Requirement): - def __init__(self, foo): - super().__init__('foo') - self.foo = foo - - r = Requirement('foo') - r_ = SampleRequirement.make_from(r, foo='it') - assert r_ is not r - compare(r_, expected=SampleRequirement(foo='it')) - - def test_make_from_source_has_more_attributes(self): - - class SampleRequirement(Requirement): - def __init__(self, foo): - super().__init__('bar') - self.foo = foo - - r = SampleRequirement('it') - r_ = Requirement.make_from(r) - assert r_ is not r - - assert r_.key == 'bar' - # while this is a bit ugly, it will hopefully do no harm: - assert r_.foo == 'it' - - special_names = ['attr', 'ops', 'target'] + # def test_repr_minimal(self): + # compare(repr(Requirement('foo')), + # expected="Requirement('foo')") + # + # def test_repr_maximal(self): + # r = Requirement('foo', name='n', type_='ty', default=None, target='ta') + # r.ops.append(AttrOp('bar')) + # compare(repr(r), + # expected="Requirement('foo', default=None).bar") + # + special_names = ['attr', 'ops'] @pytest.mark.parametrize("name", special_names) def test_attr_special_name(self, name): @@ -116,22 +43,21 @@ def test_no_special_name_via_getattr(self): with ShouldRaise(AttributeError): assert v.__len__ compare(v.ops, []) - - def test_resolve(self): - r = Requirement('foo') - with ShouldRaise(NotImplementedError): - r.resolve(None) - - -class TestValue: - - def test_type_from_key(self): - v = Value(str) - compare(v.requirement.type, expected=str) - - def test_key_and_type_cannot_disagree(self): - with ShouldRaise(TypeError('type_ cannot be specified if key is a type')): - Value(key=str, type_=int) +# +# +# class TestValue: +# +# def test_type_from_key(self): +# v = Value(str) +# compare(v.requirement.type, expected=str) +# +# def test_key_and_type_cannot_disagree(self): +# with ShouldRaise(TypeError('type_ cannot be specified if key is a type')): +# Value(key=str, type_=int) +# +# def test_at_least_one_param_must_be_specified(self): +# with ShouldRaise(TypeError('xx')): +# Value() # # # class TestItem: From 26ecf23f8398684fb901b7890a16785bf8243dec Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 29 Sep 2020 08:56:05 +0100 Subject: [PATCH 124/159] more succinct repr of ResourceKeys --- mush/resources.py | 15 +++++++++++---- mush/tests/test_context.py | 30 ++++++++++++++---------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/mush/resources.py b/mush/resources.py index 13d7c43..cbf972f 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -1,3 +1,4 @@ +from types import FunctionType from typing import Callable, Optional, Type from .markers import missing @@ -18,16 +19,22 @@ def identifier(self) -> Identifier: return self[1] def __str__(self): - if self.type is None: + type_ = self.type + if type_ is None: return repr(self.identifier) - if hasattr(self.type, '__supertype__'): - type_repr = f'NewType({self.type.__name__}, {self.type.__supertype__})' + if isinstance(type_, type): + type_repr = type_.__qualname__ + elif isinstance(type_, FunctionType): + type_repr = type_.__name__ else: - type_repr = repr(self.type) + type_repr = repr(type_) if self.identifier is None: return type_repr return f'{type_repr}, {self.identifier!r}' + def __repr__(self): + return f'ResourceKey({self})' + class ResourceValue: diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 06ba7c6..b5535a3 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -26,7 +26,7 @@ def test_by_inferred_type(self): compare(context._store, expected={(TheType, None): ResourceValue(obj)}) expected = ( ": \n" + " TheType: \n" "}>" ) compare(expected, actual=repr(context)) @@ -42,8 +42,8 @@ def test_by_identifier(self): (None, 'my label'): ResourceValue(obj), }) expected = (", 'my label': \n" " 'my label': \n" + " TheType, 'my label': \n" "}>") compare(expected, actual=repr(context)) compare(expected, actual=str(context)) @@ -61,13 +61,12 @@ def test_by_identifier_only(self): compare(expected, actual=str(context)) def test_explicit_type(self): - class T2(object): pass obj = TheType() context = Context() - context.add(obj, provides=T2) - compare(context._store, expected={(T2, None): ResourceValue(obj)}) + context.add(obj, provides=Type2) + compare(context._store, expected={(Type2, None): ResourceValue(obj)}) expected = ("\n" + " Type2: \n" "}>") compare(expected, actual=repr(context)) compare(expected, actual=str(context)) @@ -77,7 +76,7 @@ def test_clash_just_type(self): obj2 = TheType() context = Context() context.add(obj1, TheType) - with ShouldRaise(ResourceError(f'Context already contains {TheType!r}')): + with ShouldRaise(ResourceError(f'Context already contains TheType')): context.add(obj2, TheType) def test_clash_just_identifier(self): @@ -144,7 +143,7 @@ def test_requires_missing(self): def foo(obj: TheType): return obj context = Context() with ShouldRaise(ResourceError( - "Value(, 'obj') could not be satisfied" + "Value(TheType, 'obj') could not be satisfied" )): context.call(foo) @@ -194,7 +193,7 @@ def foo(obj: TheType): return obj context = Context() context.add(object()) with ShouldRaise(ResourceError( - "Value(, 'obj') could not be satisfied" + "Value(TheType, 'obj') could not be satisfied" )): context.call(foo) @@ -228,7 +227,7 @@ def returner(request_: Request): return request_ with ShouldRaise(ResourceError( - "Value(NewType(Request, ), 'request_') could not be satisfied" + "Value(Request, 'request_') could not be satisfied" )): context.call(returner) @@ -238,7 +237,7 @@ def test_requires_requirement(self): def foo(requirement: Requirement): pass with ShouldRaise(ResourceError( - "Value(, 'requirement') " + "Value(Requirement, 'requirement') " "could not be satisfied" )): context.call(foo) @@ -573,7 +572,7 @@ def foo(obj: TheType): pass context.add(Provider(lambda: None), provides=object) with ShouldRaise(ResourceError( - "Value(, 'obj') could not be satisfied" + "Value(TheType, 'obj') could not be satisfied" )): context.call(foo) @@ -668,8 +667,7 @@ def provider(): pass context = Context() context.add(Provider(provider), provides=str) expected = (": Provider({provider}, " - f"cache=True, provides_subclasses=False)\n" + f" str: Provider({provider}, cache=True, provides_subclasses=False)\n" "}>") compare(expected, actual=repr(context)) compare(expected, actual=str(context)) @@ -681,10 +679,10 @@ def provider() -> str: pass context = Context() context.add(p, provides=str, identifier='the id') expected = (", 'the id': Provider({provider}, " - f"cached='it', cache=False, provides_subclasses=True)\n" f" 'the id': Provider({provider}, " f"cached='it', cache=False, provides_subclasses=True)\n" + f" str, 'the id': Provider({provider}, " + f"cached='it', cache=False, provides_subclasses=True)\n" "}>") compare(expected, actual=repr(context)) compare(expected, actual=str(context)) From 3d578d0e181dae2d4ea1102bfac0e685f881eaf3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 30 Sep 2020 07:09:24 +0100 Subject: [PATCH 125/159] __copy__ no longer needed on Markers --- mush/markers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mush/markers.py b/mush/markers.py index c03a04d..7738d48 100644 --- a/mush/markers.py +++ b/mush/markers.py @@ -10,9 +10,6 @@ def __init__(self, name): def __repr__(self): return '' % self.name - def __copy__(self): - return self - not_specified = Marker('not_specified') From 05af1d4b0a60e59a37bfc36288e6f2b14a6f8986 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 30 Sep 2020 07:10:56 +0100 Subject: [PATCH 126/159] Have a type_repr instead of name_or_repr --- mush/callpoints.py | 1 - mush/declarations.py | 2 +- mush/requirements.py | 4 ---- mush/resources.py | 20 ++++++++++++-------- mush/runner.py | 1 - mush/tests/test_context.py | 2 +- mush/tests/test_requirements.py | 28 ++++++++++++++++++---------- 7 files changed, 32 insertions(+), 26 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index 588b201..47d9c95 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -5,7 +5,6 @@ requires_nothing, returns as returns_declaration, returns_nothing ) from .extraction import extract_requires#, extract_returns -from .requirements import name_or_repr from .typing import Requires, Returns if TYPE_CHECKING: diff --git a/mush/declarations.py b/mush/declarations.py index 7453ed8..bb3c2b4 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -3,7 +3,7 @@ from typing import _type_check, Any from .markers import set_mush, missing -from .requirements import Requirement, Value, name_or_repr +from .requirements import Requirement, Value from .typing import RequirementType, ReturnType VALID_DECORATION_TYPES = (type, str, Requirement) diff --git a/mush/requirements.py b/mush/requirements.py index cae1541..108eeda 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -5,10 +5,6 @@ from .typing import Identifier -def name_or_repr(obj): - return getattr(obj, '__name__', None) or repr(obj) - - class Op: def __init__(self, name): diff --git a/mush/resources.py b/mush/resources.py index cbf972f..1b43bd7 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -5,6 +5,15 @@ from .typing import Resource, Identifier +def type_repr(type_): + if isinstance(type_, type): + return type_.__qualname__ + elif isinstance(type_, FunctionType): + return type_.__name__ + else: + return repr(type_) + + class ResourceKey(tuple): def __new__(cls, type_: Type = None, identifier: Identifier = None): @@ -22,15 +31,10 @@ def __str__(self): type_ = self.type if type_ is None: return repr(self.identifier) - if isinstance(type_, type): - type_repr = type_.__qualname__ - elif isinstance(type_, FunctionType): - type_repr = type_.__name__ - else: - type_repr = repr(type_) + type_repr_ = type_repr(type_) if self.identifier is None: - return type_repr - return f'{type_repr}, {self.identifier!r}' + return type_repr_ + return f'{type_repr_}, {self.identifier!r}' def __repr__(self): return f'ResourceKey({self})' diff --git a/mush/runner.py b/mush/runner.py index 79dffc0..97293f4 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -7,7 +7,6 @@ from .markers import not_specified from .modifier import Modifier from .plug import Plug -from .requirements import name_or_repr#, Lazy from .typing import Requires, Returns diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index b5535a3..752a091 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -306,7 +306,7 @@ def foo(obj: str = Value(dict)['foo']): pass context = Context() context.add({}) with ShouldRaise(ResourceError( - "Value()['foo'] could not be satisfied", + "Value(dict)['foo'] could not be satisfied", )): context.call(foo) diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index ac2ee40..f308d69 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -13,16 +13,24 @@ def check_ops(value, data, *, expected): class TestRequirement: - # def test_repr_minimal(self): - # compare(repr(Requirement('foo')), - # expected="Requirement('foo')") - # - # def test_repr_maximal(self): - # r = Requirement('foo', name='n', type_='ty', default=None, target='ta') - # r.ops.append(AttrOp('bar')) - # compare(repr(r), - # expected="Requirement('foo', default=None).bar") - # + def test_repr_minimal(self): + compare(repr(Requirement((), default=missing)), + expected="Requirement()") + + def test_repr_maximal(self): + r = Requirement( + keys=( + ResourceKey(type_=str), + ResourceKey(identifier='foo'), + ResourceKey(type_=int, identifier='bar') + ), + default=None + ) + r.ops.append(AttrOp('bar')) + compare(repr(r), + expected="Requirement(ResourceKey(str), ResourceKey('foo'), " + "ResourceKey(int, 'bar'), default=None).bar") + special_names = ['attr', 'ops'] @pytest.mark.parametrize("name", special_names) From 5a3d4009841cd2c9baee9fcf745424a816b75ef3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 30 Sep 2020 07:15:17 +0100 Subject: [PATCH 127/159] Introduce the Annotation requirement type. Used when there's no explicit Requirement provided for a key. --- mush/extraction.py | 9 +++-- mush/requirements.py | 36 ++++++++++++----- mush/tests/test_context.py | 13 +++--- mush/tests/test_extraction.py | 72 ++++++++++++++++----------------- mush/tests/test_requirements.py | 46 ++++++++++++++++++++- 5 files changed, 118 insertions(+), 58 deletions(-) diff --git a/mush/extraction.py b/mush/extraction.py index 4656601..d6ad95b 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -12,7 +12,8 @@ requires_nothing ) from .markers import missing, get_mush -from .requirements import Value, Requirement +from .requirements import Value, Requirement, Annotation +from .resources import ResourceKey def _apply_requires(by_name, by_index, requires_): @@ -49,7 +50,6 @@ def extract_requires(obj: Callable) -> Requirements: if isinstance(obj, partial) and p.name in obj.keywords: continue - type_ = annotations.get(name) default = missing if p.default is p.empty else p.default if isinstance(default, Requirement): @@ -57,9 +57,10 @@ def extract_requires(obj: Callable) -> Requirements: default = requirement.default elif isinstance(p.annotation, Requirement): requirement = p.annotation - default = default if requirement.default is missing else requirement.default + if requirement.default is not missing: + default = requirement.default else: - requirement = Value(type_, p.name, default) + requirement = Annotation(p.name, annotations.get(name), default) by_name[name] = Parameter( requirement, diff --git a/mush/requirements.py b/mush/requirements.py index 108eeda..8981785 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -1,7 +1,7 @@ -from typing import Any, List, Hashable, Sequence, Optional, Union +from typing import Any, List, Sequence, Optional, Union from .markers import missing -from .resources import ResourceKey +from .resources import ResourceKey, type_repr from .typing import Identifier @@ -40,7 +40,7 @@ class Requirement: The requirement for an individual parameter of a callable. """ - def __init__(self, default: Any, *keys: ResourceKey): + def __init__(self, keys: Sequence[ResourceKey], default: Optional[Any] = missing): self.keys: Sequence[ResourceKey] = keys self.default = default self.ops: List['Op'] = [] @@ -72,6 +72,29 @@ def __getitem__(self, name): return self +class Annotation(Requirement): + + def __init__(self, name: str, type_: type = None, default: Any = missing): + if type_ is None: + keys = [ResourceKey(None, name)] + else: + keys = [ + ResourceKey(type_, name), + ResourceKey(None, name), + ResourceKey(type_, None), + ] + super().__init__(keys, default) + + def __repr__(self): + type_, name = self.keys[0] + r = name + if type_ is not None: + r += f': {type_repr(type_)}' + if self.default is not missing: + r += f' = {self.default!r}' + return r + + class Value(Requirement): """ Declaration indicating that the specified resource key is required. @@ -98,12 +121,7 @@ def __init__(self, type_ = None else: type_ = type_or_identifier - super().__init__( - default, - ResourceKey(type_, identifier), - ResourceKey(None, identifier), - ResourceKey(type_, None), - ) + super().__init__([ResourceKey(type_, identifier)], default) def _keys_repr(self): return str(self.keys[0]) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 752a091..44450af 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -143,7 +143,7 @@ def test_requires_missing(self): def foo(obj: TheType): return obj context = Context() with ShouldRaise(ResourceError( - "Value(TheType, 'obj') could not be satisfied" + "obj: TheType could not be satisfied" )): context.call(foo) @@ -193,7 +193,7 @@ def foo(obj: TheType): return obj context = Context() context.add(object()) with ShouldRaise(ResourceError( - "Value(TheType, 'obj') could not be satisfied" + "obj: TheType could not be satisfied" )): context.call(foo) @@ -215,7 +215,7 @@ def returner(request_: Mapping[str, Any]): return request_ with ShouldRaise(ResourceError( - "Value(typing.Mapping[str, typing.Any], 'request_') could not be satisfied" + "request_: typing.Mapping[str, typing.Any] could not be satisfied" )): context.call(returner) @@ -227,7 +227,7 @@ def returner(request_: Request): return request_ with ShouldRaise(ResourceError( - "Value(Request, 'request_') could not be satisfied" + "request_: Request could not be satisfied" )): context.call(returner) @@ -237,8 +237,7 @@ def test_requires_requirement(self): def foo(requirement: Requirement): pass with ShouldRaise(ResourceError( - "Value(Requirement, 'requirement') " - "could not be satisfied" + "requirement: Requirement could not be satisfied" )): context.call(foo) @@ -572,7 +571,7 @@ def foo(obj: TheType): pass context.add(Provider(lambda: None), provides=object) with ShouldRaise(ResourceError( - "Value(TheType, 'obj') could not be satisfied" + "obj: TheType could not be satisfied" )): context.call(foo) diff --git a/mush/tests/test_extraction.py b/mush/tests/test_extraction.py index 08be598..a2cc57f 100644 --- a/mush/tests/test_extraction.py +++ b/mush/tests/test_extraction.py @@ -13,7 +13,7 @@ result_type, Requirements, Parameter ) from mush.extraction import extract_requires, extract_returns, update_wrapper -from mush.requirements import Requirement, ItemOp +from mush.requirements import Requirement, ItemOp, Annotation from .helpers import PY_36, Type1, Type2, Type3, Type4 from ..resources import ResourceKey @@ -32,8 +32,8 @@ def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, expected_rq=Requirements(( - Parameter(Value(identifier='a')), - Parameter(Value(identifier='b', default=None), default=None), + Parameter(Annotation('a')), + Parameter(Annotation('b', default=None), default=None), )), expected_rt=result_type) @@ -42,8 +42,8 @@ class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, expected_rq=Requirements(( - Parameter(Value(identifier='a')), - Parameter(Value(identifier='b', default=None), default=None), + Parameter(Annotation('a')), + Parameter(Annotation('b', default=None), default=None), )), expected_rt=result_type) @@ -53,8 +53,8 @@ def foo(x, y, z, a=None): pass check_extract( p, expected_rq=Requirements(( - Parameter(Value(identifier='z'), target='z'), - Parameter(Value(identifier='a', default=None), target='a', default=None), + Parameter(Annotation('z'), target='z'), + Parameter(Annotation('a', default=None), target='a', default=None), )), expected_rt=result_type ) @@ -65,7 +65,7 @@ def foo(a=None): pass check_extract( p, expected_rq=Requirements(( - Parameter(Value(identifier='a', default=None), default=None), + Parameter(Annotation('a', default=None), default=None), )), expected_rt=result_type ) @@ -114,8 +114,8 @@ def foo(b, a=None): pass check_extract( p, expected_rq=Requirements(( - Parameter(Value(identifier='b')), - Parameter(Value(identifier='a', default=None), default=None), + Parameter(Annotation('b')), + Parameter(Annotation('a', default=None), default=None), )), expected_rt=result_type ) @@ -127,7 +127,7 @@ def foo(b, a): pass p, # since b is already bound: expected_rq=Requirements(( - Parameter(Value(identifier='a')), + Parameter(Annotation('a')), )), expected_rt=result_type ) @@ -138,7 +138,7 @@ def foo(b, a): pass check_extract( p, expected_rq=Requirements(( - Parameter(Value(identifier='b')), + Parameter(Annotation('b')), )), expected_rt=result_type ) @@ -156,22 +156,22 @@ def test_extract_from_annotations(self): def foo(a: Type1, b, c: Type2 = 1, d=2) -> Type3: pass check_extract(foo, expected_rq=Requirements(( - Parameter(Value(Type1, identifier='a')), - Parameter(Value(identifier='b')), - Parameter(Value(Type2, identifier='c', default=1), default=1), - Parameter(Value(identifier='d', default=2), default=2), + Parameter(Annotation('a', Type1)), + Parameter(Annotation('b')), + Parameter(Annotation('c', Type2, default=1), default=1), + Parameter(Annotation('d', default=2), default=2), )), expected_rt=returns('bar')) def test_forward_type_references(self): check_extract(foo, - expected_rq=Requirements((Parameter(Value(Foo, identifier='a')),)), + expected_rq=Requirements((Parameter(Annotation('a', Foo)),)), expected_rt=returns(Bar)) def test_requires_only(self): def foo(a: Type1): pass check_extract(foo, - expected_rq=Requirements((Parameter(Value(Type1, identifier='a')),)), + expected_rq=Requirements((Parameter(Annotation('a', Type1)),)), expected_rt=result_type) def test_returns_only(self): @@ -240,10 +240,10 @@ def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, expected_rq=Requirements(( - Parameter(Value(identifier='a')), - Parameter(Value(identifier='b', default=1), default=1), - Parameter(Value(identifier='c'), target='c'), - Parameter(Value(identifier='d', default=None), target='d', default=None) + Parameter(Annotation('a')), + Parameter(Annotation('b', default=1), default=1), + Parameter(Annotation('c'), target='c'), + Parameter(Annotation('d', default=None), target='d', default=None) )), expected_rt=result_type) @@ -251,14 +251,14 @@ def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=Requirements((Parameter(Value(T, identifier='a')),)), + expected_rq=Requirements((Parameter(Annotation('a', T)),)), expected_rt=result_type) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=Requirements((Parameter(Value(type_, identifier='a')),)), + expected_rq=Requirements((Parameter(Annotation('a', type_)),)), expected_rt=result_type) def test_type_plus_value(self): @@ -303,9 +303,9 @@ def foo(a: Value(str) = 1): pass class TestDeclarationsFromMultipleSources: def test_declarations_from_different_sources(self): - r1 = Requirement('a') - r2 = Requirement('b') - r3 = Requirement('c') + r1 = Requirement(keys=(), default='a') + r2 = Requirement(keys=(), default='b') + r3 = Requirement(keys=(), default='c') @requires(b=r2) def foo(a: r1, b, c=r3): @@ -313,16 +313,16 @@ def foo(a: r1, b, c=r3): check_extract(foo, expected_rq=Requirements(( - Parameter(Requirement(default='a'), default='a'), - Parameter(Requirement(default='b'), default='b', target='b'), - Parameter(Requirement(default='c'), default='c', target='c'), + Parameter(Requirement((), default='a'), default='a'), + Parameter(Requirement((), default='b'), default='b', target='b'), + Parameter(Requirement((), default='c'), default='c', target='c'), )), expected_rt=result_type) def test_declaration_priorities(self): - r1 = Requirement(missing, ResourceKey(identifier='x')) - r2 = Requirement(missing, ResourceKey(identifier='y')) - r3 = Requirement(missing, ResourceKey(identifier='z')) + r1 = Requirement([ResourceKey(identifier='x')]) + r2 = Requirement([ResourceKey(identifier='y')]) + r3 = Requirement([ResourceKey(identifier='z')]) @requires(a=r1) def foo(a: r2 = r3, b: str = r2, c=r3): @@ -330,8 +330,8 @@ def foo(a: r2 = r3, b: str = r2, c=r3): check_extract(foo, expected_rq=Requirements(( - Parameter(r1, target='a'), - Parameter(r2, target='b'), - Parameter(r3, target='c'), + Parameter(Requirement([ResourceKey(identifier='x')]), target='a'), + Parameter(Requirement([ResourceKey(identifier='y')]), target='b'), + Parameter(Requirement([ResourceKey(identifier='z')]), target='c'), )), expected_rt=result_type) diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index f308d69..2953b94 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -1,8 +1,14 @@ +from typing import Text + +from testfixtures.mock import Mock + import pytest from testfixtures import compare, ShouldRaise -from mush import Value -from mush.requirements import Requirement, AttrOp, ItemOp # , AnyOf, Like +from mush import Value, missing +from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf, Like, Annotation +from mush.resources import ResourceKey +from mush.tests.helpers import Type1 def check_ops(value, data, *, expected): @@ -126,6 +132,42 @@ def test_no_special_name_via_getattr(self): # compare(c.call(lambda x: x, requires(Value('key', default=1).foo.bar)), # expected=1) # + + +class TestAnnotation: + + def test_name_only(self): + r = Annotation('x', None, missing) + compare(r.keys, expected=[ + ResourceKey(None, 'x') + ]) + compare(r.default, expected=missing) + + def test_name_and_type(self): + r = Annotation('x', str, missing) + compare(r.keys, expected=[ + ResourceKey(str, 'x'), + ResourceKey(None, 'x'), + ResourceKey(str, None), + ]) + compare(r.default, expected=missing) + + def test_all(self): + r = Annotation('x', str, 'default') + compare(r.keys, expected=[ + ResourceKey(str, 'x'), + ResourceKey(None, 'x'), + ResourceKey(str, None), + ]) + compare(r.default, expected='default') + + def test_repr_min(self): + compare(repr(Annotation('x', None, missing)), + expected="x") + + def test_repr_max(self): + compare(repr(Annotation('x', str, 'default')), + expected="x: str = 'default'") # # class TestAnyOf: # From d15b5308d8d99a8e88cfcbd317d8be28e5d82ca9 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 30 Sep 2020 07:23:51 +0100 Subject: [PATCH 128/159] Finish off Value implementation. --- mush/requirements.py | 2 + mush/tests/test_context.py | 63 +++++++++---- mush/tests/test_requirements.py | 154 ++++++++++++++++---------------- 3 files changed, 129 insertions(+), 90 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index 8981785..1a826a0 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -116,6 +116,8 @@ def __init__(self, if identifier is None: if isinstance(type_or_identifier, type): type_ = type_or_identifier + elif type_or_identifier is None: + raise TypeError('type or identifier must be supplied') else: identifier = type_or_identifier type_ = None diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 44450af..ffbadb5 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,13 +1,14 @@ # from typing import Tuple, List # from typing import NewType, Mapping, Any +from testfixtures.mock import Mock from testfixtures import ShouldRaise, compare # from testfixtures.mock import Mock # from mush import ( - Context, Requirement # , requires, returns, returns_mapping, Value, missing + Context, Requirement, Value, requires ) from mush.context import ResourceError # from mush.declarations import RequiresType, requires_nothing, returns_nothing @@ -298,7 +299,7 @@ def foo(x: str = Value(identifier='foo')['bar']): context = Context() context.add(dict(bar='baz'), identifier='foo') result = context.call(foo) - compare(result, 'baz') + compare(result, expected='baz') def test_call_requires_item_missing(self): def foo(obj: str = Value(dict)['foo']): pass @@ -309,6 +310,21 @@ def foo(obj: str = Value(dict)['foo']): pass )): context.call(foo) + def test_call_requires_optional_item_missing(self): + def foo(x: str = Value('foo', default=1)['bar']): + return x + context = Context() + result = context.call(foo) + compare(result, expected=1) + + def test_call_requires_optional_item_present(self): + def foo(x: str = Value('foo', default=1)['bar']): + return x + context = Context() + context.add(dict(bar='baz'), identifier='foo') + result = context.call(foo) + compare(result, expected='baz') + def test_call_requires_attr(self): @requires(Value('foo').bar) def foo(x): @@ -319,31 +335,48 @@ def foo(x): result = context.call(foo) compare(result, m.bar) - def test_call_requires_item_attr(self): - @requires(Value('foo').bar['baz']) + def test_call_requires_attr_missing(self): + @requires(Value('foo').bar) def foo(x): return x - m = Mock() - m.bar= dict(baz='bob') + o = object() context = Context() - context.add(m, identifier='foo') + context.add(o, identifier='foo') + with ShouldRaise(ResourceError( + "Value('foo').bar could not be satisfied", + )): + context.call(foo) + + def test_call_requires_optional_attr_missing(self): + @requires(Value('foo', default=1).bar) + def foo(x): + return x + o = object() + context = Context() + context.add(o, identifier='foo') result = context.call(foo) - compare(result, 'bob') + compare(result, expected=1) - def test_call_requires_optional_item_missing(self): - def foo(x: str = Value('foo', default=1)['bar']): + def test_call_requires_optional_attr_present(self): + @requires(Value('foo', default=1).bar) + def foo(x): return x + m = Mock() context = Context() + context.add(m, identifier='foo') result = context.call(foo) - compare(result, 1) + compare(result, expected=m.bar) - def test_call_requires_optional_item_present(self): - def foo(x: str = Value('foo', default=1)['bar']): + def test_call_requires_item_attr(self): + @requires(Value('foo').bar['baz']) + def foo(x): return x + m = Mock() + m.bar = dict(baz='bob') context = Context() - context.add(dict(bar='baz'), identifier='foo') + context.add(m, identifier='foo') result = context.call(foo) - compare(result, 'baz') + compare(result, expected='bob') # XXX requirements caching: diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 2953b94..8d993dd 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -57,81 +57,53 @@ def test_no_special_name_via_getattr(self): with ShouldRaise(AttributeError): assert v.__len__ compare(v.ops, []) -# -# -# class TestValue: -# -# def test_type_from_key(self): -# v = Value(str) -# compare(v.requirement.type, expected=str) -# -# def test_key_and_type_cannot_disagree(self): -# with ShouldRaise(TypeError('type_ cannot be specified if key is a type')): -# Value(key=str, type_=int) -# -# def test_at_least_one_param_must_be_specified(self): -# with ShouldRaise(TypeError('xx')): -# Value() -# -# -# class TestItem: -# -# def test_single(self): -# h = Value(Type1)['foo'] -# compare(repr(h), expected="Value(Type1)['foo']") -# check_ops(h, {'foo': 1}, expected=1) -# -# def test_multiple(self): -# h = Value(Type1)['foo']['bar'] -# compare(repr(h), expected="Value(Type1)['foo']['bar']") -# check_ops(h, {'foo': {'bar': 1}}, expected=1) -# -# def test_missing_obj(self): -# h = Value(Type1)['foo']['bar'] -# with ShouldRaise(TypeError): -# check_ops(h, object(), expected=None) -# -# def test_missing_key(self): -# h = Value(Type1)['foo'] -# check_ops(h, {}, expected=missing) -# -# def test_passed_missing(self): -# c = Context() -# c.add({}, provides='key') -# compare(c.call(lambda x: x, requires(Value('key', default=1)['foo']['bar'])), -# expected=1) -# -# def test_bad_type(self): -# h = Value(Type1)['foo']['bar'] -# with ShouldRaise(TypeError): -# check_ops(h, [], expected=None) -# -# -# class TestAttr(TestCase): -# -# def test_single(self): -# h = Value(Type1).foo -# compare(repr(h), "Value(Type1).foo") -# m = Mock() -# check_ops(h, m, expected=m.foo) -# -# def test_multiple(self): -# h = Value(Type1).foo.bar -# compare(repr(h), "Value(Type1).foo.bar") -# m = Mock() -# check_ops(h, m, expected=m.foo.bar) -# -# def test_missing(self): -# h = Value(Type1).foo -# compare(repr(h), "Value(Type1).foo") -# check_ops(h, object(), expected=missing) -# -# def test_passed_missing(self): -# c = Context() -# c.add(object(), provides='key') -# compare(c.call(lambda x: x, requires(Value('key', default=1).foo.bar)), -# expected=1) -# + + +class TestItem: + + def test_single(self): + h = Value(Type1)['foo'] + compare(repr(h), expected="Value(Type1)['foo']") + check_ops(h, {'foo': 1}, expected=1) + + def test_multiple(self): + h = Value(Type1)['foo']['bar'] + compare(repr(h), expected="Value(Type1)['foo']['bar']") + check_ops(h, {'foo': {'bar': 1}}, expected=1) + + def test_missing_obj(self): + h = Value(Type1)['foo']['bar'] + with ShouldRaise(TypeError): + check_ops(h, object(), expected=None) + + def test_missing_key(self): + h = Value(Type1)['foo'] + check_ops(h, {}, expected=missing) + + def test_bad_type(self): + h = Value(Type1)['foo']['bar'] + with ShouldRaise(TypeError): + check_ops(h, [], expected=None) + + +class TestAttr: + + def test_single(self): + h = Value(Type1).foo + compare(repr(h), "Value(Type1).foo") + m = Mock() + check_ops(h, m, expected=m.foo) + + def test_multiple(self): + h = Value(Type1).foo.bar + compare(repr(h), "Value(Type1).foo.bar") + m = Mock() + check_ops(h, m, expected=m.foo.bar) + + def test_missing(self): + h = Value(Type1).foo + compare(repr(h), "Value(Type1).foo") + check_ops(h, object(), expected=missing) class TestAnnotation: @@ -168,6 +140,38 @@ def test_repr_min(self): def test_repr_max(self): compare(repr(Annotation('x', str, 'default')), expected="x: str = 'default'") + + +class TestValue: + + def test_type_only(self): + v = Value(str) + compare(v.keys, expected=[ResourceKey(str, None)]) + + def test_typing_only(self): + v = Value(Text) + compare(v.keys, expected=[ResourceKey(Text, None)]) + + def test_identifier_only(self): + v = Value('foo') + compare(v.keys, expected=[ResourceKey(None, 'foo')]) + + def test_type_and_identifier(self): + v = Value(str, 'foo') + compare(v.keys, expected=[ResourceKey(str, 'foo')]) + + def test_nothing_specified(self): + with ShouldRaise(TypeError('type or identifier must be supplied')): + Value() + + def test_repr_min(self): + compare(repr(Value(Type1)), + expected="Value(Type1)") + + def test_repr_max(self): + compare(repr(Value(Type1, 'foo')['bar'].baz), + expected="Value(Type1, 'foo')['bar'].baz") + # # class TestAnyOf: # From 2ac4e4801c3c946716791163d61c0a2096bd15b2 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 30 Sep 2020 08:40:49 +0100 Subject: [PATCH 129/159] tighten up typing around resource types --- mush/context.py | 6 +++--- mush/requirements.py | 20 ++++++++++---------- mush/resources.py | 15 +++++++++++---- mush/tests/test_requirements.py | 11 ++++++++++- mush/typing.py | 10 ++++++---- 5 files changed, 40 insertions(+), 22 deletions(-) diff --git a/mush/context.py b/mush/context.py index a810c00..52a4575 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,11 +1,11 @@ from inspect import signature -from typing import Optional, Callable, Type, Union, Any, Dict +from typing import Optional, Callable, Union, Any, Dict from .extraction import extract_requires from .markers import missing, Marker from .requirements import Requirement from .resources import ResourceKey, ResourceValue, Provider -from .typing import Resource, Identifier +from .typing import Resource, Identifier, Type_ NONE_TYPE = type(None) unspecified = Marker('unspecified') @@ -30,7 +30,7 @@ def __init__(self): def add(self, obj: Union[Provider, Resource], - provides: Optional[Type] = missing, + provides: Optional[Type_] = missing, identifier: Identifier = None): """ Add a resource to the context. diff --git a/mush/requirements.py b/mush/requirements.py index 1a826a0..6dcf2f8 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -1,8 +1,8 @@ -from typing import Any, List, Sequence, Optional, Union +from typing import Any, List, Sequence, Optional, Union, Type from .markers import missing -from .resources import ResourceKey, type_repr -from .typing import Identifier +from .resources import ResourceKey, type_repr, is_type +from .typing import Identifier, Type_ class Op: @@ -74,7 +74,7 @@ def __getitem__(self, name): class Annotation(Requirement): - def __init__(self, name: str, type_: type = None, default: Any = missing): + def __init__(self, name: str, type_: Type_ = None, default: Any = missing): if type_ is None: keys = [ResourceKey(None, name)] else: @@ -110,19 +110,19 @@ class Value(Requirement): """ def __init__(self, - type_or_identifier: Union[type, Identifier] = None, + key: Union[Type_, Identifier] = None, identifier: Identifier = None, default: Any = missing): if identifier is None: - if isinstance(type_or_identifier, type): - type_ = type_or_identifier - elif type_or_identifier is None: + if is_type(key): + type_ = key + elif key is None: raise TypeError('type or identifier must be supplied') else: - identifier = type_or_identifier + identifier = key type_ = None else: - type_ = type_or_identifier + type_ = key super().__init__([ResourceKey(type_, identifier)], default) def _keys_repr(self): diff --git a/mush/resources.py b/mush/resources.py index 1b43bd7..81b3dc5 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -1,8 +1,8 @@ from types import FunctionType -from typing import Callable, Optional, Type +from typing import Callable, Optional, _GenericAlias from .markers import missing -from .typing import Resource, Identifier +from .typing import Resource, Identifier, Type_ def type_repr(type_): @@ -14,13 +14,20 @@ def type_repr(type_): return repr(type_) +def is_type(obj): + return ( + isinstance(obj, (type, _GenericAlias)) or + (callable(obj) and hasattr(obj, '__supertype__')) + ) + + class ResourceKey(tuple): - def __new__(cls, type_: Type = None, identifier: Identifier = None): + def __new__(cls, type_: Type_ = None, identifier: Identifier = None): return tuple.__new__(cls, (type_, identifier)) @property - def type(self) -> Type: + def type(self) -> Type_: return self[0] @property diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 8d993dd..819a3c3 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -1,4 +1,4 @@ -from typing import Text +from typing import Text, Tuple, NewType from testfixtures.mock import Mock @@ -152,6 +152,15 @@ def test_typing_only(self): v = Value(Text) compare(v.keys, expected=[ResourceKey(Text, None)]) + def test_typing_generic_alias(self): + v = Value(Tuple[str]) + compare(v.keys, expected=[ResourceKey(Tuple[str], None)]) + + def test_typing_new_type(self): + Type = NewType('Type', str) + v = Value(Type) + compare(v.keys, expected=[ResourceKey(Type, None)]) + def test_identifier_only(self): v = Value('foo') compare(v.keys, expected=[ResourceKey(None, 'foo')]) diff --git a/mush/typing.py b/mush/typing.py index d2df70f..dd432c6 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -1,17 +1,19 @@ -from typing import NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple +from typing import NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple, Type, _GenericAlias if TYPE_CHECKING: from .declarations import Requirements, Return from .requirements import Requirement -RequirementType = Union['Requirement', type, str] +Type_ = Union[type, Type, _GenericAlias] +Identifier = Hashable + +RequirementType = Union['Requirement', Type_, str] Requires = Union['Requirements', RequirementType, List[RequirementType], Tuple[RequirementType, ...]] -ReturnType = Union[type, str] +ReturnType = Union[Type_, str] Returns = Union['Return', ReturnType, List[ReturnType], Tuple[ReturnType, ...]] Resource = NewType('Resource', Any) -Identifier = Hashable From bb53fcf9e857f9eaf5e020bbf7d24dbf412d4fd3 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 30 Sep 2020 08:42:36 +0100 Subject: [PATCH 130/159] AnyOf --- mush/requirements.py | 40 ++++++++++-------- mush/tests/test_requirements.py | 75 ++++++++++++++++----------------- 2 files changed, 60 insertions(+), 55 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index 6dcf2f8..fc7f1c2 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -128,23 +128,6 @@ def __init__(self, def _keys_repr(self): return str(self.keys[0]) -# -# -# class AnyOf(Requirement): -# """ -# A requirement that is resolved by any of the specified keys. -# """ -# -# def __init__(self, *keys, default=missing): -# super().__init__(keys, default=default) -# -# @nonblocking -# def resolve(self, context: 'Context'): -# for key in self.key: -# value = context.get(key, missing) -# if value is not missing: -# return value -# return self.default # # # class Like(Requirement): @@ -177,3 +160,26 @@ def _keys_repr(self): # if resource is missing: # context.extract(self.provider.obj, self.provider.requires, self.provider.returns) # return self.original.resolve(context) + +class AnyOf(Requirement): + """ + A requirement that is resolved by any of the specified keys. + + A key may either be a :class:`type` or an :class:`Identifier` + """ + + def __init__(self, *keys: Union[Type_, Identifier], default: Any = missing): + if not keys: + raise TypeError('at least one key must be specified') + resource_keys = [] + for key in keys: + type_ = identifier = None + if is_type(key): + type_ = key + else: + identifier = key + resource_keys.append(ResourceKey(type_, identifier)) + super().__init__(resource_keys, default) + + def _keys_repr(self): + return ', '.join(str(key) for key in self.keys) diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 819a3c3..2338b2f 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -181,44 +181,43 @@ def test_repr_max(self): compare(repr(Value(Type1, 'foo')['bar'].baz), expected="Value(Type1, 'foo')['bar'].baz") -# -# class TestAnyOf: -# -# def test_first(self): -# context = Context() -# context.add(('foo', )) -# context.add(('bar', ), provides=Tuple[str]) -# -# def bob(x: str = AnyOf(tuple, Tuple[str])): -# return x[0] -# -# compare(context.call(bob), expected='foo') -# -# def test_second(self): -# context = Context() -# context.add(('bar', ), provides=Tuple[str]) -# -# def bob(x: str = AnyOf(tuple, Tuple[str])): -# return x[0] -# -# compare(context.call(bob), expected='bar') -# -# def test_none(self): -# context = Context() -# -# def bob(x: str = AnyOf(tuple, Tuple[str])): -# pass -# -# with ShouldRaise(ResourceError): -# context.call(bob) -# -# def test_default(self): -# context = Context() -# -# def bob(x: str = AnyOf(tuple, Tuple[str], default=(42,))): -# return x[0] -# -# compare(context.call(bob), expected=42) + +class TestAnyOf: + + def test_types_and_typing(self): + r = AnyOf(tuple, Tuple[str]) + compare(r.keys, expected=[ + ResourceKey(tuple, None), + ResourceKey(Tuple[str], None), + ]) + compare(r.default, expected=missing) + + def test_identifiers(self): + r = AnyOf('a', 'b') + compare(r.keys, expected=[ + ResourceKey(None, 'a'), + ResourceKey(None, 'b'), + ]) + compare(r.default, expected=missing) + + def test_default(self): + r = AnyOf(tuple, default='x') + compare(r.keys, expected=[ + ResourceKey(tuple, None), + ]) + compare(r.default, expected='x') + + def test_none(self): + with ShouldRaise(TypeError('at least one key must be specified')): + AnyOf() + + def test_repr_min(self): + compare(repr(AnyOf(Type1)), + expected="AnyOf(Type1)") + + def test_repr_max(self): + compare(repr(AnyOf(Type1, 'foo', default='baz')['bob'].bar), + expected="AnyOf(Type1, 'foo', default='baz')['bob'].bar") # # # class Parent(object): From c5699153d8c36144d356510742ce672cb038272c Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 30 Sep 2020 08:49:51 +0100 Subject: [PATCH 131/159] Like --- mush/requirements.py | 35 +++++++------- mush/tests/test_requirements.py | 86 +++++++++++++-------------------- 2 files changed, 52 insertions(+), 69 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index fc7f1c2..351febf 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -128,23 +128,6 @@ def __init__(self, def _keys_repr(self): return str(self.keys[0]) -# -# -# class Like(Requirement): -# """ -# A requirements that is resolved by the specified class or -# any of its base classes. -# """ -# -# @nonblocking -# def resolve(self, context: 'Context'): -# for key in self.key.__mro__: -# if key is object: -# break -# value = context.get(key, missing) -# if value is not missing: -# return value -# return self.default # # # class Lazy(Requirement): @@ -183,3 +166,21 @@ def __init__(self, *keys: Union[Type_, Identifier], default: Any = missing): def _keys_repr(self): return ', '.join(str(key) for key in self.keys) + + +class Like(Requirement): + """ + A requirements that is resolved by the specified class or + any of its base classes. + """ + + def __init__(self, type_: type, default: Any = missing): + keys = [] + for type__ in type_.__mro__: + if type__ is object: + break + keys.append(ResourceKey(type__, None)) + super().__init__(keys, default) + + def _keys_repr(self): + return str(self.keys[0]) diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 2338b2f..1deb078 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -218,55 +218,37 @@ def test_repr_min(self): def test_repr_max(self): compare(repr(AnyOf(Type1, 'foo', default='baz')['bob'].bar), expected="AnyOf(Type1, 'foo', default='baz')['bob'].bar") -# -# -# class Parent(object): -# pass -# -# -# class Child(Parent): -# pass -# -# -# class TestLike: -# -# def test_actual(self): -# context = Context() -# p = Parent() -# c = Child() -# context.add(p) -# context.add(c) -# -# def bob(x: str = Like(Child)): -# return x -# -# assert context.call(bob) is c -# -# def test_base(self): -# context = Context() -# p = Parent() -# context.add(p) -# -# def bob(x: str = Like(Child)): -# return x -# -# assert context.call(bob) is p -# -# def test_none(self): -# context = Context() -# # make sure we don't pick up object! -# context.add(object()) -# -# def bob(x: str = Like(Child)): -# pass -# -# with ShouldRaise(ResourceError): -# context.call(bob) -# -# def test_default(self): -# context = Context() -# -# def bob(x: str = Like(Child, default=42)): -# return x -# -# compare(context.call(bob), expected=42) + + +class Parent(object): + pass + + +class Child(Parent): + pass + + +class TestLike: + + def test_simple(self): + r = Like(Child) + compare(r.keys, expected=[ + ResourceKey(Child, None), + ResourceKey(Parent, None), + ]) + compare(r.default, expected=missing) + + def test_default(self): + r = Like(Parent, default='foo') + compare(r.keys, expected=[ + ResourceKey(Parent, None), + ]) + compare(r.default, expected='foo') + + def test_repr_min(self): + compare(repr(Like(Type1)), + expected="Like(Type1)") + + def test_repr_max(self): + compare(repr(Like(Type1, default='baz')['bob'].bar), + expected="Like(Type1, default='baz')['bob'].bar") From be7bf183d9b0177febc967911a3311a70827dc51 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 1 Oct 2020 08:29:44 +0100 Subject: [PATCH 132/159] Rename declaration types. --- mush/asyncio.py | 12 +++--- mush/declarations.py | 16 ++++---- mush/extraction.py | 8 ++-- mush/tests/test_async_context.py | 6 +-- mush/tests/test_callpoints.py | 16 ++++---- mush/tests/test_extraction.py | 66 +++++++++++++++----------------- mush/typing.py | 6 +-- 7 files changed, 63 insertions(+), 67 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index 2d7f994..f8bc4c9 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -5,7 +5,7 @@ from . import ( Context as SyncContext, Runner as SyncRunner, ResourceError, ContextError ) -from .declarations import Requirements, Return +from .declarations import RequirementsDeclaration, ReturnsDeclaration from .extraction import default_requirement_type from .markers import get_mush, AsyncType from .typing import RequirementModifier @@ -20,12 +20,12 @@ def __init__(self, context, loop): self.add = context.add self.get = context.get - def call(self, obj: Callable, requires: Requirements = None): + def call(self, obj: Callable, requires: RequirementsDeclaration = None): coro = self.context.call(obj, requires) future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() - def extract(self, obj: Callable, requires: Requirements = None, returns: Return = None): + def extract(self, obj: Callable, requires: RequirementsDeclaration = None, returns: ReturnsDeclaration = None): coro = self.context.extract(obj, requires, returns) future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() @@ -70,7 +70,7 @@ async def _ensure_async(self, func, *args, **kw): def _context_for(self, obj): return self if asyncio.iscoroutinefunction(obj) else self._sync_context - async def call(self, obj: Callable, requires: Requirements = None): + async def call(self, obj: Callable, requires: RequirementsDeclaration = None): args = [] kw = {} resolving = self._resolve(obj, requires, args, kw, self._context_for(obj)) @@ -82,8 +82,8 @@ async def call(self, obj: Callable, requires: Requirements = None): async def extract(self, obj: Callable, - requires: Requirements = None, - returns: Return = None): + requires: RequirementsDeclaration = None, + returns: ReturnsDeclaration = None): result = await self.call(obj, requires) self._process(obj, result, returns) return result diff --git a/mush/declarations.py b/mush/declarations.py index bb3c2b4..3f6ee26 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -30,17 +30,17 @@ def __init__(self, requirement: Requirement, target: str = None, default: Any = self.default = default -class Requirements(list): +class RequirementsDeclaration(list): + + def __call__(self, obj): + set_mush(obj, 'requires', self) + return obj def __repr__(self): parts = (repr(r) if r.target is None else f'{r.target}={r!r}' for r in self) return f"requires({', '.join(parts)})" - def __call__(self, obj): - set_mush(obj, 'requires', self) - return obj - def requires(*args: RequirementType, **kw: RequirementType): """ @@ -53,7 +53,7 @@ def requires(*args: RequirementType, **kw: RequirementType): String names for resources must be used instead of types where the callable returning those resources is configured to return the named resource. """ - requires_ = Requirements() + requires_ = RequirementsDeclaration() valid_decoration_types(*args) valid_decoration_types(*kw.values()) for target, possible in chain( @@ -68,10 +68,10 @@ def requires(*args: RequirementType, **kw: RequirementType): return requires_ -requires_nothing = Requirements() +requires_nothing = RequirementsDeclaration() -class Return(object): +class ReturnsDeclaration(object): def __call__(self, obj): set_mush(obj, 'returns', self) diff --git a/mush/extraction.py b/mush/extraction.py index d6ad95b..9dfd971 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -8,8 +8,8 @@ from typing import Callable, get_type_hints from .declarations import ( - Parameter, Requirements, Return, requires_nothing + Parameter, RequirementsDeclaration, ReturnsDeclaration, ) from .markers import missing, get_mush from .requirements import Value, Requirement, Annotation @@ -33,7 +33,7 @@ def _apply_requires(by_name, by_index, requires_): by_name[name] = p -def extract_requires(obj: Callable) -> Requirements: +def extract_requires(obj: Callable) -> RequirementsDeclaration: by_name = {} # from annotations @@ -86,11 +86,11 @@ def extract_requires(obj: Callable) -> Requirements: elif needs_target: parameter.target = name - return Requirements(by_name.values()) + return RequirementsDeclaration(by_name.values()) -def extract_returns(obj: Callable, explicit: Return = None): return None +def extract_returns(obj: Callable, explicit: ReturnsDeclaration = None): # if explicit is None: # returns_ = get_mush(obj, 'returns', None) # if returns_ is None: diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 4a43799..eec5885 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -9,7 +9,7 @@ from mush import Value, requires, returns, Context as SyncContext, blocking, nonblocking # from mush.asyncio import Context -from mush.declarations import Requirements +from mush.declarations import RequirementsDeclaration # from mush.requirements import Requirement, AnyOf, Like from .helpers import TheType, no_threads, must_run_in_thread from ..markers import AsyncType @@ -165,7 +165,7 @@ async def test_call_cache_requires(): context = Context() def foo(): pass await context.call(foo) - compare(context._requires_cache[foo], expected=Requirements()) + compare(context._requires_cache[foo], expected=RequirementsDeclaration()) @pytest.mark.asyncio @@ -219,7 +219,7 @@ def foo() -> TheType: result = await context.extract(foo) assert result is o compare({TheType: o}, actual=context._store) - compare(context._requires_cache[foo], expected=Requirements()) + compare(context._requires_cache[foo], expected=RequirementsDeclaration()) compare(context._returns_cache[foo], expected=returns(TheType)) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index b1687ec..ca37c2e 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -7,7 +7,7 @@ from testfixtures.mock import Mock, call from mush.callpoints import CallPoint -from mush.declarations import requires, returns, Requirements +from mush.declarations import requires, returns, RequirementsDeclaration # from mush.extraction import update_wrapper from mush.requirements import Value from mush.runner import Runner @@ -34,7 +34,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(self.context.extract.mock_calls, expected=[call(foo, - Requirements([Value.make(key='foo', name='a1')]), + RequirementsDeclaration([Value.make(key='foo', name='a1')]), rt)]) def test_extract_from_decorations(self): @@ -49,7 +49,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(self.context.extract.mock_calls, expected=[call(foo, - Requirements([Value.make(key='foo', name='a1')]), + RequirementsDeclaration([Value.make(key='foo', name='a1')]), returns('bar'))]) def test_extract_from_decorated_class(self): @@ -75,7 +75,7 @@ def foo(prefix): self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) result = CallPoint(self.runner, foo)(self.context) compare(result, expected=('the answer', - Requirements([Value.make(key='foo', name='prefix')]), + RequirementsDeclaration([Value.make(key='foo', name='prefix')]), rt)) def test_explicit_trumps_decorators(self): @@ -88,7 +88,7 @@ def foo(a1): pass compare(result, self.context.extract.return_value) compare(self.context.extract.mock_calls, expected=[call(foo, - Requirements([Value.make(key='baz', name='a1')]), + RequirementsDeclaration([Value.make(key='baz', name='a1')]), returns('bob'))]) def test_repr_minimal(self): @@ -107,7 +107,7 @@ def foo(a1): pass def test_convert_to_requires_and_returns(self): def foo(baz): pass point = CallPoint(self.runner, foo, requires='foo', returns='bar') - self.assertTrue(isinstance(point.requires, Requirements)) + self.assertTrue(isinstance(point.requires, RequirementsDeclaration)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires(Value('foo')) returns('bar')", repr(point)) @@ -118,7 +118,7 @@ def foo(a1, a2): pass foo, requires=('foo', 'bar'), returns=('baz', 'bob')) - self.assertTrue(isinstance(point.requires, Requirements)) + self.assertTrue(isinstance(point.requires, RequirementsDeclaration)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires(Value('foo'), Value('bar')) returns('baz', 'bob')", repr(point)) @@ -129,7 +129,7 @@ def foo(a1, a2): pass foo, requires=['foo', 'bar'], returns=['baz', 'bob']) - self.assertTrue(isinstance(point.requires, Requirements)) + self.assertTrue(isinstance(point.requires, RequirementsDeclaration)) self.assertTrue(isinstance(point.returns, returns)) compare(repr(foo)+" requires(Value('foo'), Value('bar')) returns('baz', 'bob')", repr(point)) diff --git a/mush/tests/test_extraction.py b/mush/tests/test_extraction.py index a2cc57f..9ca1550 100644 --- a/mush/tests/test_extraction.py +++ b/mush/tests/test_extraction.py @@ -1,20 +1,17 @@ from functools import partial -from typing import Tuple, get_type_hints -from unittest import TestCase import pytest -from testfixtures import compare, ShouldRaise +from testfixtures import compare -from mush import Value, missing +from mush import Value from mush.declarations import ( requires, returns, - returns_mapping, returns_sequence, returns_result_type, - requires_nothing, - result_type, Requirements, Parameter + returns_mapping, returns_sequence, requires_nothing, + result_type, RequirementsDeclaration, Parameter ) from mush.extraction import extract_requires, extract_returns, update_wrapper -from mush.requirements import Requirement, ItemOp, Annotation -from .helpers import PY_36, Type1, Type2, Type3, Type4 +from mush.requirements import Requirement, Annotation +from .helpers import Type1, Type2, Type3 from ..resources import ResourceKey @@ -22,8 +19,7 @@ def check_extract(obj, expected_rq, expected_rt): rq = extract_requires(obj) rt = extract_returns(obj, None) compare(rq, expected=expected_rq, strict=True) - assert rt is None - # compare(rt, expected=expected_rt, strict=True) + compare(rt, expected=expected_rt, strict=True) class TestRequirementsExtraction(object): @@ -31,7 +27,7 @@ class TestRequirementsExtraction(object): def test_default_requirements_for_function(self): def foo(a, b=None): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('a')), Parameter(Annotation('b', default=None), default=None), )), @@ -41,7 +37,7 @@ def test_default_requirements_for_class(self): class MyClass(object): def __init__(self, a, b=None): pass check_extract(MyClass, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('a')), Parameter(Annotation('b', default=None), default=None), )), @@ -52,7 +48,7 @@ def foo(x, y, z, a=None): pass p = partial(foo, 1, y=2) check_extract( p, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('z'), target='z'), Parameter(Annotation('a', default=None), target='a', default=None), )), @@ -64,7 +60,7 @@ def foo(a=None): pass p = partial(foo) check_extract( p, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('a', default=None), default=None), )), expected_rt=result_type @@ -113,7 +109,7 @@ def foo(b, a=None): pass p = partial(foo) check_extract( p, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('b')), Parameter(Annotation('a', default=None), default=None), )), @@ -126,7 +122,7 @@ def foo(b, a): pass check_extract( p, # since b is already bound: - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('a')), )), expected_rt=result_type @@ -137,7 +133,7 @@ def foo(b, a): pass p = partial(foo, a=1) check_extract( p, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('b')), )), expected_rt=result_type @@ -155,23 +151,23 @@ class TestExtractDeclarationsFromTypeAnnotations(object): def test_extract_from_annotations(self): def foo(a: Type1, b, c: Type2 = 1, d=2) -> Type3: pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('a', Type1)), Parameter(Annotation('b')), Parameter(Annotation('c', Type2, default=1), default=1), Parameter(Annotation('d', default=2), default=2), )), - expected_rt=returns('bar')) + expected_rt=returns(Type3)) def test_forward_type_references(self): check_extract(foo, - expected_rq=Requirements((Parameter(Annotation('a', Foo)),)), + expected_rq=RequirementsDeclaration((Parameter(Annotation('a', Foo)),)), expected_rt=returns(Bar)) def test_requires_only(self): def foo(a: Type1): pass check_extract(foo, - expected_rq=Requirements((Parameter(Annotation('a', Type1)),)), + expected_rq=RequirementsDeclaration((Parameter(Annotation('a', Type1)),)), expected_rt=result_type) def test_returns_only(self): @@ -199,7 +195,7 @@ def foo(a=None): compare(foo(), expected='the answer') check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Value(identifier='foo'), target='a'), )), expected_rt=returns('bar')) @@ -209,7 +205,7 @@ def test_decorator_trumps_annotations(self): @returns('bar') def foo(a: Type1) -> Type2: pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Value(identifier='foo')),) ), expected_rt=returns('bar')) @@ -231,7 +227,7 @@ def foo() -> rt: pass def test_how_instance_in_annotations(self): def foo(a: Value('config')['db_url']): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Value(identifier='config')['db_url']), )), expected_rt=result_type) @@ -239,7 +235,7 @@ def foo(a: Value('config')['db_url']): pass def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Annotation('a')), Parameter(Annotation('b', default=1), default=1), Parameter(Annotation('c'), target='c'), @@ -251,26 +247,26 @@ def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=Requirements((Parameter(Annotation('a', T)),)), + expected_rq=RequirementsDeclaration((Parameter(Annotation('a', T)),)), expected_rt=result_type) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=Requirements((Parameter(Annotation('a', type_)),)), + expected_rq=RequirementsDeclaration((Parameter(Annotation('a', type_)),)), expected_rt=result_type) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=Requirements((Parameter(Value(identifier='b')),)), + expected_rq=RequirementsDeclaration((Parameter(Value(identifier='b')),)), expected_rt=result_type) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Value(identifier='b', default=1), default=1), )), expected_rt=result_type) @@ -278,7 +274,7 @@ def foo(a: str = Value('b', default=1)): pass def test_value_annotation_plus_default(self): def foo(a: Value(str, identifier='b') = 1): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Value(str, identifier='b'), default=1), )), expected_rt=result_type) @@ -286,7 +282,7 @@ def foo(a: Value(str, identifier='b') = 1): pass def test_requirement_default_preferred_to_annotation_default(self): def foo(a: Value(str, identifier='b', default=2) = 1): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Value(str, identifier='b', default=2), default=2), )), expected_rt=result_type) @@ -294,7 +290,7 @@ def foo(a: Value(str, identifier='b', default=2) = 1): pass def test_value_annotation_just_type_in_value_key_plus_default(self): def foo(a: Value(str) = 1): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Value(str), default=1), )), expected_rt=result_type) @@ -312,7 +308,7 @@ def foo(a: r1, b, c=r3): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Requirement((), default='a'), default='a'), Parameter(Requirement((), default='b'), default='b', target='b'), Parameter(Requirement((), default='c'), default='c', target='c'), @@ -329,7 +325,7 @@ def foo(a: r2 = r3, b: str = r2, c=r3): pass check_extract(foo, - expected_rq=Requirements(( + expected_rq=RequirementsDeclaration(( Parameter(Requirement([ResourceKey(identifier='x')]), target='a'), Parameter(Requirement([ResourceKey(identifier='y')]), target='b'), Parameter(Requirement([ResourceKey(identifier='z')]), target='c'), diff --git a/mush/typing.py b/mush/typing.py index dd432c6..750eee5 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -1,19 +1,19 @@ from typing import NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple, Type, _GenericAlias if TYPE_CHECKING: - from .declarations import Requirements, Return + from .declarations import RequirementsDeclaration, ReturnsDeclaration from .requirements import Requirement Type_ = Union[type, Type, _GenericAlias] Identifier = Hashable RequirementType = Union['Requirement', Type_, str] -Requires = Union['Requirements', +Requires = Union['RequirementDeclaraction', RequirementType, List[RequirementType], Tuple[RequirementType, ...]] ReturnType = Union[Type_, str] -Returns = Union['Return', ReturnType, List[ReturnType], Tuple[ReturnType, ...]] +Returns = Union['ReturnsDeclaration', ReturnType] Resource = NewType('Resource', Any) From 6f71f7bfe54c897e00a5d6ba109fe2601a4ccd66 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Fri, 2 Oct 2020 08:38:03 +0100 Subject: [PATCH 133/159] move update_wrapper to a more appropriate module --- mush/__init__.py | 1 + mush/declarations.py | 19 +++++++++++++++++++ mush/extraction.py | 17 ----------------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index c2813ef..f2a12aa 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -3,6 +3,7 @@ requires, returns, returns_result_type, returns_mapping, returns_sequence, ) from .extraction import extract_requires#, extract_returns, update_wrapper +from .declarations import requires, returns, update_wrapper from .markers import missing, nonblocking, blocking from .plug import Plug from .requirements import Requirement, Value#, AnyOf, Like diff --git a/mush/declarations.py b/mush/declarations.py index 3f6ee26..5ba1fdb 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -1,4 +1,9 @@ from enum import Enum, auto +from functools import ( + WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, + WRAPPER_UPDATES, + update_wrapper as functools_update_wrapper +) from itertools import chain from typing import _type_check, Any @@ -160,3 +165,17 @@ class DeclarationsFrom(Enum): original = DeclarationsFrom.original #: Use declarations from the replacement callable. replacement = DeclarationsFrom.replacement + + +WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) + + +def update_wrapper(wrapper, + wrapped, + assigned=WRAPPER_ASSIGNMENTS, + updated=WRAPPER_UPDATES): + """ + An extended version of :func:`functools.update_wrapper` that + also preserves Mush's annotations. + """ + return functools_update_wrapper(wrapper, wrapped, assigned, updated) diff --git a/mush/extraction.py b/mush/extraction.py index 9dfd971..fbd23b3 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -1,7 +1,4 @@ from functools import ( - WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, - WRAPPER_UPDATES, - update_wrapper as functools_update_wrapper, partial ) from inspect import signature @@ -107,17 +104,3 @@ def extract_returns(obj: Callable, explicit: ReturnsDeclaration = None): # returns_ = returns(returns_) # # return returns_ or result_type - - -WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) - - -def update_wrapper(wrapper, - wrapped, - assigned=WRAPPER_ASSIGNMENTS, - updated=WRAPPER_UPDATES): - """ - An extended version of :func:`functools.update_wrapper` that - also preserves Mush's annotations. - """ - return functools_update_wrapper(wrapper, wrapped, assigned, updated) From bf7ea99a1767dea2dc67f8636ebb59fc9bbbc5b4 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 5 Oct 2020 10:57:48 +0100 Subject: [PATCH 134/159] Polish up requirements extraction and specification. --- mush/declarations.py | 15 +++++++------- mush/extraction.py | 6 +++--- mush/requirements.py | 12 ++++------- mush/resources.py | 9 +++++++++ mush/tests/test_declarations.py | 35 ++++++++++++++++++--------------- mush/typing.py | 2 +- 6 files changed, 44 insertions(+), 35 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index 5ba1fdb..0774e71 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -5,16 +5,17 @@ update_wrapper as functools_update_wrapper ) from itertools import chain -from typing import _type_check, Any +from typing import _type_check, Any, List, Set from .markers import set_mush, missing from .requirements import Requirement, Value +from .resources import ResourceKey from .typing import RequirementType, ReturnType VALID_DECORATION_TYPES = (type, str, Requirement) -def valid_decoration_types(*objs): +def check_decoration_types(*objs): for obj in objs: if isinstance(obj, VALID_DECORATION_TYPES): continue @@ -35,15 +36,15 @@ def __init__(self, requirement: Requirement, target: str = None, default: Any = self.default = default -class RequirementsDeclaration(list): +class RequirementsDeclaration(List[Parameter]): def __call__(self, obj): set_mush(obj, 'requires', self) return obj def __repr__(self): - parts = (repr(r) if r.target is None else f'{r.target}={r!r}' - for r in self) + parts = (repr(p.requirement) if p.target is None else f'{p.target}={p.requirement!r}' + for p in self) return f"requires({', '.join(parts)})" @@ -59,8 +60,8 @@ def requires(*args: RequirementType, **kw: RequirementType): returning those resources is configured to return the named resource. """ requires_ = RequirementsDeclaration() - valid_decoration_types(*args) - valid_decoration_types(*kw.values()) + check_decoration_types(*args) + check_decoration_types(*kw.values()) for target, possible in chain( ((None, arg) for arg in args), kw.items(), diff --git a/mush/extraction.py b/mush/extraction.py index fbd23b3..e758d3f 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -35,9 +35,9 @@ def extract_requires(obj: Callable) -> RequirementsDeclaration: # from annotations try: - annotations = get_type_hints(obj) + hints = get_type_hints(obj) except TypeError: - annotations = {} + hints = {} for name, p in signature(obj).parameters.items(): if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): @@ -57,7 +57,7 @@ def extract_requires(obj: Callable) -> RequirementsDeclaration: if requirement.default is not missing: default = requirement.default else: - requirement = Annotation(p.name, annotations.get(name), default) + requirement = Annotation(p.name, hints.get(name), default) by_name[name] = Parameter( requirement, diff --git a/mush/requirements.py b/mush/requirements.py index 351febf..62568d0 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -114,16 +114,12 @@ def __init__(self, identifier: Identifier = None, default: Any = missing): if identifier is None: - if is_type(key): - type_ = key - elif key is None: + if key is None: raise TypeError('type or identifier must be supplied') - else: - identifier = key - type_ = None + resource_key = ResourceKey.guess(key) else: - type_ = key - super().__init__([ResourceKey(type_, identifier)], default) + resource_key = ResourceKey(key, identifier) + super().__init__([resource_key], default) def _keys_repr(self): return str(self.keys[0]) diff --git a/mush/resources.py b/mush/resources.py index 81b3dc5..78f408f 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -26,6 +26,15 @@ class ResourceKey(tuple): def __new__(cls, type_: Type_ = None, identifier: Identifier = None): return tuple.__new__(cls, (type_, identifier)) + @classmethod + def guess(cls, key): + type_ = identifier = None + if is_type(key): + type_ = key + else: + identifier = key + return cls(type_, identifier) + @property def type(self) -> Type_: return self[0] diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 7dba6a7..816824a 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -1,15 +1,13 @@ -import pytest; pytestmark = pytest.mark.skip("WIP") from typing import Tuple from unittest import TestCase from testfixtures import compare, ShouldRaise -from mush import Value -from mush.declarations import ( - requires, returns, - returns_mapping, returns_sequence, returns_result_type -) +from mush import Value, AnyOf +from mush.declarations import requires, returns, Parameter, RequirementsDeclaration, \ + ReturnsDeclaration from .helpers import PY_36, Type1, Type2, Type3, Type4 +from ..resources import ResourceKey class TestRequires(TestCase): @@ -23,27 +21,27 @@ def test_types(self): r_ = requires(Type1, Type2, x=Type3, y=Type4) compare(repr(r_), 'requires(Value(Type1), Value(Type2), x=Value(Type3), y=Value(Type4))') compare(r_, expected=[ - Value(Type1), - Value(Type2), - Value.make(key=Type3, type=Type3, name='x', target='x'), - Value.make(key=Type4, type=Type4, name='y', target='y'), + Parameter(Value(Type1)), + Parameter(Value(Type2)), + Parameter(Value(Type3), target='x'), + Parameter(Value(Type4), target='y'), ]) def test_strings(self): r_ = requires('1', '2', x='3', y='4') compare(repr(r_), "requires(Value('1'), Value('2'), x=Value('3'), y=Value('4'))") compare(r_, expected=[ - Value('1'), - Value('2'), - Value.make(key='3', name='x', target='x'), - Value.make(key='4', name='y', target='y'), + Parameter(Value('1')), + Parameter(Value('2')), + Parameter(Value('3'), target='x'), + Parameter(Value('4'), target='y'), ]) def test_typing(self): r_ = requires(Tuple[str]) text = 'Tuple' if PY_36 else 'typing.Tuple[str]' compare(repr(r_),expected=f"requires(Value({text}))") - compare(r_, expected=[Value.make(key=Tuple[str], type=Tuple[str])]) + compare(r_, expected=[Parameter(Value(Tuple[str]))]) def test_tuple_arg(self): with ShouldRaise(TypeError("('1', '2') is not a valid decoration type")): @@ -58,9 +56,14 @@ def test_decorator_paranoid(self): def foo(): return 'bar' - compare(foo.__mush__['requires'], expected=[Value(Type1)]) + compare(foo.__mush__['requires'], expected=[Parameter(Value(Type1))]) compare(foo(), 'bar') + def test_requirement_instance(self): + compare(requires(x=AnyOf('foo', 'bar')), + expected=RequirementsDeclaration([Parameter(AnyOf('foo', 'bar'), target='x')]), + strict=True) + class TestReturns(TestCase): diff --git a/mush/typing.py b/mush/typing.py index 750eee5..f60f1f4 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -7,7 +7,7 @@ Type_ = Union[type, Type, _GenericAlias] Identifier = Hashable -RequirementType = Union['Requirement', Type_, str] +RequirementType = Union['Requirement', Type_, Identifier] Requires = Union['RequirementDeclaraction', RequirementType, List[RequirementType], From 7bb5377af1d17c91c2f624a4ca700ce72446da9d Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 5 Oct 2020 11:05:53 +0100 Subject: [PATCH 135/159] Re-implement returns declaration and extraction. Loose complicated flexibility and use a simple set of resource keys instead. --- mush/__init__.py | 10 +-- mush/declarations.py | 71 ++-------------- mush/extraction.py | 54 +++++++----- mush/tests/test_context.py | 18 ---- mush/tests/test_declarations.py | 75 ++++------------ mush/tests/test_extraction.py | 146 ++++++++++++++++---------------- mush/tests/test_runner.py | 91 -------------------- 7 files changed, 132 insertions(+), 333 deletions(-) diff --git a/mush/__init__.py b/mush/__init__.py index f2a12aa..b07df0c 100755 --- a/mush/__init__.py +++ b/mush/__init__.py @@ -1,12 +1,9 @@ from .context import Context, ResourceError -from .declarations import ( - requires, returns, returns_result_type, returns_mapping, returns_sequence, -) -from .extraction import extract_requires#, extract_returns, update_wrapper from .declarations import requires, returns, update_wrapper +from .extraction import extract_requires, extract_returns from .markers import missing, nonblocking, blocking from .plug import Plug -from .requirements import Requirement, Value#, AnyOf, Like +from .requirements import Requirement, Value, AnyOf, Like from .runner import Runner, ContextError __all__ = [ @@ -24,8 +21,5 @@ 'nonblocking', 'requires', 'returns', - 'returns_mapping', - 'returns_result_type', - 'returns_sequence', 'update_wrapper', ] diff --git a/mush/declarations.py b/mush/declarations.py index 0774e71..2d8464c 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -77,84 +77,25 @@ def requires(*args: RequirementType, **kw: RequirementType): requires_nothing = RequirementsDeclaration() -class ReturnsDeclaration(object): +class ReturnsDeclaration(Set[ResourceKey]): def __call__(self, obj): set_mush(obj, 'returns', self) return obj def __repr__(self): - return self.__class__.__name__ + '()' + return f"returns({', '.join(str(k) for k in sorted(self, key=lambda o: str(o)))})" -class returns(Return): +def returns(*keys: ReturnType): """ - Declaration that specifies names for returned resources or overrides - the type of a returned resource. - - This declaration can be used to indicate the type or name of a single - returned resource or, if multiple arguments are passed, that the callable - will return a sequence of values where each one should be named or have its - type overridden. - """ - - def __init__(self, *args: ReturnType): - valid_decoration_types(*args) - self.args = args - - def process(self, obj): - if len(self.args) == 1: - yield self.args[0], obj - elif self.args: - for t, o in zip(self.args, obj): - yield t, o - - def __repr__(self): - args_repr = ', '.join(name_or_repr(arg) for arg in self.args) - return self.__class__.__name__ + '(' + args_repr + ')' - - -class returns_result_type(Return): - """ - Default declaration that indicates a callable's return value - should be used as a resource based on the type of the object returned. - - ``None`` is ignored as a return value, as are context managers - """ - - def process(self, obj): - if not (obj is None or hasattr(obj, '__enter__') or hasattr(obj, '__aenter__')): - yield obj.__class__, obj - - -class returns_mapping(Return): """ - Declaration that indicates a callable returns a mapping of type or name - to resource. - """ - - def process(self, mapping): - return mapping.items() - - -class returns_sequence(returns_result_type): - """ - Declaration that indicates a callable's returns a sequence of values - that should be used as a resources based on the type of the object returned. - - Any ``None`` values in the sequence are ignored. - """ - - def process(self, sequence): - super_process = super(returns_sequence, self).process - for obj in sequence: - for pair in super_process(obj): - yield pair + check_decoration_types(*keys) + return ReturnsDeclaration(ResourceKey.guess(k) for k in keys) -returns_nothing = returns() +returns_nothing = ignore_return = ReturnsDeclaration() -result_type = returns_result_type() class DeclarationsFrom(Enum): diff --git a/mush/extraction.py b/mush/extraction.py index e758d3f..70f1f58 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -5,11 +5,11 @@ from typing import Callable, get_type_hints from .declarations import ( - requires_nothing Parameter, RequirementsDeclaration, ReturnsDeclaration, + requires_nothing ) from .markers import missing, get_mush -from .requirements import Value, Requirement, Annotation +from .requirements import Requirement, Annotation from .resources import ResourceKey @@ -86,21 +86,35 @@ def extract_requires(obj: Callable) -> RequirementsDeclaration: return RequirementsDeclaration(by_name.values()) - return None -def extract_returns(obj: Callable, explicit: ReturnsDeclaration = None): -# if explicit is None: -# returns_ = get_mush(obj, 'returns', None) -# if returns_ is None: -# annotations = getattr(obj, '__annotations__', {}) -# returns_ = annotations.get('return') -# else: -# returns_ = explicit -# -# if returns_ is None or isinstance(returns_, ReturnsType): -# pass -# elif isinstance(returns_, (list, tuple)): -# returns_ = returns(*returns_) -# else: -# returns_ = returns(returns_) -# -# return returns_ or result_type +def extract_returns(obj: Callable): + returns_ = get_mush(obj, 'returns', None) + if returns_ is not None: + return returns_ + + returns_ = ReturnsDeclaration() + try: + type_ = get_type_hints(obj).get('return') + except TypeError: + type_ = None + else: + if type_ is type(None): + return returns_ + + if type_ is None and isinstance(obj, type): + type_ = obj + + if isinstance(obj, partial): + obj = obj.func + identifier = getattr(obj, '__name__', None) + + type_supplied = type_ is not None + identifier_supplied = identifier is not None + + if type_supplied: + returns_.add(ResourceKey(type_, None)) + if identifier_supplied: + returns_.add(ResourceKey(None, identifier)) + if type_supplied and identifier_supplied: + returns_.add(ResourceKey(type_, identifier)) + + return returns_ diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index ffbadb5..67ae13f 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -432,24 +432,6 @@ def foo(x): # compare(result, 'bar') # compare({TheType: 'bar'}, actual=context._store) # -# def test_returns_sequence(self): -# def foo(): -# return 1, 2 -# context = Context() -# result = context.extract(foo, requires_nothing, returns('foo', 'bar')) -# compare(result, (1, 2)) -# compare({'foo': 1, 'bar': 2}, -# actual=context._store) -# -# def test_returns_mapping(self): -# def foo(): -# return {'foo': 1, 'bar': 2} -# context = Context() -# result = context.extract(foo, requires_nothing, returns_mapping()) -# compare(result, {'foo': 1, 'bar': 2}) -# compare({'foo': 1, 'bar': 2}, -# actual=context._store) -# # def test_ignore_return(self): # def foo(): # return 'bar' diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index 816824a..a862f60 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -6,7 +6,7 @@ from mush import Value, AnyOf from mush.declarations import requires, returns, Parameter, RequirementsDeclaration, \ ReturnsDeclaration -from .helpers import PY_36, Type1, Type2, Type3, Type4 +from .helpers import PY_36, Type1, Type2, Type3, Type4, TheType from ..resources import ResourceKey @@ -64,30 +64,32 @@ def test_requirement_instance(self): expected=RequirementsDeclaration([Parameter(AnyOf('foo', 'bar'), target='x')]), strict=True) + def test_accidental_tuple(self): + with ShouldRaise(TypeError( + "(, " + ") " + "is not a valid decoration type" + )): + requires((TheType, TheType)) + class TestReturns(TestCase): def test_type(self): r = returns(Type1) compare(repr(r), 'returns(Type1)') - compare(dict(r.process('foo')), {Type1: 'foo'}) + compare(r, expected=ReturnsDeclaration((ResourceKey(Type1),))) def test_string(self): r = returns('bar') compare(repr(r), "returns('bar')") - compare(dict(r.process('foo')), {'bar': 'foo'}) + compare(r, expected=ReturnsDeclaration((ResourceKey(identifier='bar'),))) def test_typing(self): r = returns(Tuple[str]) text = 'Tuple' if PY_36 else 'typing.Tuple[str]' compare(repr(r), f'returns({text})') - compare(dict(r.process('foo')), {Tuple[str]: 'foo'}) - - def test_sequence(self): - r = returns(Type1, 'bar') - compare(repr(r), "returns(Type1, 'bar')") - compare(dict(r.process(('foo', 'baz'))), - {Type1: 'foo', 'bar': 'baz'}) + compare(r, expected=ReturnsDeclaration((ResourceKey(Tuple[str]),))) def test_decorator(self): @returns(Type1) @@ -95,7 +97,7 @@ def foo(): return 'foo' r = foo.__mush__['returns'] compare(repr(r), 'returns(Type1)') - compare(dict(r.process(foo())), {Type1: 'foo'}) + compare(r, expected=ReturnsDeclaration((ResourceKey(Type1),))) def test_bad_type(self): with ShouldRaise(TypeError( @@ -104,51 +106,6 @@ def test_bad_type(self): @returns([]) def foo(): pass - -class TestReturnsMapping(TestCase): - - def test_it(self): - @returns_mapping() - def foo(): - return {Type1: 'foo', 'bar': 'baz'} - r = foo.__mush__['returns'] - compare(repr(r), 'returns_mapping()') - compare(dict(r.process(foo())), - {Type1: 'foo', 'bar': 'baz'}) - - -class TestReturnsSequence(TestCase): - - def test_it(self): - t1 = Type1() - t2 = Type2() - @returns_sequence() - def foo(): - return t1, t2 - r = foo.__mush__['returns'] - compare(repr(r), 'returns_sequence()') - compare(dict(r.process(foo())), - {Type1: t1, Type2: t2}) - - -class TestReturnsResultType(TestCase): - - def test_basic(self): - @returns_result_type() - def foo(): - return 'foo' - r = foo.__mush__['returns'] - compare(repr(r), 'returns_result_type()') - compare(dict(r.process(foo())), {str: 'foo'}) - - def test_old_style_class(self): - class Type: pass - obj = Type() - r = returns_result_type() - compare(dict(r.process(obj)), {Type: obj}) - - def test_returns_nothing(self): - def foo(): - pass - r = returns_result_type() - compare(dict(r.process(foo())), {}) + def test_keys_are_orderable(self): + r = returns(Type1, 'foo') + compare(repr(r), expected="returns('foo', Type1)") diff --git a/mush/tests/test_extraction.py b/mush/tests/test_extraction.py index 9ca1550..c47e44d 100644 --- a/mush/tests/test_extraction.py +++ b/mush/tests/test_extraction.py @@ -1,23 +1,27 @@ from functools import partial +from typing import Optional +from testfixtures.mock import Mock import pytest from testfixtures import compare -from mush import Value +from mush import Value, update_wrapper from mush.declarations import ( - requires, returns, - returns_mapping, returns_sequence, requires_nothing, - result_type, RequirementsDeclaration, Parameter + requires, returns, requires_nothing, RequirementsDeclaration, Parameter, ReturnsDeclaration, + returns_nothing ) -from mush.extraction import extract_requires, extract_returns, update_wrapper +from mush.extraction import extract_requires, extract_returns from mush.requirements import Requirement, Annotation from .helpers import Type1, Type2, Type3 from ..resources import ResourceKey -def check_extract(obj, expected_rq, expected_rt): +returns_foo = ReturnsDeclaration([ResourceKey(identifier='foo')]) + + +def check_extract(obj, expected_rq, expected_rt=returns_foo): rq = extract_requires(obj) - rt = extract_returns(obj, None) + rt = extract_returns(obj) compare(rq, expected=expected_rq, strict=True) compare(rt, expected=expected_rt, strict=True) @@ -30,8 +34,7 @@ def foo(a, b=None): pass expected_rq=RequirementsDeclaration(( Parameter(Annotation('a')), Parameter(Annotation('b', default=None), default=None), - )), - expected_rt=result_type) + ))) def test_default_requirements_for_class(self): class MyClass(object): @@ -41,7 +44,11 @@ def __init__(self, a, b=None): pass Parameter(Annotation('a')), Parameter(Annotation('b', default=None), default=None), )), - expected_rt=result_type) + expected_rt=ReturnsDeclaration([ + ResourceKey(MyClass), + ResourceKey(identifier='MyClass'), + ResourceKey(MyClass, 'MyClass'), + ])) def test_extract_from_partial(self): def foo(x, y, z, a=None): pass @@ -51,8 +58,7 @@ def foo(x, y, z, a=None): pass expected_rq=RequirementsDeclaration(( Parameter(Annotation('z'), target='z'), Parameter(Annotation('a', default=None), target='a', default=None), - )), - expected_rt=result_type + )) ) def test_extract_from_partial_default_not_in_partial(self): @@ -62,8 +68,7 @@ def foo(a=None): pass p, expected_rq=RequirementsDeclaration(( Parameter(Annotation('a', default=None), default=None), - )), - expected_rt=result_type + )) ) def test_extract_from_partial_default_in_partial_arg(self): @@ -72,8 +77,7 @@ def foo(a=None): pass check_extract( p, # since a is already bound by the partial: - expected_rq=requires_nothing, - expected_rt=result_type + expected_rq=requires_nothing ) def test_extract_from_partial_default_in_partial_kw(self): @@ -81,8 +85,7 @@ def foo(a=None): pass p = partial(foo, a=1) check_extract( p, - expected_rq=requires_nothing, - expected_rt=result_type + expected_rq=requires_nothing ) def test_extract_from_partial_required_in_partial_arg(self): @@ -91,8 +94,7 @@ def foo(a): pass check_extract( p, # since a is already bound by the partial: - expected_rq=requires_nothing, - expected_rt=result_type + expected_rq=requires_nothing ) def test_extract_from_partial_required_in_partial_kw(self): @@ -100,8 +102,7 @@ def foo(a): pass p = partial(foo, a=1) check_extract( p, - expected_rq=requires_nothing, - expected_rt=result_type + expected_rq=requires_nothing ) def test_extract_from_partial_plus_one_default_not_in_partial(self): @@ -112,8 +113,7 @@ def foo(b, a=None): pass expected_rq=RequirementsDeclaration(( Parameter(Annotation('b')), Parameter(Annotation('a', default=None), default=None), - )), - expected_rt=result_type + )) ) def test_extract_from_partial_plus_one_required_in_partial_arg(self): @@ -124,8 +124,7 @@ def foo(b, a): pass # since b is already bound: expected_rq=RequirementsDeclaration(( Parameter(Annotation('a')), - )), - expected_rt=result_type + )) ) def test_extract_from_partial_plus_one_required_in_partial_kw(self): @@ -135,13 +134,20 @@ def foo(b, a): pass p, expected_rq=RequirementsDeclaration(( Parameter(Annotation('b')), - )), - expected_rt=result_type + )) + ) + + def test_extract_from_mock(self): + foo = Mock() + check_extract( + foo, + expected_rq=requires_nothing, + expected_rt=returns_nothing, ) # https://bugs.python.org/issue41872 -def foo(a: 'Foo') -> 'Bar': pass +def foo_(a: 'Foo') -> 'Bar': pass class Foo: pass class Bar: pass @@ -157,24 +163,41 @@ def foo(a: Type1, b, c: Type2 = 1, d=2) -> Type3: pass Parameter(Annotation('c', Type2, default=1), default=1), Parameter(Annotation('d', default=2), default=2), )), - expected_rt=returns(Type3)) + expected_rt=ReturnsDeclaration([ + ResourceKey(Type3), + ResourceKey(identifier='foo'), + ResourceKey(Type3, 'foo'), + ])) def test_forward_type_references(self): - check_extract(foo, + check_extract(foo_, expected_rq=RequirementsDeclaration((Parameter(Annotation('a', Foo)),)), - expected_rt=returns(Bar)) + expected_rt=ReturnsDeclaration([ + ResourceKey(Bar), + ResourceKey(identifier='foo_'), + ResourceKey(Bar, 'foo_'), + ])) def test_requires_only(self): def foo(a: Type1): pass check_extract(foo, - expected_rq=RequirementsDeclaration((Parameter(Annotation('a', Type1)),)), - expected_rt=result_type) + expected_rq=RequirementsDeclaration((Parameter(Annotation('a', Type1)),))) def test_returns_only(self): def foo() -> Type1: pass check_extract(foo, expected_rq=requires_nothing, - expected_rt=returns(Type1)) + expected_rt=ReturnsDeclaration([ + ResourceKey(Type1), + ResourceKey(identifier='foo'), + ResourceKey(Type1, 'foo'), + ])) + + def test_returns_nothing(self): + def foo() -> None: pass + check_extract(foo, + expected_rq=requires_nothing, + expected_rt=ReturnsDeclaration()) def test_extract_from_decorated_class(self): @@ -198,9 +221,9 @@ def foo(a=None): expected_rq=RequirementsDeclaration(( Parameter(Value(identifier='foo'), target='a'), )), - expected_rt=returns('bar')) + expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')])) - def test_decorator_trumps_annotations(self): + def test_decorator_preferred_to_annotations(self): @requires('foo') @returns('bar') def foo(a: Type1) -> Type2: pass @@ -208,29 +231,14 @@ def foo(a: Type1) -> Type2: pass expected_rq=RequirementsDeclaration(( Parameter(Value(identifier='foo')),) ), - expected_rt=returns('bar')) - - def test_returns_mapping(self): - rt = returns_mapping() - def foo() -> rt: pass - check_extract(foo, - expected_rq=requires_nothing, - expected_rt=rt) - - def test_returns_sequence(self): - rt = returns_sequence() - def foo() -> rt: pass - check_extract(foo, - expected_rq=requires_nothing, - expected_rt=rt) + expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')])) def test_how_instance_in_annotations(self): def foo(a: Value('config')['db_url']): pass check_extract(foo, expected_rq=RequirementsDeclaration(( Parameter(Value(identifier='config')['db_url']), - )), - expected_rt=result_type) + ))) def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass @@ -240,52 +248,46 @@ def foo(a, b=1, *, c, d=None): pass Parameter(Annotation('b', default=1), default=1), Parameter(Annotation('c'), target='c'), Parameter(Annotation('d', default=None), target='d', default=None) - )), - expected_rt=result_type) + ))) def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, expected_rq=RequirementsDeclaration((Parameter(Annotation('a', T)),)), - expected_rt=result_type) + expected_rt=ReturnsDeclaration([ResourceKey(identifier='foo')])) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=RequirementsDeclaration((Parameter(Annotation('a', type_)),)), - expected_rt=result_type) + expected_rq=RequirementsDeclaration((Parameter(Annotation('a', type_)),))) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=RequirementsDeclaration((Parameter(Value(identifier='b')),)), - expected_rt=result_type) + expected_rq=RequirementsDeclaration((Parameter(Value(identifier='b')),))) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, expected_rq=RequirementsDeclaration(( Parameter(Value(identifier='b', default=1), default=1), - )), - expected_rt=result_type) + ))) def test_value_annotation_plus_default(self): def foo(a: Value(str, identifier='b') = 1): pass check_extract(foo, expected_rq=RequirementsDeclaration(( Parameter(Value(str, identifier='b'), default=1), - )), - expected_rt=result_type) + ))) def test_requirement_default_preferred_to_annotation_default(self): def foo(a: Value(str, identifier='b', default=2) = 1): pass check_extract(foo, expected_rq=RequirementsDeclaration(( Parameter(Value(str, identifier='b', default=2), default=2), - )), - expected_rt=result_type) + ))) def test_value_annotation_just_type_in_value_key_plus_default(self): def foo(a: Value(str) = 1): pass @@ -293,7 +295,7 @@ def foo(a: Value(str) = 1): pass expected_rq=RequirementsDeclaration(( Parameter(Value(str), default=1), )), - expected_rt=result_type) + expected_rt=ReturnsDeclaration([ResourceKey(identifier='foo')])) class TestDeclarationsFromMultipleSources: @@ -312,8 +314,7 @@ def foo(a: r1, b, c=r3): Parameter(Requirement((), default='a'), default='a'), Parameter(Requirement((), default='b'), default='b', target='b'), Parameter(Requirement((), default='c'), default='c', target='c'), - )), - expected_rt=result_type) + ))) def test_declaration_priorities(self): r1 = Requirement([ResourceKey(identifier='x')]) @@ -321,7 +322,8 @@ def test_declaration_priorities(self): r3 = Requirement([ResourceKey(identifier='z')]) @requires(a=r1) - def foo(a: r2 = r3, b: str = r2, c=r3): + @returns('bar') + def foo(a: r2 = r3, b: str = r2, c=r3) -> Optional[Type1]: pass check_extract(foo, @@ -330,4 +332,4 @@ def foo(a: r2 = r3, b: str = r2, c=r3): Parameter(Requirement([ResourceKey(identifier='y')]), target='b'), Parameter(Requirement([ResourceKey(identifier='z')]), target='c'), )), - expected_rt=result_type) + expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')])) diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index d2e5747..31d65ef 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -265,97 +265,6 @@ def job4(t2_): call.job4(t2), ], m.mock_calls) - def test_returns_type_mapping(self): - m = Mock() - class T1(object): pass - class T2(object): pass - t = T1() - - @returns_mapping() - def job1(): - m.job1() - return {T2:t} - - @requires(T2) - def job2(obj): - m.job2(obj) - - Runner(job1, job2)() - - compare([ - call.job1(), - call.job2(t), - ], m.mock_calls) - - def test_returns_type_mapping_of_none(self): - m = Mock() - class T2(object): pass - - @returns_mapping() - def job1(): - m.job1() - return {T2:None} - - @requires(T2) - def job2(obj): - m.job2(obj) - - Runner(job1, job2)() - - compare([ - call.job1(), - call.job2(None), - ], m.mock_calls) - - def test_returns_tuple(self): - m = Mock() - class T1(object): pass - class T2(object): pass - - t1 = T1() - t2 = T2() - - @returns(T1, T2) - def job1(): - m.job1() - return t1, t2 - - @requires(T1, T2) - def job2(obj1, obj2): - m.job2(obj1, obj2) - - Runner(job1, job2)() - - compare([ - call.job1(), - call.job2(t1, t2), - ], m.mock_calls) - - def test_returns_list(self): - m = Mock() - class T1(object): pass - class T2(object): pass - - t1 = T1() - t2 = T2() - - def job1(): - m.job1() - return [t1, t2] - - @requires(obj1=T1, obj2=T2) - def job2(obj1, obj2): - m.job2(obj1, obj2) - - runner = Runner() - runner.add(job1, returns=returns(T1, T2)) - runner.add(job2) - runner() - - compare([ - call.job1(), - call.job2(t1, t2), - ], m.mock_calls) def test_return_type_specified_decorator(self): m = Mock() From 5fedbaa4411f7e901ee083b2f96c83cbbb1370ac Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Mon, 5 Oct 2020 11:09:50 +0100 Subject: [PATCH 136/159] Bring back context.extract() --- mush/context.py | 63 +++++++++++------------ mush/tests/test_context.py | 101 ++++++++++++++++++------------------- 2 files changed, 81 insertions(+), 83 deletions(-) diff --git a/mush/context.py b/mush/context.py index 52a4575..adfd265 100644 --- a/mush/context.py +++ b/mush/context.py @@ -28,6 +28,12 @@ def __init__(self): # self._requires_cache = {} # self._returns_cache = {} + def add_by_keys(self, resource: ResourceValue, keys: Iterable[ResourceKey]): + for key in keys: + if key in self._store: + raise ResourceError(f'Context already contains {key}') + self._store[key] = resource + def add(self, obj: Union[Provider, Resource], provides: Optional[Type_] = missing, @@ -39,33 +45,28 @@ def add(self, ``provides`` can be explicitly specified as ``None`` to only register against the identifier """ + keys = set() if isinstance(obj, Provider): resource = obj if provides is missing: - sig = signature(obj.provider) - annotation = sig.return_annotation - if annotation is sig.empty: - if identifier is None: - raise ResourceError( - f'Could not determine what is provided by {obj.provider}' - ) - else: - provides = annotation + keys.update(extract_returns(resource.provider)) else: resource = ResourceValue(obj) if provides is missing: provides = type(obj) - to_add = [] if provides is not missing: - to_add.append(ResourceKey(provides, identifier)) + keys.add(ResourceKey(provides, identifier)) if not (identifier is None or provides is None): - to_add.append(ResourceKey(None, identifier)) - for key in to_add: - if key in self._store: - raise ResourceError(f'Context already contains {key}') - self._store[key] = resource + keys.add(ResourceKey(None, identifier)) + + if not keys: + raise ResourceError( + f'Could not determine what is provided by {resource}' + ) + + self.add_by_keys(resource, keys) # def remove(self, key: ResourceKey, *, strict: bool = True): # """ @@ -85,21 +86,21 @@ def __repr__(self): if bits: bits.append('\n') return f"" - # - # def _process(self, obj, result, returns): - # if returns is None: - # returns = self._returns_cache.get(obj) - # if returns is None: - # returns = extract_returns(obj, explicit=None) - # self._returns_cache[obj] = returns - # - # for type, obj in returns.process(result): - # self.add(obj, type) - # - # def extract(self, obj: Callable, requires: RequiresType = None, returns: ReturnsType = None): - # result = self.call(obj, requires) - # self._process(obj, result, returns) - # return result + + def _process(self, obj, result, returns): + if returns is None: + # returns = self._returns_cache.get(obj) + # if returns is None: + returns = extract_returns(obj) + # self._returns_cache[obj] = returns + + if returns: + self.add_by_keys(ResourceValue(result), returns) + + def extract(self, obj: Callable):#, requires: RequiresType = None, returns: ReturnsType = None): + result = self.call(obj) + self._process(obj, result, returns=None) + return result def _find_resource(self, key): if not isinstance(key[0], type): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 67ae13f..da6c49b 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -397,56 +397,50 @@ def foo(x): # compare(context._requires_cache, expected={}) # - # XXX extract -# def test_extract_minimal(self): -# o = TheType() -# def foo() -> TheType: -# return o -# context = Context() -# result = context.extract(foo) -# assert result is o -# compare({TheType: o}, actual=context._store) -# compare(context._requires_cache[foo], expected=RequiresType()) -# compare(context._returns_cache[foo], expected=returns(TheType)) -# -# def test_extract_maximal(self): -# def foo(*args): -# return args -# context = Context() -# context.add('a') -# result = context.extract(foo, requires(str), returns(Tuple[str])) -# compare(result, expected=('a',)) -# compare({ -# str: 'a', -# Tuple[str]: ('a',), -# }, actual=context._store) -# compare(context._requires_cache, expected={}) -# compare(context._returns_cache, expected={}) -# -# def test_returns_single(self): -# def foo(): -# return 'bar' -# context = Context() -# result = context.extract(foo, requires_nothing, returns(TheType)) -# compare(result, 'bar') -# compare({TheType: 'bar'}, actual=context._store) -# -# def test_ignore_return(self): -# def foo(): -# return 'bar' -# context = Context() -# result = context.extract(foo, requires_nothing, returns_nothing) -# compare(result, 'bar') -# compare({}, context._store) -# -# def test_ignore_non_iterable_return(self): -# def foo(): pass -# context = Context() -# result = context.extract(foo) -# compare(result, expected=None) -# compare(context._store, expected={}) -# +class TestExtract: + + def test_extract_minimal(self): + o = TheType() + def foo(): + return o + context = Context() + result = context.extract(foo) + assert result is o + compare({ResourceKey(identifier='foo'): ResourceValue(o)}, actual=context._store) + + def test_extract_maximal(self): + def foo(o: str) -> Tuple[str, ...]: + return o, o + context = Context() + context.add('a') + result = context.extract(foo) + compare(result, expected=('a', 'a')) + compare({ + ResourceKey(str): ResourceValue('a'), + ResourceKey(identifier='foo'): ResourceValue(result), + ResourceKey(Tuple[str, ...], 'foo'): ResourceValue(result), + ResourceKey(Tuple[str, ...]): ResourceValue(result), + }, actual=context._store) + + def test_ignore_return(self): + @ignore_return + def foo(): + return 'bar' + context = Context() + result = context.extract(foo) + compare(result, 'bar') + compare({}, context._store) + + def test_returns_none(self): + def foo(): pass + context = Context() + result = context.extract(foo) + compare(result, expected=None) + compare(context._store, expected={ + ResourceKey(identifier='foo'): ResourceValue(None), + }) + # XXX - remove @@ -647,10 +641,13 @@ def returner(obj: Type1): assert isinstance(context.call(returner), Type1) def test_no_provides(self): - def provider(): pass + provider = Mock() context = Context() - with ShouldRaise(ResourceError(f'Could not determine what is provided by {provider}')): - context.add(Provider(provider)) + with ShouldRaise(ResourceError( + f'Could not determine what is provided by ' + f'Provider(functools.partial({provider}), cache=True, provides_subclasses=False)' + )): + context.add(Provider(partial(provider))) def test_identifier(self): def provider() -> str: From 245ae9cb94102b5ade0cd6b80bca17eff58d4238 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 8 Oct 2020 07:58:53 +0100 Subject: [PATCH 137/159] Allow additional declarations to be explicitly supplied. --- mush/extraction.py | 23 ++++++++-- mush/tests/test_context.py | 79 +++++++++++++++++------------------ mush/tests/test_extraction.py | 55 +++++++++++++++++++++++- 3 files changed, 111 insertions(+), 46 deletions(-) diff --git a/mush/extraction.py b/mush/extraction.py index 70f1f58..f535053 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -6,11 +6,12 @@ from .declarations import ( Parameter, RequirementsDeclaration, ReturnsDeclaration, - requires_nothing + requires_nothing, returns, requires ) from .markers import missing, get_mush from .requirements import Requirement, Annotation from .resources import ResourceKey +from .typing import Requires, Returns def _apply_requires(by_name, by_index, requires_): @@ -30,7 +31,10 @@ def _apply_requires(by_name, by_index, requires_): by_name[name] = p -def extract_requires(obj: Callable) -> RequirementsDeclaration: +def extract_requires( + obj: Callable, + explicit: Requires = None, +) -> RequirementsDeclaration: by_name = {} # from annotations @@ -72,6 +76,14 @@ def extract_requires(obj: Callable) -> RequirementsDeclaration: if mush_requires is not None: _apply_requires(by_name, by_index, mush_requires) + # explicit + if explicit is not None: + if not isinstance(explicit, RequirementsDeclaration): + if not isinstance(explicit, (list, tuple)): + explicit = (explicit,) + explicit = requires(*explicit) + _apply_requires(by_name, by_index, explicit) + if not by_name: return requires_nothing @@ -86,7 +98,12 @@ def extract_requires(obj: Callable) -> RequirementsDeclaration: return RequirementsDeclaration(by_name.values()) -def extract_returns(obj: Callable): +def extract_returns(obj: Callable, explicit: Returns = None): + if explicit is not None: + if not isinstance(explicit, ReturnsDeclaration): + return returns(explicit) + return explicit + returns_ = get_mush(obj, 'returns', None) if returns_ is not None: return returns_ diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index da6c49b..9d7e59c 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,9 +1,10 @@ # from typing import Tuple, List # -from typing import NewType, Mapping, Any -from testfixtures.mock import Mock +from functools import partial +from typing import NewType, Mapping, Any, Tuple from testfixtures import ShouldRaise, compare +from testfixtures.mock import Mock # from testfixtures.mock import Mock # @@ -14,7 +15,9 @@ # from mush.declarations import RequiresType, requires_nothing, returns_nothing # from mush.requirements import Requirement from .helpers import TheType, Type1, Type2 -from ..resources import ResourceValue, Provider +from ..declarations import ignore_return +from ..requirements import ItemOp +from ..resources import ResourceValue, Provider, ResourceKey class TestAdd: @@ -251,44 +254,38 @@ def foo(*, x: int): result = context.call(foo) compare(result, expected=2) -# XXX - these are for explicit requires() objects: - # def test_call_requires_string(self): - # def foo(obj): - # return obj - # context = Context() - # context.add('bar', identifier='baz') - # result = context.call(foo, requires('baz')) - # compare(result, expected='bar') - # compare({'baz': 'bar'}, actual=context._store) - -# def test_call_requires_type(self): -# def foo(obj): -# return obj -# context = Context() -# context.add('bar', TheType) -# result = context.call(foo, requires(TheType)) -# compare(result, 'bar') -# compare({TheType: 'bar'}, actual=context._store) -# - # - # def test_call_requires_accidental_tuple(self): - # def foo(obj): return obj - # context = Context() - # with ShouldRaise(TypeError( - # "(, " - # ") " - # "is not a valid decoration type" - # )): - # context.call(foo, requires((TheType, TheType))) -# -# def test_call_requires_optional_override_source_and_default(self): -# def foo(x=1): -# return x -# context = Context() -# context.add(2, provides='x') -# result = context.call(foo, requires(x=Value('y', default=3))) -# compare(result, expected=3) -# + def test_call_requires_string(self): + def foo(obj): + return obj + context = Context() + context.add('bar', identifier='baz') + result = context.call(foo, requires('baz')) + compare(result, expected='bar') + + def test_call_requires_type(self): + def foo(obj): + return obj + context = Context() + context.add('bar', TheType) + result = context.call(foo, requires(TheType)) + compare(result, 'bar') + + def test_call_requires_optional_override_source_and_default(self): + def foo(x=1): + return x + context = Context() + context.add(2, provides='x') + result = context.call(foo, requires(x=Value('y', default=3))) + compare(result, expected=3) + + def test_kw_parameter(self): + def foo(x, y): + return x, y + context = Context() + context.add('foo', TheType) + context.add('bar', identifier='baz') + result = context.call(foo, requires(y='baz', x=TheType)) + compare(result, expected=('foo', 'bar')) class TestOps: diff --git a/mush/tests/test_extraction.py b/mush/tests/test_extraction.py index c47e44d..b2512db 100644 --- a/mush/tests/test_extraction.py +++ b/mush/tests/test_extraction.py @@ -26,7 +26,7 @@ def check_extract(obj, expected_rq, expected_rt=returns_foo): compare(rt, expected=expected_rt, strict=True) -class TestRequirementsExtraction(object): +class TestRequirementsExtraction: def test_default_requirements_for_function(self): def foo(a, b=None): pass @@ -152,7 +152,7 @@ class Foo: pass class Bar: pass -class TestExtractDeclarationsFromTypeAnnotations(object): +class TestExtractDeclarationsFromTypeAnnotations: def test_extract_from_annotations(self): def foo(a: Type1, b, c: Type2 = 1, d=2) -> Type3: pass @@ -298,6 +298,57 @@ def foo(a: Value(str) = 1): pass expected_rt=ReturnsDeclaration([ResourceKey(identifier='foo')])) +def it(): + pass + + +class TestExplicitDeclarations: + + def test_requires_from_string(self): + compare(extract_requires(it, 'bar'), strict=True, expected=RequirementsDeclaration(( + Parameter(Value(identifier='bar')), + ))) + + def test_requires_from_type(self): + compare(extract_requires(it, Type1), strict=True, expected=RequirementsDeclaration(( + Parameter(Value(Type1)), + ))) + + def test_requires_requirement(self): + compare(extract_requires(it, Value(Type1, 'bar')), strict=True, expected=RequirementsDeclaration(( + Parameter(Value(Type1, 'bar')), + ))) + + def test_requires_from_tuple(self): + compare(extract_requires(it, ('bar', 'baz')), strict=True, expected=RequirementsDeclaration(( + Parameter(Value(identifier='bar')), + Parameter(Value(identifier='baz')), + ))) + + def test_requires_from_list(self): + compare(extract_requires(it, ['bar', 'baz']), strict=True, expected=RequirementsDeclaration(( + Parameter(Value(identifier='bar')), + Parameter(Value(identifier='baz')), + ))) + + def test_explicit_requires_where_parameter_has_default(self): + def foo(x=1): pass + compare(extract_requires(foo, 'bar'), strict=True, expected=RequirementsDeclaration(( + # default is not longer considered: + Parameter(Value(identifier='bar')), + ))) + + def test_returns_from_string(self): + compare(extract_returns(it, 'bar'), strict=True, expected=ReturnsDeclaration([ + ResourceKey(identifier='bar') + ])) + + def test_returns_from_type(self): + compare(extract_returns(it, Type1), strict=True, expected=ReturnsDeclaration([ + ResourceKey(Type1) + ])) + + class TestDeclarationsFromMultipleSources: def test_declarations_from_different_sources(self): From d922f70fb9d5fc92071b25b86828b9cff22fcc21 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 8 Oct 2020 08:09:26 +0100 Subject: [PATCH 138/159] Turns out it's better to use type in preference to identifier. --- mush/requirements.py | 2 +- mush/tests/test_requirements.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index 62568d0..975c206 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -80,8 +80,8 @@ def __init__(self, name: str, type_: Type_ = None, default: Any = missing): else: keys = [ ResourceKey(type_, name), - ResourceKey(None, name), ResourceKey(type_, None), + ResourceKey(None, name), ] super().__init__(keys, default) diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 1deb078..07e281e 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -119,8 +119,8 @@ def test_name_and_type(self): r = Annotation('x', str, missing) compare(r.keys, expected=[ ResourceKey(str, 'x'), - ResourceKey(None, 'x'), ResourceKey(str, None), + ResourceKey(None, 'x'), ]) compare(r.default, expected=missing) @@ -128,8 +128,8 @@ def test_all(self): r = Annotation('x', str, 'default') compare(r.keys, expected=[ ResourceKey(str, 'x'), - ResourceKey(None, 'x'), ResourceKey(str, None), + ResourceKey(None, 'x'), ]) compare(r.default, expected='default') From 992fc1fd1819e4e5785c864f02f94594b30d9472 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 8 Oct 2020 08:12:03 +0100 Subject: [PATCH 139/159] Bring back runners. Lazy removed in favour of using providers. --- mush/callpoints.py | 41 +-- mush/context.py | 34 +- mush/modifier.py | 7 +- mush/requirements.py | 15 - mush/runner.py | 51 +-- mush/tests/example_with_mush_clone.py | 7 +- mush/tests/test_callpoints.py | 107 ++---- mush/tests/test_example_with_mush_clone.py | 6 +- mush/tests/test_example_with_mush_factory.py | 3 +- mush/tests/test_example_without_mush.py | 1 + mush/tests/test_plug.py | 5 +- mush/tests/test_runner.py | 361 +++++-------------- 12 files changed, 177 insertions(+), 461 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index 47d9c95..b5bc898 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -1,21 +1,10 @@ -from collections import namedtuple from typing import TYPE_CHECKING, Callable -from .declarations import ( - requires_nothing, returns as returns_declaration, returns_nothing -) -from .extraction import extract_requires#, extract_returns +from .extraction import extract_requires, extract_returns from .typing import Requires, Returns if TYPE_CHECKING: - from .runner import Runner - - -def do_nothing(): - pass - - -LazyProvider = namedtuple('LazyProvider', ['obj', 'requires', 'returns']) + from . import Context class CallPoint(object): @@ -23,36 +12,20 @@ class CallPoint(object): next = None previous = None - def __init__(self, runner: 'Runner', obj: Callable, - requires: Requires = None, returns: Returns = None, - lazy: bool = False): - requires = extract_requires(obj, requires, runner.modify_requirement) - returns = extract_returns(obj, returns) - if lazy: - if not (type(returns) is returns_declaration and len(returns.args) == 1): - raise TypeError('a single return type must be explicitly specified') - key = returns.args[0] - if key in runner.lazy: - raise TypeError( - f'{name_or_repr(key)} has more than one lazy provider:\n' - f'{runner.lazy[key].obj}\n' - f'{obj}' - ) - runner.lazy[key] = LazyProvider(obj, requires, returns) - obj = do_nothing - requires = requires_nothing - returns = returns_nothing + def __init__(self, obj: Callable, requires: Requires = None, returns: Returns = None): self.obj = obj self.requires = requires self.returns = returns self.labels = set() self.added_using = set() - def __call__(self, context): + def __call__(self, context: 'Context'): return context.extract(self.obj, self.requires, self.returns) def __repr__(self): - txt = '%r %r %r' % (self.obj, self.requires, self.returns) + requires = extract_requires(self.obj, self.requires) + returns = extract_returns(self.obj, self.returns) + txt = f'{self.obj!r} {requires!r} {returns!r}' if self.labels: txt += (' <-- ' + ', '.join(sorted(self.labels))) return txt diff --git a/mush/context.py b/mush/context.py index adfd265..c839007 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,11 +1,12 @@ from inspect import signature -from typing import Optional, Callable, Union, Any, Dict +from typing import Optional, Callable, Union, Any, Dict, Iterable -from .extraction import extract_requires +from .callpoints import CallPoint +from .extraction import extract_requires, extract_returns from .markers import missing, Marker from .requirements import Requirement from .resources import ResourceKey, ResourceValue, Provider -from .typing import Resource, Identifier, Type_ +from .typing import Resource, Identifier, Type_, Requires, Returns NONE_TYPE = type(None) unspecified = Marker('unspecified') @@ -21,7 +22,7 @@ class Context: "Stores resources for a particular run." # _parent: 'Context' = None - # point: CallPoint = None + point: CallPoint = None def __init__(self): self._store = {} @@ -87,19 +88,11 @@ def __repr__(self): bits.append('\n') return f"" - def _process(self, obj, result, returns): - if returns is None: - # returns = self._returns_cache.get(obj) - # if returns is None: - returns = extract_returns(obj) - # self._returns_cache[obj] = returns - + def extract(self, obj: Callable, requires: Requires = None, returns: Returns = None): + result = self.call(obj, requires) + returns = extract_returns(obj, returns) if returns: self.add_by_keys(ResourceValue(result), returns) - - def extract(self, obj: Callable):#, requires: RequiresType = None, returns: ReturnsType = None): - result = self.call(obj) - self._process(obj, result, returns=None) return result def _find_resource(self, key): @@ -113,11 +106,11 @@ def _find_resource(self, key): return resource exact = False - def _resolve(self, obj, specials=None): + def _resolve(self, obj, requires=None, specials=None): if specials is None: specials: Dict[type, Any] = {Context: self} - requires = extract_requires(obj) + requires = extract_requires(obj, requires) args = [] kw = {} @@ -137,7 +130,7 @@ def _resolve(self, obj, specials=None): if resource.obj is missing: specials_ = specials.copy() specials_[Requirement] = requirement - o = self._resolve(resource.provider, specials_) + o = self._resolve(resource.provider, specials=specials_) if resource.cache: resource.obj = o else: @@ -150,6 +143,7 @@ def _resolve(self, obj, specials=None): o = parameter.default if o is not requirement.default: + # move to requirement.process? for op in requirement.ops: o = op(o) if o is missing: @@ -166,8 +160,8 @@ def _resolve(self, obj, specials=None): return obj(*args, **kw) - def call(self, obj: Callable): - return self._resolve(obj) + def call(self, obj: Callable, requires: Requires = None): + return self._resolve(obj, requires) # # def nest(self, requirement_modifier: RequirementModifier = None): diff --git a/mush/modifier.py b/mush/modifier.py index 16fa223..4ac684a 100644 --- a/mush/modifier.py +++ b/mush/modifier.py @@ -22,7 +22,7 @@ def __init__(self, runner, callpoint, label): self.labels = {label} def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, - label: str = None, lazy: bool = False): + label: str = None): """ :param obj: The callable to be added. @@ -40,9 +40,6 @@ def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, point where ``obj`` is added that can later be retrieved with :meth:`Runner.__getitem__`. - :param lazy: If true, ``obj`` will only be called the first time it - is needed. - If no label is specified but the point which this :class:`~.modifier.Modifier` represents has any labels, those labels will be moved to the newly inserted point. @@ -51,7 +48,7 @@ def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, raise ValueError('%r already points to %r' % ( label, self.runner.labels[label] )) - callpoint = CallPoint(self.runner, obj, requires, returns, lazy) + callpoint = CallPoint(obj, requires, returns) if label: self.add_label(label, callpoint) diff --git a/mush/requirements.py b/mush/requirements.py index 975c206..f92dd5b 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -124,21 +124,6 @@ def __init__(self, def _keys_repr(self): return str(self.keys[0]) -# -# -# class Lazy(Requirement): -# -# def __init__(self, original, provider): -# super().__init__(original.key) -# self.original = original -# self.provider = provider -# self.ops = original.ops -# -# def resolve(self, context): -# resource = context.get(self.key, missing) -# if resource is missing: -# context.extract(self.provider.obj, self.provider.requires, self.provider.returns) -# return self.original.resolve(context) class AnyOf(Requirement): """ diff --git a/mush/runner.py b/mush/runner.py index 97293f4..8c7807b 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -3,7 +3,7 @@ from .callpoints import CallPoint from .context import Context, ResourceError from .declarations import DeclarationsFrom -from .extraction import extract_requires#, extract_returns +from .extraction import extract_requires, extract_returns # , extract_returns from .markers import not_specified from .modifier import Modifier from .plug import Plug @@ -21,19 +21,11 @@ class Runner(object): end: Optional[CallPoint] = None def __init__(self, *objects: Callable): - self.requirement_modifier = requirement_modifier self.labels = {} - self.lazy = {} self.extend(*objects) - def modify_requirement(self, requirement): - requirement = self.requirement_modifier(requirement) - if requirement.key in self.lazy: - requirement = Lazy(requirement, provider=self.lazy[requirement.key]) - return requirement - def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, - label: str = None, lazy: bool = False): + label: str = None): """ Add a callable to the runner. @@ -52,15 +44,12 @@ def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, :param label: If specified, this is a string that adds a label to the point where ``obj`` is added that can later be retrieved with :meth:`Runner.__getitem__`. - - :param lazy: If true, ``obj`` will only be called the first time it - is needed. """ if isinstance(obj, Plug): obj.add_to(self) else: m = Modifier(self, self.end, not_specified) - m.add(obj, requires, returns, label, lazy) + m.add(obj, requires, returns, label) return m def add_label(self, label: str): @@ -71,26 +60,14 @@ def add_label(self, label: str): m.add_label(label) return m - def _copy_from(self, runner, start_point, end_point, added_using=None): - if self.requirement_modifier is not runner.requirement_modifier: - raise TypeError('requirement_modifier must be identical') - - lazy_clash = set(self.lazy) & set(runner.lazy) - if lazy_clash: - raise TypeError( - 'both runners have lazy providers for these resources:\n' + - '\n'.join(f'{name_or_repr(key)}: \n' - f' {self.lazy[key].obj}\n' - f' {runner.lazy[key].obj}' for key in lazy_clash) - ) - self.lazy.update(runner.lazy) + def _copy_from(self, start_point, end_point, added_using=None): previous_cloned_point = self.end point = start_point while point: if added_using is None or added_using in point.added_using: - cloned_point = CallPoint(self, point.obj, point.requires, point.returns) + cloned_point = CallPoint(point.obj, point.requires, point.returns) cloned_point.labels = set(point.labels) for label in cloned_point.labels: self.labels[label] = cloned_point @@ -119,7 +96,7 @@ def extend(self, *objs: Callable): """ for obj in objs: if isinstance(obj, Runner): - self._copy_from(obj, obj.start, obj.end) + self._copy_from(obj.start, obj.end) else: self.add(obj) @@ -149,7 +126,7 @@ def clone(self, label specified in this option should be cloned. This filtering is applied in addition to the above options. """ - runner = self.__class__(requirement_modifier=self.requirement_modifier) + runner = self.__class__() if start_label: start = self.labels[start_label] @@ -176,7 +153,7 @@ def clone(self, return runner point = point.previous - runner._copy_from(self, start, end, added_using) + runner._copy_from(start, end, added_using) return runner def replace(self, @@ -213,13 +190,13 @@ def replace(self, if requires_from is DeclarationsFrom.replacement: requires = extract_requires(replacement) else: - requires = point.requires + requires = extract_requires(point.obj, point.requires) if returns_from is DeclarationsFrom.replacement: returns = extract_returns(replacement) else: - returns = point.returns + returns = extract_returns(point.obj, point.returns) - new_point = CallPoint(self, replacement, requires, returns) + new_point = CallPoint(replacement, requires, returns) if point.previous is None: self.start = new_point @@ -252,7 +229,7 @@ def __add__(self, other: 'Runner'): """ runner = self.__class__() for r in self, other: - runner._copy_from(r, r.start, r.end) + runner._copy_from(r.start, r.end) return runner def __call__(self, context: Context = None): @@ -289,7 +266,7 @@ def __call__(self, context: Context = None): if getattr(result, '__enter__', None): with result as managed: - if managed is not None: + if managed is not None and managed is not result: context.add(managed) # If the context manager swallows an exception, # None should be returned, not the context manager: @@ -314,7 +291,7 @@ class ContextError(Exception): """ Errors likely caused by incorrect building of a runner. """ - def __init__(self, text: str, point: CallPoint=None, context: Context = None): + def __init__(self, text: str, point: CallPoint = None, context: Context = None): self.text: str = text self.point: CallPoint = point self.context: Context = context diff --git a/mush/tests/example_with_mush_clone.py b/mush/tests/example_with_mush_clone.py index d8fbae3..6387d84 100644 --- a/mush/tests/example_with_mush_clone.py +++ b/mush/tests/example_with_mush_clone.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser, Namespace from configparser import RawConfigParser -from mush import Runner, requires, Value +from mush import Runner, requires, Value, returns import logging, os, sqlite3, sys log = logging.getLogger() @@ -12,10 +12,11 @@ def base_options(parser: ArgumentParser): parser.add_argument('--verbose', action='store_true', help='Log more to the console') -def parse_args(parser: ArgumentParser): +def parse_args(parser: ArgumentParser) -> Namespace: return parser.parse_args() -def parse_config(args: Namespace) -> 'config': +@returns('config') +def parse_config(args: Namespace): config = RawConfigParser() config.read(args.config) return dict(config.items('main')) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index ca37c2e..4984052 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -1,43 +1,38 @@ -import pytest; pytestmark = pytest.mark.skip("WIP") - -from functools import update_wrapper -from unittest import TestCase - from testfixtures import compare from testfixtures.mock import Mock, call +import pytest from mush.callpoints import CallPoint -from mush.declarations import requires, returns, RequirementsDeclaration -# from mush.extraction import update_wrapper +from mush.declarations import ( + requires, returns, RequirementsDeclaration, ReturnsDeclaration, update_wrapper +) from mush.requirements import Value -from mush.runner import Runner -class TestCallPoints(TestCase): +@pytest.fixture() +def context(): + return Mock() + - def setUp(self): - self.context = Mock() - self.runner = Runner() +class TestCallPoints: def test_passive_attributes(self): # these are managed by Modifiers - point = CallPoint(self.runner, Mock()) + point = CallPoint(Mock()) compare(point.previous, None) compare(point.next, None) compare(point.labels, set()) - def test_supplied_explicitly(self): + def test_supplied_explicitly(self, context): def foo(a1): pass rq = requires('foo') rt = returns('bar') - result = CallPoint(self.runner, foo, rq, rt)(self.context) - compare(result, self.context.extract.return_value) - compare(self.context.extract.mock_calls, - expected=[call(foo, - RequirementsDeclaration([Value.make(key='foo', name='a1')]), - rt)]) - - def test_extract_from_decorations(self): + result = CallPoint(foo, rq, rt)(context) + compare(result, context.extract.return_value) + compare(context.extract.mock_calls, + expected=[call(foo, rq, rt)]) + + def test_extract_from_decorations(self, context): rq = requires('foo') rt = returns('bar') @@ -45,14 +40,12 @@ def test_extract_from_decorations(self): @rt def foo(a1): pass - result = CallPoint(self.runner, foo)(self.context) - compare(result, self.context.extract.return_value) - compare(self.context.extract.mock_calls, - expected=[call(foo, - RequirementsDeclaration([Value.make(key='foo', name='a1')]), - returns('bar'))]) + result = CallPoint(foo)(context) + compare(result, context.extract.return_value) + compare(context.extract.mock_calls, + expected=[call(foo, None, None)]) - def test_extract_from_decorated_class(self): + def test_extract_from_decorated_class(self, context): rq = requires('foo') rt = returns('bar') @@ -72,33 +65,18 @@ def my_dec(func): def foo(prefix): return prefix+'answer' - self.context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) - result = CallPoint(self.runner, foo)(self.context) - compare(result, expected=('the answer', - RequirementsDeclaration([Value.make(key='foo', name='prefix')]), - rt)) - - def test_explicit_trumps_decorators(self): - @requires('foo') - @returns('bar') - def foo(a1): pass - - point = CallPoint(self.runner, foo, requires('baz'), returns('bob')) - result = point(self.context) - compare(result, self.context.extract.return_value) - compare(self.context.extract.mock_calls, - expected=[call(foo, - RequirementsDeclaration([Value.make(key='baz', name='a1')]), - returns('bob'))]) + context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) + result = CallPoint(foo)(context) + compare(result, expected=('the answer', None, None)) def test_repr_minimal(self): def foo(): pass - point = CallPoint(self.runner, foo) - compare(repr(foo)+" requires() returns_result_type()", repr(point)) + point = CallPoint(foo) + compare(repr(foo)+" requires() returns('foo')", repr(point)) def test_repr_maximal(self): def foo(a1): pass - point = CallPoint(self.runner, foo, requires('foo'), returns('bar')) + point = CallPoint(foo, requires('foo'), returns('bar')) point.labels.add('baz') point.labels.add('bob') compare(expected=repr(foo)+" requires(Value('foo')) returns('bar') <-- baz, bob", @@ -106,30 +84,9 @@ def foo(a1): pass def test_convert_to_requires_and_returns(self): def foo(baz): pass - point = CallPoint(self.runner, foo, requires='foo', returns='bar') - self.assertTrue(isinstance(point.requires, RequirementsDeclaration)) - self.assertTrue(isinstance(point.returns, returns)) + point = CallPoint(foo, requires='foo', returns='bar') + # this is deferred until later + assert isinstance(point.requires, str) + assert isinstance(point.returns, str) compare(repr(foo)+" requires(Value('foo')) returns('bar')", repr(point)) - - def test_convert_to_requires_and_returns_tuple(self): - def foo(a1, a2): pass - point = CallPoint(self.runner, - foo, - requires=('foo', 'bar'), - returns=('baz', 'bob')) - self.assertTrue(isinstance(point.requires, RequirementsDeclaration)) - self.assertTrue(isinstance(point.returns, returns)) - compare(repr(foo)+" requires(Value('foo'), Value('bar')) returns('baz', 'bob')", - repr(point)) - - def test_convert_to_requires_and_returns_list(self): - def foo(a1, a2): pass - point = CallPoint(self.runner, - foo, - requires=['foo', 'bar'], - returns=['baz', 'bob']) - self.assertTrue(isinstance(point.requires, RequirementsDeclaration)) - self.assertTrue(isinstance(point.returns, returns)) - compare(repr(foo)+" requires(Value('foo'), Value('bar')) returns('baz', 'bob')", - repr(point)) diff --git a/mush/tests/test_example_with_mush_clone.py b/mush/tests/test_example_with_mush_clone.py index 0364779..2f60a38 100644 --- a/mush/tests/test_example_with_mush_clone.py +++ b/mush/tests/test_example_with_mush_clone.py @@ -1,4 +1,4 @@ -# from .example_with_mush_clone import DatabaseHandler, main, do, setup_logging +from .example_with_mush_clone import DatabaseHandler, main, do, setup_logging from unittest import TestCase from testfixtures import TempDirectory from testfixtures import Replacer @@ -6,7 +6,6 @@ from testfixtures import ShouldRaise import sqlite3 -import pytest; pytestmark = pytest.mark.skip("WIP") class Tests(TestCase): @@ -58,7 +57,8 @@ def test_setup_logging(self): with TempDirectory() as dir: with LogCapture(): setup_logging(dir.getpath('test.log'), verbose=True) - + + class DatabaseHandlerTests(TestCase): def setUp(self): diff --git a/mush/tests/test_example_with_mush_factory.py b/mush/tests/test_example_with_mush_factory.py index 392207f..58f24ba 100644 --- a/mush/tests/test_example_with_mush_factory.py +++ b/mush/tests/test_example_with_mush_factory.py @@ -1,5 +1,4 @@ -# from .example_with_mush_factory import main -import pytest; pytestmark = pytest.mark.skip("WIP") +from .example_with_mush_factory import main from unittest import TestCase from testfixtures import TempDirectory, Replacer diff --git a/mush/tests/test_example_without_mush.py b/mush/tests/test_example_without_mush.py index afa4df2..f9f1ee9 100644 --- a/mush/tests/test_example_without_mush.py +++ b/mush/tests/test_example_without_mush.py @@ -3,6 +3,7 @@ from testfixtures import TempDirectory, Replacer, OutputCapture import sqlite3 + class Tests(TestCase): def test_main(self): diff --git a/mush/tests/test_plug.py b/mush/tests/test_plug.py index e45513e..f91824b 100644 --- a/mush/tests/test_plug.py +++ b/mush/tests/test_plug.py @@ -1,6 +1,3 @@ -import pytest; pytestmark = pytest.mark.skip("WIP") - -from unittest import TestCase from testfixtures import compare, ShouldRaise from testfixtures.mock import Mock, call @@ -10,7 +7,7 @@ from mush.tests.test_runner import verify -class TestPlug(TestCase): +class TestPlug: def test_simple(self): m = Mock() diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 31d65ef..5a582fa 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -1,10 +1,8 @@ -import pytest; pytestmark = pytest.mark.skip("WIP") -from unittest import TestCase +import pytest -from mush.declarations import ( - requires, returns, returns_mapping, - replacement, original) +from mush.declarations import requires, returns, replacement, original from mush import Value, ContextError, Context, Requirement +from mush.resources import Provider from mush.runner import Runner from testfixtures import ( ShouldRaise, @@ -39,7 +37,7 @@ def verify(runner, *expected): compare(seen_labels, runner.labels.keys()) -class RunnerTests(TestCase): +class TestRunner: def test_simple(self): m = Mock() @@ -199,23 +197,21 @@ def test_runner_add_label(self): def test_declarative(self): m = Mock() - class T1(object): pass - class T2(object): pass + class T1: pass + class T2: pass t1 = T1() t2 = T2() - def job1(): + def job1() -> T1: m.job1() return t1 - @requires(T1) - def job2(obj): + def job2(obj: T1) -> T2: m.job2(obj) return t2 - @requires(T2) - def job3(obj): + def job3(obj: T2) -> None: m.job3(obj) runner = Runner(job1, job2, job3) @@ -229,8 +225,8 @@ def job3(obj): def test_imperative(self): m = Mock() - class T1(object): pass - class T2(object): pass + class T1: pass + class T2: pass t1 = T1() t2 = T2() @@ -246,14 +242,14 @@ def job2(obj): def job3(t2_): m.job3(t2_) - # imperative config trumps declarative + # imperative config overrides decorator @requires(T1) def job4(t2_): m.job4(t2_) runner = Runner() - runner.add(job1) - runner.add(job2, requires(T1)) + runner.add(job1, returns=T1) + runner.add(job2, requires(T1), returns(T2)) runner.add(job3, requires(t2_=T2)) runner.add(job4, requires(T2)) runner() @@ -265,11 +261,10 @@ def job4(t2_): call.job4(t2), ], m.mock_calls) - def test_return_type_specified_decorator(self): m = Mock() - class T1(object): pass - class T2(object): pass + class T1: pass + class T2: pass t = T1() @returns(T2) @@ -290,8 +285,8 @@ def job2(obj): def test_return_type_specified_imperative(self): m = Mock() - class T1(object): pass - class T2(object): pass + class T1: pass + class T2: pass t = T1() def job1(): @@ -314,8 +309,8 @@ def job2(obj): def test_lazy(self): m = Mock() - class T1(object): pass - class T2(object): pass + class T1: pass + class T2: pass t = T1() def lazy_used(): @@ -325,189 +320,25 @@ def lazy_used(): def lazy_unused(): raise AssertionError('should not be called') # pragma: no cover - def job(obj): - m.job(obj) - - runner = Runner() - runner.add(lazy_used, returns=returns(T1), lazy=True) - runner.add(lazy_unused, returns=returns(T2), lazy=True) - runner.add(job, requires(T1)) - runner() - - compare(m.mock_calls, expected=[ - call.lazy_used(), - call.job(t), - ], ) - - def test_lazy_no_return_type_specified(self): - runner = Runner() - with ShouldRaise( - TypeError('a single return type must be explicitly specified') - ): - runner.add(lambda: None, lazy=True) - - def test_returns_more_than_one_type(self): - class T1(object): pass - class T2(object): pass - runner = Runner() - with ShouldRaise( - TypeError('a single return type must be explicitly specified') - ): - runner.add(lambda: None, returns=returns(T1, T2), lazy=True) - - def test_lazy_two_callable_provide_same_type(self): - class T1(object): pass - def foo(): pass - def bar(): pass - runner = Runner() - runner.add(foo, returns=returns(T1), lazy=True) - with ShouldRaise(TypeError( - 'T1 has more than one lazy provider:\n' - f'{foo!r}\n' - f'{bar!r}' - )): - runner.add(bar, returns=returns(T1), lazy=True) - - def test_lazy_per_context(self): - m = Mock() - class T1(object): pass - t = T1() - - def lazy(): - m.lazy_used() - return t + def providers(context: Context): + context.add(Provider(lazy_used), provides=T1) + context.add(Provider(lazy_unused), provides=T2) def job(obj): m.job(obj) runner = Runner() - runner.add(lazy, returns=returns(T1), lazy=True) + runner.add(providers) runner.add(job, requires(T1)) runner() - runner() - - compare(m.mock_calls, expected=[ - call.lazy_used(), - call.job(t), - call.lazy_used(), - call.job(t), - ], ) - - def test_lazy_after_clone(self): - m = Mock() - class T1(object): pass - t = T1() - - def lazy(): - m.lazy_used() - return t - - def job(obj): - m.job(obj) - - runner = Runner() - runner.add(lazy, returns=returns(T1), lazy=True) - runner_ = runner.clone() - runner_.add(job, requires(T1)) - runner_() - - compare(m.mock_calls, expected=[ - call.lazy_used(), - call.job(t), - ], ) - - def test_lazy_after_add(self): - m = Mock() - class T1(object): pass - t = T1() - - def lazy(): - m.lazy_used() - return t - - def job(obj): - m.job(obj) - - runner1 = Runner() - runner1.add(lazy, returns=returns(T1), lazy=True) - runner2 = Runner() - runner2.add(job, requires(T1)) - runner = runner1 + runner2 - runner() compare(m.mock_calls, expected=[ call.lazy_used(), call.job(t), ], ) - def test_lazy_add_clash(self): - class T1(object): pass - def foo(): pass - def bar(): pass - runner1 = Runner() - runner1.add(foo, returns=returns(T1), lazy=True) - runner2 = Runner() - runner2.add(bar, returns=returns(T1), lazy=True) - with ShouldRaise(TypeError( - 'both runners have lazy providers for these resources:\n' - 'T1: \n' - f' {foo!r}\n' - f' {bar!r}' - )): - runner1 + runner2 - - def test_lazy_only_resolved_once(self): - m = Mock() - class T1(object): pass - t = T1() - - def lazy_used(): - m.lazy_used() - return t - - def job1(obj): - m.job1(obj) - - def job2(obj): - m.job2(obj) - - runner = Runner() - runner.add(lazy_used, returns=returns(T1), lazy=True) - runner.add(job1, requires(T1)) - runner.add(job2, requires(T1)) - runner() - - compare(m.mock_calls, expected=[ - call.lazy_used(), - call.job1(t), - call.job2(t), - ], ) - - def test_lazy_with_requirement_modifier(self): - def make_data(): - return {'foo': 'bar'} - - class FromKey(Requirement): - def resolve(self, context): - return context.get('data')[self.data_key] - - def modifier(requirement): - if type(requirement) is Requirement: - # another limitation of lazy: - requirement = FromKey.make_from(requirement, - key='data', - data_key=requirement.key) - return requirement - - runner = Runner(requirement_modifier=modifier) - runner.add(make_data, returns='data', lazy=True) - runner.add(lambda foo: foo+'baz', returns='processed') - runner.add(lambda *args: args, requires(Value('data')['foo'], 'processed')) - - compare(runner(), expected=('bar', 'barbaz')) - def test_missing_from_context_no_chain(self): - class T(object): pass + class T: pass @requires(T) def job(arg): @@ -518,20 +349,21 @@ def job(arg): with ShouldRaise(ContextError) as s: runner() + t_str = 'TestRunner.test_missing_from_context_no_chain..T' text = '\n'.join(( - 'While calling: '+repr(job)+' requires(Value(T)) returns_result_type()', + f"While calling: {job!r} requires(Value({t_str})) returns('job')", 'with :', '', - "No Value(T) in context", + f"Value({t_str}) could not be satisfied", )) compare(text, actual=repr(s.raised)) compare(text, actual=str(s.raised)) def test_missing_from_context_with_chain(self): - class T(object): pass + class T: pass - def job1(): pass - def job2(): pass + def job1() -> None: pass + def job2() -> None: pass @requires(T) def job3(arg): @@ -550,20 +382,22 @@ def job5(foo, bar): pass with ShouldRaise(ContextError) as s: runner() + t_str = 'TestRunner.test_missing_from_context_with_chain..T' + text = '\n'.join(( '', '', 'Already called:', - repr(job1)+' requires() returns_result_type() <-- 1', - repr(job2)+' requires() returns_result_type()', + repr(job1)+' requires() returns() <-- 1', + repr(job2)+' requires() returns()', '', - 'While calling: '+repr(job3)+' requires(Value(T)) returns_result_type()', + f"While calling: {job3!r} requires(Value({t_str})) returns('job3')", 'with :', '', - "No Value(T) in context", + f"Value({t_str}) could not be satisfied", '', 'Still to call:', - repr(job4)+' requires() returns_result_type() <-- 4', + repr(job4)+" requires() returns('job4') <-- 4", repr(job5)+" requires(Value('foo'), bar=Value('baz')) returns('bob')", )) compare(text, actual=repr(s.raised)) @@ -575,29 +409,37 @@ def job(arg): runner = Runner(job) with ShouldRaise(ContextError) as s: runner() - compare(s.raised.text, expected="No Value('arg') in context") + compare(s.raised.text, expected='arg could not be satisfied') def test_already_in_context(self): - class T(object): pass + class T: pass t1 = T() + t2 = T() + ts = [t2, t1] - @returns(T, T) + @returns(T) def job(): - return t1, T() + return ts.pop() - runner = Runner(job) + runner = Runner(job, job) with ShouldRaise(ContextError) as s: runner() + t_str = 'TestRunner.test_already_in_context..T' text = '\n'.join(( - 'While calling: '+repr(job)+' requires() returns(T, T)', + '', + '', + 'Already called:', + f"{job!r} requires() returns({t_str})", + '', + f"While calling: {job!r} requires() returns({t_str})", 'with :', '', - 'Context already contains '+repr(T), + f'Context already contains {t_str}', )) compare(text, repr(s.raised)) compare(text, str(s.raised)) @@ -619,7 +461,7 @@ def job1(): def job2(obj): m.job2(obj) runner = Runner() - runner.add(job1) + runner.add(job1, returns=T) runner.add(job2, requires(Value(T).foo)) runner() @@ -641,7 +483,7 @@ def job1(): def job2(obj): m.job2(obj) runner = Runner() - runner.add(job1) + runner.add(job1, returns=T) runner.add(job2, requires(Value(T).foo.bar)) runner() @@ -661,7 +503,7 @@ def job1(): def job2(obj): m.job2(obj) runner = Runner() - runner.add(job1) + runner.add(job1, returns=MyDict) runner.add(job2, requires(Value(MyDict)['the_thing'])) runner() compare([ @@ -680,7 +522,7 @@ def job1(): def job2(obj): m.job2(obj) runner = Runner() - runner.add(job1) + runner.add(job1, returns=MyDict) runner.add(job2, requires(Value(MyDict)['the_thing']['other_thing'])) runner() compare([ @@ -688,7 +530,7 @@ def job2(obj): call.job2(m.the_thing), ], m.mock_calls) - def test_nested(self): + def test_item_of_attr(self): class T(object): foo = dict(baz='bar') m = Mock() @@ -698,7 +540,7 @@ def job1(): def job2(obj): m.job2(obj) runner = Runner() - runner.add(job1) + runner.add(job1, returns=T) runner.add(job2, requires(Value(T).foo['baz'])) runner() @@ -984,17 +826,15 @@ class T2(object): pass t1 = T1() t2 = T2() - def job1(): + def job1() -> T1: m.job1() return t1 - @requires(T1) - def job2(obj): + def job2(obj: T1) -> T2: m.job2(obj) return t2 - @requires(T2) - def job3(obj): + def job3(obj: T2): m.job3(obj) runner = Runner() @@ -1043,17 +883,15 @@ class T2(object): pass t1 = T1() t2 = T2() - def job1(): + def job1() -> T1: m.job1() return t1 - @requires(T1) - def job2(obj): + def job2(obj: T1) -> T2: m.job2(obj) return t2 - @requires(T2) - def job3(obj): + def job3(obj: T2): m.job3(obj) runner1 = Runner(job1) @@ -1083,15 +921,13 @@ class T2(object): pass t1 = T1() t2 = T2() - def job1(): + def job1() -> T1: raise Exception() # pragma: nocover - @requires(T1) - def job2(obj): + def job2(obj: T1) -> T2: raise Exception() # pragma: nocover - @requires(T2) - def job3(obj): + def job3(obj: T2): raise Exception() # pragma: nocover runner = Runner(job1, job2, job3) @@ -1116,7 +952,8 @@ class T3(object): pass class T4(object): pass t2 = T2() - def job0(): + + def job0() -> T2: return t2 @requires(T1) @@ -1145,7 +982,8 @@ class T3(object): pass class T4(object): pass t2 = T2() - def job0(): + + def job0() -> T2: return t2 @requires(T1) @@ -1233,7 +1071,7 @@ def test_replace_explicit_at_end(self): call.jobnew2(), ], actual=m.mock_calls) - def test_replace_keep_explicit_requirements(self): + def test_replace_keep_explicit_requires(self): def foo(): return 'bar' def barbar(sheep): @@ -1247,6 +1085,20 @@ def barbar(sheep): runner.replace(barbar, lambda dog: None, requires_from=original) compare(runner(), expected=None) + def test_replace_keep_explicit_returns(self): + def foo(): + return 'bar' + def barbar(sheep): + return sheep*2 + + runner = Runner() + runner.add(foo, returns='flossy') + runner.add(barbar, requires='flossy') + compare(runner(), expected='barbar') + + runner.replace(foo, lambda: 'woof') + compare(runner(), expected='woofwoof') + def test_modifier_changes_endpoint(self): m = Mock() runner = Runner(m.job1) @@ -1288,7 +1140,7 @@ def test_duplicate_label_runner_add(self): runner.add(m.job2) with ShouldRaise(ValueError( "'label' already points to "+repr(m.job1)+" requires() " - "returns_result_type() <-- label" + "returns() <-- label" )): runner.add(m.job3, label='label') verify(runner, @@ -1302,7 +1154,7 @@ def test_duplicate_label_runner_next_add(self): runner.add(m.job1, label='label') with ShouldRaise(ValueError( "'label' already points to "+repr(m.job1)+" requires() " - "returns_result_type() <-- label" + "returns() <-- label" )): runner.add(m.job2, label='label') verify(runner, @@ -1317,7 +1169,7 @@ def test_duplicate_label_modifier(self): mod.add(m.job2, label='label2') with ShouldRaise(ValueError( "'label1' already points to "+repr(m.job1)+" requires() " - "returns_result_type() <-- label1" + "returns() <-- label1" )): mod.add(m.job3, label='label1') verify(runner, @@ -1334,11 +1186,14 @@ class T2: pass runner.add(m.job2, requires('foo', T1), returns(T2), label='label2') runner.add(m.job3) + t1_str = 'TestRunner.test_repr..T1' + t2_str = 'TestRunner.test_repr..T2' + compare('\n'.join(( '', - ' '+repr(m.job1)+' requires() returns_result_type() <-- label1', - ' '+repr(m.job2)+" requires(Value('foo'), Value(T1)) returns(T2) <-- label2", - ' '+repr(m.job3)+' requires() returns_result_type()', + f' {m.job1!r} requires() returns() <-- label1', + f" {m.job2!r} requires(Value('foo'), Value({t1_str})) returns({t2_str}) <-- label2", + f' {m.job3!r} requires() returns()', '' )), repr(runner)) @@ -1353,6 +1208,7 @@ def foo(): runner = Runner(foo) compare(runner(context), expected=42) + @pytest.mark.skip('need another approach') def test_requirement_modifier(self): class FromRequest(Requirement): @@ -1372,24 +1228,3 @@ def modifier(requirement): context = Context() context.add({'bar': 'foo'}, provides='request') compare(runner(context), expected='foo') - - def test_clone_requirement_modifier(self): - def modifier(requirement): pass - runner = Runner(requirement_modifier=modifier) - assert runner.clone().requirement_modifier is runner.requirement_modifier - - def test_add_clashing_requirement_modifier(self): - def modifier1(requirement): pass - runner1 = Runner(requirement_modifier=modifier1) - def modifier2(requirement): pass - runner2 = Runner(requirement_modifier=modifier2) - with ShouldRaise(TypeError('requirement_modifier must be identical')): - runner1 + runner2 - - def test_extend_other_runner_clashing_requirement_modifier(self): - def modifier1(requirement): pass - runner1 = Runner(requirement_modifier=modifier1) - def modifier2(requirement): pass - runner2 = Runner(requirement_modifier=modifier2) - with ShouldRaise(TypeError('requirement_modifier must be identical')): - runner1.extend(runner2) From 4abf7593c78e35722441fc2c6fb2f888fce80700 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 8 Oct 2020 08:35:20 +0100 Subject: [PATCH 140/159] Allow requirements to post-process the resource. --- mush/context.py | 7 +--- mush/requirements.py | 11 +++++++ mush/tests/test_context.py | 65 +++++++++++++++----------------------- 3 files changed, 37 insertions(+), 46 deletions(-) diff --git a/mush/context.py b/mush/context.py index c839007..f82362c 100644 --- a/mush/context.py +++ b/mush/context.py @@ -143,12 +143,7 @@ def _resolve(self, obj, requires=None, specials=None): o = parameter.default if o is not requirement.default: - # move to requirement.process? - for op in requirement.ops: - o = op(o) - if o is missing: - o = requirement.default - break + o = requirement.process(o) if o is missing: raise ResourceError(f'{requirement!r} could not be satisfied') diff --git a/mush/requirements.py b/mush/requirements.py index f92dd5b..6d41442 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -10,6 +10,9 @@ class Op: def __init__(self, name): self.name = name + def __call__(self, o): # pragma: no cover + raise NotImplementedError() + class AttrOp(Op): @@ -71,6 +74,14 @@ def __getitem__(self, name): self.ops.append(ItemOp(name)) return self + def process(self, obj): + for op in self.ops: + obj = op(obj) + if obj is missing: + obj = self.default + break + return obj + class Annotation(Requirement): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 9d7e59c..76b798d 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,22 +1,13 @@ -# from typing import Tuple, List -# from functools import partial from typing import NewType, Mapping, Any, Tuple from testfixtures import ShouldRaise, compare from testfixtures.mock import Mock -# from testfixtures.mock import Mock -# -from mush import ( - Context, Requirement, Value, requires -) +from mush import Context, Requirement, Value, requires, missing from mush.context import ResourceError -# from mush.declarations import RequiresType, requires_nothing, returns_nothing -# from mush.requirements import Requirement from .helpers import TheType, Type1, Type2 from ..declarations import ignore_return -from ..requirements import ItemOp from ..resources import ResourceValue, Provider, ResourceKey @@ -695,33 +686,27 @@ def provider() -> str: pass compare(expected, actual=repr(context)) compare(expected, actual=str(context)) -# XXX "custom requirement" stuff -# -# def test_custom_requirement(self): -# -# class FromRequest(Requirement): -# def resolve(self, context): -# return context.get('request')[self.key] -# -# def foo(bar: FromRequest('bar')): -# return bar -# -# context = Context() -# context.add({'bar': 'foo'}, provides='request') -# compare(context.call(foo), expected='foo') -# -# def test_custom_requirement_returns_missing(self): -# -# class FromRequest(Requirement): -# def resolve(self, context): -# return context.get('request').get(self.key, missing) -# -# def foo(bar: FromRequest('bar')): -# pass -# -# context = Context() -# context.add({}, provides='request') -# with ShouldRaise(ResourceError("No FromRequest('bar') in context", -# key='bar', -# requirement=FromRequest.make(key='bar', name='bar'))): -# compare(context.call(foo)) + def test_custom_requirement(self): + + class FromRequest(Requirement): + + def __init__(self, name): + super().__init__([ResourceKey(identifier='request')]) + self.name = name + + def process(self, obj): + # this example doesn't show it, but this is a method so + # there can be conditional stuff in here: + return obj.get(self.name, missing) + + def foo(bar: str): + return bar + + context = Context() + context.add({'bar': 'foo'}, identifier='request') + compare(context.call(foo, requires=FromRequest('bar')), expected='foo') + # real world, FromRequest would have a decent repr: + with ShouldRaise(ResourceError( + "FromRequest(ResourceKey('request')) could not be satisfied" + )): + context.call(foo, requires=FromRequest('baz')) From 5f6075c282c702f85454f605d53970ac7bacfc21 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 8 Oct 2020 08:41:59 +0100 Subject: [PATCH 141/159] uncomment remaining tests and mark skipped with reasons --- mush/tests/test_context.py | 163 ++++++++++++++++++++----------------- 1 file changed, 89 insertions(+), 74 deletions(-) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 76b798d..d01d0c3 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -1,6 +1,7 @@ from functools import partial from typing import NewType, Mapping, Any, Tuple +import pytest from testfixtures import ShouldRaise, compare from testfixtures.mock import Mock @@ -367,25 +368,6 @@ def foo(x): compare(result, expected='bob') -# XXX requirements caching: -# -# def test_call_caches_requires(self): -# context = Context() -# def foo(): pass -# context.call(foo) -# compare(context._requires_cache[foo], expected=RequiresType()) -# -# def test_call_explict_explicit_requires_no_cache(self): -# context = Context() -# context.add('a') -# def foo(*args): -# return args -# result = context.call(foo, requires(str)) -# compare(result, ('a',)) -# compare(context._requires_cache, expected={}) -# - - class TestExtract: def test_extract_minimal(self): @@ -430,61 +412,50 @@ def foo(): pass }) - # XXX - remove - -# def test_remove(self): -# context = Context() -# context.add('foo') -# context.remove(str) -# compare(context._store, expected={}) -# -# def test_remove_not_there_strict(self): -# context = Context() -# with ShouldRaise(ResourceError("Context does not contain 'foo'", -# key='foo')): -# context.remove('foo') -# compare(context._store, expected={}) -# -# def test_remove_not_there_not_strict(self): -# context = Context() -# context.remove('foo', strict=False) -# compare(context._store, expected={}) -# -# XXX - nest -# -# def test_nest(self): -# c1 = Context() -# c1.add('a', provides='a') -# c1.add('c', provides='c') -# c2 = c1.nest() -# c2.add('b', provides='b') -# c2.add('d', provides='c') -# compare(c2.get('a'), expected='a') -# compare(c2.get('b'), expected='b') -# compare(c2.get('c'), expected='d') -# compare(c1.get('a'), expected='a') -# compare(c1.get('b', default=None), expected=None) -# compare(c1.get('c'), expected='c') -# -# def test_nest_with_overridden_default_requirement_type(self): -# def modifier(): pass -# c1 = Context(modifier) -# c2 = c1.nest() -# assert c2.requirement_modifier is modifier -# -# def test_nest_with_explicit_default_requirement_type(self): -# def modifier1(): pass -# def modifier2(): pass -# c1 = Context(modifier1) -# c2 = c1.nest(modifier2) -# assert c2.requirement_modifier is modifier2 -# -# def test_nest_keeps_declarations_cache(self): -# c1 = Context() -# c2 = c1.nest() -# assert c2._requires_cache is c1._requires_cache -# assert c2._returns_cache is c1._returns_cache -# - XXX nesting versus cached providers! +@pytest.mark.skip('requirements/returns caching') +class TestExtractionCaching: + + def test_call_caches_requires(self): + context = Context() + + def foo(): pass + + context.call(foo) + compare(context._requires_cache[foo], expected=RequiresType()) + + def test_call_explict_explicit_requires_no_cache(self): + context = Context() + context.add('a') + + def foo(*args): + return args + + result = context.call(foo, requires(str)) + compare(result, ('a',)) + compare(context._requires_cache, expected={}) + + +@pytest.mark.skip('remove') +class TestRemove: + + def test_remove(self): + context = Context() + context.add('foo') + context.remove(str) + compare(context._store, expected={}) + + def test_remove_not_there_strict(self): + context = Context() + with ShouldRaise(ResourceError("Context does not contain 'foo'", + key='foo')): + context.remove('foo') + compare(context._store, expected={}) + + def test_remove_not_there_not_strict(self): + context = Context() + context.remove('foo', strict=False) + compare(context._store, expected={}) + class TestProviders: @@ -710,3 +681,47 @@ def foo(bar: str): "FromRequest(ResourceKey('request')) could not be satisfied" )): context.call(foo, requires=FromRequest('baz')) + + +@pytest.mark.skip('remove') +class TestNesting: + + def test_nest(self): + c1 = Context() + c1.add('a', provides='a') + c1.add('c', provides='c') + c2 = c1.nest() + c2.add('b', provides='b') + c2.add('d', provides='c') + compare(c2.get('a'), expected='a') + compare(c2.get('b'), expected='b') + compare(c2.get('c'), expected='d') + compare(c1.get('a'), expected='a') + compare(c1.get('b', default=None), expected=None) + compare(c1.get('c'), expected='c') + + def test_nest_with_overridden_default_requirement_type(self): + def modifier(): pass + + c1 = Context(modifier) + c2 = c1.nest() + assert c2.requirement_modifier is modifier + + def test_nest_with_explicit_default_requirement_type(self): + def modifier1(): pass + + def modifier2(): pass + + c1 = Context(modifier1) + c2 = c1.nest(modifier2) + assert c2.requirement_modifier is modifier2 + + def test_nest_keeps_declarations_cache(self): + c1 = Context() + c2 = c1.nest() + assert c2._requires_cache is c1._requires_cache + assert c2._returns_cache is c1._returns_cache + + def test_test_versus_caching_providers(self): + # should the nested context get the cache? + pass From fa2db9238562f6da85cc476855e6e5b6b6180e90 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 8 Oct 2020 09:09:46 +0100 Subject: [PATCH 142/159] Implement default requirement types for context. --- mush/context.py | 10 +++++----- mush/extraction.py | 5 +++-- mush/tests/test_runner.py | 23 +++++++++++------------ mush/typing.py | 8 +++++++- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/mush/context.py b/mush/context.py index f82362c..42310c2 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,12 +1,11 @@ -from inspect import signature from typing import Optional, Callable, Union, Any, Dict, Iterable from .callpoints import CallPoint from .extraction import extract_requires, extract_returns from .markers import missing, Marker -from .requirements import Requirement +from .requirements import Requirement, Annotation from .resources import ResourceKey, ResourceValue, Provider -from .typing import Resource, Identifier, Type_, Requires, Returns +from .typing import Resource, Identifier, Type_, Requires, Returns, DefaultRequirement NONE_TYPE = type(None) unspecified = Marker('unspecified') @@ -24,8 +23,9 @@ class Context: # _parent: 'Context' = None point: CallPoint = None - def __init__(self): + def __init__(self, default_requirement: DefaultRequirement = Annotation): self._store = {} + self._default_requirement = default_requirement # self._requires_cache = {} # self._returns_cache = {} @@ -110,7 +110,7 @@ def _resolve(self, obj, requires=None, specials=None): if specials is None: specials: Dict[type, Any] = {Context: self} - requires = extract_requires(obj, requires) + requires = extract_requires(obj, requires, self._default_requirement) args = [] kw = {} diff --git a/mush/extraction.py b/mush/extraction.py index f535053..d19ddb8 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -11,7 +11,7 @@ from .markers import missing, get_mush from .requirements import Requirement, Annotation from .resources import ResourceKey -from .typing import Requires, Returns +from .typing import Requires, Returns, DefaultRequirement def _apply_requires(by_name, by_index, requires_): @@ -34,6 +34,7 @@ def _apply_requires(by_name, by_index, requires_): def extract_requires( obj: Callable, explicit: Requires = None, + default_requirement: DefaultRequirement = Annotation ) -> RequirementsDeclaration: by_name = {} @@ -61,7 +62,7 @@ def extract_requires( if requirement.default is not missing: default = requirement.default else: - requirement = Annotation(p.name, hints.get(name), default) + requirement = default_requirement(p.name, hints.get(name), default) by_name[name] = Parameter( requirement, diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index 5a582fa..c613684 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -2,7 +2,8 @@ from mush.declarations import requires, returns, replacement, original from mush import Value, ContextError, Context, Requirement -from mush.resources import Provider +from mush.requirements import ItemOp +from mush.resources import Provider, ResourceKey from mush.runner import Runner from testfixtures import ( ShouldRaise, @@ -1208,23 +1209,21 @@ def foo(): runner = Runner(foo) compare(runner(context), expected=42) - @pytest.mark.skip('need another approach') - def test_requirement_modifier(self): + def test_default_requirement(self): class FromRequest(Requirement): - def resolve(self, context): - return context.get('request')[self.key] + + def __init__(self, name, type_, default): + keys = [ResourceKey(None, 'request')] + super().__init__(keys, default) + self.ops.append(ItemOp(name)) def foo(bar): return bar - def modifier(requirement): - if type(requirement) is Requirement: - requirement = FromRequest.make_from(requirement) - return requirement + context = Context(default_requirement=FromRequest) + context.add({'bar': 'foo'}, identifier='request') - runner = Runner(requirement_modifier=modifier) + runner = Runner() runner.add(foo) - context = Context() - context.add({'bar': 'foo'}, provides='request') compare(runner(context), expected='foo') diff --git a/mush/typing.py b/mush/typing.py index f60f1f4..a4ed8d9 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -1,4 +1,7 @@ -from typing import NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple, Type, _GenericAlias +from typing import ( + NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple, Type, _GenericAlias, + Callable, Optional +) if TYPE_CHECKING: from .declarations import RequirementsDeclaration, ReturnsDeclaration @@ -17,3 +20,6 @@ Returns = Union['ReturnsDeclaration', ReturnType] Resource = NewType('Resource', Any) + + +DefaultRequirement = Callable[[str, Optional[Type], Any], 'Requirement'] From 17cd2968404c30011ece823ee73b1eb57d0bbdf2 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 11 Oct 2020 10:56:39 +0100 Subject: [PATCH 143/159] Add support for requirement completion and drop support for requirements in annotations. --- mush/declarations.py | 7 +- mush/extraction.py | 20 +++--- mush/requirements.py | 3 + mush/tests/test_extraction.py | 118 +++++++++++++++++++--------------- 4 files changed, 86 insertions(+), 62 deletions(-) diff --git a/mush/declarations.py b/mush/declarations.py index 2d8464c..bbafc6a 100644 --- a/mush/declarations.py +++ b/mush/declarations.py @@ -10,7 +10,7 @@ from .markers import set_mush, missing from .requirements import Requirement, Value from .resources import ResourceKey -from .typing import RequirementType, ReturnType +from .typing import RequirementType, ReturnType, Type_ VALID_DECORATION_TYPES = (type, str, Requirement) @@ -30,10 +30,12 @@ def check_decoration_types(*objs): class Parameter: - def __init__(self, requirement: Requirement, target: str = None, default: Any = missing): + def __init__(self, requirement: Requirement, target: str = None, + type_: Type_ = None, default: Any = missing): self.requirement = requirement self.target = target self.default = default + self.type = type_ class RequirementsDeclaration(List[Parameter]): @@ -97,7 +99,6 @@ def returns(*keys: ReturnType): returns_nothing = ignore_return = ReturnsDeclaration() - class DeclarationsFrom(Enum): original = auto() replacement = auto() diff --git a/mush/extraction.py b/mush/extraction.py index d19ddb8..1f32e83 100644 --- a/mush/extraction.py +++ b/mush/extraction.py @@ -23,12 +23,15 @@ def _apply_requires(by_name, by_index, requires_): name = by_index[i] except IndexError: # case where something takes *args - by_name[i] = p + by_name[i] = Parameter(p.requirement, p.target, p.type, p.default) continue else: name = p.target - by_name[name] = p + original_p = by_name[name] + original_p.requirement = p.requirement + original_p.target = p.target + original_p.default = p.default def extract_requires( @@ -52,22 +55,20 @@ def extract_requires( if isinstance(obj, partial) and p.name in obj.keywords: continue + type_ = hints.get(name) default = missing if p.default is p.empty else p.default if isinstance(default, Requirement): requirement = default default = requirement.default - elif isinstance(p.annotation, Requirement): - requirement = p.annotation - if requirement.default is not missing: - default = requirement.default else: - requirement = default_requirement(p.name, hints.get(name), default) + requirement = default_requirement(p.name, type_, default) by_name[name] = Parameter( requirement, target=p.name if p.kind is p.KEYWORD_ONLY else None, - default=default + default=default, + type_=type_ ) by_index = list(by_name) @@ -95,6 +96,9 @@ def extract_requires( needs_target = True elif needs_target: parameter.target = name + parameter.requirement = parameter.requirement.complete( + name, parameter.type, parameter.default + ) return RequirementsDeclaration(by_name.values()) diff --git a/mush/requirements.py b/mush/requirements.py index 6d41442..75b6001 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -74,6 +74,9 @@ def __getitem__(self, name): self.ops.append(ItemOp(name)) return self + def complete(self, name: str, type_: Type_, default: Any): + return self + def process(self, obj): for op in self.ops: obj = op(obj) diff --git a/mush/tests/test_extraction.py b/mush/tests/test_extraction.py index b2512db..afa7516 100644 --- a/mush/tests/test_extraction.py +++ b/mush/tests/test_extraction.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional +from typing import Optional, Any from testfixtures.mock import Mock import pytest @@ -14,7 +14,7 @@ from mush.requirements import Requirement, Annotation from .helpers import Type1, Type2, Type3 from ..resources import ResourceKey - +from ..typing import Type_ returns_foo = ReturnsDeclaration([ResourceKey(identifier='foo')]) @@ -158,9 +158,9 @@ def test_extract_from_annotations(self): def foo(a: Type1, b, c: Type2 = 1, d=2) -> Type3: pass check_extract(foo, expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a', Type1)), + Parameter(Annotation('a', Type1), type_=Type1), Parameter(Annotation('b')), - Parameter(Annotation('c', Type2, default=1), default=1), + Parameter(Annotation('c', Type2, default=1), type_=Type2, default=1), Parameter(Annotation('d', default=2), default=2), )), expected_rt=ReturnsDeclaration([ @@ -171,7 +171,9 @@ def foo(a: Type1, b, c: Type2 = 1, d=2) -> Type3: pass def test_forward_type_references(self): check_extract(foo_, - expected_rq=RequirementsDeclaration((Parameter(Annotation('a', Foo)),)), + expected_rq=RequirementsDeclaration(( + Parameter(Annotation('a', Foo), type_=Foo), + )), expected_rt=ReturnsDeclaration([ ResourceKey(Bar), ResourceKey(identifier='foo_'), @@ -181,7 +183,9 @@ def test_forward_type_references(self): def test_requires_only(self): def foo(a: Type1): pass check_extract(foo, - expected_rq=RequirementsDeclaration((Parameter(Annotation('a', Type1)),))) + expected_rq=RequirementsDeclaration(( + Parameter(Annotation('a', Type1), type_=Type1), + ))) def test_returns_only(self): def foo() -> Type1: pass @@ -229,17 +233,10 @@ def test_decorator_preferred_to_annotations(self): def foo(a: Type1) -> Type2: pass check_extract(foo, expected_rq=RequirementsDeclaration(( - Parameter(Value(identifier='foo')),) + Parameter(Value(identifier='foo'), type_=Type1),) ), expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')])) - def test_how_instance_in_annotations(self): - def foo(a: Value('config')['db_url']): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Value(identifier='config')['db_url']), - ))) - def test_default_requirements(self): def foo(a, b=1, *, c, d=None): pass check_extract(foo, @@ -254,48 +251,66 @@ def test_type_only(self): class T: pass def foo(a: T): pass check_extract(foo, - expected_rq=RequirementsDeclaration((Parameter(Annotation('a', T)),)), + expected_rq=RequirementsDeclaration(( + Parameter(Annotation('a', T), type_=T), + )), expected_rt=ReturnsDeclaration([ResourceKey(identifier='foo')])) @pytest.mark.parametrize("type_", [str, int, dict, list]) def test_simple_type_only(self, type_): def foo(a: type_): pass check_extract(foo, - expected_rq=RequirementsDeclaration((Parameter(Annotation('a', type_)),))) + expected_rq=RequirementsDeclaration(( + Parameter(Annotation('a', type_), type_=type_), + ))) def test_type_plus_value(self): def foo(a: str = Value('b')): pass check_extract(foo, - expected_rq=RequirementsDeclaration((Parameter(Value(identifier='b')),))) + expected_rq=RequirementsDeclaration(( + Parameter(Value(identifier='b'), type_=str), + ))) def test_type_plus_value_with_default(self): def foo(a: str = Value('b', default=1)): pass check_extract(foo, expected_rq=RequirementsDeclaration(( - Parameter(Value(identifier='b', default=1), default=1), + Parameter(Value(identifier='b', default=1), type_=str, default=1), ))) - def test_value_annotation_plus_default(self): - def foo(a: Value(str, identifier='b') = 1): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Value(str, identifier='b'), default=1), - ))) - def test_requirement_default_preferred_to_annotation_default(self): - def foo(a: Value(str, identifier='b', default=2) = 1): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Value(str, identifier='b', default=2), default=2), - ))) +class Path(Requirement): + + def __init__(self, name=None, type_=None): + super().__init__(()) + self.name=name + self.type=type_ + + def complete(self, name: str, type_: Type_, default: Any): + return type(self)(name=name, type_=type_) - def test_value_annotation_just_type_in_value_key_plus_default(self): - def foo(a: Value(str) = 1): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Value(str), default=1), - )), - expected_rt=ReturnsDeclaration([ResourceKey(identifier='foo')])) + +class TestCustomRequirementCompletion: + + def test_use_name(self): + def foo(bar=Path()): pass + check_extract(foo, RequirementsDeclaration(( + Parameter(Path(name='bar', type_=None)), + ))) + + def test_use_type(self): + def foo(bar: str = Path()): pass + check_extract(foo, RequirementsDeclaration(( + Parameter(Path(name='bar', type_=str), type_=str), + ))) + + def test_precedence(self): + class PathSubclass(Path): pass + @requires(PathSubclass()) + def foo(bar: str = Path()): pass + check_extract(foo, RequirementsDeclaration(( + Parameter(PathSubclass(name='bar', type_=str), type_=str), + ))) def it(): @@ -352,17 +367,16 @@ def test_returns_from_type(self): class TestDeclarationsFromMultipleSources: def test_declarations_from_different_sources(self): - r1 = Requirement(keys=(), default='a') - r2 = Requirement(keys=(), default='b') - r3 = Requirement(keys=(), default='c') + r1 = Requirement(keys=(), default='b') + r2 = Requirement(keys=(), default='c') - @requires(b=r2) - def foo(a: r1, b, c=r3): + @requires(b=r1) + def foo(a: str, b, c=r2): pass check_extract(foo, expected_rq=RequirementsDeclaration(( - Parameter(Requirement((), default='a'), default='a'), + Parameter(Annotation('a', str), type_=str), Parameter(Requirement((), default='b'), default='b', target='b'), Parameter(Requirement((), default='c'), default='c', target='c'), ))) @@ -374,13 +388,15 @@ def test_declaration_priorities(self): @requires(a=r1) @returns('bar') - def foo(a: r2 = r3, b: str = r2, c=r3) -> Optional[Type1]: + def foo(a: int = r3, b: str = r2, c=r3) -> Optional[Type1]: pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Requirement([ResourceKey(identifier='x')]), target='a'), - Parameter(Requirement([ResourceKey(identifier='y')]), target='b'), - Parameter(Requirement([ResourceKey(identifier='z')]), target='c'), - )), - expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')])) + check_extract( + foo, + expected_rq=RequirementsDeclaration(( + Parameter(Requirement([ResourceKey(identifier='x')]), type_=int, target='a'), + Parameter(Requirement([ResourceKey(identifier='y')]), type_=str, target='b'), + Parameter(Requirement([ResourceKey(identifier='z')]), target='c'), + )), + expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')]) + ) From edb4343f87f56d200c450cfe2aad626e1e6e2a48 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 11 Oct 2020 13:12:35 +0100 Subject: [PATCH 144/159] drop remove until there's a real need --- mush/context.py | 11 ----------- mush/tests/test_context.py | 22 ---------------------- 2 files changed, 33 deletions(-) diff --git a/mush/context.py b/mush/context.py index 42310c2..4bfbc14 100644 --- a/mush/context.py +++ b/mush/context.py @@ -69,17 +69,6 @@ def add(self, self.add_by_keys(resource, keys) - # def remove(self, key: ResourceKey, *, strict: bool = True): - # """ - # Remove the specified resource key from the context. - # - # If ``strict``, then a :class:`ResourceError` will be raised if the - # specified resource is not present in the context. - # """ - # if strict and key not in self._store: - # raise ResourceError(f'Context does not contain {key!r}', key) - # self._store.pop(key, None) - # def __repr__(self): bits = [] for key, value in sorted(self._store.items(), key=lambda o: repr(o)): diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index d01d0c3..4a8b95c 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -435,28 +435,6 @@ def foo(*args): compare(context._requires_cache, expected={}) -@pytest.mark.skip('remove') -class TestRemove: - - def test_remove(self): - context = Context() - context.add('foo') - context.remove(str) - compare(context._store, expected={}) - - def test_remove_not_there_strict(self): - context = Context() - with ShouldRaise(ResourceError("Context does not contain 'foo'", - key='foo')): - context.remove('foo') - compare(context._store, expected={}) - - def test_remove_not_there_not_strict(self): - context = Context() - context.remove('foo', strict=False) - compare(context._store, expected={}) - - class TestProviders: def test_cached(self): From fdca8a17cb4fa532e45f92232b0e2e91461ee36a Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 11 Oct 2020 13:20:31 +0100 Subject: [PATCH 145/159] drop caching until there's something we can profile --- mush/context.py | 4 ---- mush/tests/test_context.py | 23 ----------------------- 2 files changed, 27 deletions(-) diff --git a/mush/context.py b/mush/context.py index 4bfbc14..dfc2559 100644 --- a/mush/context.py +++ b/mush/context.py @@ -26,8 +26,6 @@ class Context: def __init__(self, default_requirement: DefaultRequirement = Annotation): self._store = {} self._default_requirement = default_requirement - # self._requires_cache = {} - # self._returns_cache = {} def add_by_keys(self, resource: ResourceValue, keys: Iterable[ResourceKey]): for key in keys: @@ -153,6 +151,4 @@ def call(self, obj: Callable, requires: Requires = None): # requirement_modifier = self.requirement_modifier # nested = self.__class__(requirement_modifier) # nested._parent = self - # nested._requires_cache = self._requires_cache - # nested._returns_cache = self._returns_cache # return nested diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 4a8b95c..d0eca76 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -412,29 +412,6 @@ def foo(): pass }) -@pytest.mark.skip('requirements/returns caching') -class TestExtractionCaching: - - def test_call_caches_requires(self): - context = Context() - - def foo(): pass - - context.call(foo) - compare(context._requires_cache[foo], expected=RequiresType()) - - def test_call_explict_explicit_requires_no_cache(self): - context = Context() - context.add('a') - - def foo(*args): - return args - - result = context.call(foo, requires(str)) - compare(result, ('a',)) - compare(context._requires_cache, expected={}) - - class TestProviders: def test_cached(self): From f879551733e8c49d15d063ec2c129b4499529ab7 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 13 Oct 2020 07:38:20 +0100 Subject: [PATCH 146/159] Bring back nesting. --- mush/context.py | 58 ++++++++++------- mush/tests/test_context.py | 124 ++++++++++++++++++++++++++++--------- 2 files changed, 130 insertions(+), 52 deletions(-) diff --git a/mush/context.py b/mush/context.py index dfc2559..14a8d13 100644 --- a/mush/context.py +++ b/mush/context.py @@ -20,7 +20,7 @@ class ResourceError(Exception): class Context: "Stores resources for a particular run." - # _parent: 'Context' = None + _parent: 'Context' = None point: CallPoint = None def __init__(self, default_requirement: DefaultRequirement = Annotation): @@ -28,6 +28,7 @@ def __init__(self, default_requirement: DefaultRequirement = Annotation): self._default_requirement = default_requirement def add_by_keys(self, resource: ResourceValue, keys: Iterable[ResourceKey]): + keys_ = keys for key in keys: if key in self._store: raise ResourceError(f'Context already contains {key}') @@ -83,15 +84,16 @@ def extract(self, obj: Callable, requires: Requires = None, returns: Returns = N return result def _find_resource(self, key): + exact = True if not isinstance(key[0], type): - return self._store.get(key) + return self._store.get(key), exact type_, identifier = key - exact = True for type__ in type_.__mro__: resource = self._store.get((type__, identifier)) if resource is not None and (exact or resource.provides_subclasses): - return resource + return resource, exact exact = False + return None, exact def _resolve(self, obj, requires=None, specials=None): if specials is None: @@ -109,19 +111,32 @@ def _resolve(self, obj, requires=None, specials=None): for key in requirement.keys: - resource = self._find_resource(key) - - if resource is None: - o = specials.get(key[0], missing) - else: - if resource.obj is missing: - specials_ = specials.copy() - specials_[Requirement] = requirement - o = self._resolve(resource.provider, specials=specials_) - if resource.cache: - resource.obj = o + context = self + + while True: + resource, exact = context._find_resource(key) + + if resource is None: + o = specials.get(key[0], missing) else: - o = resource.obj + if resource.obj is missing: + specials_ = specials.copy() + specials_[Requirement] = requirement + o = context._resolve(resource.provider, specials=specials_) + if resource.cache: + if exact and context is self: + resource.obj = o + else: + self.add_by_keys(ResourceValue(o), (key,)) + else: + o = resource.obj + + if o is not missing: + break + + context = context._parent + if context is None: + break if o is not missing: break @@ -145,10 +160,7 @@ def _resolve(self, obj, requires=None, specials=None): def call(self, obj: Callable, requires: Requires = None): return self._resolve(obj, requires) - # - # def nest(self, requirement_modifier: RequirementModifier = None): - # if requirement_modifier is None: - # requirement_modifier = self.requirement_modifier - # nested = self.__class__(requirement_modifier) - # nested._parent = self - # return nested + def nest(self): + nested = self.__class__(self._default_requirement) + nested._parent = self + return nested diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index d0eca76..dc26a0f 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -487,6 +487,31 @@ def foo(bar: TheType): assert isinstance(context.call(foo), TheType) + def test_provides_subclasses_caching(self): + class Base: pass + class Type1(Base): pass + class Type2(Base): pass + + t1 = Type1() + t2 = Type2() + instances = {Type1: t1, Type2: t2} + + def provider(requirement: Requirement): + # .pop so each instance can only be obtained once! + return instances.pop(requirement.keys[0].type) + + def foo(bar): + return bar + + context = Context() + context.add(Provider(provider, cache=True, provides_subclasses=True), provides=Base) + + assert context.call(foo, requires=Type1) is t1 + # cached: + assert context.call(foo, requires=Type1) is t1 + assert context.call(foo, requires=Type2) is t2 + assert context.call(foo, requires=Type2) is t2 + def test_does_not_provide_subclasses(self): def foo(obj: TheType): pass @@ -638,45 +663,86 @@ def foo(bar: str): context.call(foo, requires=FromRequest('baz')) -@pytest.mark.skip('remove') class TestNesting: def test_nest(self): c1 = Context() - c1.add('a', provides='a') - c1.add('c', provides='c') + c1.add('c1a', identifier='a') + c1.add('c1c', identifier='c') + c2 = c1.nest() + c2.add('c2b', identifier='b') + c2.add('c2c', identifier='c') + + def foo(a, b=None, c=None): + return a, b, c + + compare(c2.call(foo), expected=('c1a', 'c2b', 'c2c')) + compare(c1.call(foo), expected=('c1a', None, 'c1c')) + + def test_uses_existing_cached_value(self): + class X: pass + + x_ = X() + + xs = [x_] + + def make_x(): + return xs.pop() + + c1 = Context() + c1.add(Provider(make_x, cache=True), identifier='x') + + assert c1.call(lambda x: x) is x_ c2 = c1.nest() - c2.add('b', provides='b') - c2.add('d', provides='c') - compare(c2.get('a'), expected='a') - compare(c2.get('b'), expected='b') - compare(c2.get('c'), expected='d') - compare(c1.get('a'), expected='a') - compare(c1.get('b', default=None), expected=None) - compare(c1.get('c'), expected='c') - - def test_nest_with_overridden_default_requirement_type(self): - def modifier(): pass - - c1 = Context(modifier) + assert c2.call(lambda x: x) is x_ + + assert c2.call(lambda x: x) is x_ + assert c1.call(lambda x: x) is x_ + + def test_stored_cached_value_in_nested_context(self): + class X: pass + + x1 = X() + x2 = X() + + xs = [x2, x1] + + def make_x(): + return xs.pop() + + c1 = Context() + c1.add(Provider(make_x, cache=True), identifier='x') + c2 = c1.nest() - assert c2.requirement_modifier is modifier + assert c2.call(lambda x: x) is x1 + assert c1.call(lambda x: x) is x2 + + assert c1.call(lambda x: x) is x2 + assert c2.call(lambda x: x) is x1 + + def test_no_cache_in_nested(self): + class X: pass - def test_nest_with_explicit_default_requirement_type(self): - def modifier1(): pass + x1 = X() + x2 = X() - def modifier2(): pass + xs = [x2, x1] - c1 = Context(modifier1) - c2 = c1.nest(modifier2) - assert c2.requirement_modifier is modifier2 + def make_x(): + return xs.pop() - def test_nest_keeps_declarations_cache(self): c1 = Context() + c1.add(Provider(make_x, cache=False), identifier='x') + c2 = c1.nest() - assert c2._requires_cache is c1._requires_cache - assert c2._returns_cache is c1._returns_cache + assert c2.call(lambda x: x) is x1 + assert c2.call(lambda x: x) is x2 + + def test_with_default_requirement(self): - def test_test_versus_caching_providers(self): - # should the nested context get the cache? - pass + def make_requirement(name, type_, default) -> Requirement: + pass + + c1 = Context(default_requirement=make_requirement) + c2 = c1.nest() + assert c2._default_requirement is make_requirement From 712e80f3fcafbdbd88ddc643e4686f2f25fec95a Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 14 Oct 2020 07:31:21 +0100 Subject: [PATCH 147/159] Add the first key from the requirement as a resource. --- mush/context.py | 4 ++++ mush/requirements.py | 1 + mush/tests/test_context.py | 12 ++++++++++++ 3 files changed, 17 insertions(+) diff --git a/mush/context.py b/mush/context.py index 14a8d13..01438f0 100644 --- a/mush/context.py +++ b/mush/context.py @@ -108,8 +108,11 @@ def _resolve(self, obj, requires=None, specials=None): requirement = parameter.requirement o = missing + first_key = None for key in requirement.keys: + if first_key is None: + first_key = key context = self @@ -122,6 +125,7 @@ def _resolve(self, obj, requires=None, specials=None): if resource.obj is missing: specials_ = specials.copy() specials_[Requirement] = requirement + specials_[ResourceKey] = first_key o = context._resolve(resource.provider, specials=specials_) if resource.cache: if exact and context is self: diff --git a/mush/requirements.py b/mush/requirements.py index 75b6001..f5bf3f6 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -44,6 +44,7 @@ class Requirement: """ def __init__(self, keys: Sequence[ResourceKey], default: Optional[Any] = missing): + #: Note that the first key returned should be the "most specific" self.keys: Sequence[ResourceKey] = keys self.default = default self.ops: List['Op'] = [] diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index dc26a0f..1e2a7d6 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -471,6 +471,18 @@ def returner(obj: str): compare(context.call(returner), expected='obj') + def test_needs_resource_key(self): + def provider(key: ResourceKey): + return key.type, key.identifier + + context = Context() + context.add(Provider(provider), provides=tuple) + + def returner(obj: tuple): + return obj + + compare(context.call(returner), expected=(tuple, 'obj')) + def test_provides_subclasses(self): class Base: pass From 8488a4726e0829988d8df5507b490a3ec73424d9 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 14 Oct 2020 07:37:42 +0100 Subject: [PATCH 148/159] Bring back asyncio --- mush/asyncio.py | 90 ++++++----- mush/context.py | 25 ++- mush/requirements.py | 3 + mush/tests/test_async_context.py | 257 ++++++++++++------------------- mush/tests/test_async_runner.py | 6 +- 5 files changed, 179 insertions(+), 202 deletions(-) diff --git a/mush/asyncio.py b/mush/asyncio.py index f8bc4c9..4800026 100644 --- a/mush/asyncio.py +++ b/mush/asyncio.py @@ -1,14 +1,15 @@ import asyncio from functools import partial -from typing import Callable +from typing import Callable, Dict, Any from . import ( - Context as SyncContext, Runner as SyncRunner, ResourceError, ContextError + Context as SyncContext, Runner as SyncRunner, ResourceError, ContextError, extract_returns ) from .declarations import RequirementsDeclaration, ReturnsDeclaration -from .extraction import default_requirement_type from .markers import get_mush, AsyncType -from .typing import RequirementModifier +from .requirements import Annotation +from .resources import ResourceValue +from .typing import DefaultRequirement class AsyncFromSyncContext: @@ -16,50 +17,59 @@ class AsyncFromSyncContext: def __init__(self, context, loop): self.context: Context = context self.loop = loop - self.remove = context.remove self.add = context.add - self.get = context.get def call(self, obj: Callable, requires: RequirementsDeclaration = None): coro = self.context.call(obj, requires) future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() - def extract(self, obj: Callable, requires: RequirementsDeclaration = None, returns: ReturnsDeclaration = None): + def extract( + self, + obj: Callable, + requires: RequirementsDeclaration = None, + returns: ReturnsDeclaration = None + ): coro = self.context.extract(obj, requires, returns) future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() +def async_behaviour(callable_): + to_check = callable_ + if isinstance(callable_, partial): + to_check = callable_.func + if asyncio.iscoroutinefunction(to_check): + return AsyncType.async_ + elif asyncio.iscoroutinefunction(to_check.__call__): + return AsyncType.async_ + else: + async_type = get_mush(callable_, 'async', default=None) + if async_type is None: + if isinstance(callable_, type): + return AsyncType.nonblocking + else: + return AsyncType.blocking + else: + return async_type + + class Context(SyncContext): - def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type): - super().__init__(requirement_modifier) + def __init__(self, default_requirement: DefaultRequirement = Annotation): + super().__init__(default_requirement) self._sync_context = AsyncFromSyncContext(self, asyncio.get_event_loop()) self._async_cache = {} async def _ensure_async(self, func, *args, **kw): - async_type = self._async_cache.get(func) - if async_type is None: - to_check = func - if isinstance(func, partial): - to_check = func.func - if asyncio.iscoroutinefunction(to_check): - async_type = AsyncType.async_ - elif asyncio.iscoroutinefunction(to_check.__call__): - async_type = AsyncType.async_ - else: - async_type = get_mush(func, 'async', default=None) - if async_type is None: - if isinstance(func, type): - async_type = AsyncType.nonblocking - else: - async_type = AsyncType.blocking - self._async_cache[func] = async_type + behaviour = self._async_cache.get(func) + if behaviour is None: + behaviour = async_behaviour(func) + self._async_cache[func] = behaviour - if async_type is AsyncType.nonblocking: + if behaviour is AsyncType.nonblocking: return func(*args, **kw) - elif async_type is AsyncType.blocking: + elif behaviour is AsyncType.blocking: if kw: func = partial(func, **kw) loop = asyncio.get_event_loop() @@ -67,25 +77,25 @@ async def _ensure_async(self, func, *args, **kw): else: return await func(*args, **kw) - def _context_for(self, obj): - return self if asyncio.iscoroutinefunction(obj) else self._sync_context + def _specials(self) -> Dict[type, Any]: + return {Context: self, SyncContext: self._sync_context} async def call(self, obj: Callable, requires: RequirementsDeclaration = None): - args = [] - kw = {} - resolving = self._resolve(obj, requires, args, kw, self._context_for(obj)) - for requirement in resolving: - r = requirement.resolve - o = await self._ensure_async(r, self._context_for(r)) - resolving.send(o) - return await self._ensure_async(obj, *args, **kw) + resolving = self._resolve(obj, requires) + for call in resolving: + result = await self._ensure_async(call.obj, *call.args, **call.kw) + if call.send: + resolving.send(result) + return result async def extract(self, obj: Callable, requires: RequirementsDeclaration = None, returns: ReturnsDeclaration = None): result = await self.call(obj, requires) - self._process(obj, result, returns) + returns = extract_returns(obj, returns) + if returns: + self.add_by_keys(ResourceValue(result), returns) return result @@ -128,7 +138,7 @@ async def __call__(self, context: Context = None): if getattr(manager, '__aenter__', None): async with manager as managed: - if managed is not None: + if managed is not None and managed is not result: context.add(managed) # If the context manager swallows an exception, # None should be returned, not the context manager: diff --git a/mush/context.py b/mush/context.py index 01438f0..3d1de2e 100644 --- a/mush/context.py +++ b/mush/context.py @@ -1,3 +1,4 @@ +from collections import namedtuple from typing import Optional, Callable, Union, Any, Dict, Iterable from .callpoints import CallPoint @@ -17,6 +18,9 @@ class ResourceError(Exception): """ +Call = namedtuple('Call', ('obj', 'args', 'kw', 'send')) + + class Context: "Stores resources for a particular run." @@ -95,9 +99,12 @@ def _find_resource(self, key): exact = False return None, exact + def _specials(self) -> Dict[type, Any]: + return {Context: self} + def _resolve(self, obj, requires=None, specials=None): if specials is None: - specials: Dict[type, Any] = {Context: self} + specials = self._specials() requires = extract_requires(obj, requires, self._default_requirement) @@ -127,6 +134,13 @@ def _resolve(self, obj, requires=None, specials=None): specials_[Requirement] = requirement specials_[ResourceKey] = first_key o = context._resolve(resource.provider, specials=specials_) + provider = resource.provider + resolving = context._resolve(provider, specials=specials_) + for call in resolving: + o = yield Call(call.obj, call.args, call.kw, send=True) + yield + if call.send: + resolving.send(o) if resource.cache: if exact and context is self: resource.obj = o @@ -159,10 +173,15 @@ def _resolve(self, obj, requires=None, specials=None): else: kw[parameter.target] = o - return obj(*args, **kw) + yield Call(obj, args, kw, send=False) def call(self, obj: Callable, requires: Requires = None): - return self._resolve(obj, requires) + resolving = self._resolve(obj, requires) + for call in resolving: + result = call.obj(*call.args, **call.kw) + if call.send: + resolving.send(result) + return result def nest(self): nested = self.__class__(self._default_requirement) diff --git a/mush/requirements.py b/mush/requirements.py index f5bf3f6..a953fd0 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -79,6 +79,9 @@ def complete(self, name: str, type_: Type_, default: Any): return self def process(self, obj): + """ + .. warning:: This must not block when used with an async context! + """ for op in self.ops: obj = op(obj) if obj is missing: diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index eec5885..a3690a0 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -1,5 +1,3 @@ -import pytest; pytestmark = pytest.mark.skip("WIP") - import asyncio from functools import partial from typing import Tuple @@ -7,12 +5,12 @@ import pytest from testfixtures import compare, ShouldRaise -from mush import Value, requires, returns, Context as SyncContext, blocking, nonblocking -# from mush.asyncio import Context -from mush.declarations import RequirementsDeclaration -# from mush.requirements import Requirement, AnyOf, Like +from mush import requires, returns, Context as SyncContext, blocking, nonblocking +from mush.asyncio import Context +from mush.requirements import Requirement, AnyOf, Like from .helpers import TheType, no_threads, must_run_in_thread from ..markers import AsyncType +from ..resources import ResourceKey, Provider @pytest.mark.asyncio @@ -29,7 +27,7 @@ def it(): @pytest.mark.asyncio async def test_call_async(): context = Context() - context.add('1', provides='a') + context.add('1', identifier='a') async def it(a, b='2'): return a+b with no_threads(): @@ -59,28 +57,20 @@ async def it(): compare(await context.call(partial(it)), expected=42) -@pytest.mark.asyncio -async def test_call_async_requires_context(): - context = Context() - context.add('bar', provides='baz') - async def it(context: SyncContext): - return context.get('baz') - compare(await context.call(it), expected='bar') - - @pytest.mark.asyncio async def test_call_async_requires_async_context(): context = Context() - context.add('bar', provides='baz') + async def baz(): + return 'bar' async def it(context: Context): - return context.get('baz') - compare(await context.call(it), expected='bar') + return await context.call(baz) + 'bob' + compare(await context.call(it), expected='barbob') @pytest.mark.asyncio async def test_call_sync(): context = Context() - context.add('foo', provides='baz') + context.add('foo', identifier='baz') def it(*, baz): return baz+'bar' with must_run_in_thread(it): @@ -90,19 +80,66 @@ def it(*, baz): @pytest.mark.asyncio async def test_call_sync_requires_context(): context = Context() - context.add('bar', provides='baz') - def it(context: Context): - return context.get('baz') - compare(await context.call(it), expected='bar') + # NB: this is intentionally async to test calling async + # in a sync context: + async def baz(): + return 'bar' + # sync method, so needs a sync context: + def it(context: SyncContext): + return context.call(baz) + 'bob' + compare(await context.call(it), expected='barbob') + + +@pytest.mark.asyncio +async def test_async_provider_async_user(): + o = TheType() + lookup = {TheType: o} + async def provider(key: ResourceKey): + return lookup[key.type] + context = Context() + context.add(Provider(provider), provides=TheType) + async def returner(obj: TheType): + return obj + assert await context.call(returner) is o + + +@pytest.mark.asyncio +async def test_async_provider_sync_user(): + o = TheType() + lookup = {TheType: o} + async def provider(key: ResourceKey): + return lookup[key.type] + context = Context() + context.add(Provider(provider), provides=TheType) + def returner(obj: TheType): + return obj + assert await context.call(returner) is o + + +@pytest.mark.asyncio +async def test_sync_provider_async_user(): + o = TheType() + lookup = {TheType: o} + def provider(key: ResourceKey): + return lookup[key.type] + context = Context() + context.add(Provider(provider), provides=TheType) + async def returner(obj: TheType): + return obj + assert await context.call(returner) is o @pytest.mark.asyncio -async def test_call_sync_requires_async_context(): +async def test_sync_provider_sync_user(): + o = TheType() + lookup = {TheType: o} + def provider(key: ResourceKey): + return lookup[key.type] context = Context() - context.add('bar', provides='baz') - def it(context: Context): - return context.get('baz') - compare(await context.call(it), expected='bar') + context.add(Provider(provider), provides=TheType) + def returner(obj: TheType): + return obj + assert await context.call(returner) is o @pytest.mark.asyncio @@ -160,14 +197,6 @@ async def test_call_async_function_explicitly_marked_as_blocking(): async def foo(): pass -@pytest.mark.asyncio -async def test_call_cache_requires(): - context = Context() - def foo(): pass - await context.call(foo) - compare(context._requires_cache[foo], expected=RequirementsDeclaration()) - - @pytest.mark.asyncio async def test_call_caches_asyncness(): async def foo(): @@ -185,29 +214,39 @@ def it(): result = context.extract(it, requires(), returns('baz')) assert asyncio.iscoroutine(result) compare(await result, expected='bar') - compare(context.get('baz'), expected='bar') + async def returner(baz): + return baz + compare(await context.call(returner), expected='bar') @pytest.mark.asyncio async def test_extract_async(): context = Context() - context.add('foo', provides='bar') + async def bob(): + return 'foo' async def it(context): - return context.get('bar')+'bar' + return await context.extract(bob)+'bar' result = context.extract(it, requires(Context), returns('baz')) compare(await result, expected='foobar') - compare(context.get('baz'), expected='foobar') + async def returner(bob): + return bob + compare(await context.call(returner), expected='foo') @pytest.mark.asyncio async def test_extract_sync(): context = Context() - context.add('foo', provides='bar') + # NB: this is intentionally async to test calling async + # in a sync context: + def bob(): + return 'foo' def it(context): - return context.get('bar')+'bar' - result = context.extract(it, requires(Context), returns('baz')) + return context.extract(bob)+'bar' + result = context.extract(it, requires(SyncContext), returns('baz')) compare(await result, expected='foobar') - compare(context.get('baz'), expected='foobar') + def returner(bob): + return bob + compare(await context.call(returner), expected='foo') @pytest.mark.asyncio @@ -218,9 +257,9 @@ def foo() -> TheType: context = Context() result = await context.extract(foo) assert result is o - compare({TheType: o}, actual=context._store) - compare(context._requires_cache[foo], expected=RequirementsDeclaration()) - compare(context._returns_cache[foo], expected=returns(TheType)) + async def returner(x: TheType): + return x + compare(await context.call(returner), expected=o) @pytest.mark.asyncio @@ -231,19 +270,16 @@ def foo(*args): context.add('a') result = await context.extract(foo, requires(str), returns(Tuple[str])) compare(result, expected=('a',)) - compare({ - str: 'a', - Tuple[str]: ('a',), - }, actual=context._store) - compare(context._requires_cache, expected={}) - compare(context._returns_cache, expected={}) + async def returner(x: Tuple[str]): + return x + compare(await context.call(returner), expected=('a',)) @pytest.mark.asyncio async def test_value_resolve_does_not_run_in_thread(): with no_threads(): context = Context() - context.add('foo', provides='baz') + context.add('foo', identifier='baz') async def it(baz): return baz+'bar' @@ -276,109 +312,20 @@ async def bob(x: str = Like(TheType)): assert await context.call(bob) is o -@pytest.mark.asyncio -async def test_custom_requirement_async_resolve(): - - class FromRequest(Requirement): - async def resolve(self, context): - return (context.get('request'))[self.key] - - def foo(bar: FromRequest('bar')): - return bar - - context = Context() - context.add({'bar': 'foo'}, provides='request') - compare(await context.call(foo), expected='foo') - - -@pytest.mark.asyncio -async def test_custom_requirement_sync_resolve_get(): - - class FromRequest(Requirement): - def resolve(self, context): - return context.get('request')[self.key] - - def foo(bar: FromRequest('bar')): - return bar - - context = Context() - context.add({'bar': 'foo'}, provides='request') - compare(await context.call(foo), expected='foo') - - -@pytest.mark.asyncio -async def test_custom_requirement_sync_resolve_call(): - - async def baz(request: dict = Value('request')): - return request['bar'] - - class Syncer(Requirement): - def resolve(self, context): - return context.call(self.key) - - def foo(bar: Syncer(baz)): - return bar - - context = Context() - context.add({'bar': 'foo'}, provides='request') - compare(await context.call(foo), expected='foo') - - -@pytest.mark.asyncio -async def test_custom_requirement_sync_resolve_extract(): - - @returns('response') - async def baz(request: dict = Value('request')): - return request['bar'] - - class Syncer(Requirement): - def resolve(self, context): - return context.extract(self.key) - - def foo(bar: Syncer(baz)): - return bar - - context = Context() - context.add({'bar': 'foo'}, provides='request') - compare(await context.call(foo), expected='foo') - compare(context.get('response'), expected='foo') - - -@pytest.mark.asyncio -async def test_custom_requirement_sync_resolve_add_remove(): - - class Syncer(Requirement): - def resolve(self, context): - request = context.get('request') - context.remove('request') - context.add(request['bar'], provides='response') - return request['bar'] - - def foo(bar: Syncer('request')): - return bar - - context = Context() - context.add({'bar': 'foo'}, provides='request') - compare(await context.call(foo), expected='foo') - compare(context.get('request', default=None), expected=None) - compare(context.get('response'), expected='foo') - - @pytest.mark.asyncio async def test_default_custom_requirement(): class FromRequest(Requirement): - async def resolve(self, context): - return (context.get('request'))[self.key] - - def default_requirement_type(requirement): - if type(requirement) is Requirement: - requirement = FromRequest.make_from(requirement) - return requirement - - def foo(bar): + def __init__(self, name, type_, default): + self.name = name + self.type = type_ + super().__init__(keys=[ResourceKey(identifier='request')], default=default) + def process(self, obj): + return self.type(obj[self.name]) + + def foo(bar: int): return bar - context = Context(default_requirement_type) - context.add({'bar': 'foo'}, provides='request') - compare(await context.call(foo), expected='foo') + context = Context(FromRequest) + context.add({'bar': '42'}, identifier='request') + compare(await context.call(foo), expected=42) diff --git a/mush/tests/test_async_runner.py b/mush/tests/test_async_runner.py index b9a2799..dad3a50 100644 --- a/mush/tests/test_async_runner.py +++ b/mush/tests/test_async_runner.py @@ -1,4 +1,3 @@ -import pytest; pytestmark = pytest.mark.skip("WIP") import asyncio from testfixtures.mock import Mock, call @@ -7,7 +6,7 @@ from testfixtures import compare, ShouldRaise, Comparison as C from mush import ContextError, requires, returns -# from mush.asyncio import Runner, Context +from mush.asyncio import Runner, Context from .helpers import no_threads, must_run_in_thread @@ -47,9 +46,8 @@ def it(): async def test_addition_still_async(): async def foo(): return 'foo' - @requires(str) @returns() - async def bar(foo): + async def bar(foo: str): return foo+'bar' r1 = Runner(foo) r2 = Runner(bar) From 860d213a83408b1a632f6c27ad3ad4c794dc83a7 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 14 Oct 2020 07:49:14 +0100 Subject: [PATCH 149/159] Python 3.6 compatibility. --- mush/compat.py | 11 +++++++++++ mush/resources.py | 3 ++- mush/tests/test_context.py | 10 +++++++--- mush/typing.py | 4 +++- 4 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 mush/compat.py diff --git a/mush/compat.py b/mush/compat.py new file mode 100644 index 0000000..4bc4471 --- /dev/null +++ b/mush/compat.py @@ -0,0 +1,11 @@ +import sys + +PY_VERSION = sys.version_info[:2] + +PY_37_PLUS = PY_VERSION >= (3, 7) + +try: + from typing import _GenericAlias +except ImportError: + class _GenericAlias: + pass diff --git a/mush/resources.py b/mush/resources.py index 78f408f..2a6805f 100644 --- a/mush/resources.py +++ b/mush/resources.py @@ -1,6 +1,7 @@ from types import FunctionType -from typing import Callable, Optional, _GenericAlias +from typing import Callable, Optional +from .compat import _GenericAlias from .markers import missing from .typing import Resource, Identifier, Type_ diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 1e2a7d6..0e5f9ce 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -8,6 +8,7 @@ from mush import Context, Requirement, Value, requires, missing from mush.context import ResourceError from .helpers import TheType, Type1, Type2 +from ..compat import PY_37_PLUS from ..declarations import ignore_return from ..resources import ResourceValue, Provider, ResourceKey @@ -210,9 +211,12 @@ def test_requires_typing_missing_typing(self): def returner(request_: Mapping[str, Any]): return request_ - with ShouldRaise(ResourceError( - "request_: typing.Mapping[str, typing.Any] could not be satisfied" - )): + if PY_37_PLUS: + expected = "request_: typing.Mapping[str, typing.Any] could not be satisfied" + else: + expected = "request_: Mapping could not be satisfied" + + with ShouldRaise(ResourceError(expected)): context.call(returner) def test_requires_typing_missing_new_type(self): diff --git a/mush/typing.py b/mush/typing.py index a4ed8d9..0c80e3b 100644 --- a/mush/typing.py +++ b/mush/typing.py @@ -1,8 +1,10 @@ from typing import ( - NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple, Type, _GenericAlias, + NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple, Type, Callable, Optional ) +from .compat import _GenericAlias + if TYPE_CHECKING: from .declarations import RequirementsDeclaration, ReturnsDeclaration from .requirements import Requirement From 10288e06be795e699cc034f5d18149f8f7343290 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 14 Oct 2020 07:49:40 +0100 Subject: [PATCH 150/159] Test min and max versions. --- .circleci/config.yml | 6 +++--- setup.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 33425cc..b21fbbc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,7 +1,7 @@ version: 2.1 orbs: - python: cjw296/python-ci@1.2 + python: cjw296/python-ci@2.1 common: &common jobs: @@ -11,11 +11,11 @@ common: &common parameters: image: - circleci/python:3.6 - - circleci/python:3.7 - - circleci/python:3.8 + - circleci/python:3.9 - python/coverage: name: coverage + image: circleci/python:3.9 requires: - python/pip-run-tests diff --git a/setup.py b/setup.py index 2b8a0ef..3c8db51 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', ], packages=find_packages(), zip_safe=False, From 962aa6d3e10bb81dbb297d523e68263f4494b549 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 14 Oct 2020 09:35:19 +0100 Subject: [PATCH 151/159] https://github.com/nedbat/coveragepy/issues/1042 --- .coveragerc | 1 - 1 file changed, 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index e7f3765..1afc40e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,7 +9,6 @@ exclude_lines = # stuff that we don't worry about pass - ... __name__ == '__main__' # circular references needed for type checking: From ca3298eee75b21cdb5532c939115da7a741bfd68 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 14 Oct 2020 09:38:41 +0100 Subject: [PATCH 152/159] sort out coverage --- mush/tests/test_context.py | 8 ++++---- mush/tests/test_declarations.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 0e5f9ce..d35ab2b 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -144,7 +144,7 @@ def foo(obj: TheType): return obj )): context.call(foo) - def tes_optional_type_present(self): + def test_optional_type_present(self): def foo(x: TheType = 1): return x context = Context() @@ -209,7 +209,7 @@ def test_requires_typing_missing_typing(self): context = Context() def returner(request_: Mapping[str, Any]): - return request_ + pass if PY_37_PLUS: expected = "request_: typing.Mapping[str, typing.Any] could not be satisfied" @@ -224,7 +224,7 @@ def test_requires_typing_missing_new_type(self): context = Context() def returner(request_: Request): - return request_ + pass with ShouldRaise(ResourceError( "request_: Request could not be satisfied" @@ -331,7 +331,7 @@ def foo(x): def test_call_requires_attr_missing(self): @requires(Value('foo').bar) def foo(x): - return x + pass o = object() context = Context() context.add(o, identifier='foo') diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py index a862f60..4638721 100644 --- a/mush/tests/test_declarations.py +++ b/mush/tests/test_declarations.py @@ -94,7 +94,7 @@ def test_typing(self): def test_decorator(self): @returns(Type1) def foo(): - return 'foo' + pass r = foo.__mush__['returns'] compare(repr(r), 'returns(Type1)') compare(r, expected=ReturnsDeclaration((ResourceKey(Type1),))) From 8edf7b9911e6ef7c3a5c70411188df75cc57264c Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 15 Oct 2020 07:37:58 +0100 Subject: [PATCH 153/159] Better representation of failed runners. --- mush/callpoints.py | 3 ++- mush/runner.py | 17 +++++++++-------- mush/tests/test_callpoints.py | 19 +++++++++++++++---- mush/tests/test_runner.py | 25 +++++++++++++++++-------- 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/mush/callpoints.py b/mush/callpoints.py index b5bc898..ce50669 100644 --- a/mush/callpoints.py +++ b/mush/callpoints.py @@ -25,7 +25,8 @@ def __call__(self, context: 'Context'): def __repr__(self): requires = extract_requires(self.obj, self.requires) returns = extract_returns(self.obj, self.returns) - txt = f'{self.obj!r} {requires!r} {returns!r}' + name = getattr(self.obj, '__qualname__', repr(self.obj)) + txt = f'{name} {requires!r} {returns!r}' if self.labels: txt += (' <-- ' + ', '.join(sorted(self.labels))) return txt diff --git a/mush/runner.py b/mush/runner.py index 8c7807b..0698cc3 100644 --- a/mush/runner.py +++ b/mush/runner.py @@ -297,22 +297,23 @@ def __init__(self, text: str, point: CallPoint = None, context: Context = None): self.context: Context = context def __str__(self): - rows = [] + rows = ['', ''] if self.point: + already_called = [] point = self.point.previous while point: - rows.append(repr(point)) + already_called.append(repr(point)) point = point.previous - if rows: + if already_called: rows.append('Already called:') - rows.append('') - rows.append('') - rows.reverse() + rows.extend(reversed(already_called)) rows.append('') - rows.append('While calling: '+repr(self.point)) + rows.append('While calling:') + rows.append(repr(self.point)) + rows.append('') if self.context is not None: - rows.append('with '+repr(self.context)+':') + rows.append(f'with {self.context!r}:') rows.append('') rows.append(self.text) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py index 4984052..911843e 100644 --- a/mush/tests/test_callpoints.py +++ b/mush/tests/test_callpoints.py @@ -1,3 +1,5 @@ +from functools import partial + from testfixtures import compare from testfixtures.mock import Mock, call import pytest @@ -72,14 +74,22 @@ def foo(prefix): def test_repr_minimal(self): def foo(): pass point = CallPoint(foo) - compare(repr(foo)+" requires() returns('foo')", repr(point)) + compare("TestCallPoints.test_repr_minimal..foo requires() returns('foo')", + actual=repr(point)) + + def test_repr_partial(self): + def foo(): pass + point = CallPoint(partial(foo)) + compare(f"functools.partial({foo!r}) requires() returns('foo')", + actual=repr(point)) def test_repr_maximal(self): def foo(a1): pass point = CallPoint(foo, requires('foo'), returns('bar')) point.labels.add('baz') point.labels.add('bob') - compare(expected=repr(foo)+" requires(Value('foo')) returns('bar') <-- baz, bob", + compare("TestCallPoints.test_repr_maximal..foo " + "requires(Value('foo')) returns('bar') <-- baz, bob", actual=repr(point)) def test_convert_to_requires_and_returns(self): @@ -88,5 +98,6 @@ def foo(baz): pass # this is deferred until later assert isinstance(point.requires, str) assert isinstance(point.returns, str) - compare(repr(foo)+" requires(Value('foo')) returns('bar')", - repr(point)) + compare("TestCallPoints.test_convert_to_requires_and_returns..foo " + "requires(Value('foo')) returns('bar')", + expected=repr(point)) diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py index c613684..c96c0f2 100644 --- a/mush/tests/test_runner.py +++ b/mush/tests/test_runner.py @@ -352,7 +352,11 @@ def job(arg): t_str = 'TestRunner.test_missing_from_context_no_chain..T' text = '\n'.join(( - f"While calling: {job!r} requires(Value({t_str})) returns('job')", + '', + '', + "While calling:", + f"{job.__qualname__} requires(Value({t_str})) returns('job')", + '', 'with :', '', f"Value({t_str}) could not be satisfied", @@ -389,17 +393,20 @@ def job5(foo, bar): pass '', '', 'Already called:', - repr(job1)+' requires() returns() <-- 1', - repr(job2)+' requires() returns()', + f'{job1.__qualname__} requires() returns() <-- 1', + f'{job2.__qualname__} requires() returns()', + '', + "While calling:", + f"{job3.__qualname__} requires(Value({t_str})) returns('job3')", '', - f"While calling: {job3!r} requires(Value({t_str})) returns('job3')", 'with :', '', f"Value({t_str}) could not be satisfied", '', 'Still to call:', - repr(job4)+" requires() returns('job4') <-- 4", - repr(job5)+" requires(Value('foo'), bar=Value('baz')) returns('bob')", + f'' + f"{job4.__qualname__} requires() returns('job4') <-- 4", + f"{job5.__qualname__} requires(Value('foo'), bar=Value('baz')) returns('bob')", )) compare(text, actual=repr(s.raised)) compare(text, actual=str(s.raised)) @@ -433,9 +440,11 @@ def job(): '', '', 'Already called:', - f"{job!r} requires() returns({t_str})", + f"{job.__qualname__} requires() returns({t_str})", + '', + "While calling:", + f"{job.__qualname__} requires() returns({t_str})", '', - f"While calling: {job!r} requires() returns({t_str})", 'with :', From eb5e5fd7fca11593da24a22270f98afa018a48bb Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Thu, 15 Oct 2020 09:01:11 +0100 Subject: [PATCH 154/159] Safer repr for requirements to help out with subclasses. --- mush/requirements.py | 6 +++++- mush/tests/test_context.py | 3 +-- mush/tests/test_requirements.py | 9 +++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index a953fd0..7bc833d 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -54,8 +54,12 @@ def _keys_repr(self): def __repr__(self): default = '' if self.default is missing else f', default={self.default!r}' + other = ', '.join(f'{n}={v!r}' for n, v in vars(self).items() + if n not in ('keys', 'default', 'ops')) + if other: + other = ', '+other ops = ''.join(repr(o) for o in self.ops) - return f"{type(self).__name__}({self._keys_repr()}{default}){ops}" + return f"{type(self).__name__}({self._keys_repr()}{default}{other}){ops}" def attr(self, name): """ diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index d35ab2b..40e72c7 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -672,9 +672,8 @@ def foo(bar: str): context = Context() context.add({'bar': 'foo'}, identifier='request') compare(context.call(foo, requires=FromRequest('bar')), expected='foo') - # real world, FromRequest would have a decent repr: with ShouldRaise(ResourceError( - "FromRequest(ResourceKey('request')) could not be satisfied" + "FromRequest(ResourceKey('request'), name='baz') could not be satisfied" )): context.call(foo, requires=FromRequest('baz')) diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 07e281e..9426a57 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -39,6 +39,15 @@ def test_repr_maximal(self): special_names = ['attr', 'ops'] + def test_repr_subclass(self): + class SubClass(Requirement): + def __init__(self): + self.foo = 42 + self.bar = 'baz' + super().__init__([ResourceKey(str)], missing) + compare(repr(SubClass()), + expected="SubClass(ResourceKey(str), foo=42, bar='baz')") + @pytest.mark.parametrize("name", special_names) def test_attr_special_name(self, name): v = Requirement('foo') From 1649bf5258b768e6f3c9825fcf7486043801c21e Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 27 Oct 2020 08:08:47 +0000 Subject: [PATCH 155/159] Fix bug where providers in a base context couldn't use resources found in a nested context. --- mush/context.py | 2 +- mush/tests/test_context.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mush/context.py b/mush/context.py index 3d1de2e..e0d3f48 100644 --- a/mush/context.py +++ b/mush/context.py @@ -135,7 +135,7 @@ def _resolve(self, obj, requires=None, specials=None): specials_[ResourceKey] = first_key o = context._resolve(resource.provider, specials=specials_) provider = resource.provider - resolving = context._resolve(provider, specials=specials_) + resolving = self._resolve(provider, specials=specials_) for call in resolving: o = yield Call(call.obj, call.args, call.kw, send=True) yield diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 40e72c7..9e430f7 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -753,6 +753,19 @@ def make_x(): assert c2.call(lambda x: x) is x1 assert c2.call(lambda x: x) is x2 + def test_provider_uses_resources_from_nested_context(self): + + def expanded(it: str): + return it*2 + + c1 = Context() + c1.add(Provider(expanded)) + + c2 = c1.nest() + c2.add('foo') + + compare(c2.call(lambda expanded: expanded), expected='foofoo') + def test_with_default_requirement(self): def make_requirement(name, type_, default) -> Requirement: From 4622b80c1744fc14eddc247b5046bb34d96ddb18 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 28 Oct 2020 07:22:29 +0000 Subject: [PATCH 156/159] move this test to the correct suite --- mush/tests/test_context.py | 48 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py index 9e430f7..6df2d6c 100644 --- a/mush/tests/test_context.py +++ b/mush/tests/test_context.py @@ -283,6 +283,30 @@ def foo(x, y): result = context.call(foo, requires(y='baz', x=TheType)) compare(result, expected=('foo', 'bar')) + def test_custom_requirement(self): + + class FromRequest(Requirement): + + def __init__(self, name): + super().__init__([ResourceKey(identifier='request')]) + self.name = name + + def process(self, obj): + # this example doesn't show it, but this is a method so + # there can be conditional stuff in here: + return obj.get(self.name, missing) + + def foo(bar: str): + return bar + + context = Context() + context.add({'bar': 'foo'}, identifier='request') + compare(context.call(foo, requires=FromRequest('bar')), expected='foo') + with ShouldRaise(ResourceError( + "FromRequest(ResourceKey('request'), name='baz') could not be satisfied" + )): + context.call(foo, requires=FromRequest('baz')) + class TestOps: @@ -653,30 +677,6 @@ def provider() -> str: pass compare(expected, actual=repr(context)) compare(expected, actual=str(context)) - def test_custom_requirement(self): - - class FromRequest(Requirement): - - def __init__(self, name): - super().__init__([ResourceKey(identifier='request')]) - self.name = name - - def process(self, obj): - # this example doesn't show it, but this is a method so - # there can be conditional stuff in here: - return obj.get(self.name, missing) - - def foo(bar: str): - return bar - - context = Context() - context.add({'bar': 'foo'}, identifier='request') - compare(context.call(foo, requires=FromRequest('bar')), expected='foo') - with ShouldRaise(ResourceError( - "FromRequest(ResourceKey('request'), name='baz') could not be satisfied" - )): - context.call(foo, requires=FromRequest('baz')) - class TestNesting: From 76bab1d33f81b9c5db646ee2283d904b5135ee2e Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 9 Mar 2021 08:55:16 +0000 Subject: [PATCH 157/159] scorched earth... --- mush/__init__.py | 25 - mush/asyncio.py | 152 --- mush/callpoints.py | 32 - mush/compat.py | 11 - mush/context.py | 189 --- mush/declarations.py | 124 -- mush/extraction.py | 142 -- mush/markers.py | 56 - mush/modifier.py | 90 -- mush/plug.py | 71 - mush/requirements.py | 189 --- mush/resources.py | 84 -- mush/runner.py | 332 ----- mush/tests/__init__.py | 0 mush/tests/conftest.py | 9 - mush/tests/example_with_mush_clone.py | 71 - mush/tests/example_with_mush_factory.py | 35 - mush/tests/example_without_mush.py | 44 - mush/tests/helpers.py | 58 - mush/tests/test_async_context.py | 331 ----- mush/tests/test_async_runner.py | 600 --------- mush/tests/test_callpoints.py | 103 -- mush/tests/test_context.py | 776 ----------- mush/tests/test_context_py38.py | 23 - mush/tests/test_declarations.py | 111 -- mush/tests/test_example_with_mush_clone.py | 98 -- mush/tests/test_example_with_mush_factory.py | 31 - mush/tests/test_example_without_mush.py | 75 -- mush/tests/test_extraction.py | 402 ------ mush/tests/test_marker.py | 6 - mush/tests/test_plug.py | 235 ---- mush/tests/test_requirements.py | 263 ---- mush/tests/test_runner.py | 1238 ------------------ mush/typing.py | 27 - 34 files changed, 6033 deletions(-) delete mode 100755 mush/__init__.py delete mode 100644 mush/asyncio.py delete mode 100644 mush/callpoints.py delete mode 100644 mush/compat.py delete mode 100644 mush/context.py delete mode 100644 mush/declarations.py delete mode 100644 mush/extraction.py delete mode 100644 mush/markers.py delete mode 100644 mush/modifier.py delete mode 100644 mush/plug.py delete mode 100644 mush/requirements.py delete mode 100644 mush/resources.py delete mode 100644 mush/runner.py delete mode 100644 mush/tests/__init__.py delete mode 100644 mush/tests/conftest.py delete mode 100644 mush/tests/example_with_mush_clone.py delete mode 100644 mush/tests/example_with_mush_factory.py delete mode 100644 mush/tests/example_without_mush.py delete mode 100644 mush/tests/helpers.py delete mode 100644 mush/tests/test_async_context.py delete mode 100644 mush/tests/test_async_runner.py delete mode 100644 mush/tests/test_callpoints.py delete mode 100644 mush/tests/test_context.py delete mode 100644 mush/tests/test_context_py38.py delete mode 100644 mush/tests/test_declarations.py delete mode 100644 mush/tests/test_example_with_mush_clone.py delete mode 100644 mush/tests/test_example_with_mush_factory.py delete mode 100644 mush/tests/test_example_without_mush.py delete mode 100644 mush/tests/test_extraction.py delete mode 100644 mush/tests/test_marker.py delete mode 100644 mush/tests/test_plug.py delete mode 100644 mush/tests/test_requirements.py delete mode 100644 mush/tests/test_runner.py delete mode 100644 mush/typing.py diff --git a/mush/__init__.py b/mush/__init__.py deleted file mode 100755 index b07df0c..0000000 --- a/mush/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from .context import Context, ResourceError -from .declarations import requires, returns, update_wrapper -from .extraction import extract_requires, extract_returns -from .markers import missing, nonblocking, blocking -from .plug import Plug -from .requirements import Requirement, Value, AnyOf, Like -from .runner import Runner, ContextError - -__all__ = [ - 'AnyOf', - 'Context', - 'ContextError', - 'Like', - 'Plug', - 'Requirement', - 'ResourceError', - 'Runner', - 'Value', - 'blocking', - 'missing', - 'nonblocking', - 'requires', - 'returns', - 'update_wrapper', -] diff --git a/mush/asyncio.py b/mush/asyncio.py deleted file mode 100644 index 4800026..0000000 --- a/mush/asyncio.py +++ /dev/null @@ -1,152 +0,0 @@ -import asyncio -from functools import partial -from typing import Callable, Dict, Any - -from . import ( - Context as SyncContext, Runner as SyncRunner, ResourceError, ContextError, extract_returns -) -from .declarations import RequirementsDeclaration, ReturnsDeclaration -from .markers import get_mush, AsyncType -from .requirements import Annotation -from .resources import ResourceValue -from .typing import DefaultRequirement - - -class AsyncFromSyncContext: - - def __init__(self, context, loop): - self.context: Context = context - self.loop = loop - self.add = context.add - - def call(self, obj: Callable, requires: RequirementsDeclaration = None): - coro = self.context.call(obj, requires) - future = asyncio.run_coroutine_threadsafe(coro, self.loop) - return future.result() - - def extract( - self, - obj: Callable, - requires: RequirementsDeclaration = None, - returns: ReturnsDeclaration = None - ): - coro = self.context.extract(obj, requires, returns) - future = asyncio.run_coroutine_threadsafe(coro, self.loop) - return future.result() - - -def async_behaviour(callable_): - to_check = callable_ - if isinstance(callable_, partial): - to_check = callable_.func - if asyncio.iscoroutinefunction(to_check): - return AsyncType.async_ - elif asyncio.iscoroutinefunction(to_check.__call__): - return AsyncType.async_ - else: - async_type = get_mush(callable_, 'async', default=None) - if async_type is None: - if isinstance(callable_, type): - return AsyncType.nonblocking - else: - return AsyncType.blocking - else: - return async_type - - -class Context(SyncContext): - - def __init__(self, default_requirement: DefaultRequirement = Annotation): - super().__init__(default_requirement) - self._sync_context = AsyncFromSyncContext(self, asyncio.get_event_loop()) - self._async_cache = {} - - async def _ensure_async(self, func, *args, **kw): - behaviour = self._async_cache.get(func) - if behaviour is None: - behaviour = async_behaviour(func) - self._async_cache[func] = behaviour - - if behaviour is AsyncType.nonblocking: - return func(*args, **kw) - elif behaviour is AsyncType.blocking: - if kw: - func = partial(func, **kw) - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, func, *args) - else: - return await func(*args, **kw) - - def _specials(self) -> Dict[type, Any]: - return {Context: self, SyncContext: self._sync_context} - - async def call(self, obj: Callable, requires: RequirementsDeclaration = None): - resolving = self._resolve(obj, requires) - for call in resolving: - result = await self._ensure_async(call.obj, *call.args, **call.kw) - if call.send: - resolving.send(result) - return result - - async def extract(self, - obj: Callable, - requires: RequirementsDeclaration = None, - returns: ReturnsDeclaration = None): - result = await self.call(obj, requires) - returns = extract_returns(obj, returns) - if returns: - self.add_by_keys(ResourceValue(result), returns) - return result - - -class SyncContextManagerWrapper: - - def __init__(self, sync_manager): - self.sync_manager = sync_manager - self.loop = asyncio.get_event_loop() - - async def __aenter__(self): - return await self.loop.run_in_executor(None, self.sync_manager.__enter__) - - async def __aexit__(self, exc_type, exc_val, exc_tb): - return await self.loop.run_in_executor(None, self.sync_manager.__exit__, - exc_type, exc_val, exc_tb) - - -class Runner(SyncRunner): - - async def __call__(self, context: Context = None): - if context is None: - context = Context() - if context.point is None: - context.point = self.start - - result = None - - while context.point: - - point = context.point - context.point = point.next - - try: - result = manager = await point(context) - except ResourceError as e: - raise ContextError(str(e), point, context) - - if getattr(result, '__enter__', None): - manager = SyncContextManagerWrapper(result) - - if getattr(manager, '__aenter__', None): - async with manager as managed: - if managed is not None and managed is not result: - context.add(managed) - # If the context manager swallows an exception, - # None should be returned, not the context manager: - result = None - if context.point is not None: - result = await self(context) - - return result - - -__all__ = ['Context', 'Runner'] diff --git a/mush/callpoints.py b/mush/callpoints.py deleted file mode 100644 index ce50669..0000000 --- a/mush/callpoints.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import TYPE_CHECKING, Callable - -from .extraction import extract_requires, extract_returns -from .typing import Requires, Returns - -if TYPE_CHECKING: - from . import Context - - -class CallPoint(object): - - next = None - previous = None - - def __init__(self, obj: Callable, requires: Requires = None, returns: Returns = None): - self.obj = obj - self.requires = requires - self.returns = returns - self.labels = set() - self.added_using = set() - - def __call__(self, context: 'Context'): - return context.extract(self.obj, self.requires, self.returns) - - def __repr__(self): - requires = extract_requires(self.obj, self.requires) - returns = extract_returns(self.obj, self.returns) - name = getattr(self.obj, '__qualname__', repr(self.obj)) - txt = f'{name} {requires!r} {returns!r}' - if self.labels: - txt += (' <-- ' + ', '.join(sorted(self.labels))) - return txt diff --git a/mush/compat.py b/mush/compat.py deleted file mode 100644 index 4bc4471..0000000 --- a/mush/compat.py +++ /dev/null @@ -1,11 +0,0 @@ -import sys - -PY_VERSION = sys.version_info[:2] - -PY_37_PLUS = PY_VERSION >= (3, 7) - -try: - from typing import _GenericAlias -except ImportError: - class _GenericAlias: - pass diff --git a/mush/context.py b/mush/context.py deleted file mode 100644 index e0d3f48..0000000 --- a/mush/context.py +++ /dev/null @@ -1,189 +0,0 @@ -from collections import namedtuple -from typing import Optional, Callable, Union, Any, Dict, Iterable - -from .callpoints import CallPoint -from .extraction import extract_requires, extract_returns -from .markers import missing, Marker -from .requirements import Requirement, Annotation -from .resources import ResourceKey, ResourceValue, Provider -from .typing import Resource, Identifier, Type_, Requires, Returns, DefaultRequirement - -NONE_TYPE = type(None) -unspecified = Marker('unspecified') - - -class ResourceError(Exception): - """ - An exception raised when there is a problem with a resource. - """ - - -Call = namedtuple('Call', ('obj', 'args', 'kw', 'send')) - - -class Context: - "Stores resources for a particular run." - - _parent: 'Context' = None - point: CallPoint = None - - def __init__(self, default_requirement: DefaultRequirement = Annotation): - self._store = {} - self._default_requirement = default_requirement - - def add_by_keys(self, resource: ResourceValue, keys: Iterable[ResourceKey]): - keys_ = keys - for key in keys: - if key in self._store: - raise ResourceError(f'Context already contains {key}') - self._store[key] = resource - - def add(self, - obj: Union[Provider, Resource], - provides: Optional[Type_] = missing, - identifier: Identifier = None): - """ - Add a resource to the context. - - Optionally specify what the resource provides. - - ``provides`` can be explicitly specified as ``None`` to only register against the identifier - """ - keys = set() - if isinstance(obj, Provider): - resource = obj - if provides is missing: - keys.update(extract_returns(resource.provider)) - - else: - resource = ResourceValue(obj) - if provides is missing: - provides = type(obj) - - if provides is not missing: - keys.add(ResourceKey(provides, identifier)) - if not (identifier is None or provides is None): - keys.add(ResourceKey(None, identifier)) - - if not keys: - raise ResourceError( - f'Could not determine what is provided by {resource}' - ) - - self.add_by_keys(resource, keys) - - def __repr__(self): - bits = [] - for key, value in sorted(self._store.items(), key=lambda o: repr(o)): - bits.append(f'\n {key}: {value!r}') - if bits: - bits.append('\n') - return f"" - - def extract(self, obj: Callable, requires: Requires = None, returns: Returns = None): - result = self.call(obj, requires) - returns = extract_returns(obj, returns) - if returns: - self.add_by_keys(ResourceValue(result), returns) - return result - - def _find_resource(self, key): - exact = True - if not isinstance(key[0], type): - return self._store.get(key), exact - type_, identifier = key - for type__ in type_.__mro__: - resource = self._store.get((type__, identifier)) - if resource is not None and (exact or resource.provides_subclasses): - return resource, exact - exact = False - return None, exact - - def _specials(self) -> Dict[type, Any]: - return {Context: self} - - def _resolve(self, obj, requires=None, specials=None): - if specials is None: - specials = self._specials() - - requires = extract_requires(obj, requires, self._default_requirement) - - args = [] - kw = {} - - for parameter in requires: - requirement = parameter.requirement - - o = missing - first_key = None - - for key in requirement.keys: - if first_key is None: - first_key = key - - context = self - - while True: - resource, exact = context._find_resource(key) - - if resource is None: - o = specials.get(key[0], missing) - else: - if resource.obj is missing: - specials_ = specials.copy() - specials_[Requirement] = requirement - specials_[ResourceKey] = first_key - o = context._resolve(resource.provider, specials=specials_) - provider = resource.provider - resolving = self._resolve(provider, specials=specials_) - for call in resolving: - o = yield Call(call.obj, call.args, call.kw, send=True) - yield - if call.send: - resolving.send(o) - if resource.cache: - if exact and context is self: - resource.obj = o - else: - self.add_by_keys(ResourceValue(o), (key,)) - else: - o = resource.obj - - if o is not missing: - break - - context = context._parent - if context is None: - break - - if o is not missing: - break - - if o is missing: - o = parameter.default - - if o is not requirement.default: - o = requirement.process(o) - - if o is missing: - raise ResourceError(f'{requirement!r} could not be satisfied') - - if parameter.target is None: - args.append(o) - else: - kw[parameter.target] = o - - yield Call(obj, args, kw, send=False) - - def call(self, obj: Callable, requires: Requires = None): - resolving = self._resolve(obj, requires) - for call in resolving: - result = call.obj(*call.args, **call.kw) - if call.send: - resolving.send(result) - return result - - def nest(self): - nested = self.__class__(self._default_requirement) - nested._parent = self - return nested diff --git a/mush/declarations.py b/mush/declarations.py deleted file mode 100644 index bbafc6a..0000000 --- a/mush/declarations.py +++ /dev/null @@ -1,124 +0,0 @@ -from enum import Enum, auto -from functools import ( - WRAPPER_ASSIGNMENTS as FUNCTOOLS_ASSIGNMENTS, - WRAPPER_UPDATES, - update_wrapper as functools_update_wrapper -) -from itertools import chain -from typing import _type_check, Any, List, Set - -from .markers import set_mush, missing -from .requirements import Requirement, Value -from .resources import ResourceKey -from .typing import RequirementType, ReturnType, Type_ - -VALID_DECORATION_TYPES = (type, str, Requirement) - - -def check_decoration_types(*objs): - for obj in objs: - if isinstance(obj, VALID_DECORATION_TYPES): - continue - try: - _type_check(obj, '') - continue - except TypeError: - pass - raise TypeError( - repr(obj)+" is not a valid decoration type" - ) - - -class Parameter: - def __init__(self, requirement: Requirement, target: str = None, - type_: Type_ = None, default: Any = missing): - self.requirement = requirement - self.target = target - self.default = default - self.type = type_ - - -class RequirementsDeclaration(List[Parameter]): - - def __call__(self, obj): - set_mush(obj, 'requires', self) - return obj - - def __repr__(self): - parts = (repr(p.requirement) if p.target is None else f'{p.target}={p.requirement!r}' - for p in self) - return f"requires({', '.join(parts)})" - - -def requires(*args: RequirementType, **kw: RequirementType): - """ - Represents requirements for a particular callable. - - The passed in ``args`` and ``kw`` should map to the types, including - any required :class:`~.declarations.how`, for the matching - arguments or keyword parameters the callable requires. - - String names for resources must be used instead of types where the callable - returning those resources is configured to return the named resource. - """ - requires_ = RequirementsDeclaration() - check_decoration_types(*args) - check_decoration_types(*kw.values()) - for target, possible in chain( - ((None, arg) for arg in args), - kw.items(), - ): - if isinstance(possible, Requirement): - parameter = Parameter(possible, target, default=possible.default) - else: - parameter = Parameter(Value(possible), target) - requires_.append(parameter) - return requires_ - - -requires_nothing = RequirementsDeclaration() - - -class ReturnsDeclaration(Set[ResourceKey]): - - def __call__(self, obj): - set_mush(obj, 'returns', self) - return obj - - def __repr__(self): - return f"returns({', '.join(str(k) for k in sorted(self, key=lambda o: str(o)))})" - - -def returns(*keys: ReturnType): - """ - """ - check_decoration_types(*keys) - return ReturnsDeclaration(ResourceKey.guess(k) for k in keys) - - -returns_nothing = ignore_return = ReturnsDeclaration() - - -class DeclarationsFrom(Enum): - original = auto() - replacement = auto() - - -#: Use declarations from the original callable. -original = DeclarationsFrom.original -#: Use declarations from the replacement callable. -replacement = DeclarationsFrom.replacement - - -WRAPPER_ASSIGNMENTS = FUNCTOOLS_ASSIGNMENTS + ('__mush__',) - - -def update_wrapper(wrapper, - wrapped, - assigned=WRAPPER_ASSIGNMENTS, - updated=WRAPPER_UPDATES): - """ - An extended version of :func:`functools.update_wrapper` that - also preserves Mush's annotations. - """ - return functools_update_wrapper(wrapper, wrapped, assigned, updated) diff --git a/mush/extraction.py b/mush/extraction.py deleted file mode 100644 index 1f32e83..0000000 --- a/mush/extraction.py +++ /dev/null @@ -1,142 +0,0 @@ -from functools import ( - partial -) -from inspect import signature -from typing import Callable, get_type_hints - -from .declarations import ( - Parameter, RequirementsDeclaration, ReturnsDeclaration, - requires_nothing, returns, requires -) -from .markers import missing, get_mush -from .requirements import Requirement, Annotation -from .resources import ResourceKey -from .typing import Requires, Returns, DefaultRequirement - - -def _apply_requires(by_name, by_index, requires_): - - for i, p in enumerate(requires_): - - if p.target is None: - try: - name = by_index[i] - except IndexError: - # case where something takes *args - by_name[i] = Parameter(p.requirement, p.target, p.type, p.default) - continue - else: - name = p.target - - original_p = by_name[name] - original_p.requirement = p.requirement - original_p.target = p.target - original_p.default = p.default - - -def extract_requires( - obj: Callable, - explicit: Requires = None, - default_requirement: DefaultRequirement = Annotation -) -> RequirementsDeclaration: - by_name = {} - - # from annotations - try: - hints = get_type_hints(obj) - except TypeError: - hints = {} - - for name, p in signature(obj).parameters.items(): - if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): - continue - - # https://bugs.python.org/issue39753: - if isinstance(obj, partial) and p.name in obj.keywords: - continue - - type_ = hints.get(name) - default = missing if p.default is p.empty else p.default - - if isinstance(default, Requirement): - requirement = default - default = requirement.default - else: - requirement = default_requirement(p.name, type_, default) - - by_name[name] = Parameter( - requirement, - target=p.name if p.kind is p.KEYWORD_ONLY else None, - default=default, - type_=type_ - ) - - by_index = list(by_name) - - # from declarations - mush_requires = get_mush(obj, 'requires', None) - if mush_requires is not None: - _apply_requires(by_name, by_index, mush_requires) - - # explicit - if explicit is not None: - if not isinstance(explicit, RequirementsDeclaration): - if not isinstance(explicit, (list, tuple)): - explicit = (explicit,) - explicit = requires(*explicit) - _apply_requires(by_name, by_index, explicit) - - if not by_name: - return requires_nothing - - # sort out target: - needs_target = False - for name, parameter in by_name.items(): - if parameter.target is not None: - needs_target = True - elif needs_target: - parameter.target = name - parameter.requirement = parameter.requirement.complete( - name, parameter.type, parameter.default - ) - - return RequirementsDeclaration(by_name.values()) - - -def extract_returns(obj: Callable, explicit: Returns = None): - if explicit is not None: - if not isinstance(explicit, ReturnsDeclaration): - return returns(explicit) - return explicit - - returns_ = get_mush(obj, 'returns', None) - if returns_ is not None: - return returns_ - - returns_ = ReturnsDeclaration() - try: - type_ = get_type_hints(obj).get('return') - except TypeError: - type_ = None - else: - if type_ is type(None): - return returns_ - - if type_ is None and isinstance(obj, type): - type_ = obj - - if isinstance(obj, partial): - obj = obj.func - identifier = getattr(obj, '__name__', None) - - type_supplied = type_ is not None - identifier_supplied = identifier is not None - - if type_supplied: - returns_.add(ResourceKey(type_, None)) - if identifier_supplied: - returns_.add(ResourceKey(None, identifier)) - if type_supplied and identifier_supplied: - returns_.add(ResourceKey(type_, identifier)) - - return returns_ diff --git a/mush/markers.py b/mush/markers.py deleted file mode 100644 index 7738d48..0000000 --- a/mush/markers.py +++ /dev/null @@ -1,56 +0,0 @@ -import asyncio -from enum import Enum, auto - - -class Marker(object): - - def __init__(self, name): - self.name = name - - def __repr__(self): - return '' % self.name - - -not_specified = Marker('not_specified') - -#: A sentinel object to indicate that a value is missing. -missing = Marker('missing') - - -def set_mush(obj, key, value): - if not hasattr(obj, '__mush__'): - obj.__mush__ = {} - obj.__mush__[key] = value - - -def get_mush(obj, key, default): - __mush__ = getattr(obj, '__mush__', missing) - if __mush__ is missing: - return default - return __mush__.get(key, default) - - -class AsyncType(Enum): - blocking = auto() - nonblocking = auto() - async_ = auto() - - -def nonblocking(obj): - """ - A decorator to mark a callable as not requiring running - in a thread, even though it's not async. - """ - set_mush(obj, 'async', AsyncType.nonblocking) - return obj - - -def blocking(obj): - """ - A decorator to explicitly mark a callable as requiring running - in a thread. - """ - if asyncio.iscoroutinefunction(obj): - raise TypeError('cannot mark an async function as blocking') - set_mush(obj, 'async', AsyncType.blocking) - return obj diff --git a/mush/modifier.py b/mush/modifier.py deleted file mode 100644 index 4ac684a..0000000 --- a/mush/modifier.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -.. currentmodule:: mush -""" -from typing import Callable - -from .callpoints import CallPoint -from .markers import not_specified -from .typing import Requires, Returns - - -class Modifier(object): - """ - Used to make changes at a particular point in a runner. - These are returned by :meth:`Runner.add` and :meth:`Runner.__getitem__`. - """ - def __init__(self, runner, callpoint, label): - self.runner = runner - self.callpoint = callpoint - if label is not_specified: - self.labels = set() - else: - self.labels = {label} - - def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, - label: str = None): - """ - :param obj: The callable to be added. - - :param requires: The resources to required as parameters when calling - ``obj``. These can be specified by passing a single - type, a string name or a :class:`requires` object. - - :param returns: The resources that ``obj`` will return. - These can be specified as a single - type, a string name or a :class:`returns`, - :class:`returns_mapping`, :class:`returns_sequence` - object. - - :param label: If specified, this is a string that adds a label to the - point where ``obj`` is added that can later be retrieved - with :meth:`Runner.__getitem__`. - - If no label is specified but the point which this - :class:`~.modifier.Modifier` represents has any labels, those labels - will be moved to the newly inserted point. - """ - if label in self.runner.labels: - raise ValueError('%r already points to %r' % ( - label, self.runner.labels[label] - )) - callpoint = CallPoint(obj, requires, returns) - - if label: - self.add_label(label, callpoint) - - callpoint.previous = self.callpoint - - if self.callpoint: - - callpoint.next = self.callpoint.next - if self.callpoint.next: - self.callpoint.next.previous = callpoint - self.callpoint.next = callpoint - - if not label: - for label in self.labels: - self.add_label(label, callpoint) - callpoint.added_using.add(label) - else: - self.runner.start = callpoint - - if self.callpoint is self.runner.end or self.runner.end is None: - self.runner.end = callpoint - - self.callpoint = callpoint - - def add_label(self, label, callpoint=None): - """ - Add a label to the point represented by this - :class:`~.modifier.Modifier`. - - :param callpoint: For internal use only. - """ - callpoint = callpoint or self.callpoint - callpoint.labels.add(label) - old_callpoint = self.runner.labels.get(label) - if old_callpoint: - old_callpoint.labels.remove(label) - self.runner.labels[label] = callpoint - self.labels.add(label) diff --git a/mush/plug.py b/mush/plug.py deleted file mode 100644 index 0f06ebb..0000000 --- a/mush/plug.py +++ /dev/null @@ -1,71 +0,0 @@ -from .markers import set_mush, get_mush - - -class ignore(object): - """ - A decorator to explicitly mark that a method of a :class:`~mush.Plug` should - not be added to a runner by :meth:`~mush.Plug.add_to` - """ - def __call__(self, method): - set_mush(method, 'plug', self) - return method - - def apply(self, runner, obj): - pass - - -class insert(ignore): - """ - A decorator to explicitly mark that a method of a :class:`~mush.Plug` should - be added to a runner by :meth:`~mush.Plug.add_to`. The ``label`` parameter - can be used to indicate a different label at which to add the method, - instead of using the name of the method. - """ - def __init__(self, label=None): - self.label = label - - def apply(self, runner, obj): - runner[self.label or obj.__name__].add(obj) - -class append(ignore): - """ - A decorator to mark that this method of a :class:`~mush.Plug` should - be added to the end of a runner by :meth:`~mush.Plug.add_to`. - """ - - def apply(self, runner, obj): - runner.add(obj) - - -class Plug(object): - """ - Base class for a 'plug' that can add to several points in a runner. - """ - - #: Control whether methods need to be decorated with :class:`insert` - #: in order to be added by this :class:`~mush.Plug`. - explicit = False - - @ignore() - def add_to(self, runner): - """ - Add methods of the instance to the supplied runner. - By default, all methods will be added and the name of the method will be - used as the label in the runner at which the method will be added. - If no such label exists, a :class:`KeyError` will be raised. - - If :attr:`explicit` is ``True``, then only methods decorated with an - :class:`~mush.plug.insert` will be added. - """ - - if self.explicit: - default_action = ignore() - else: - default_action = insert() - - for name in dir(self): - if not name.startswith('_'): - obj = getattr(self, name) - if callable(obj): - action = get_mush(obj, 'plug', default_action) - action.apply(runner, obj) diff --git a/mush/requirements.py b/mush/requirements.py deleted file mode 100644 index 7bc833d..0000000 --- a/mush/requirements.py +++ /dev/null @@ -1,189 +0,0 @@ -from typing import Any, List, Sequence, Optional, Union, Type - -from .markers import missing -from .resources import ResourceKey, type_repr, is_type -from .typing import Identifier, Type_ - - -class Op: - - def __init__(self, name): - self.name = name - - def __call__(self, o): # pragma: no cover - raise NotImplementedError() - - -class AttrOp(Op): - - def __call__(self, o): - try: - return getattr(o, self.name) - except AttributeError: - return missing - - def __repr__(self): - return f'.{self.name}' - - -class ItemOp(Op): - - def __call__(self, o): - try: - return o[self.name] - except KeyError: - return missing - - def __repr__(self): - return f'[{self.name!r}]' - - -class Requirement: - """ - The requirement for an individual parameter of a callable. - """ - - def __init__(self, keys: Sequence[ResourceKey], default: Optional[Any] = missing): - #: Note that the first key returned should be the "most specific" - self.keys: Sequence[ResourceKey] = keys - self.default = default - self.ops: List['Op'] = [] - - def _keys_repr(self): - return ', '.join(repr(key) for key in self.keys) - - def __repr__(self): - default = '' if self.default is missing else f', default={self.default!r}' - other = ', '.join(f'{n}={v!r}' for n, v in vars(self).items() - if n not in ('keys', 'default', 'ops')) - if other: - other = ', '+other - ops = ''.join(repr(o) for o in self.ops) - return f"{type(self).__name__}({self._keys_repr()}{default}{other}){ops}" - - def attr(self, name): - """ - If you need to get an attribute called either ``attr`` or ``item`` - then you will need to call this method instead of using the - generating behaviour. - """ - self.ops.append(AttrOp(name)) - return self - - def __getattr__(self, name): - if name.startswith('__'): - raise AttributeError(name) - return self.attr(name) - - def __getitem__(self, name): - self.ops.append(ItemOp(name)) - return self - - def complete(self, name: str, type_: Type_, default: Any): - return self - - def process(self, obj): - """ - .. warning:: This must not block when used with an async context! - """ - for op in self.ops: - obj = op(obj) - if obj is missing: - obj = self.default - break - return obj - - -class Annotation(Requirement): - - def __init__(self, name: str, type_: Type_ = None, default: Any = missing): - if type_ is None: - keys = [ResourceKey(None, name)] - else: - keys = [ - ResourceKey(type_, name), - ResourceKey(type_, None), - ResourceKey(None, name), - ] - super().__init__(keys, default) - - def __repr__(self): - type_, name = self.keys[0] - r = name - if type_ is not None: - r += f': {type_repr(type_)}' - if self.default is not missing: - r += f' = {self.default!r}' - return r - - -class Value(Requirement): - """ - Declaration indicating that the specified resource key is required. - - Values are generative, so they can be used to indicate attributes or - items from a resource are required. - - A default may be specified, which will be used if the specified - resource is not available. - - A type may also be explicitly specified, but you probably shouldn't - ever use this. - """ - - def __init__(self, - key: Union[Type_, Identifier] = None, - identifier: Identifier = None, - default: Any = missing): - if identifier is None: - if key is None: - raise TypeError('type or identifier must be supplied') - resource_key = ResourceKey.guess(key) - else: - resource_key = ResourceKey(key, identifier) - super().__init__([resource_key], default) - - def _keys_repr(self): - return str(self.keys[0]) - - -class AnyOf(Requirement): - """ - A requirement that is resolved by any of the specified keys. - - A key may either be a :class:`type` or an :class:`Identifier` - """ - - def __init__(self, *keys: Union[Type_, Identifier], default: Any = missing): - if not keys: - raise TypeError('at least one key must be specified') - resource_keys = [] - for key in keys: - type_ = identifier = None - if is_type(key): - type_ = key - else: - identifier = key - resource_keys.append(ResourceKey(type_, identifier)) - super().__init__(resource_keys, default) - - def _keys_repr(self): - return ', '.join(str(key) for key in self.keys) - - -class Like(Requirement): - """ - A requirements that is resolved by the specified class or - any of its base classes. - """ - - def __init__(self, type_: type, default: Any = missing): - keys = [] - for type__ in type_.__mro__: - if type__ is object: - break - keys.append(ResourceKey(type__, None)) - super().__init__(keys, default) - - def _keys_repr(self): - return str(self.keys[0]) diff --git a/mush/resources.py b/mush/resources.py deleted file mode 100644 index 2a6805f..0000000 --- a/mush/resources.py +++ /dev/null @@ -1,84 +0,0 @@ -from types import FunctionType -from typing import Callable, Optional - -from .compat import _GenericAlias -from .markers import missing -from .typing import Resource, Identifier, Type_ - - -def type_repr(type_): - if isinstance(type_, type): - return type_.__qualname__ - elif isinstance(type_, FunctionType): - return type_.__name__ - else: - return repr(type_) - - -def is_type(obj): - return ( - isinstance(obj, (type, _GenericAlias)) or - (callable(obj) and hasattr(obj, '__supertype__')) - ) - - -class ResourceKey(tuple): - - def __new__(cls, type_: Type_ = None, identifier: Identifier = None): - return tuple.__new__(cls, (type_, identifier)) - - @classmethod - def guess(cls, key): - type_ = identifier = None - if is_type(key): - type_ = key - else: - identifier = key - return cls(type_, identifier) - - @property - def type(self) -> Type_: - return self[0] - - @property - def identifier(self) -> Identifier: - return self[1] - - def __str__(self): - type_ = self.type - if type_ is None: - return repr(self.identifier) - type_repr_ = type_repr(type_) - if self.identifier is None: - return type_repr_ - return f'{type_repr_}, {self.identifier!r}' - - def __repr__(self): - return f'ResourceKey({self})' - - -class ResourceValue: - - provider: Optional[Callable] = None - provides_subclasses: bool = False - - def __init__(self, obj: Resource): - self.obj = obj - - def __repr__(self): - return repr(self.obj) - - -class Provider(ResourceValue): - - def __init__(self, obj: Callable, *, cache: bool = True, provides_subclasses: bool = False): - super().__init__(missing) - self.provider = obj - self.cache = cache - self.provides_subclasses = provides_subclasses - - def __repr__(self): - obj_repr = '' if self.obj is missing else f'cached={self.obj!r}, ' - return (f'Provider({self.provider}, {obj_repr}' - f'cache={self.cache}, ' - f'provides_subclasses={self.provides_subclasses})') diff --git a/mush/runner.py b/mush/runner.py deleted file mode 100644 index 0698cc3..0000000 --- a/mush/runner.py +++ /dev/null @@ -1,332 +0,0 @@ -from typing import Callable, Optional - -from .callpoints import CallPoint -from .context import Context, ResourceError -from .declarations import DeclarationsFrom -from .extraction import extract_requires, extract_returns # , extract_returns -from .markers import not_specified -from .modifier import Modifier -from .plug import Plug -from .typing import Requires, Returns - - -class Runner(object): - """ - A chain of callables along with declarations of their required and - returned resources along with tools to manage the order in which they - will be called. - """ - - start: Optional[CallPoint] = None - end: Optional[CallPoint] = None - - def __init__(self, *objects: Callable): - self.labels = {} - self.extend(*objects) - - def add(self, obj: Callable, requires: Requires = None, returns: Returns = None, - label: str = None): - """ - Add a callable to the runner. - - :param obj: The callable to be added. - - :param requires: The resources to required as parameters when calling - ``obj``. These can be specified by passing a single - type, a string name or a :class:`requires` object. - - :param returns: The resources that ``obj`` will return. - These can be specified as a single - type, a string name or a :class:`returns`, - :class:`returns_mapping`, :class:`returns_sequence` - object. - - :param label: If specified, this is a string that adds a label to the - point where ``obj`` is added that can later be retrieved - with :meth:`Runner.__getitem__`. - """ - if isinstance(obj, Plug): - obj.add_to(self) - else: - m = Modifier(self, self.end, not_specified) - m.add(obj, requires, returns, label) - return m - - def add_label(self, label: str): - """ - Add a label to the the point currently at the end of the runner. - """ - m = Modifier(self, self.end, not_specified) - m.add_label(label) - return m - - def _copy_from(self, start_point, end_point, added_using=None): - - previous_cloned_point = self.end - point = start_point - - while point: - if added_using is None or added_using in point.added_using: - cloned_point = CallPoint(point.obj, point.requires, point.returns) - cloned_point.labels = set(point.labels) - for label in cloned_point.labels: - self.labels[label] = cloned_point - - if self.start is None: - self.start = cloned_point - - if previous_cloned_point: - previous_cloned_point.next = cloned_point - cloned_point.previous = previous_cloned_point - - previous_cloned_point = cloned_point - - point = point.next - if point and point.previous is end_point: - break - - self.end = previous_cloned_point - - def extend(self, *objs: Callable): - """ - Add the specified callables to this runner. - - If any of the objects passed is a :class:`Runner`, the contents of that - runner will be added to this runner. - """ - for obj in objs: - if isinstance(obj, Runner): - self._copy_from(obj.start, obj.end) - else: - self.add(obj) - - def clone(self, - start_label: str = None, end_label: str = None, - include_start: bool = False, include_end: bool = False, - added_using: str = None): - """ - Return a copy of this :class:`Runner`. - - :param start_label: - An optional string specifying the point at which to start cloning. - - :param end_label: - An optional string specifying the point at which to stop cloning. - - :param include_start: - If ``True``, the point specified in ``start_label`` will be included - in the cloned runner. - - :param include_end: - If ``True``, the point specified in ``end_label`` will be included - in the cloned runner. - - :param added_using: - An optional string specifying that only points added using the - label specified in this option should be cloned. - This filtering is applied in addition to the above options. - """ - runner = self.__class__() - - if start_label: - start = self.labels[start_label] - if not include_start: - start = start.next - else: - start = self.start - - if end_label: - end = self.labels[end_label] - if not include_end: - end = end.previous - else: - end = self.end - - # check start point is before end_point - if start is not None: - point = start.previous - else: - point = None - - while point: - if point is end: - return runner - point = point.previous - - runner._copy_from(start, end, added_using) - return runner - - def replace(self, - original: Callable, - replacement: Callable, - requires_from: DeclarationsFrom = DeclarationsFrom.replacement, - returns_from: DeclarationsFrom = DeclarationsFrom.original): - """ - Replace all instances of one callable with another. - - :param original: The callable to replaced. - - :param replacement: The callable use instead. - - :param requires_from: - - Which :class:`requires` to use. - If :attr:`~mush.declarations.DeclarationsFrom.original`, - the existing ones will be used. - If :attr:`~mush.declarations.DeclarationsFrom.replacement`, - they will be extracted from the supplied replacements. - - :param returns_from: - - Which :class:`returns` to use. - If :attr:`~mush.declarations.DeclarationsFrom.original`, - the existing ones will be used. - If :attr:`~mush.declarations.DeclarationsFrom.replacement`, - they will be extracted from the supplied replacements. - """ - point = self.start - while point: - if point.obj is original: - if requires_from is DeclarationsFrom.replacement: - requires = extract_requires(replacement) - else: - requires = extract_requires(point.obj, point.requires) - if returns_from is DeclarationsFrom.replacement: - returns = extract_returns(replacement) - else: - returns = extract_returns(point.obj, point.returns) - - new_point = CallPoint(replacement, requires, returns) - - if point.previous is None: - self.start = new_point - else: - point.previous.next = new_point - if point.next is None: - self.end = new_point - else: - point.next.previous = new_point - new_point.next = point.next - - for label in point.labels: - self.labels[label] = new_point - new_point.labels.add(label) - new_point.added_using = set(point.added_using) - - point = point.next - - def __getitem__(self, label: str): - """ - Retrieve a :class:`~.modifier.Modifier` for a previous labelled point in - the runner. - """ - return Modifier(self, self.labels[label], label) - - def __add__(self, other: 'Runner'): - """ - Return a new :class:`Runner` containing the contents of the two - :class:`Runner` instances being added together. - """ - runner = self.__class__() - for r in self, other: - runner._copy_from(r.start, r.end) - return runner - - def __call__(self, context: Context = None): - """ - Execute the callables in this runner in the required order - storing objects that are returned and providing them as - arguments or keyword parameters when required. - - A runner may be called multiple times. Each time a new - :class:`~.context.Context` will be created meaning that no required - objects are kept between calls and all callables will be - called each time. - - :param context: - Used for passing a context when context managers are used. - You should never need to pass this parameter. - """ - if context is None: - context = Context() - if context.point is None: - context.point = self.start - - result = None - - while context.point: - - point = context.point - context.point = point.next - - try: - result = point(context) - except ResourceError as e: - raise ContextError(str(e), point, context) - - if getattr(result, '__enter__', None): - with result as managed: - if managed is not None and managed is not result: - context.add(managed) - # If the context manager swallows an exception, - # None should be returned, not the context manager: - result = None - if context.point is not None: - result = self(context) - - return result - - def __repr__(self): - bits = [] - point = self.start - while point: - bits.append('\n ' + repr(point)) - point = point.next - if bits: - bits.append('\n') - return '%s' % ''.join(bits) - - -class ContextError(Exception): - """ - Errors likely caused by incorrect building of a runner. - """ - def __init__(self, text: str, point: CallPoint = None, context: Context = None): - self.text: str = text - self.point: CallPoint = point - self.context: Context = context - - def __str__(self): - rows = ['', ''] - if self.point: - already_called = [] - point = self.point.previous - while point: - already_called.append(repr(point)) - point = point.previous - if already_called: - rows.append('Already called:') - rows.extend(reversed(already_called)) - rows.append('') - - rows.append('While calling:') - rows.append(repr(self.point)) - rows.append('') - if self.context is not None: - rows.append(f'with {self.context!r}:') - rows.append('') - - rows.append(self.text) - - if self.point: - point = self.point.next - if point: - rows.append('') - rows.append('Still to call:') - while point: - rows.append(repr(point)) - point = point.next - - return '\n'.join(rows) - - __repr__ = __str__ diff --git a/mush/tests/__init__.py b/mush/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mush/tests/conftest.py b/mush/tests/conftest.py deleted file mode 100644 index 62cd59b..0000000 --- a/mush/tests/conftest.py +++ /dev/null @@ -1,9 +0,0 @@ -import sys -from re import search - - -def pytest_ignore_collect(path): - file_min_version_match = search(r'_py(\d)(\d)$', path.purebasename) - if file_min_version_match: - file_min_version = tuple(int(d) for d in file_min_version_match.groups()) - return sys.version_info < file_min_version diff --git a/mush/tests/example_with_mush_clone.py b/mush/tests/example_with_mush_clone.py deleted file mode 100644 index 6387d84..0000000 --- a/mush/tests/example_with_mush_clone.py +++ /dev/null @@ -1,71 +0,0 @@ -from argparse import ArgumentParser, Namespace -from configparser import RawConfigParser -from mush import Runner, requires, Value, returns -import logging, os, sqlite3, sys - -log = logging.getLogger() - -def base_options(parser: ArgumentParser): - parser.add_argument('config', help='Path to .ini file') - parser.add_argument('--quiet', action='store_true', - help='Log less to the console') - parser.add_argument('--verbose', action='store_true', - help='Log more to the console') - -def parse_args(parser: ArgumentParser) -> Namespace: - return parser.parse_args() - -@returns('config') -def parse_config(args: Namespace): - config = RawConfigParser() - config.read(args.config) - return dict(config.items('main')) - -def setup_logging(log_path, quiet=False, verbose=False): - handler = logging.FileHandler(log_path) - handler.setLevel(logging.DEBUG) - log.addHandler(handler) - if not quiet: - handler = logging.StreamHandler(sys.stderr) - handler.setLevel(logging.DEBUG if verbose else logging.INFO) - log.addHandler(handler) - -class DatabaseHandler: - def __init__(self, db_path): - self.conn = sqlite3.connect(db_path) - def __enter__(self): - return self - def __exit__(self, type, obj, tb): - if type: - log.exception('Something went wrong') - self.conn.rollback() - -base_runner = Runner(ArgumentParser) -base_runner.add(base_options, label='args') -base_runner.extend(parse_args, parse_config) -base_runner.add(setup_logging, requires( - log_path = Value('config')['log'], - quiet = Value(Namespace).quiet, - verbose = Value(Namespace).verbose, -)) - - -def args(parser): - parser.add_argument('path', help='Path to the file to process') - -def do(conn, path): - filename = os.path.basename(path) - with open(path) as source: - conn.execute('insert into notes values (?, ?)', - (filename, source.read())) - conn.commit() - log.info('Successfully added %r', filename) - -main = base_runner.clone() -main['args'].add(args, requires=ArgumentParser) -main.add(DatabaseHandler, requires=Value('config')['db']) -main.add(do, - requires(Value(DatabaseHandler).conn, Value(Namespace).path)) - -if __name__ == '__main__': - main() diff --git a/mush/tests/example_with_mush_factory.py b/mush/tests/example_with_mush_factory.py deleted file mode 100644 index 2757faa..0000000 --- a/mush/tests/example_with_mush_factory.py +++ /dev/null @@ -1,35 +0,0 @@ -from mush import Runner, requires, Value -from argparse import ArgumentParser, Namespace - -from .example_with_mush_clone import ( - DatabaseHandler, parse_args, parse_config, do, - setup_logging - ) - - -def options(parser): - parser.add_argument('config', help='Path to .ini file') - parser.add_argument('--quiet', action='store_true', - help='Log less to the console') - parser.add_argument('--verbose', action='store_true', - help='Log more to the console') - parser.add_argument('path', help='Path to the file to process') - -def make_runner(do): - runner = Runner(ArgumentParser) - runner.add(options, requires=ArgumentParser) - runner.add(parse_args, requires=ArgumentParser) - runner.add(parse_config, requires=Namespace) - runner.add(setup_logging, requires( - log_path=Value('config')['log'], - quiet=Value(Namespace).quiet, - verbose=Value(Namespace).verbose, - )) - runner.add(DatabaseHandler, requires=Value('config')['db']) - runner.add( - do, - requires(Value(DatabaseHandler).conn, Value(Namespace).path) - ) - return runner - -main = make_runner(do) diff --git a/mush/tests/example_without_mush.py b/mush/tests/example_without_mush.py deleted file mode 100644 index df454ac..0000000 --- a/mush/tests/example_without_mush.py +++ /dev/null @@ -1,44 +0,0 @@ -from argparse import ArgumentParser -from configparser import RawConfigParser -import logging, os, sqlite3, sys - -log = logging.getLogger() - -def main(): - parser = ArgumentParser() - parser.add_argument('config', help='Path to .ini file') - parser.add_argument('--quiet', action='store_true', - help='Log less to the console') - parser.add_argument('--verbose', action='store_true', - help='Log more to the console') - parser.add_argument('path', help='Path to the file to process') - - args = parser.parse_args() - - config = RawConfigParser() - config.read(args.config) - - handler = logging.FileHandler(config.get('main', 'log')) - handler.setLevel(logging.DEBUG) - log.addHandler(handler) - log.setLevel(logging.DEBUG) - - if not args.quiet: - handler = logging.StreamHandler(sys.stderr) - handler.setLevel(logging.DEBUG if args.verbose else logging.INFO) - log.addHandler(handler) - - conn = sqlite3.connect(config.get('main', 'db')) - - try: - filename = os.path.basename(args.path) - with open(args.path) as source: - conn.execute('insert into notes values (?, ?)', - (filename, source.read())) - conn.commit() - log.info('Successfully added %r', filename) - except: - log.exception('Something went wrong') - -if __name__ == '__main__': - main() diff --git a/mush/tests/helpers.py b/mush/tests/helpers.py deleted file mode 100644 index b89189d..0000000 --- a/mush/tests/helpers.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import sys -from contextlib import contextmanager -from functools import partial - -from testfixtures.mock import Mock - - -PY_VERSION = sys.version_info[:2] - -PY_36 = PY_VERSION == (3, 6) - - -class Type1(object): pass -class Type2(object): pass -class Type3(object): pass -class Type4(object): pass - - -class TheType(object): - def __repr__(self): - return '' - - -@contextmanager -def no_threads(): - loop = asyncio.get_event_loop() - original = loop.run_in_executor - loop.run_in_executor = Mock(side_effect=Exception('threads used when they should not be')) - try: - yield - finally: - loop.run_in_executor = original - - -@contextmanager -def must_run_in_thread(*expected): - seen = set() - loop = asyncio.get_event_loop() - original = loop.run_in_executor - - def recording_run_in_executor(executor, func, *args): - if isinstance(func, partial): - to_record = func.func - else: - # get the underlying method for bound methods: - to_record = getattr(func, '__func__', func) - seen.add(to_record) - return original(executor, func, *args) - - loop.run_in_executor = recording_run_in_executor - try: - yield - finally: - loop.run_in_executor = original - - not_seen = set(expected) - seen - assert not not_seen, f'{not_seen} not run in a thread, seen: {seen}' diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py deleted file mode 100644 index a3690a0..0000000 --- a/mush/tests/test_async_context.py +++ /dev/null @@ -1,331 +0,0 @@ -import asyncio -from functools import partial -from typing import Tuple - -import pytest -from testfixtures import compare, ShouldRaise - -from mush import requires, returns, Context as SyncContext, blocking, nonblocking -from mush.asyncio import Context -from mush.requirements import Requirement, AnyOf, Like -from .helpers import TheType, no_threads, must_run_in_thread -from ..markers import AsyncType -from ..resources import ResourceKey, Provider - - -@pytest.mark.asyncio -async def test_call_is_async(): - context = Context() - def it(): - return 'bar' - result = context.call(it) - assert asyncio.iscoroutine(result) - with must_run_in_thread(it): - compare(await result, expected='bar') - - -@pytest.mark.asyncio -async def test_call_async(): - context = Context() - context.add('1', identifier='a') - async def it(a, b='2'): - return a+b - with no_threads(): - compare(await context.call(it), expected='12') - - -@pytest.mark.asyncio -async def test_call_async_callable_object(): - context = Context() - - class AsyncCallable: - async def __call__(self): - return 42 - - with no_threads(): - compare(await context.call(AsyncCallable()), expected=42) - - -@pytest.mark.asyncio -async def test_call_partial_around_async(): - context = Context() - - async def it(): - return 42 - - with no_threads(): - compare(await context.call(partial(it)), expected=42) - - -@pytest.mark.asyncio -async def test_call_async_requires_async_context(): - context = Context() - async def baz(): - return 'bar' - async def it(context: Context): - return await context.call(baz) + 'bob' - compare(await context.call(it), expected='barbob') - - -@pytest.mark.asyncio -async def test_call_sync(): - context = Context() - context.add('foo', identifier='baz') - def it(*, baz): - return baz+'bar' - with must_run_in_thread(it): - compare(await context.call(it), expected='foobar') - - -@pytest.mark.asyncio -async def test_call_sync_requires_context(): - context = Context() - # NB: this is intentionally async to test calling async - # in a sync context: - async def baz(): - return 'bar' - # sync method, so needs a sync context: - def it(context: SyncContext): - return context.call(baz) + 'bob' - compare(await context.call(it), expected='barbob') - - -@pytest.mark.asyncio -async def test_async_provider_async_user(): - o = TheType() - lookup = {TheType: o} - async def provider(key: ResourceKey): - return lookup[key.type] - context = Context() - context.add(Provider(provider), provides=TheType) - async def returner(obj: TheType): - return obj - assert await context.call(returner) is o - - -@pytest.mark.asyncio -async def test_async_provider_sync_user(): - o = TheType() - lookup = {TheType: o} - async def provider(key: ResourceKey): - return lookup[key.type] - context = Context() - context.add(Provider(provider), provides=TheType) - def returner(obj: TheType): - return obj - assert await context.call(returner) is o - - -@pytest.mark.asyncio -async def test_sync_provider_async_user(): - o = TheType() - lookup = {TheType: o} - def provider(key: ResourceKey): - return lookup[key.type] - context = Context() - context.add(Provider(provider), provides=TheType) - async def returner(obj: TheType): - return obj - assert await context.call(returner) is o - - -@pytest.mark.asyncio -async def test_sync_provider_sync_user(): - o = TheType() - lookup = {TheType: o} - def provider(key: ResourceKey): - return lookup[key.type] - context = Context() - context.add(Provider(provider), provides=TheType) - def returner(obj: TheType): - return obj - assert await context.call(returner) is o - - -@pytest.mark.asyncio -async def test_call_class_defaults_to_non_blocking(): - context = Context() - with no_threads(): - obj = await context.call(TheType) - assert isinstance(obj, TheType) - - -@pytest.mark.asyncio -async def test_call_class_explicitly_marked_as_blocking(): - @blocking - class BlockingType: pass - context = Context() - with must_run_in_thread(BlockingType): - obj = await context.call(BlockingType) - assert isinstance(obj, BlockingType) - - -@pytest.mark.asyncio -async def test_call_function_defaults_to_blocking(): - def foo(): - return 42 - context = Context() - with must_run_in_thread(foo): - compare(await context.call(foo), expected=42) - - -@pytest.mark.asyncio -async def test_call_function_explicitly_marked_as_non_blocking(): - @nonblocking - def foo(): - return 42 - context = Context() - with no_threads(): - compare(await context.call(foo), expected=42) - - -@pytest.mark.asyncio -async def test_call_async_function_explicitly_marked_as_non_blocking(): - # sure, I mean, whatever... - @nonblocking - async def foo(): - return 42 - context = Context() - with no_threads(): - compare(await context.call(foo), expected=42) - - -@pytest.mark.asyncio -async def test_call_async_function_explicitly_marked_as_blocking(): - with ShouldRaise(TypeError('cannot mark an async function as blocking')): - @blocking - async def foo(): pass - - -@pytest.mark.asyncio -async def test_call_caches_asyncness(): - async def foo(): - return 42 - context = Context() - await context.call(foo) - compare(context._async_cache[foo], expected=AsyncType.async_) - - -@pytest.mark.asyncio -async def test_extract_is_async(): - context = Context() - def it(): - return 'bar' - result = context.extract(it, requires(), returns('baz')) - assert asyncio.iscoroutine(result) - compare(await result, expected='bar') - async def returner(baz): - return baz - compare(await context.call(returner), expected='bar') - - -@pytest.mark.asyncio -async def test_extract_async(): - context = Context() - async def bob(): - return 'foo' - async def it(context): - return await context.extract(bob)+'bar' - result = context.extract(it, requires(Context), returns('baz')) - compare(await result, expected='foobar') - async def returner(bob): - return bob - compare(await context.call(returner), expected='foo') - - -@pytest.mark.asyncio -async def test_extract_sync(): - context = Context() - # NB: this is intentionally async to test calling async - # in a sync context: - def bob(): - return 'foo' - def it(context): - return context.extract(bob)+'bar' - result = context.extract(it, requires(SyncContext), returns('baz')) - compare(await result, expected='foobar') - def returner(bob): - return bob - compare(await context.call(returner), expected='foo') - - -@pytest.mark.asyncio -async def test_extract_minimal(): - o = TheType() - def foo() -> TheType: - return o - context = Context() - result = await context.extract(foo) - assert result is o - async def returner(x: TheType): - return x - compare(await context.call(returner), expected=o) - - -@pytest.mark.asyncio -async def test_extract_maximal(): - def foo(*args): - return args - context = Context() - context.add('a') - result = await context.extract(foo, requires(str), returns(Tuple[str])) - compare(result, expected=('a',)) - async def returner(x: Tuple[str]): - return x - compare(await context.call(returner), expected=('a',)) - - -@pytest.mark.asyncio -async def test_value_resolve_does_not_run_in_thread(): - with no_threads(): - context = Context() - context.add('foo', identifier='baz') - - async def it(baz): - return baz+'bar' - - compare(await context.call(it), expected='foobar') - - -@pytest.mark.asyncio -async def test_anyof_resolve_does_not_run_in_thread(): - with no_threads(): - context = Context() - context.add(('foo', )) - - async def bob(x: str = AnyOf(tuple, Tuple[str])): - return x[0] - - compare(await context.call(bob), expected='foo') - - -@pytest.mark.asyncio -async def test_like_resolve_does_not_run_in_thread(): - with no_threads(): - o = TheType() - context = Context() - context.add(o) - - async def bob(x: str = Like(TheType)): - return x - - assert await context.call(bob) is o - - -@pytest.mark.asyncio -async def test_default_custom_requirement(): - - class FromRequest(Requirement): - def __init__(self, name, type_, default): - self.name = name - self.type = type_ - super().__init__(keys=[ResourceKey(identifier='request')], default=default) - def process(self, obj): - return self.type(obj[self.name]) - - def foo(bar: int): - return bar - - context = Context(FromRequest) - context.add({'bar': '42'}, identifier='request') - compare(await context.call(foo), expected=42) diff --git a/mush/tests/test_async_runner.py b/mush/tests/test_async_runner.py deleted file mode 100644 index dad3a50..0000000 --- a/mush/tests/test_async_runner.py +++ /dev/null @@ -1,600 +0,0 @@ - -import asyncio -from testfixtures.mock import Mock, call - -import pytest -from testfixtures import compare, ShouldRaise, Comparison as C - -from mush import ContextError, requires, returns -from mush.asyncio import Runner, Context -from .helpers import no_threads, must_run_in_thread - - -@pytest.mark.asyncio -async def test_call_is_async(): - def it(): - return 'bar' - runner = Runner(it) - result = runner() - assert asyncio.iscoroutine(result) - with must_run_in_thread(it): - compare(await result, expected='bar') - - -@pytest.mark.asyncio -async def test_resource_missing(): - def it(foo): - pass - runner = Runner(it) - context = Context() - with ShouldRaise(ContextError(C(str), runner.start, context)): - await runner(context) - - -@pytest.mark.asyncio -async def test_cloned_still_async(): - def it(): - return 'bar' - runner = Runner(it) - runner_ = runner.clone() - result = runner_() - assert asyncio.iscoroutine(result) - compare(await result, expected='bar') - - -@pytest.mark.asyncio -async def test_addition_still_async(): - async def foo(): - return 'foo' - @returns() - async def bar(foo: str): - return foo+'bar' - r1 = Runner(foo) - r2 = Runner(bar) - runner = r1 + r2 - result = runner() - assert asyncio.iscoroutine(result) - compare(await result, expected='foobar') - - -class CommonCM: - m = None - context = None - swallow_exceptions = None - - -class AsyncCM(CommonCM): - - async def __aenter__(self): - self.m.enter() - if self.context == 'self': - return self - return self.context - - async def __aexit__(self, type, obj, tb): - self.m.exit(obj) - return self.swallow_exceptions - - -class SyncCM(CommonCM): - - def __enter__(self): - self.m.enter() - if self.context == 'self': - return self - return self.context - - def __exit__(self, type, obj, tb): - self.m.exit(obj) - return self.swallow_exceptions - - -def make_cm(name, type_, m, context=None, swallow_exceptions=None): - return type(name, - (type_,), - {'m': getattr(m, name.lower()), - 'context': context, - 'swallow_exceptions': swallow_exceptions}) - - -@pytest.mark.asyncio -async def test_async_context_manager(): - m = Mock() - CM = make_cm('CM', AsyncCM, m) - - async def func(): - m.func() - - runner = Runner(CM, func) - - with no_threads(): - await runner() - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.func(), - call.cm.exit(None) - ]) - - -@pytest.mark.asyncio -async def test_async_context_manager_inner_requires_cm(): - m = Mock() - CM = make_cm('CM', AsyncCM, m, context='self') - - @requires(CM) - async def func(obj): - m.func(type(obj)) - - runner = Runner(CM, func) - - with no_threads(): - await runner() - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.func(CM), - call.cm.exit(None) - ]) - - -@pytest.mark.asyncio -async def test_async_context_manager_inner_requires_context(): - m = Mock() - class CMContext: pass - cm_context = CMContext() - CM = make_cm('CM', AsyncCM, m, context=cm_context) - - @requires(CMContext) - async def func(obj): - m.func(obj) - - runner = Runner(CM, func) - - with no_threads(): - await runner() - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.func(cm_context), - call.cm.exit(None) - ]) - - -@pytest.mark.asyncio -async def test_async_context_manager_nested(): - m = Mock() - CM1 = make_cm('CM1', AsyncCM, m) - CM2 = make_cm('CM2', AsyncCM, m) - - async def func(): - m.func() - - runner = Runner(CM1, CM2, func) - - with no_threads(): - await runner() - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.func(), - call.cm2.exit(None), - call.cm1.exit(None), - ]) - - -@pytest.mark.asyncio -async def test_async_context_manager_nested_exception_inner_handles(): - m = Mock() - CM1 = make_cm('CM1', AsyncCM, m) - CM2 = make_cm('CM2', AsyncCM, m, swallow_exceptions=True) - - e = Exception() - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - with no_threads(): - await runner() - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(None), - ]) - - -@pytest.mark.asyncio -async def test_async_context_manager_nested_exception_outer_handles(): - m = Mock() - CM1 = make_cm('CM1', AsyncCM, m, swallow_exceptions=True) - CM2 = make_cm('CM2', AsyncCM, m) - - e = Exception() - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - with no_threads(): - await runner() - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(e), - ]) - - -@pytest.mark.asyncio -async def test_async_context_manager_exception_not_handled(): - m = Mock() - CM = make_cm('CM', AsyncCM, m) - - e = Exception('foo') - - async def func(): - raise e - - runner = Runner(CM, func) - - with no_threads(), ShouldRaise(e): - await runner() - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.cm.exit(e) - ]) - - -@pytest.mark.asyncio -async def test_sync_context_manager(): - m = Mock() - CM = make_cm('CM', SyncCM, m) - - async def func(): - m.func() - - runner = Runner(CM, func) - - with must_run_in_thread(CM.__enter__, CM.__exit__): - await runner() - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.func(), - call.cm.exit(None) - ]) - - -@pytest.mark.asyncio -async def test_sync_context_manager_inner_requires_cm(): - m = Mock() - CM = make_cm('CM', SyncCM, m, context='self') - - @requires(CM) - async def func(obj): - m.func(type(obj)) - - runner = Runner(CM, func) - - with must_run_in_thread(CM.__enter__, CM.__exit__): - await runner() - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.func(CM), - call.cm.exit(None) - ]) - - -@pytest.mark.asyncio -async def test_sync_context_manager_inner_requires_context(): - m = Mock() - class CMContext: pass - cm_context = CMContext() - CM = make_cm('CM', SyncCM, m, context=cm_context) - - @requires(CMContext) - async def func(obj): - m.func(obj) - - runner = Runner(CM, func) - - with must_run_in_thread(CM.__enter__, CM.__exit__): - await runner() - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.func(cm_context), - call.cm.exit(None) - ]) - - -@pytest.mark.asyncio -async def test_sync_context_manager_nested(): - m = Mock() - CM1 = make_cm('CM1', SyncCM, m) - CM2 = make_cm('CM2', SyncCM, m) - - async def func(): - m.func() - - runner = Runner(CM1, CM2, func) - - with must_run_in_thread(CM1.__enter__, CM1.__exit__, CM2.__enter__, CM2.__exit__): - await runner() - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.func(), - call.cm2.exit(None), - call.cm1.exit(None), - ]) - - -@pytest.mark.asyncio -async def test_sync_context_manager_nested_exception_inner_handles(): - m = Mock() - CM1 = make_cm('CM1', SyncCM, m) - CM2 = make_cm('CM2', SyncCM, m, swallow_exceptions=True) - - e = Exception() - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - with must_run_in_thread(CM1.__enter__, CM1.__exit__, CM2.__enter__, CM2.__exit__): - await runner() - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(None), - ]) - - -@pytest.mark.asyncio -async def test_sync_context_manager_nested_exception_outer_handles(): - m = Mock() - CM1 = make_cm('CM1', SyncCM, m, swallow_exceptions=True) - CM2 = make_cm('CM2', SyncCM, m) - - e = Exception() - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - with must_run_in_thread(CM1.__enter__, CM1.__exit__, CM2.__enter__, CM2.__exit__): - await runner() - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(e), - ]) - - -@pytest.mark.asyncio -async def test_sync_context_manager_exception_not_handled(): - m = Mock() - CM = make_cm('CM', SyncCM, m) - - e = Exception('foo') - - async def func(): - raise e - - runner = Runner(CM, func) - - with must_run_in_thread(CM.__enter__, CM.__exit__), ShouldRaise(e): - await runner() - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.cm.exit(e) - ]) - -@pytest.mark.asyncio -async def test_sync_context_then_async_context(): - m = Mock() - CM1 = make_cm('CM1', SyncCM, m) - CM2 = make_cm('CM2', AsyncCM, m) - - async def func(): - return 42 - - runner = Runner(CM1, CM2, func) - - compare(await runner(), expected=42) - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(None), - call.cm1.exit(None), - ]) - - -@pytest.mark.asyncio -async def test_async_context_then_sync_context(): - m = Mock() - CM1 = make_cm('CM1', AsyncCM, m) - CM2 = make_cm('CM2', SyncCM, m) - - async def func(): - return 42 - - runner = Runner(CM1, CM2, func) - - compare(await runner(), expected=42) - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(None), - call.cm1.exit(None), - ]) - - -@pytest.mark.asyncio -async def test_sync_context_then_async_context_exception_handled_inner(): - m = Mock() - CM1 = make_cm('CM1', SyncCM, m) - CM2 = make_cm('CM2', AsyncCM, m, swallow_exceptions=True) - - e = Exception() - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - # if something goes wrong *and handled by a CM*, you get None - compare(await runner(), expected=None) - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(None), - ]) - - -@pytest.mark.asyncio -async def test_sync_context_then_async_context_exception_handled_outer(): - m = Mock() - CM1 = make_cm('CM1', SyncCM, m, swallow_exceptions=True) - CM2 = make_cm('CM2', AsyncCM, m) - - e = Exception() - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - # if something goes wrong *and handled by a CM*, you get None - compare(await runner(), expected=None) - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(e), - ]) - - -@pytest.mark.asyncio -async def test_sync_context_then_async_context_exception_not_handled(): - m = Mock() - CM1 = make_cm('CM1', SyncCM, m) - CM2 = make_cm('CM2', AsyncCM, m) - - e = Exception('foo') - - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - with ShouldRaise(e): - await runner() - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(e), - ]) - - -@pytest.mark.asyncio -async def test_async_context_then_sync_context_exception_handled_inner(): - m = Mock() - CM1 = make_cm('CM1', AsyncCM, m) - CM2 = make_cm('CM2', SyncCM, m, swallow_exceptions=True) - - e = Exception() - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - # if something goes wrong *and handled by a CM*, you get None - compare(await runner(), expected=None) - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(None), - ]) - - -@pytest.mark.asyncio -async def test_async_context_then_sync_context_exception_handled_outer(): - m = Mock() - CM1 = make_cm('CM1', AsyncCM, m, swallow_exceptions=True) - CM2 = make_cm('CM2', SyncCM, m) - - e = Exception() - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - # if something goes wrong *and handled by a CM*, you get None - compare(await runner(), expected=None) - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(e), - ]) - - -@pytest.mark.asyncio -async def test_async_context_then_sync_context_exception_not_handled(): - m = Mock() - CM1 = make_cm('CM1', AsyncCM, m) - CM2 = make_cm('CM2', SyncCM, m) - - e = Exception('foo') - - async def func(): - raise e - - runner = Runner(CM1, CM2, func) - - with ShouldRaise(e): - await runner() - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.cm2.exit(e), - call.cm1.exit(e), - ]) - - -@pytest.mark.asyncio -async def test_context_manager_is_last_callpoint(): - m = Mock() - CM = make_cm('CM', AsyncCM, m) - - runner = Runner(CM) - - compare(await runner(), expected=None) - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.cm.exit(None), - ]) diff --git a/mush/tests/test_callpoints.py b/mush/tests/test_callpoints.py deleted file mode 100644 index 911843e..0000000 --- a/mush/tests/test_callpoints.py +++ /dev/null @@ -1,103 +0,0 @@ -from functools import partial - -from testfixtures import compare -from testfixtures.mock import Mock, call -import pytest - -from mush.callpoints import CallPoint -from mush.declarations import ( - requires, returns, RequirementsDeclaration, ReturnsDeclaration, update_wrapper -) -from mush.requirements import Value - - -@pytest.fixture() -def context(): - return Mock() - - -class TestCallPoints: - - def test_passive_attributes(self): - # these are managed by Modifiers - point = CallPoint(Mock()) - compare(point.previous, None) - compare(point.next, None) - compare(point.labels, set()) - - def test_supplied_explicitly(self, context): - def foo(a1): pass - rq = requires('foo') - rt = returns('bar') - result = CallPoint(foo, rq, rt)(context) - compare(result, context.extract.return_value) - compare(context.extract.mock_calls, - expected=[call(foo, rq, rt)]) - - def test_extract_from_decorations(self, context): - rq = requires('foo') - rt = returns('bar') - - @rq - @rt - def foo(a1): pass - - result = CallPoint(foo)(context) - compare(result, context.extract.return_value) - compare(context.extract.mock_calls, - expected=[call(foo, None, None)]) - - def test_extract_from_decorated_class(self, context): - - rq = requires('foo') - rt = returns('bar') - - class Wrapper(object): - def __init__(self, func): - self.func = func - def __call__(self): - return self.func('the ') - - def my_dec(func): - return update_wrapper(Wrapper(func), func) - - @my_dec - @rq - @rt - def foo(prefix): - return prefix+'answer' - - context.extract.side_effect = lambda func, rq, rt: (func(), rq, rt) - result = CallPoint(foo)(context) - compare(result, expected=('the answer', None, None)) - - def test_repr_minimal(self): - def foo(): pass - point = CallPoint(foo) - compare("TestCallPoints.test_repr_minimal..foo requires() returns('foo')", - actual=repr(point)) - - def test_repr_partial(self): - def foo(): pass - point = CallPoint(partial(foo)) - compare(f"functools.partial({foo!r}) requires() returns('foo')", - actual=repr(point)) - - def test_repr_maximal(self): - def foo(a1): pass - point = CallPoint(foo, requires('foo'), returns('bar')) - point.labels.add('baz') - point.labels.add('bob') - compare("TestCallPoints.test_repr_maximal..foo " - "requires(Value('foo')) returns('bar') <-- baz, bob", - actual=repr(point)) - - def test_convert_to_requires_and_returns(self): - def foo(baz): pass - point = CallPoint(foo, requires='foo', returns='bar') - # this is deferred until later - assert isinstance(point.requires, str) - assert isinstance(point.returns, str) - compare("TestCallPoints.test_convert_to_requires_and_returns..foo " - "requires(Value('foo')) returns('bar')", - expected=repr(point)) diff --git a/mush/tests/test_context.py b/mush/tests/test_context.py deleted file mode 100644 index 6df2d6c..0000000 --- a/mush/tests/test_context.py +++ /dev/null @@ -1,776 +0,0 @@ -from functools import partial -from typing import NewType, Mapping, Any, Tuple - -import pytest -from testfixtures import ShouldRaise, compare -from testfixtures.mock import Mock - -from mush import Context, Requirement, Value, requires, missing -from mush.context import ResourceError -from .helpers import TheType, Type1, Type2 -from ..compat import PY_37_PLUS -from ..declarations import ignore_return -from ..resources import ResourceValue, Provider, ResourceKey - - -class TestAdd: - - def test_by_inferred_type(self): - obj = TheType() - context = Context() - context.add(obj) - - compare(context._store, expected={(TheType, None): ResourceValue(obj)}) - expected = ( - "\n" - "}>" - ) - compare(expected, actual=repr(context)) - compare(expected, actual=str(context)) - - def test_by_identifier(self): - obj = TheType() - context = Context() - context.add(obj, identifier='my label') - - compare(context._store, expected={ - (TheType, 'my label'): ResourceValue(obj), - (None, 'my label'): ResourceValue(obj), - }) - expected = ("\n" - " TheType, 'my label': \n" - "}>") - compare(expected, actual=repr(context)) - compare(expected, actual=str(context)) - - def test_by_identifier_only(self): - obj = TheType() - context = Context() - context.add(obj, provides=None, identifier='my label') - - compare(context._store, expected={(None, 'my label'): ResourceValue(obj)}) - expected = ("\n" - "}>") - compare(expected, actual=repr(context)) - compare(expected, actual=str(context)) - - def test_explicit_type(self): - obj = TheType() - context = Context() - context.add(obj, provides=Type2) - compare(context._store, expected={(Type2, None): ResourceValue(obj)}) - expected = ("\n" - "}>") - compare(expected, actual=repr(context)) - compare(expected, actual=str(context)) - - def test_clash_just_type(self): - obj1 = TheType() - obj2 = TheType() - context = Context() - context.add(obj1, TheType) - with ShouldRaise(ResourceError(f'Context already contains TheType')): - context.add(obj2, TheType) - - def test_clash_just_identifier(self): - obj1 = TheType() - obj2 = TheType() - context = Context() - context.add(obj1, provides=None, identifier='my label') - with ShouldRaise(ResourceError("Context already contains 'my label'")): - context.add(obj2, provides=None, identifier='my label') - - def test_clash_identifier_only_with_identifier_plus_type(self): - obj1 = TheType() - obj2 = TheType() - context = Context() - context.add(obj1, provides=None, identifier='my label') - with ShouldRaise(ResourceError("Context already contains 'my label'")): - context.add(obj2, identifier='my label') - - def test_clash_identifier_plus_type_with_identifier_only(self): - obj1 = TheType() - obj2 = TheType() - context = Context() - context.add(obj1, identifier='my label') - with ShouldRaise(ResourceError("Context already contains 'my label'")): - context.add(obj2, provides=None, identifier='my label') - - -class TestCall: - - def test_no_params(self): - def foo(): - return 'bar' - context = Context() - result = context.call(foo) - compare(result, 'bar') - - def test_type_from_annotation(self): - def foo(baz: str): - return baz - context = Context() - context.add('bar') - result = context.call(foo) - compare(result, expected='bar') - - def test_identifier_from_annotation(self): - def foo(baz: str): - return baz - context = Context() - context.add('bar', provides=str) - context.add('bob', identifier='baz') - result = context.call(foo) - compare(result, expected='bob') - - def test_by_identifier_only(self): - def foo(param): - return param - - context = Context() - context.add('bar', identifier='param') - result = context.call(foo) - compare(result, 'bar') - - def test_requires_missing(self): - def foo(obj: TheType): return obj - context = Context() - with ShouldRaise(ResourceError( - "obj: TheType could not be satisfied" - )): - context.call(foo) - - def test_optional_type_present(self): - def foo(x: TheType = 1): - return x - context = Context() - context.add(2, TheType) - result = context.call(foo) - compare(result, 2) - - def test_optional_type_missing(self): - def foo(x: TheType = 1): - return x - context = Context() - result = context.call(foo) - compare(result, 1) - - def test_optional_identifier_present(self): - def foo(x=1): - return x - - context = Context() - context.add(2, identifier='x') - result = context.call(foo) - compare(result, 2) - - def test_optional_identifier_missing(self): - def foo(x=1): - return x - - context = Context() - context.add(2) - result = context.call(foo) - compare(result, 1) - - def test_requires_context(self): - context = Context() - - def return_context(context_: Context): - return context_ - - assert context.call(return_context) is context - - def test_base_class_should_not_match(self): - def foo(obj: TheType): return obj - context = Context() - context.add(object()) - with ShouldRaise(ResourceError( - "obj: TheType could not be satisfied" - )): - context.call(foo) - - def test_requires_typing(self): - Request = NewType('Request', dict) - context = Context() - request = {} - context.add(request, provides=Request) - - def returner(request_: Request): - return request_ - - assert context.call(returner) is request - - def test_requires_typing_missing_typing(self): - context = Context() - - def returner(request_: Mapping[str, Any]): - pass - - if PY_37_PLUS: - expected = "request_: typing.Mapping[str, typing.Any] could not be satisfied" - else: - expected = "request_: Mapping could not be satisfied" - - with ShouldRaise(ResourceError(expected)): - context.call(returner) - - def test_requires_typing_missing_new_type(self): - Request = NewType('Request', dict) - context = Context() - - def returner(request_: Request): - pass - - with ShouldRaise(ResourceError( - "request_: Request could not be satisfied" - )): - context.call(returner) - - def test_requires_requirement(self): - context = Context() - - def foo(requirement: Requirement): pass - - with ShouldRaise(ResourceError( - "requirement: Requirement could not be satisfied" - )): - context.call(foo) - - def test_keyword_only(self): - def foo(*, x: int): - return x - - context = Context() - context.add(2) - result = context.call(foo) - compare(result, expected=2) - - def test_call_requires_string(self): - def foo(obj): - return obj - context = Context() - context.add('bar', identifier='baz') - result = context.call(foo, requires('baz')) - compare(result, expected='bar') - - def test_call_requires_type(self): - def foo(obj): - return obj - context = Context() - context.add('bar', TheType) - result = context.call(foo, requires(TheType)) - compare(result, 'bar') - - def test_call_requires_optional_override_source_and_default(self): - def foo(x=1): - return x - context = Context() - context.add(2, provides='x') - result = context.call(foo, requires(x=Value('y', default=3))) - compare(result, expected=3) - - def test_kw_parameter(self): - def foo(x, y): - return x, y - context = Context() - context.add('foo', TheType) - context.add('bar', identifier='baz') - result = context.call(foo, requires(y='baz', x=TheType)) - compare(result, expected=('foo', 'bar')) - - def test_custom_requirement(self): - - class FromRequest(Requirement): - - def __init__(self, name): - super().__init__([ResourceKey(identifier='request')]) - self.name = name - - def process(self, obj): - # this example doesn't show it, but this is a method so - # there can be conditional stuff in here: - return obj.get(self.name, missing) - - def foo(bar: str): - return bar - - context = Context() - context.add({'bar': 'foo'}, identifier='request') - compare(context.call(foo, requires=FromRequest('bar')), expected='foo') - with ShouldRaise(ResourceError( - "FromRequest(ResourceKey('request'), name='baz') could not be satisfied" - )): - context.call(foo, requires=FromRequest('baz')) - - -class TestOps: - - def test_call_requires_item(self): - def foo(x: str = Value(identifier='foo')['bar']): - return x - context = Context() - context.add(dict(bar='baz'), identifier='foo') - result = context.call(foo) - compare(result, expected='baz') - - def test_call_requires_item_missing(self): - def foo(obj: str = Value(dict)['foo']): pass - context = Context() - context.add({}) - with ShouldRaise(ResourceError( - "Value(dict)['foo'] could not be satisfied", - )): - context.call(foo) - - def test_call_requires_optional_item_missing(self): - def foo(x: str = Value('foo', default=1)['bar']): - return x - context = Context() - result = context.call(foo) - compare(result, expected=1) - - def test_call_requires_optional_item_present(self): - def foo(x: str = Value('foo', default=1)['bar']): - return x - context = Context() - context.add(dict(bar='baz'), identifier='foo') - result = context.call(foo) - compare(result, expected='baz') - - def test_call_requires_attr(self): - @requires(Value('foo').bar) - def foo(x): - return x - m = Mock() - context = Context() - context.add(m, identifier='foo') - result = context.call(foo) - compare(result, m.bar) - - def test_call_requires_attr_missing(self): - @requires(Value('foo').bar) - def foo(x): - pass - o = object() - context = Context() - context.add(o, identifier='foo') - with ShouldRaise(ResourceError( - "Value('foo').bar could not be satisfied", - )): - context.call(foo) - - def test_call_requires_optional_attr_missing(self): - @requires(Value('foo', default=1).bar) - def foo(x): - return x - o = object() - context = Context() - context.add(o, identifier='foo') - result = context.call(foo) - compare(result, expected=1) - - def test_call_requires_optional_attr_present(self): - @requires(Value('foo', default=1).bar) - def foo(x): - return x - m = Mock() - context = Context() - context.add(m, identifier='foo') - result = context.call(foo) - compare(result, expected=m.bar) - - def test_call_requires_item_attr(self): - @requires(Value('foo').bar['baz']) - def foo(x): - return x - m = Mock() - m.bar = dict(baz='bob') - context = Context() - context.add(m, identifier='foo') - result = context.call(foo) - compare(result, expected='bob') - - -class TestExtract: - - def test_extract_minimal(self): - o = TheType() - def foo(): - return o - context = Context() - result = context.extract(foo) - assert result is o - compare({ResourceKey(identifier='foo'): ResourceValue(o)}, actual=context._store) - - def test_extract_maximal(self): - def foo(o: str) -> Tuple[str, ...]: - return o, o - context = Context() - context.add('a') - result = context.extract(foo) - compare(result, expected=('a', 'a')) - compare({ - ResourceKey(str): ResourceValue('a'), - ResourceKey(identifier='foo'): ResourceValue(result), - ResourceKey(Tuple[str, ...], 'foo'): ResourceValue(result), - ResourceKey(Tuple[str, ...]): ResourceValue(result), - }, actual=context._store) - - def test_ignore_return(self): - @ignore_return - def foo(): - return 'bar' - context = Context() - result = context.extract(foo) - compare(result, 'bar') - compare({}, context._store) - - def test_returns_none(self): - def foo(): pass - context = Context() - result = context.extract(foo) - compare(result, expected=None) - compare(context._store, expected={ - ResourceKey(identifier='foo'): ResourceValue(None), - }) - - -class TestProviders: - - def test_cached(self): - items = [] - - def provider(): - items.append(1) - return sum(items) - - context = Context() - context.add(Provider(provider), provides=int) - - def returner(obj: int): - return obj - - compare(context.call(returner), expected=1) - compare(context.call(returner), expected=1) - - def test_not_cached(self): - items = [] - - def provider(): - items.append(1) - return sum(items) - - context = Context() - context.add(Provider(provider, cache=False), provides=int) - - def returner(obj: int): - return obj - - compare(context.call(returner), expected=1) - compare(context.call(returner), expected=2) - - def test_needs_resources(self): - def provider(start: int): - return start*2 - - context = Context() - context.add(Provider(provider), provides=int) - context.add(4, identifier='start') - - def returner(obj: int): - return obj - - compare(context.call(returner), expected=8) - - def test_needs_requirement(self): - def provider(requirement: Requirement): - return requirement.keys[0].identifier - - context = Context() - context.add(Provider(provider), provides=str) - - def returner(obj: str): - return obj - - compare(context.call(returner), expected='obj') - - def test_needs_resource_key(self): - def provider(key: ResourceKey): - return key.type, key.identifier - - context = Context() - context.add(Provider(provider), provides=tuple) - - def returner(obj: tuple): - return obj - - compare(context.call(returner), expected=(tuple, 'obj')) - - def test_provides_subclasses(self): - class Base: pass - - class TheType(Base): pass - - def provider(requirement: Requirement): - return requirement.keys[0].type() - - def foo(bar: TheType): - return bar - - context = Context() - context.add(Provider(provider, provides_subclasses=True), provides=Base) - - assert isinstance(context.call(foo), TheType) - - def test_provides_subclasses_caching(self): - class Base: pass - class Type1(Base): pass - class Type2(Base): pass - - t1 = Type1() - t2 = Type2() - instances = {Type1: t1, Type2: t2} - - def provider(requirement: Requirement): - # .pop so each instance can only be obtained once! - return instances.pop(requirement.keys[0].type) - - def foo(bar): - return bar - - context = Context() - context.add(Provider(provider, cache=True, provides_subclasses=True), provides=Base) - - assert context.call(foo, requires=Type1) is t1 - # cached: - assert context.call(foo, requires=Type1) is t1 - assert context.call(foo, requires=Type2) is t2 - assert context.call(foo, requires=Type2) is t2 - - def test_does_not_provide_subclasses(self): - def foo(obj: TheType): pass - - context = Context() - context.add(Provider(lambda: None), provides=object) - - with ShouldRaise(ResourceError( - "obj: TheType could not be satisfied" - )): - context.call(foo) - - def test_multiple_providers_using_requirement(self): - def provider(requirement: Requirement): - return requirement.keys[0].type() - - def foo(t1: Type1, t2: Type2): - return t1, t2 - - context = Context() - context.add(Provider(provider), provides=Type1) - context.add(Provider(provider), provides=Type2) - - t1, t2 = context.call(foo) - assert isinstance(t1, Type1) - assert isinstance(t2, Type2) - - def test_nested_providers_using_requirement(self): - class Base1: pass - - class Type1(Base1): pass - - def provider1(requirement: Requirement): - return requirement.keys[0].type() - - class Base2: - def __init__(self, x): - self.x = x - - class Type2(Base2): pass - - # order here is important - def provider2(t1: Type1, requirement: Requirement): - return requirement.keys[0].type(t1) - - def foo(t2: Type2): - return t2 - - context = Context() - context.add(Provider(provider1, provides_subclasses=True), provides=Base1) - context.add(Provider(provider2, provides_subclasses=True), provides=Base2) - - t2 = context.call(foo) - assert isinstance(t2, Type2) - assert isinstance(t2.x, Type1) - - def test_from_return_type_annotation(self): - def provider() -> Type1: - return Type1() - - context = Context() - context.add(Provider(provider)) - - def returner(obj: Type1): - return obj - - assert isinstance(context.call(returner), Type1) - - def test_no_provides(self): - provider = Mock() - context = Context() - with ShouldRaise(ResourceError( - f'Could not determine what is provided by ' - f'Provider(functools.partial({provider}), cache=True, provides_subclasses=False)' - )): - context.add(Provider(partial(provider))) - - def test_identifier(self): - def provider() -> str: - return 'some foo' - - context = Context() - context.add(Provider(provider), identifier='param') - - def foo(param): - return param - - compare(context.call(foo), expected='some foo') - - def test_identifier_only(self): - def provider(): - return 'some foo' - - context = Context() - context.add(Provider(provider), identifier='param') - - def foo(param): - return param - - compare(context.call(foo), expected='some foo') - - def test_minimal_representation(self): - def provider(): pass - context = Context() - context.add(Provider(provider), provides=str) - expected = ("") - compare(expected, actual=repr(context)) - compare(expected, actual=str(context)) - - def test_maximal_representation(self): - def provider() -> str: pass - p = Provider(provider, cache=False, provides_subclasses=True) - p.obj = 'it' - context = Context() - context.add(p, provides=str, identifier='the id') - expected = ("") - compare(expected, actual=repr(context)) - compare(expected, actual=str(context)) - - -class TestNesting: - - def test_nest(self): - c1 = Context() - c1.add('c1a', identifier='a') - c1.add('c1c', identifier='c') - c2 = c1.nest() - c2.add('c2b', identifier='b') - c2.add('c2c', identifier='c') - - def foo(a, b=None, c=None): - return a, b, c - - compare(c2.call(foo), expected=('c1a', 'c2b', 'c2c')) - compare(c1.call(foo), expected=('c1a', None, 'c1c')) - - def test_uses_existing_cached_value(self): - class X: pass - - x_ = X() - - xs = [x_] - - def make_x(): - return xs.pop() - - c1 = Context() - c1.add(Provider(make_x, cache=True), identifier='x') - - assert c1.call(lambda x: x) is x_ - c2 = c1.nest() - assert c2.call(lambda x: x) is x_ - - assert c2.call(lambda x: x) is x_ - assert c1.call(lambda x: x) is x_ - - def test_stored_cached_value_in_nested_context(self): - class X: pass - - x1 = X() - x2 = X() - - xs = [x2, x1] - - def make_x(): - return xs.pop() - - c1 = Context() - c1.add(Provider(make_x, cache=True), identifier='x') - - c2 = c1.nest() - assert c2.call(lambda x: x) is x1 - assert c1.call(lambda x: x) is x2 - - assert c1.call(lambda x: x) is x2 - assert c2.call(lambda x: x) is x1 - - def test_no_cache_in_nested(self): - class X: pass - - x1 = X() - x2 = X() - - xs = [x2, x1] - - def make_x(): - return xs.pop() - - c1 = Context() - c1.add(Provider(make_x, cache=False), identifier='x') - - c2 = c1.nest() - assert c2.call(lambda x: x) is x1 - assert c2.call(lambda x: x) is x2 - - def test_provider_uses_resources_from_nested_context(self): - - def expanded(it: str): - return it*2 - - c1 = Context() - c1.add(Provider(expanded)) - - c2 = c1.nest() - c2.add('foo') - - compare(c2.call(lambda expanded: expanded), expected='foofoo') - - def test_with_default_requirement(self): - - def make_requirement(name, type_, default) -> Requirement: - pass - - c1 = Context(default_requirement=make_requirement) - c2 = c1.nest() - assert c2._default_requirement is make_requirement diff --git a/mush/tests/test_context_py38.py b/mush/tests/test_context_py38.py deleted file mode 100644 index a3dc935..0000000 --- a/mush/tests/test_context_py38.py +++ /dev/null @@ -1,23 +0,0 @@ -from testfixtures import compare - -from mush import Context - - -class TestCall: - - def test_positional_only(self): - def foo(x:int, /): - return x - - context = Context() - context.add(2) - result = context.call(foo) - compare(result, expected=2) - - def test_positional_only_with_default(self): - def foo(x:int = 1, /): - return x - - context = Context() - result = context.call(foo) - compare(result, expected=1) diff --git a/mush/tests/test_declarations.py b/mush/tests/test_declarations.py deleted file mode 100644 index 4638721..0000000 --- a/mush/tests/test_declarations.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Tuple -from unittest import TestCase - -from testfixtures import compare, ShouldRaise - -from mush import Value, AnyOf -from mush.declarations import requires, returns, Parameter, RequirementsDeclaration, \ - ReturnsDeclaration -from .helpers import PY_36, Type1, Type2, Type3, Type4, TheType -from ..resources import ResourceKey - - -class TestRequires(TestCase): - - def test_empty(self): - r = requires() - compare(repr(r), 'requires()') - compare(r, expected=[]) - - def test_types(self): - r_ = requires(Type1, Type2, x=Type3, y=Type4) - compare(repr(r_), 'requires(Value(Type1), Value(Type2), x=Value(Type3), y=Value(Type4))') - compare(r_, expected=[ - Parameter(Value(Type1)), - Parameter(Value(Type2)), - Parameter(Value(Type3), target='x'), - Parameter(Value(Type4), target='y'), - ]) - - def test_strings(self): - r_ = requires('1', '2', x='3', y='4') - compare(repr(r_), "requires(Value('1'), Value('2'), x=Value('3'), y=Value('4'))") - compare(r_, expected=[ - Parameter(Value('1')), - Parameter(Value('2')), - Parameter(Value('3'), target='x'), - Parameter(Value('4'), target='y'), - ]) - - def test_typing(self): - r_ = requires(Tuple[str]) - text = 'Tuple' if PY_36 else 'typing.Tuple[str]' - compare(repr(r_),expected=f"requires(Value({text}))") - compare(r_, expected=[Parameter(Value(Tuple[str]))]) - - def test_tuple_arg(self): - with ShouldRaise(TypeError("('1', '2') is not a valid decoration type")): - requires(('1', '2')) - - def test_tuple_kw(self): - with ShouldRaise(TypeError("('1', '2') is not a valid decoration type")): - requires(foo=('1', '2')) - - def test_decorator_paranoid(self): - @requires(Type1) - def foo(): - return 'bar' - - compare(foo.__mush__['requires'], expected=[Parameter(Value(Type1))]) - compare(foo(), 'bar') - - def test_requirement_instance(self): - compare(requires(x=AnyOf('foo', 'bar')), - expected=RequirementsDeclaration([Parameter(AnyOf('foo', 'bar'), target='x')]), - strict=True) - - def test_accidental_tuple(self): - with ShouldRaise(TypeError( - "(, " - ") " - "is not a valid decoration type" - )): - requires((TheType, TheType)) - - -class TestReturns(TestCase): - - def test_type(self): - r = returns(Type1) - compare(repr(r), 'returns(Type1)') - compare(r, expected=ReturnsDeclaration((ResourceKey(Type1),))) - - def test_string(self): - r = returns('bar') - compare(repr(r), "returns('bar')") - compare(r, expected=ReturnsDeclaration((ResourceKey(identifier='bar'),))) - - def test_typing(self): - r = returns(Tuple[str]) - text = 'Tuple' if PY_36 else 'typing.Tuple[str]' - compare(repr(r), f'returns({text})') - compare(r, expected=ReturnsDeclaration((ResourceKey(Tuple[str]),))) - - def test_decorator(self): - @returns(Type1) - def foo(): - pass - r = foo.__mush__['returns'] - compare(repr(r), 'returns(Type1)') - compare(r, expected=ReturnsDeclaration((ResourceKey(Type1),))) - - def test_bad_type(self): - with ShouldRaise(TypeError( - '[] is not a valid decoration type' - )): - @returns([]) - def foo(): pass - - def test_keys_are_orderable(self): - r = returns(Type1, 'foo') - compare(repr(r), expected="returns('foo', Type1)") diff --git a/mush/tests/test_example_with_mush_clone.py b/mush/tests/test_example_with_mush_clone.py deleted file mode 100644 index 2f60a38..0000000 --- a/mush/tests/test_example_with_mush_clone.py +++ /dev/null @@ -1,98 +0,0 @@ -from .example_with_mush_clone import DatabaseHandler, main, do, setup_logging -from unittest import TestCase -from testfixtures import TempDirectory -from testfixtures import Replacer -from testfixtures import LogCapture -from testfixtures import ShouldRaise -import sqlite3 - - -class Tests(TestCase): - - def test_main(self): - with TempDirectory() as d: - # setup db - db_path = d.getpath('sqlite.db') - conn = sqlite3.connect(db_path) - conn.execute('create table notes (filename varchar, text varchar)') - conn.commit() - # setup config - config = d.write('config.ini', ''' -[main] -db = %s -log = %s -''' % (db_path, d.getpath('script.log')), 'ascii') - # setup file to read - source = d.write('test.txt', 'some text', 'ascii') - with Replacer() as r: - r.replace('sys.argv', ['script.py', config, source, '--quiet']) - main() - # check results - self.assertEqual( - conn.execute('select * from notes').fetchall(), - [('test.txt', 'some text')] - ) - - # coverage.py says no test of branch to log.check call! - def test_do(self): - # setup db - conn = sqlite3.connect(':memory:') - conn.execute('create table notes (filename varchar, text varchar)') - conn.commit() - with TempDirectory() as d: - # setup file to read - source = d.write('test.txt', 'some text', 'ascii') - with LogCapture() as log: - # call the function under test - do(conn, source) # pragma: no branch (coverage.py bug) - # check results - self.assertEqual( - conn.execute('select * from notes').fetchall(), - [('test.txt', 'some text')] - ) - # check logging - log.check(('root', 'INFO', "Successfully added 'test.txt'")) - - def test_setup_logging(self): - with TempDirectory() as dir: - with LogCapture(): - setup_logging(dir.getpath('test.log'), verbose=True) - - -class DatabaseHandlerTests(TestCase): - - def setUp(self): - self.dir = TempDirectory() - self.addCleanup(self.dir.cleanup) - self.db_path = self.dir.getpath('test.db') - self.conn = sqlite3.connect(self.db_path) - self.conn.execute('create table notes ' - '(filename varchar, text varchar)') - self.conn.commit() - self.log = LogCapture() - self.addCleanup(self.log.uninstall) - - def test_normal(self): - with DatabaseHandler(self.db_path) as handler: - handler.conn.execute('insert into notes values (?, ?)', - ('test.txt', 'a note')) - handler.conn.commit() - # check the row was inserted and committed - curs = self.conn.cursor() - curs.execute('select * from notes') - self.assertEqual(curs.fetchall(), [('test.txt', 'a note')]) - # check there was no logging - self.log.check() - - def test_exception(self): - with ShouldRaise(Exception('foo')): - with DatabaseHandler(self.db_path) as handler: - handler.conn.execute('insert into notes values (?, ?)', - ('test.txt', 'a note')) - raise Exception('foo') - # check the row not inserted and the transaction was rolled back - curs = handler.conn.cursor() - curs.execute('select * from notes') - self.assertEqual(curs.fetchall(), []) - # check the error was logged - self.log.check(('root', 'ERROR', 'Something went wrong')) diff --git a/mush/tests/test_example_with_mush_factory.py b/mush/tests/test_example_with_mush_factory.py deleted file mode 100644 index 58f24ba..0000000 --- a/mush/tests/test_example_with_mush_factory.py +++ /dev/null @@ -1,31 +0,0 @@ -from .example_with_mush_factory import main - -from unittest import TestCase -from testfixtures import TempDirectory, Replacer -import sqlite3 - -class Tests(TestCase): - - def test_main(self): - with TempDirectory() as d: - # setup db - db_path = d.getpath('sqlite.db') - conn = sqlite3.connect(db_path) - conn.execute('create table notes (filename varchar, text varchar)') - conn.commit() - # setup config - config = d.write('config.ini', ''' -[main] -db = %s -log = %s -''' % (db_path, d.getpath('script.log')), 'ascii') - # setup file to read - source = d.write('test.txt', 'some text', 'ascii') - with Replacer() as r: - r.replace('sys.argv', ['script.py', config, source, '--quiet']) - main() - # check results - self.assertEqual( - conn.execute('select * from notes').fetchall(), - [('test.txt', 'some text')] - ) diff --git a/mush/tests/test_example_without_mush.py b/mush/tests/test_example_without_mush.py deleted file mode 100644 index f9f1ee9..0000000 --- a/mush/tests/test_example_without_mush.py +++ /dev/null @@ -1,75 +0,0 @@ -from .example_without_mush import main -from unittest import TestCase -from testfixtures import TempDirectory, Replacer, OutputCapture -import sqlite3 - - -class Tests(TestCase): - - def test_main(self): - with TempDirectory() as d: - # setup db - db_path = d.getpath('sqlite.db') - conn = sqlite3.connect(db_path) - conn.execute('create table notes (filename varchar, text varchar)') - conn.commit() - # setup config - config = d.write('config.ini', ''' -[main] -db = %s -log = %s -''' % (db_path, d.getpath('script.log')), 'ascii') - # setup file to read - source = d.write('test.txt', 'some text', 'ascii') - with Replacer() as r: - r.replace('sys.argv', ['script.py', config, source, '--quiet']) - main() - # check results - self.assertEqual( - conn.execute('select * from notes').fetchall(), - [('test.txt', 'some text')] - ) - - def test_main_verbose(self): - with TempDirectory() as d: - # setup db - db_path = d.getpath('sqlite.db') - conn = sqlite3.connect(db_path) - conn.execute('create table notes (filename varchar, text varchar)') - conn.commit() - # setup config - config = d.write('config.ini', ''' -[main] -db = %s -log = %s -''' % (db_path, d.getpath('script.log')), 'ascii') - # setup file to read - source = d.write('test.txt', 'some text', 'ascii') - with Replacer() as r: - r.replace('sys.argv', ['script.py', config, source]) - with OutputCapture() as output: - main() - output.compare("Successfully added 'test.txt'") - - def test_main_exception(self): - with TempDirectory() as d: - from testfixtures import OutputCapture - # setup db - db_path = d.getpath('sqlite.db') - conn = sqlite3.connect(db_path) - # don't create the table so we get at exception - conn.commit() - # setup config - config = d.write('config.ini', ''' -[main] -db = %s -log = %s -''' % (db_path, d.getpath('script.log')), 'ascii') - # setup file to read - source = d.write('bad.txt', 'some text', 'ascii') - with Replacer() as r: - r.replace('sys.argv', ['script.py', config, source]) - with OutputCapture() as output: - main() - self.assertTrue('OperationalError' in output.captured, - output.captured) diff --git a/mush/tests/test_extraction.py b/mush/tests/test_extraction.py deleted file mode 100644 index afa7516..0000000 --- a/mush/tests/test_extraction.py +++ /dev/null @@ -1,402 +0,0 @@ -from functools import partial -from typing import Optional, Any -from testfixtures.mock import Mock - -import pytest -from testfixtures import compare - -from mush import Value, update_wrapper -from mush.declarations import ( - requires, returns, requires_nothing, RequirementsDeclaration, Parameter, ReturnsDeclaration, - returns_nothing -) -from mush.extraction import extract_requires, extract_returns -from mush.requirements import Requirement, Annotation -from .helpers import Type1, Type2, Type3 -from ..resources import ResourceKey -from ..typing import Type_ - -returns_foo = ReturnsDeclaration([ResourceKey(identifier='foo')]) - - -def check_extract(obj, expected_rq, expected_rt=returns_foo): - rq = extract_requires(obj) - rt = extract_returns(obj) - compare(rq, expected=expected_rq, strict=True) - compare(rt, expected=expected_rt, strict=True) - - -class TestRequirementsExtraction: - - def test_default_requirements_for_function(self): - def foo(a, b=None): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a')), - Parameter(Annotation('b', default=None), default=None), - ))) - - def test_default_requirements_for_class(self): - class MyClass(object): - def __init__(self, a, b=None): pass - check_extract(MyClass, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a')), - Parameter(Annotation('b', default=None), default=None), - )), - expected_rt=ReturnsDeclaration([ - ResourceKey(MyClass), - ResourceKey(identifier='MyClass'), - ResourceKey(MyClass, 'MyClass'), - ])) - - def test_extract_from_partial(self): - def foo(x, y, z, a=None): pass - p = partial(foo, 1, y=2) - check_extract( - p, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('z'), target='z'), - Parameter(Annotation('a', default=None), target='a', default=None), - )) - ) - - def test_extract_from_partial_default_not_in_partial(self): - def foo(a=None): pass - p = partial(foo) - check_extract( - p, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a', default=None), default=None), - )) - ) - - def test_extract_from_partial_default_in_partial_arg(self): - def foo(a=None): pass - p = partial(foo, 1) - check_extract( - p, - # since a is already bound by the partial: - expected_rq=requires_nothing - ) - - def test_extract_from_partial_default_in_partial_kw(self): - def foo(a=None): pass - p = partial(foo, a=1) - check_extract( - p, - expected_rq=requires_nothing - ) - - def test_extract_from_partial_required_in_partial_arg(self): - def foo(a): pass - p = partial(foo, 1) - check_extract( - p, - # since a is already bound by the partial: - expected_rq=requires_nothing - ) - - def test_extract_from_partial_required_in_partial_kw(self): - def foo(a): pass - p = partial(foo, a=1) - check_extract( - p, - expected_rq=requires_nothing - ) - - def test_extract_from_partial_plus_one_default_not_in_partial(self): - def foo(b, a=None): pass - p = partial(foo) - check_extract( - p, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('b')), - Parameter(Annotation('a', default=None), default=None), - )) - ) - - def test_extract_from_partial_plus_one_required_in_partial_arg(self): - def foo(b, a): pass - p = partial(foo, 1) - check_extract( - p, - # since b is already bound: - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a')), - )) - ) - - def test_extract_from_partial_plus_one_required_in_partial_kw(self): - def foo(b, a): pass - p = partial(foo, a=1) - check_extract( - p, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('b')), - )) - ) - - def test_extract_from_mock(self): - foo = Mock() - check_extract( - foo, - expected_rq=requires_nothing, - expected_rt=returns_nothing, - ) - - -# https://bugs.python.org/issue41872 -def foo_(a: 'Foo') -> 'Bar': pass -class Foo: pass -class Bar: pass - - -class TestExtractDeclarationsFromTypeAnnotations: - - def test_extract_from_annotations(self): - def foo(a: Type1, b, c: Type2 = 1, d=2) -> Type3: pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a', Type1), type_=Type1), - Parameter(Annotation('b')), - Parameter(Annotation('c', Type2, default=1), type_=Type2, default=1), - Parameter(Annotation('d', default=2), default=2), - )), - expected_rt=ReturnsDeclaration([ - ResourceKey(Type3), - ResourceKey(identifier='foo'), - ResourceKey(Type3, 'foo'), - ])) - - def test_forward_type_references(self): - check_extract(foo_, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a', Foo), type_=Foo), - )), - expected_rt=ReturnsDeclaration([ - ResourceKey(Bar), - ResourceKey(identifier='foo_'), - ResourceKey(Bar, 'foo_'), - ])) - - def test_requires_only(self): - def foo(a: Type1): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a', Type1), type_=Type1), - ))) - - def test_returns_only(self): - def foo() -> Type1: pass - check_extract(foo, - expected_rq=requires_nothing, - expected_rt=ReturnsDeclaration([ - ResourceKey(Type1), - ResourceKey(identifier='foo'), - ResourceKey(Type1, 'foo'), - ])) - - def test_returns_nothing(self): - def foo() -> None: pass - check_extract(foo, - expected_rq=requires_nothing, - expected_rt=ReturnsDeclaration()) - - def test_extract_from_decorated_class(self): - - class Wrapper(object): - def __init__(self, func): - self.func = func - def __call__(self): - return 'the '+self.func() - - def my_dec(func): - return update_wrapper(Wrapper(func), func) - - @my_dec - @requires(a=Value('foo')) - @returns('bar') - def foo(a=None): - return 'answer' - - compare(foo(), expected='the answer') - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Value(identifier='foo'), target='a'), - )), - expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')])) - - def test_decorator_preferred_to_annotations(self): - @requires('foo') - @returns('bar') - def foo(a: Type1) -> Type2: pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Value(identifier='foo'), type_=Type1),) - ), - expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')])) - - def test_default_requirements(self): - def foo(a, b=1, *, c, d=None): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a')), - Parameter(Annotation('b', default=1), default=1), - Parameter(Annotation('c'), target='c'), - Parameter(Annotation('d', default=None), target='d', default=None) - ))) - - def test_type_only(self): - class T: pass - def foo(a: T): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a', T), type_=T), - )), - expected_rt=ReturnsDeclaration([ResourceKey(identifier='foo')])) - - @pytest.mark.parametrize("type_", [str, int, dict, list]) - def test_simple_type_only(self, type_): - def foo(a: type_): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a', type_), type_=type_), - ))) - - def test_type_plus_value(self): - def foo(a: str = Value('b')): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Value(identifier='b'), type_=str), - ))) - - def test_type_plus_value_with_default(self): - def foo(a: str = Value('b', default=1)): pass - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Value(identifier='b', default=1), type_=str, default=1), - ))) - - -class Path(Requirement): - - def __init__(self, name=None, type_=None): - super().__init__(()) - self.name=name - self.type=type_ - - def complete(self, name: str, type_: Type_, default: Any): - return type(self)(name=name, type_=type_) - - -class TestCustomRequirementCompletion: - - def test_use_name(self): - def foo(bar=Path()): pass - check_extract(foo, RequirementsDeclaration(( - Parameter(Path(name='bar', type_=None)), - ))) - - def test_use_type(self): - def foo(bar: str = Path()): pass - check_extract(foo, RequirementsDeclaration(( - Parameter(Path(name='bar', type_=str), type_=str), - ))) - - def test_precedence(self): - class PathSubclass(Path): pass - @requires(PathSubclass()) - def foo(bar: str = Path()): pass - check_extract(foo, RequirementsDeclaration(( - Parameter(PathSubclass(name='bar', type_=str), type_=str), - ))) - - -def it(): - pass - - -class TestExplicitDeclarations: - - def test_requires_from_string(self): - compare(extract_requires(it, 'bar'), strict=True, expected=RequirementsDeclaration(( - Parameter(Value(identifier='bar')), - ))) - - def test_requires_from_type(self): - compare(extract_requires(it, Type1), strict=True, expected=RequirementsDeclaration(( - Parameter(Value(Type1)), - ))) - - def test_requires_requirement(self): - compare(extract_requires(it, Value(Type1, 'bar')), strict=True, expected=RequirementsDeclaration(( - Parameter(Value(Type1, 'bar')), - ))) - - def test_requires_from_tuple(self): - compare(extract_requires(it, ('bar', 'baz')), strict=True, expected=RequirementsDeclaration(( - Parameter(Value(identifier='bar')), - Parameter(Value(identifier='baz')), - ))) - - def test_requires_from_list(self): - compare(extract_requires(it, ['bar', 'baz']), strict=True, expected=RequirementsDeclaration(( - Parameter(Value(identifier='bar')), - Parameter(Value(identifier='baz')), - ))) - - def test_explicit_requires_where_parameter_has_default(self): - def foo(x=1): pass - compare(extract_requires(foo, 'bar'), strict=True, expected=RequirementsDeclaration(( - # default is not longer considered: - Parameter(Value(identifier='bar')), - ))) - - def test_returns_from_string(self): - compare(extract_returns(it, 'bar'), strict=True, expected=ReturnsDeclaration([ - ResourceKey(identifier='bar') - ])) - - def test_returns_from_type(self): - compare(extract_returns(it, Type1), strict=True, expected=ReturnsDeclaration([ - ResourceKey(Type1) - ])) - - -class TestDeclarationsFromMultipleSources: - - def test_declarations_from_different_sources(self): - r1 = Requirement(keys=(), default='b') - r2 = Requirement(keys=(), default='c') - - @requires(b=r1) - def foo(a: str, b, c=r2): - pass - - check_extract(foo, - expected_rq=RequirementsDeclaration(( - Parameter(Annotation('a', str), type_=str), - Parameter(Requirement((), default='b'), default='b', target='b'), - Parameter(Requirement((), default='c'), default='c', target='c'), - ))) - - def test_declaration_priorities(self): - r1 = Requirement([ResourceKey(identifier='x')]) - r2 = Requirement([ResourceKey(identifier='y')]) - r3 = Requirement([ResourceKey(identifier='z')]) - - @requires(a=r1) - @returns('bar') - def foo(a: int = r3, b: str = r2, c=r3) -> Optional[Type1]: - pass - - check_extract( - foo, - expected_rq=RequirementsDeclaration(( - Parameter(Requirement([ResourceKey(identifier='x')]), type_=int, target='a'), - Parameter(Requirement([ResourceKey(identifier='y')]), type_=str, target='b'), - Parameter(Requirement([ResourceKey(identifier='z')]), target='c'), - )), - expected_rt=ReturnsDeclaration([ResourceKey(identifier='bar')]) - ) diff --git a/mush/tests/test_marker.py b/mush/tests/test_marker.py deleted file mode 100644 index 3a60f74..0000000 --- a/mush/tests/test_marker.py +++ /dev/null @@ -1,6 +0,0 @@ -from mush.markers import Marker -from testfixtures import compare - - -def test_repr(): - compare(repr(Marker('foo')), expected='') diff --git a/mush/tests/test_plug.py b/mush/tests/test_plug.py deleted file mode 100644 index f91824b..0000000 --- a/mush/tests/test_plug.py +++ /dev/null @@ -1,235 +0,0 @@ - -from testfixtures import compare, ShouldRaise -from testfixtures.mock import Mock, call - -from mush import Plug, Runner, returns, requires -from mush.plug import insert, ignore, append -from mush.tests.test_runner import verify - - -class TestPlug: - - def test_simple(self): - m = Mock() - - runner = Runner() - runner.add(m.job1, label='one') - runner.add(m.job2) - runner.add(m.job3, label='three') - runner.add(m.job4) - - class MyPlug(Plug): - - def one(self): - m.plug_one() - - def three(self): - m.plug_two() - - plug = MyPlug() - plug.add_to(runner) - - runner() - - compare([ - call.job1(), call.plug_one(), call.job2(), - call.job3(), call.plug_two(), call.job4() - ], m.mock_calls) - - verify(runner, - (m.job1, set()), - (plug.one, {'one'}), - (m.job2, set()), - (m.job3, set()), - (plug.three, {'three'}), - (m.job4, set()), - ) - - def test_label_not_there(self): - runner = Runner() - - class MyPlug(Plug): - def not_there(self): pass - - with ShouldRaise(KeyError('not_there')): - MyPlug().add_to(runner) - - def test_requirements_and_returns(self): - m = Mock() - - @returns('r1') - def job1(): - m.job1() - return 1 - - @requires('r2') - def job3(r): - m.job3(r) - - runner = Runner() - runner.add(job1, label='point') - runner.add(job3) - - class MyPlug(Plug): - - @requires('r1') - @returns('r2') - def point(self, r): - m.point(r) - return 2 - - plug = MyPlug() - plug.add_to(runner) - - runner() - - compare([ - call.job1(), call.point(1), call.job3(2), - ], m.mock_calls) - - verify(runner, - (job1, set()), - (plug.point, {'point'}), - (job3, set()), - ) - - def test_explict(self): - m = Mock() - - runner = Runner() - runner.add(m.job1, label='one') - - class MyPlug(Plug): - - explicit = True - - def helper(self): - m.plug_one() - - @insert() - def one(self): - self.helper() - - plug = MyPlug() - plug.add_to(runner) - - runner() - - compare([ - call.job1(), - call.plug_one() - ], actual=m.mock_calls) - - verify(runner, - (m.job1, set()), - (plug.one, {'one'}), - ) - - def test_ignore(self): - m = Mock() - - runner = Runner() - runner.add(m.job1, label='one') - - class MyPlug(Plug): - - @ignore() - def helper(self): # pragma: no cover - m.plug_bad() - - def one(self): - m.plug_good() - - plug = MyPlug() - plug.add_to(runner) - - runner() - - compare([ - call.job1(), - call.plug_good() - ], actual=m.mock_calls) - - verify(runner, - (m.job1, set()), - (plug.one, {'one'}), - ) - - def test_remap_name(self): - m = Mock() - - runner = Runner() - runner.add(m.job1, label='one') - - class MyPlug(Plug): - - @insert(label='one') - def run_plug(self): - m.plug_one() - - plug = MyPlug() - plug.add_to(runner) - - runner() - - compare([ - call.job1(), - call.plug_one() - ], m.mock_calls) - - verify(runner, - (m.job1, set()), - (plug.run_plug, {'one'}), - ) - - def test_append(self): - m = Mock() - - runner = Runner() - runner.add(m.job1, label='one') - - class MyPlug(Plug): - - @append() - def run_plug(self): - m.do_it() - - plug = MyPlug() - plug.add_to(runner) - - runner() - - compare([ - call.job1(), - call.do_it() - ], actual=m.mock_calls) - - verify(runner, - (m.job1, {'one'}), - (plug.run_plug, set()), - ) - - def test_add_plug(self): - m = Mock() - - runner = Runner() - runner.add(m.job1, label='one') - - class MyPlug(Plug): - def one(self): - m.plug_one() - - plug = MyPlug() - runner.add(plug) - - runner() - - compare([ - call.job1(), call.plug_one() - ], m.mock_calls) - - verify(runner, - (m.job1, set()), - (plug.one, {'one'}), - ) - diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py deleted file mode 100644 index 9426a57..0000000 --- a/mush/tests/test_requirements.py +++ /dev/null @@ -1,263 +0,0 @@ -from typing import Text, Tuple, NewType - -from testfixtures.mock import Mock - -import pytest -from testfixtures import compare, ShouldRaise - -from mush import Value, missing -from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf, Like, Annotation -from mush.resources import ResourceKey -from mush.tests.helpers import Type1 - - -def check_ops(value, data, *, expected): - for op in value.ops: - data = op(data) - compare(expected, actual=data) - - -class TestRequirement: - - def test_repr_minimal(self): - compare(repr(Requirement((), default=missing)), - expected="Requirement()") - - def test_repr_maximal(self): - r = Requirement( - keys=( - ResourceKey(type_=str), - ResourceKey(identifier='foo'), - ResourceKey(type_=int, identifier='bar') - ), - default=None - ) - r.ops.append(AttrOp('bar')) - compare(repr(r), - expected="Requirement(ResourceKey(str), ResourceKey('foo'), " - "ResourceKey(int, 'bar'), default=None).bar") - - special_names = ['attr', 'ops'] - - def test_repr_subclass(self): - class SubClass(Requirement): - def __init__(self): - self.foo = 42 - self.bar = 'baz' - super().__init__([ResourceKey(str)], missing) - compare(repr(SubClass()), - expected="SubClass(ResourceKey(str), foo=42, bar='baz')") - - @pytest.mark.parametrize("name", special_names) - def test_attr_special_name(self, name): - v = Requirement('foo') - assert getattr(v, name) is not self - assert v.attr(name) is v - compare(v.ops, expected=[AttrOp(name)]) - - @pytest.mark.parametrize("name", special_names) - def test_item_special_name(self, name): - v = Requirement('foo') - assert v[name] is v - compare(v.ops, expected=[ItemOp(name)]) - - def test_no_special_name_via_getattr(self): - v = Requirement('foo') - with ShouldRaise(AttributeError): - assert v.__len__ - compare(v.ops, []) - - -class TestItem: - - def test_single(self): - h = Value(Type1)['foo'] - compare(repr(h), expected="Value(Type1)['foo']") - check_ops(h, {'foo': 1}, expected=1) - - def test_multiple(self): - h = Value(Type1)['foo']['bar'] - compare(repr(h), expected="Value(Type1)['foo']['bar']") - check_ops(h, {'foo': {'bar': 1}}, expected=1) - - def test_missing_obj(self): - h = Value(Type1)['foo']['bar'] - with ShouldRaise(TypeError): - check_ops(h, object(), expected=None) - - def test_missing_key(self): - h = Value(Type1)['foo'] - check_ops(h, {}, expected=missing) - - def test_bad_type(self): - h = Value(Type1)['foo']['bar'] - with ShouldRaise(TypeError): - check_ops(h, [], expected=None) - - -class TestAttr: - - def test_single(self): - h = Value(Type1).foo - compare(repr(h), "Value(Type1).foo") - m = Mock() - check_ops(h, m, expected=m.foo) - - def test_multiple(self): - h = Value(Type1).foo.bar - compare(repr(h), "Value(Type1).foo.bar") - m = Mock() - check_ops(h, m, expected=m.foo.bar) - - def test_missing(self): - h = Value(Type1).foo - compare(repr(h), "Value(Type1).foo") - check_ops(h, object(), expected=missing) - - -class TestAnnotation: - - def test_name_only(self): - r = Annotation('x', None, missing) - compare(r.keys, expected=[ - ResourceKey(None, 'x') - ]) - compare(r.default, expected=missing) - - def test_name_and_type(self): - r = Annotation('x', str, missing) - compare(r.keys, expected=[ - ResourceKey(str, 'x'), - ResourceKey(str, None), - ResourceKey(None, 'x'), - ]) - compare(r.default, expected=missing) - - def test_all(self): - r = Annotation('x', str, 'default') - compare(r.keys, expected=[ - ResourceKey(str, 'x'), - ResourceKey(str, None), - ResourceKey(None, 'x'), - ]) - compare(r.default, expected='default') - - def test_repr_min(self): - compare(repr(Annotation('x', None, missing)), - expected="x") - - def test_repr_max(self): - compare(repr(Annotation('x', str, 'default')), - expected="x: str = 'default'") - - -class TestValue: - - def test_type_only(self): - v = Value(str) - compare(v.keys, expected=[ResourceKey(str, None)]) - - def test_typing_only(self): - v = Value(Text) - compare(v.keys, expected=[ResourceKey(Text, None)]) - - def test_typing_generic_alias(self): - v = Value(Tuple[str]) - compare(v.keys, expected=[ResourceKey(Tuple[str], None)]) - - def test_typing_new_type(self): - Type = NewType('Type', str) - v = Value(Type) - compare(v.keys, expected=[ResourceKey(Type, None)]) - - def test_identifier_only(self): - v = Value('foo') - compare(v.keys, expected=[ResourceKey(None, 'foo')]) - - def test_type_and_identifier(self): - v = Value(str, 'foo') - compare(v.keys, expected=[ResourceKey(str, 'foo')]) - - def test_nothing_specified(self): - with ShouldRaise(TypeError('type or identifier must be supplied')): - Value() - - def test_repr_min(self): - compare(repr(Value(Type1)), - expected="Value(Type1)") - - def test_repr_max(self): - compare(repr(Value(Type1, 'foo')['bar'].baz), - expected="Value(Type1, 'foo')['bar'].baz") - - -class TestAnyOf: - - def test_types_and_typing(self): - r = AnyOf(tuple, Tuple[str]) - compare(r.keys, expected=[ - ResourceKey(tuple, None), - ResourceKey(Tuple[str], None), - ]) - compare(r.default, expected=missing) - - def test_identifiers(self): - r = AnyOf('a', 'b') - compare(r.keys, expected=[ - ResourceKey(None, 'a'), - ResourceKey(None, 'b'), - ]) - compare(r.default, expected=missing) - - def test_default(self): - r = AnyOf(tuple, default='x') - compare(r.keys, expected=[ - ResourceKey(tuple, None), - ]) - compare(r.default, expected='x') - - def test_none(self): - with ShouldRaise(TypeError('at least one key must be specified')): - AnyOf() - - def test_repr_min(self): - compare(repr(AnyOf(Type1)), - expected="AnyOf(Type1)") - - def test_repr_max(self): - compare(repr(AnyOf(Type1, 'foo', default='baz')['bob'].bar), - expected="AnyOf(Type1, 'foo', default='baz')['bob'].bar") - - -class Parent(object): - pass - - -class Child(Parent): - pass - - -class TestLike: - - def test_simple(self): - r = Like(Child) - compare(r.keys, expected=[ - ResourceKey(Child, None), - ResourceKey(Parent, None), - ]) - compare(r.default, expected=missing) - - def test_default(self): - r = Like(Parent, default='foo') - compare(r.keys, expected=[ - ResourceKey(Parent, None), - ]) - compare(r.default, expected='foo') - - def test_repr_min(self): - compare(repr(Like(Type1)), - expected="Like(Type1)") - - def test_repr_max(self): - compare(repr(Like(Type1, default='baz')['bob'].bar), - expected="Like(Type1, default='baz')['bob'].bar") diff --git a/mush/tests/test_runner.py b/mush/tests/test_runner.py deleted file mode 100644 index c96c0f2..0000000 --- a/mush/tests/test_runner.py +++ /dev/null @@ -1,1238 +0,0 @@ -import pytest - -from mush.declarations import requires, returns, replacement, original -from mush import Value, ContextError, Context, Requirement -from mush.requirements import ItemOp -from mush.resources import Provider, ResourceKey -from mush.runner import Runner -from testfixtures import ( - ShouldRaise, - compare -) -from testfixtures.mock import Mock, call - - -def verify(runner, *expected): - seen_labels = set() - - actual = [] - point = runner.start - while point: - actual.append((point.obj, point.labels)) - for label in point.labels: - if label in seen_labels: # pragma: no cover - raise AssertionError('%s occurs more than once' % label) - seen_labels.add(label) - compare(runner.labels[label], point) - point = point.next - - compare(expected=expected, actual=actual) - - actual_reverse = [] - point = runner.end - while point: - actual_reverse.append((point.obj, point.labels)) - point = point.previous - - compare(actual, reversed(actual_reverse)) - compare(seen_labels, runner.labels.keys()) - - -class TestRunner: - - def test_simple(self): - m = Mock() - def job(): - m.job() - - runner = Runner() - point = runner.add(job).callpoint - - compare(job, point.obj) - compare(runner.start, point) - compare(runner.end, point) - runner() - - compare([ - call.job() - ], m.mock_calls) - - verify(runner, (job, set())) - - def test_constructor(self): - m = Mock() - def job1(): - m.job1() - def job2(): - m.job2() - - runner = Runner(job1, job2) - compare(job1, runner.start.obj) - compare(job2, runner.end.obj) - - runner() - - compare([ - call.job1(), - call.job2(), - ], m.mock_calls) - - verify(runner, - (job1, set()), - (job2, set())) - - def test_return_value(self): - def job(): - return 42 - runner = Runner(job) - compare(runner(), 42) - - def test_return_value_empty(self): - runner = Runner() - compare(runner(), None) - - def test_add_with_label(self): - def job1(): pass - def job2(): pass - - runner = Runner() - - point1 = runner.add(job1, label='1').callpoint - point2 = runner.add(job2, label='2').callpoint - - compare(point1.obj, job1) - compare(point2.obj, job2) - - compare(runner['1'].callpoint, point1) - compare(runner['2'].callpoint, point2) - - compare({'1'}, point1.labels) - compare({'2'}, point2.labels) - - verify(runner, - (job1, {'1'}), - (job2, {'2'})) - - def test_modifier_add_moves_label(self): - def job1(): pass - def job2(): pass - - runner = Runner() - - runner.add(job1, label='the label') - runner['the label'].add(job2) - - verify(runner, - (job1, set()), - (job2, {'the label'})) - - def test_runner_add_does_not_move_label(self): - def job1(): pass - def job2(): pass - - runner = Runner() - - runner.add(job1, label='the label') - runner.add(job2) - - verify(runner, - (job1, {'the label'}), - (job2, set())) - - def test_modifier_moves_only_explicit_label(self): - def job1(): pass - def job2(): pass - - runner = Runner() - - mod = runner.add(job1) - mod.add_label('1') - mod.add_label('2') - - verify(runner, - (job1, {'1', '2'})) - - runner['2'].add(job2) - - verify(runner, - (job1, {'1'}), - (job2, {'2'})) - - def test_modifier_add_with_label(self): - def job1(): pass - def job2(): pass - - runner = Runner() - - mod = runner.add(job1) - mod.add_label('1') - - runner['1'].add(job2, label='2') - - verify(runner, - (job1, {'1'}), - (job2, {'2'})) - - def test_runner_add_label(self): - m = Mock() - - runner = Runner() - runner.add(m.job1) - runner.add_label('label') - runner.add(m.job3) - - runner['label'].add(m.job2) - - verify( - runner, - (m.job1, set()), - (m.job2, {'label'}), - (m.job3, set()) - ) - - cloned = runner.clone(added_using='label') - verify( - cloned, - (m.job2, {'label'}), - ) - - def test_declarative(self): - m = Mock() - class T1: pass - class T2: pass - - t1 = T1() - t2 = T2() - - def job1() -> T1: - m.job1() - return t1 - - def job2(obj: T1) -> T2: - m.job2(obj) - return t2 - - def job3(obj: T2) -> None: - m.job3(obj) - - runner = Runner(job1, job2, job3) - runner() - - compare([ - call.job1(), - call.job2(t1), - call.job3(t2), - ], m.mock_calls) - - def test_imperative(self): - m = Mock() - class T1: pass - class T2: pass - - t1 = T1() - t2 = T2() - - def job1(): - m.job1() - return t1 - - def job2(obj): - m.job2(obj) - return t2 - - def job3(t2_): - m.job3(t2_) - - # imperative config overrides decorator - @requires(T1) - def job4(t2_): - m.job4(t2_) - - runner = Runner() - runner.add(job1, returns=T1) - runner.add(job2, requires(T1), returns(T2)) - runner.add(job3, requires(t2_=T2)) - runner.add(job4, requires(T2)) - runner() - - compare([ - call.job1(), - call.job2(t1), - call.job3(t2), - call.job4(t2), - ], m.mock_calls) - - def test_return_type_specified_decorator(self): - m = Mock() - class T1: pass - class T2: pass - t = T1() - - @returns(T2) - def job1(): - m.job1() - return t - - @requires(T2) - def job2(obj): - m.job2(obj) - - Runner(job1, job2)() - - compare([ - call.job1(), - call.job2(t), - ], m.mock_calls) - - def test_return_type_specified_imperative(self): - m = Mock() - class T1: pass - class T2: pass - t = T1() - - def job1(): - m.job1() - return t - - @requires(T2) - def job2(obj): - m.job2(obj) - - runner = Runner() - runner.add(job1, returns=returns(T2)) - runner.add(job2, requires(T2)) - runner() - - compare([ - call.job1(), - call.job2(t), - ], m.mock_calls) - - def test_lazy(self): - m = Mock() - class T1: pass - class T2: pass - t = T1() - - def lazy_used(): - m.lazy_used() - return t - - def lazy_unused(): - raise AssertionError('should not be called') # pragma: no cover - - def providers(context: Context): - context.add(Provider(lazy_used), provides=T1) - context.add(Provider(lazy_unused), provides=T2) - - def job(obj): - m.job(obj) - - runner = Runner() - runner.add(providers) - runner.add(job, requires(T1)) - runner() - - compare(m.mock_calls, expected=[ - call.lazy_used(), - call.job(t), - ], ) - - def test_missing_from_context_no_chain(self): - class T: pass - - @requires(T) - def job(arg): - pass # pragma: nocover - - runner = Runner(job) - - with ShouldRaise(ContextError) as s: - runner() - - t_str = 'TestRunner.test_missing_from_context_no_chain..T' - text = '\n'.join(( - '', - '', - "While calling:", - f"{job.__qualname__} requires(Value({t_str})) returns('job')", - '', - 'with :', - '', - f"Value({t_str}) could not be satisfied", - )) - compare(text, actual=repr(s.raised)) - compare(text, actual=str(s.raised)) - - def test_missing_from_context_with_chain(self): - class T: pass - - def job1() -> None: pass - def job2() -> None: pass - - @requires(T) - def job3(arg): - pass # pragma: nocover - - def job4(): pass - def job5(foo, bar): pass - - runner = Runner() - runner.add(job1, label='1') - runner.add(job2) - runner.add(job3) - runner.add(job4, label='4') - runner.add(job5, requires('foo', bar='baz'), returns('bob')) - - with ShouldRaise(ContextError) as s: - runner() - - t_str = 'TestRunner.test_missing_from_context_with_chain..T' - - text = '\n'.join(( - '', - '', - 'Already called:', - f'{job1.__qualname__} requires() returns() <-- 1', - f'{job2.__qualname__} requires() returns()', - '', - "While calling:", - f"{job3.__qualname__} requires(Value({t_str})) returns('job3')", - '', - 'with :', - '', - f"Value({t_str}) could not be satisfied", - '', - 'Still to call:', - f'' - f"{job4.__qualname__} requires() returns('job4') <-- 4", - f"{job5.__qualname__} requires(Value('foo'), bar=Value('baz')) returns('bob')", - )) - compare(text, actual=repr(s.raised)) - compare(text, actual=str(s.raised)) - - def test_job_called_badly(self): - def job(arg): - pass # pragma: nocover - runner = Runner(job) - with ShouldRaise(ContextError) as s: - runner() - compare(s.raised.text, expected='arg could not be satisfied') - - def test_already_in_context(self): - class T: pass - - t1 = T() - t2 = T() - ts = [t2, t1] - - @returns(T) - def job(): - return ts.pop() - - runner = Runner(job, job) - - with ShouldRaise(ContextError) as s: - runner() - - t_str = 'TestRunner.test_already_in_context..T' - text = '\n'.join(( - '', - '', - 'Already called:', - f"{job.__qualname__} requires() returns({t_str})", - '', - "While calling:", - f"{job.__qualname__} requires() returns({t_str})", - '', - 'with :', - '', - f'Context already contains {t_str}', - )) - compare(text, repr(s.raised)) - compare(text, str(s.raised)) - - def test_job_error(self): - def job(): - raise Exception('huh?') - runner = Runner(job) - with ShouldRaise(Exception('huh?')): - runner() - - def test_attr(self): - class T(object): - foo = 'bar' - m = Mock() - def job1(): - m.job1() - return T() - def job2(obj): - m.job2(obj) - runner = Runner() - runner.add(job1, returns=T) - runner.add(job2, requires(Value(T).foo)) - runner() - - compare([ - call.job1(), - call.job2('bar'), - ], m.mock_calls) - - def test_attr_multiple(self): - class T2: - bar = 'baz' - class T: - foo = T2() - - m = Mock() - def job1(): - m.job1() - return T() - def job2(obj): - m.job2(obj) - runner = Runner() - runner.add(job1, returns=T) - runner.add(job2, requires(Value(T).foo.bar)) - runner() - - compare([ - call.job1(), - call.job2('baz'), - ], m.mock_calls) - - def test_item(self): - class MyDict(dict): pass - m = Mock() - def job1(): - m.job1() - obj = MyDict() - obj['the_thing'] = m.the_thing - return obj - def job2(obj): - m.job2(obj) - runner = Runner() - runner.add(job1, returns=MyDict) - runner.add(job2, requires(Value(MyDict)['the_thing'])) - runner() - compare([ - call.job1(), - call.job2(m.the_thing), - ], m.mock_calls) - - def test_item_multiple(self): - class MyDict(dict): pass - m = Mock() - def job1(): - m.job1() - obj = MyDict() - obj['the_thing'] = dict(other_thing=m.the_thing) - return obj - def job2(obj): - m.job2(obj) - runner = Runner() - runner.add(job1, returns=MyDict) - runner.add(job2, requires(Value(MyDict)['the_thing']['other_thing'])) - runner() - compare([ - call.job1(), - call.job2(m.the_thing), - ], m.mock_calls) - - def test_item_of_attr(self): - class T(object): - foo = dict(baz='bar') - m = Mock() - def job1(): - m.job1() - return T() - def job2(obj): - m.job2(obj) - runner = Runner() - runner.add(job1, returns=T) - runner.add(job2, requires(Value(T).foo['baz'])) - runner() - - compare([ - call.job1(), - call.job2('bar'), - ], m.mock_calls) - - def test_context_manager(self): - m = Mock() - - class CM1(object): - def __enter__(self): - m.cm1.enter() - return self - def __exit__(self, type, obj, tb): - m.cm1.exit(type, obj) - return True - - class CM2Context(object): pass - - class CM2(object): - def __enter__(self): - m.cm2.enter() - return CM2Context() - - def __exit__(self, type, obj, tb): - m.cm2.exit(type, obj) - - @requires(CM1) - def func1(obj): - m.func1(type(obj)) - - @requires(CM1, CM2Context) - def func2(obj1, obj2): - m.func2(type(obj1), - type(obj2)) - return '2' - - runner = Runner( - CM1, - CM2, - func1, - func2, - ) - - result = runner() - compare(result, '2') - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.func1(CM1), - call.func2(CM1, CM2Context), - call.cm2.exit(None, None), - call.cm1.exit(None, None) - ]) - - # now check with an exception - m.reset_mock() - m.func2.side_effect = e = Exception() - result = runner() - - # if something goes wrong, you get None - compare(None, result) - - compare(m.mock_calls, expected=[ - call.cm1.enter(), - call.cm2.enter(), - call.func1(CM1), - call.func2(CM1, CM2Context), - call.cm2.exit(Exception, e), - call.cm1.exit(Exception, e) - ]) - - def test_context_manager_is_last_callpoint(self): - m = Mock() - - class CM(object): - def __enter__(self): - m.cm.enter() - def __exit__(self, type, obj, tb): - m.cm.exit() - - runner = Runner(CM) - result = runner() - compare(result, expected=None) - - compare(m.mock_calls, expected=[ - call.cm.enter(), - call.cm.exit(), - ]) - - def test_clone(self): - m = Mock() - class T1(object): pass - class T2(object): pass - def f1(): m.f1() - def n1(): - m.n1() - return T1(), T2() - def l1(): m.l1() - def t1(obj): m.t1() - def t2(obj): m.t2() - # original - runner1 = Runner() - runner1.add(f1, label='first') - runner1.add(n1, returns=returns(T1, T2), label='normal') - runner1.add(l1, label='last') - runner1.add(t1, requires(T1)) - runner1.add(t2, requires(T2)) - # now clone and add bits - def f2(): m.f2() - def n2(): m.n2() - def l2(): m.l2() - def tn(obj): m.tn() - runner2 = runner1.clone() - runner2['first'].add(f2) - runner2['normal'].add(n2) - runner2['last'].add(l2) - # make sure types stay in order - runner2.add(tn, requires(T2)) - - # now run both, and make sure we only get what we should - - runner1() - verify(runner1, - (f1, {'first'}), - (n1, {'normal'}), - (l1, {'last'}), - (t1, set()), - (t2, set()), - ) - compare([ - call.f1(), - call.n1(), - call.l1(), - call.t1(), - call.t2(), - ], m.mock_calls) - - m.reset_mock() - - runner2() - verify(runner2, - (f1, set()), - (f2, {'first'}), - (n1, set()), - (n2, {'normal'}), - (l1, set()), - (l2, {'last'}), - (t1, set()), - (t2, set()), - (tn, set()), - ) - compare([ - call.f1(), - call.f2(), - call.n1(), - call.n2(), - call.l1(), - call.l2(), - call.t1(), - call.t2(), - call.tn() - ], m.mock_calls) - - def test_clone_end_label(self): - m = Mock() - runner1 = Runner() - runner1.add(m.f1, label='first') - runner1.add(m.f2, label='second') - runner1.add(m.f3, label='third') - - runner2 = runner1.clone(end_label='third') - verify(runner2, - (m.f1, {'first'}), - (m.f2, {'second'}), - ) - - def test_clone_end_label_include(self): - m = Mock() - runner1 = Runner() - runner1.add(m.f1, label='first') - runner1.add(m.f2, label='second') - runner1.add(m.f3, label='third') - - runner2 = runner1.clone(end_label='second', include_end=True) - verify(runner2, - (m.f1, {'first'}), - (m.f2, {'second'}), - ) - - def test_clone_start_label(self): - m = Mock() - runner1 = Runner() - runner1.add(m.f1, label='first') - runner1.add(m.f2, label='second') - runner1.add(m.f3, label='third') - - runner2 = runner1.clone(start_label='first') - verify(runner2, - (m.f2, {'second'}), - (m.f3, {'third'}), - ) - - def test_clone_start_label_include(self): - m = Mock() - runner1 = Runner() - runner1.add(m.f1, label='first') - runner1.add(m.f2, label='second') - runner1.add(m.f3, label='third') - - runner2 = runner1.clone(start_label='second', include_start=True) - verify(runner2, - (m.f2, {'second'}), - (m.f3, {'third'}), - ) - - def test_clone_between(self): - m = Mock() - runner1 = Runner() - runner1.add(m.f1, label='first') - runner1.add(m.f2, label='second') - runner1.add(m.f3, label='third') - runner1.add(m.f4, label='fourth') - - runner2 = runner1.clone(start_label='first', end_label='fourth') - verify(runner2, - (m.f2, {'second'}), - (m.f3, {'third'}), - ) - - def test_clone_between_one_item(self): - m = Mock() - runner1 = Runner() - runner1.add(m.f1, label='first') - runner1.add(m.f2, label='second') - runner1.add(m.f3, label='third') - - runner2 = runner1.clone(start_label='first', end_label='third') - verify(runner2, - (m.f2, {'second'}), - ) - - def test_clone_between_empty(self): - m = Mock() - runner1 = Runner() - runner1.add(m.f1, label='first') - runner1.add(m.f2, label='second') - - runner2 = runner1.clone(start_label='first', end_label='second') - verify(runner2) - - def test_clone_added_using(self): - runner1 = Runner() - m = Mock() - runner1.add(m.f1) - runner1.add(m.f2, label='the_label') - runner1.add(m.f3) - - runner1['the_label'].add(m.f6) - runner1['the_label'].add(m.f7) - - runner2 = runner1.clone(added_using='the_label') - verify(runner2, - (m.f6, set()), - (m.f7, {'the_label'}), - ) - - def test_clone_empty(self): - runner1 = Runner() - runner2 = runner1.clone() - # this gets set by the clone on runner 2, it's a class variable on runner1: - runner1.end = None - compare(expected=runner1, actual=runner2) - - def test_extend(self): - m = Mock() - class T1(object): pass - class T2(object): pass - - t1 = T1() - t2 = T2() - - def job1() -> T1: - m.job1() - return t1 - - def job2(obj: T1) -> T2: - m.job2(obj) - return t2 - - def job3(obj: T2): - m.job3(obj) - - runner = Runner() - runner.extend(job1, job2, job3) - runner() - - compare([ - call.job1(), - call.job2(t1), - call.job3(t2), - ], m.mock_calls) - - def test_addition(self): - m = Mock() - - def job1(): - m.job1() - - def job2(): - m.job2() - - def job3(): - m.job3() - - runner1 = Runner(job1, job2) - runner2 = Runner(job3) - runner = runner1 + runner2 - runner() - - verify(runner, - (job1, set()), - (job2, set()), - (job3, set()), - ) - compare([ - call.job1(), - call.job2(), - call.job3(), - ], m.mock_calls) - - def test_extend_with_runners(self): - m = Mock() - class T1(object): pass - class T2(object): pass - - t1 = T1() - t2 = T2() - - def job1() -> T1: - m.job1() - return t1 - - def job2(obj: T1) -> T2: - m.job2(obj) - return t2 - - def job3(obj: T2): - m.job3(obj) - - runner1 = Runner(job1) - runner2 = Runner(job2) - runner3 = Runner(job3) - - runner = Runner(runner1) - runner.extend(runner2, runner3) - runner() - - verify(runner, - (job1, set()), - (job2, set()), - (job3, set()), - ) - compare([ - call.job1(), - call.job2(t1), - call.job3(t2), - ], m.mock_calls) - - def test_replace_for_testing(self): - m = Mock() - class T1(object): pass - class T2(object): pass - - t1 = T1() - t2 = T2() - - def job1() -> T1: - raise Exception() # pragma: nocover - - def job2(obj: T1) -> T2: - raise Exception() # pragma: nocover - - def job3(obj: T2): - raise Exception() # pragma: nocover - - runner = Runner(job1, job2, job3) - runner.replace(job1, m.job1) - m.job1.return_value = t1 - runner.replace(job2, m.job2, requires_from=original) - m.job2.return_value = t2 - runner.replace(job3, m.job3, requires_from=original) - runner() - - compare([ - call.job1(), - call.job2(t1), - call.job3(t2), - ], m.mock_calls) - - def test_replace_for_behaviour(self): - m = Mock() - class T1(object): pass - class T2(object): pass - class T3(object): pass - class T4(object): pass - - t2 = T2() - - def job0() -> T2: - return t2 - - @requires(T1) - @returns(T3) - def job1(obj): - raise Exception() # pragma: nocover - - job2 = requires(T4)(m.job2) - runner = Runner(job0, job1, job2) - - runner.replace(job1, - requires(T2)(returns(T4)(m.job1)), - returns_from=replacement) - runner() - - compare([ - call.job1(t2), - call.job2(m.job1.return_value), - ], actual=m.mock_calls) - - def test_replace_explicit_requires_returns(self): - m = Mock() - class T1(object): pass - class T2(object): pass - class T3(object): pass - class T4(object): pass - - t2 = T2() - - def job0() -> T2: - return t2 - - @requires(T1) - @returns(T3) - def job1(obj): - raise Exception() # pragma: nocover - - job2 = requires(T4)(m.job2) - runner = Runner(job0, job1, job2) - - runner.replace(job1, requires(T2)(returns(T4)(m.job1)), - returns_from=replacement) - runner() - - compare([ - call.job1(t2), - call.job2(m.job1.return_value), - ], actual=m.mock_calls) - - def test_replace_explicit_with_labels(self): - m = Mock() - - runner = Runner(m.job0) - runner.add_label('foo') - runner['foo'].add(m.job1) - runner['foo'].add(m.job2) - - runner.replace(m.job2, - returns('mock')(m.jobnew), - returns_from=replacement) - - runner() - - compare([ - call.job0(), - call.job1(), - call.jobnew() - ], m.mock_calls) - - # check added_using is handled correctly - m.reset_mock() - runner2 = runner.clone(added_using='foo') - runner2() - - compare([ - call.job1(), - call.jobnew() - ], actual=m.mock_calls) - - # check runner's label pointer is sane - m.reset_mock() - runner['foo'].add(m.job3) - runner() - - compare([ - call.job0(), - call.job1(), - call.jobnew(), - call.job3() - ], actual=m.mock_calls) - - def test_replace_explicit_at_start(self): - m = returns('mock')(Mock()) - runner = Runner(m.job1, m.job2) - - runner.replace(m.job1, m.jobnew, returns_from=replacement) - runner() - - compare([ - call.jobnew(), - call.job2(), - ], actual=m.mock_calls) - - def test_replace_explicit_at_end(self): - m = returns('mock')(Mock()) - runner = Runner(m.job1, m.job2) - - runner.replace(m.job2, m.jobnew, returns_from=replacement) - runner.add(m.jobnew2) - runner() - - compare([ - call.job1(), - call.jobnew(), - call.jobnew2(), - ], actual=m.mock_calls) - - def test_replace_keep_explicit_requires(self): - def foo(): - return 'bar' - def barbar(sheep): - return sheep*2 - - runner = Runner() - runner.add(foo, returns='flossy') - runner.add(barbar, requires='flossy') - compare(runner(), expected='barbar') - - runner.replace(barbar, lambda dog: None, requires_from=original) - compare(runner(), expected=None) - - def test_replace_keep_explicit_returns(self): - def foo(): - return 'bar' - def barbar(sheep): - return sheep*2 - - runner = Runner() - runner.add(foo, returns='flossy') - runner.add(barbar, requires='flossy') - compare(runner(), expected='barbar') - - runner.replace(foo, lambda: 'woof') - compare(runner(), expected='woofwoof') - - def test_modifier_changes_endpoint(self): - m = Mock() - runner = Runner(m.job1) - compare(runner.end.obj, m.job1) - verify(runner, - (m.job1, set()), - ) - - mod = runner.add(m.job2, label='foo') - compare(runner.end.obj, m.job2) - verify(runner, - (m.job1, set()), - (m.job2, {'foo'}), - ) - - mod.add(m.job3) - compare(runner.end.obj, m.job3) - compare(runner.end.labels, {'foo'}) - verify(runner, - (m.job1, set()), - (m.job2, set()), - (m.job3, {'foo'}), - ) - - runner.add(m.job4) - compare(runner.end.obj, m.job4) - compare(runner.end.labels, set()) - verify(runner, - (m.job1, set()), - (m.job2, set()), - (m.job3, {'foo'}), - (m.job4, set()), - ) - - def test_duplicate_label_runner_add(self): - m = Mock() - runner = Runner() - runner.add(m.job1, label='label') - runner.add(m.job2) - with ShouldRaise(ValueError( - "'label' already points to "+repr(m.job1)+" requires() " - "returns() <-- label" - )): - runner.add(m.job3, label='label') - verify(runner, - (m.job1, {'label'}), - (m.job2, set()), - ) - - def test_duplicate_label_runner_next_add(self): - m = Mock() - runner = Runner() - runner.add(m.job1, label='label') - with ShouldRaise(ValueError( - "'label' already points to "+repr(m.job1)+" requires() " - "returns() <-- label" - )): - runner.add(m.job2, label='label') - verify(runner, - (m.job1, {'label'}), - ) - - def test_duplicate_label_modifier(self): - m = Mock() - runner = Runner() - runner.add(m.job1, label='label1') - mod = runner['label1'] - mod.add(m.job2, label='label2') - with ShouldRaise(ValueError( - "'label1' already points to "+repr(m.job1)+" requires() " - "returns() <-- label1" - )): - mod.add(m.job3, label='label1') - verify(runner, - (m.job1, {'label1'}), - (m.job2, {'label2'}), - ) - - def test_repr(self): - class T1: pass - class T2: pass - m = Mock() - runner = Runner() - runner.add(m.job1, label='label1') - runner.add(m.job2, requires('foo', T1), returns(T2), label='label2') - runner.add(m.job3) - - t1_str = 'TestRunner.test_repr..T1' - t2_str = 'TestRunner.test_repr..T2' - - compare('\n'.join(( - '', - f' {m.job1!r} requires() returns() <-- label1', - f" {m.job2!r} requires(Value('foo'), Value({t1_str})) returns({t2_str}) <-- label2", - f' {m.job3!r} requires() returns()', - '' - - )), repr(runner)) - - def test_repr_empty(self): - compare('', repr(Runner())) - - def test_passed_in_context_with_no_point(self): - context = Context() - def foo(): - return 42 - runner = Runner(foo) - compare(runner(context), expected=42) - - def test_default_requirement(self): - - class FromRequest(Requirement): - - def __init__(self, name, type_, default): - keys = [ResourceKey(None, 'request')] - super().__init__(keys, default) - self.ops.append(ItemOp(name)) - - def foo(bar): - return bar - - context = Context(default_requirement=FromRequest) - context.add({'bar': 'foo'}, identifier='request') - - runner = Runner() - runner.add(foo) - compare(runner(context), expected='foo') diff --git a/mush/typing.py b/mush/typing.py deleted file mode 100644 index 0c80e3b..0000000 --- a/mush/typing.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import ( - NewType, Union, Hashable, Any, TYPE_CHECKING, List, Tuple, Type, - Callable, Optional -) - -from .compat import _GenericAlias - -if TYPE_CHECKING: - from .declarations import RequirementsDeclaration, ReturnsDeclaration - from .requirements import Requirement - -Type_ = Union[type, Type, _GenericAlias] -Identifier = Hashable - -RequirementType = Union['Requirement', Type_, Identifier] -Requires = Union['RequirementDeclaraction', - RequirementType, - List[RequirementType], - Tuple[RequirementType, ...]] - -ReturnType = Union[Type_, str] -Returns = Union['ReturnsDeclaration', ReturnType] - -Resource = NewType('Resource', Any) - - -DefaultRequirement = Callable[[str, Optional[Type], Any], 'Requirement'] From e287270e2eda84e4dcf53cfc90d008590a7b92c4 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 31 Mar 2021 10:09:56 +0100 Subject: [PATCH 158/159] Initial Context and Paradigm(s) implementation. --- mush/__init__.py | 1 + mush/context.py | 22 ++++++++++ mush/paradigms/__init__.py | 14 +++++++ mush/paradigms/asyncio_.py | 28 +++++++++++++ mush/paradigms/normal_.py | 19 +++++++++ mush/paradigms/paradigm.py | 27 +++++++++++++ mush/paradigms/paradigms.py | 54 +++++++++++++++++++++++++ mush/typing.py | 7 ++++ tests/__init__.py | 0 tests/test_context.py | 13 ++++++ tests/test_paradigm_asyncio.py | 73 ++++++++++++++++++++++++++++++++++ tests/test_paradigm_normal.py | 45 +++++++++++++++++++++ tests/test_paradigms.py | 63 +++++++++++++++++++++++++++++ 13 files changed, 366 insertions(+) create mode 100644 mush/__init__.py create mode 100644 mush/context.py create mode 100644 mush/paradigms/__init__.py create mode 100644 mush/paradigms/asyncio_.py create mode 100644 mush/paradigms/normal_.py create mode 100644 mush/paradigms/paradigm.py create mode 100644 mush/paradigms/paradigms.py create mode 100644 mush/typing.py create mode 100644 tests/__init__.py create mode 100644 tests/test_context.py create mode 100644 tests/test_paradigm_asyncio.py create mode 100644 tests/test_paradigm_normal.py create mode 100644 tests/test_paradigms.py diff --git a/mush/__init__.py b/mush/__init__.py new file mode 100644 index 0000000..0cebf2b --- /dev/null +++ b/mush/__init__.py @@ -0,0 +1 @@ +from .context import Context diff --git a/mush/context.py b/mush/context.py new file mode 100644 index 0000000..aa06ef4 --- /dev/null +++ b/mush/context.py @@ -0,0 +1,22 @@ +from typing import Callable + +from .paradigms import Call, Paradigm, Paradigms, paradigms +from .typing import Calls + + +class Context: + + paradigms: Paradigms = paradigms + + def __init__(self, paradigm: Paradigm = None): + self.paradigm = paradigm + + def _resolve(self, obj: Callable, target_paradigm: Paradigm) -> Calls: + caller = self.paradigms.find_caller(obj, target_paradigm) + yield Call(caller, (), {}) + + def call(self, obj: Callable, *, paradigm: Paradigm = None): + paradigm = paradigm or self.paradigm or self.paradigms.find_paradigm(obj) + calls = self._resolve(obj, paradigm) + return paradigm.process(calls) + diff --git a/mush/paradigms/__init__.py b/mush/paradigms/__init__.py new file mode 100644 index 0000000..bd0cc17 --- /dev/null +++ b/mush/paradigms/__init__.py @@ -0,0 +1,14 @@ +from collections import namedtuple + +from .paradigm import Paradigm +from .paradigms import Paradigms + + +Call = namedtuple('Call', ('obj', 'args', 'kw')) + +paradigms = Paradigms() + +normal = paradigms.register_if_possible('mush.paradigms.normal_', 'Normal') +asyncio = paradigms.register_if_possible('mush.paradigms.asyncio_', 'AsyncIO') + +paradigms.add_shifter_if_possible(normal, asyncio, 'mush.paradigms.asyncio_', 'asyncio_to_normal') diff --git a/mush/paradigms/asyncio_.py b/mush/paradigms/asyncio_.py new file mode 100644 index 0000000..4a49c36 --- /dev/null +++ b/mush/paradigms/asyncio_.py @@ -0,0 +1,28 @@ +import asyncio +from functools import partial +from typing import Callable + +from .paradigm import Paradigm +from ..typing import Calls + + +class AsyncIO(Paradigm): + + def claim(self, obj: Callable) -> bool: + if asyncio.iscoroutinefunction(obj): + return True + + async def process(self, calls: Calls): + call = next(calls) + try: + while True: + result = await call.obj(*call.args, **call.kw) + call = calls.send(result) + except StopIteration: + return result + + +async def asyncio_to_normal(obj, *args, **kw): + loop = asyncio.get_event_loop() + obj_ = partial(obj, *args, **kw) + return await loop.run_in_executor(None, obj_) diff --git a/mush/paradigms/normal_.py b/mush/paradigms/normal_.py new file mode 100644 index 0000000..5a30f17 --- /dev/null +++ b/mush/paradigms/normal_.py @@ -0,0 +1,19 @@ +from typing import Callable + +from .paradigm import Paradigm +from ..typing import Calls + + +class Normal(Paradigm): + + def claim(self, obj: Callable) -> bool: + return True + + def process(self, calls: Calls): + call = next(calls) + try: + while True: + result = call.obj(*call.args, **call.kw) + call = calls.send(result) + except StopIteration: + return result diff --git a/mush/paradigms/paradigm.py b/mush/paradigms/paradigm.py new file mode 100644 index 0000000..3e5dd18 --- /dev/null +++ b/mush/paradigms/paradigm.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from typing import Callable +from ..typing import Calls + + +class Paradigm(ABC): + + @abstractmethod + def claim(self, obj: Callable) -> bool: + ... + + @abstractmethod + def process(self, calls: Calls): + ... + + +class MissingParadigm(Paradigm): + + def __init__(self, exception): + super().__init__() + self.exception = exception + + def claim(self, obj: Callable) -> bool: + raise self.exception + + def process(self, calls: Calls): + raise self.exception diff --git a/mush/paradigms/paradigms.py b/mush/paradigms/paradigms.py new file mode 100644 index 0000000..193d487 --- /dev/null +++ b/mush/paradigms/paradigms.py @@ -0,0 +1,54 @@ +from functools import partial +from importlib import import_module +from typing import Callable, List, Optional, Dict, Tuple + +from .paradigm import Paradigm, MissingParadigm + + + +def missing_shifter(exception, obj): + raise exception + + +class Paradigms: + + def __init__(self): + self._paradigms: List[Paradigm] = [] + self._shifters: Dict[Tuple['Paradigm', 'Paradigm'], Callable] = {} + + def register(self, paradigm: Paradigm) -> None: + self._paradigms.insert(0, paradigm) + + def register_if_possible(self, module_path: str, class_name: str) -> Paradigm: + try: + module = import_module(module_path) + except ModuleNotFoundError as e: + paradigm = MissingParadigm(e) + else: + paradigm = getattr(module, class_name)() + self.register(paradigm) + return paradigm + + def add_shifter_if_possible( + self, source: Paradigm, target: Paradigm, module_path: str, callable_name: str + ) -> None: + try: + module = import_module(module_path) + except ModuleNotFoundError as e: + shifter = partial(missing_shifter, e) + else: + shifter = getattr(module, callable_name) + self._shifters[source, target] = shifter + + def find_paradigm(self, obj: Callable) -> Paradigm: + for paradigm in self._paradigms: + if paradigm.claim(obj): + return paradigm + raise Exception('No paradigm') + + def find_caller(self, obj: Callable, target_paradigm: Paradigm) -> Callable: + source_paradigm = self.find_paradigm(obj) + if source_paradigm is target_paradigm: + return obj + else: + return partial(self._shifters[source_paradigm, target_paradigm], obj) diff --git a/mush/typing.py b/mush/typing.py new file mode 100644 index 0000000..95d2f9f --- /dev/null +++ b/mush/typing.py @@ -0,0 +1,7 @@ +from typing import Generator, Any, TYPE_CHECKING + +if TYPE_CHECKING: + from .paradigms import Call + + +Calls = Generator['Call', Any, None] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..640cfd1 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,13 @@ +from testfixtures import compare + +from mush.context import Context + + +class TestCall: + + def test_no_params(self): + def foo(): + return 'bar' + context = Context() + result = context.call(foo) + compare(result, 'bar') diff --git a/tests/test_paradigm_asyncio.py b/tests/test_paradigm_asyncio.py new file mode 100644 index 0000000..c405e79 --- /dev/null +++ b/tests/test_paradigm_asyncio.py @@ -0,0 +1,73 @@ +import asyncio +from contextlib import contextmanager +from functools import partial +from unittest.mock import Mock + +import pytest +from testfixtures import compare + +from mush import Context +from mush.paradigms import Call +from mush import paradigms +from mush.paradigms.asyncio_ import AsyncIO + + +@contextmanager +def no_threads(): + loop = asyncio.get_event_loop() + original = loop.run_in_executor + loop.run_in_executor = Mock(side_effect=Exception('threads used when they should not be')) + try: + yield + finally: + loop.run_in_executor = original + + +@contextmanager +def must_run_in_thread(*expected): + seen = set() + loop = asyncio.get_event_loop() + original = loop.run_in_executor + + def recording_run_in_executor(executor, func, *args): + if isinstance(func, partial): + to_record = func.func + else: + # get the underlying method for bound methods: + to_record = getattr(func, '__func__', func) + seen.add(to_record) + return original(executor, func, *args) + + loop.run_in_executor = recording_run_in_executor + try: + yield + finally: + loop.run_in_executor = original + + not_seen = set(expected) - seen + assert not not_seen, f'{not_seen} not run in a thread, seen: {seen}' + + +class TestContext: + + @pytest.mark.asyncio + async def test_call_is_async(self): + context = Context(paradigm=paradigms.asyncio) + + def it(): + return 'bar' + + result = context.call(it) + assert asyncio.iscoroutine(result) + with must_run_in_thread(it): + compare(await result, expected='bar') + + @pytest.mark.asyncio + async def test_call_async(self): + context = Context() + + async def it(): + return 'bar' + + with no_threads(): + compare(await context.call(it), expected='bar') diff --git a/tests/test_paradigm_normal.py b/tests/test_paradigm_normal.py new file mode 100644 index 0000000..0767bd5 --- /dev/null +++ b/tests/test_paradigm_normal.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +from testfixtures import compare + +from mush.paradigms import Call +from mush.paradigms.normal_ import Normal + + +class TestParadigm: + + def test_claim(self): + # Since this is the "backstop" paradigm, it always claims things + p = Normal() + assert p.claim(lambda x: None) + + def test_process_single(self): + obj = Mock() + + def calls(): + yield Call(obj, ('a',), {'b': 'c'}) + + p = Normal() + + compare(p.process(calls()), expected=obj.return_value) + + obj.assert_called_with('a', b='c') + + def test_process_multiple(self): + mocks = Mock() + + results = [] + + def calls(): + results.append((yield Call(mocks.obj1, ('a',), {}))) + results.append((yield Call(mocks.obj2, ('b',), {}))) + yield Call(mocks.obj3, ('c',), {}) + + p = Normal() + + compare(p.process(calls()), expected=mocks.obj3.return_value) + + compare(results, expected=[ + mocks.obj1.return_value, + mocks.obj2.return_value, + ]) diff --git a/tests/test_paradigms.py b/tests/test_paradigms.py new file mode 100644 index 0000000..ba3fc5b --- /dev/null +++ b/tests/test_paradigms.py @@ -0,0 +1,63 @@ +from testfixtures import ShouldRaise +from testfixtures.mock import Mock + +from mush.paradigms import Paradigms + + +class TestCollection: + + def test_register_not_importable(self): + p = Paradigms() + obj = p.register_if_possible('mush.badname', 'ParadigmClass') + + with ShouldRaise(ModuleNotFoundError("No module named 'mush.badname'")): + obj.claim(lambda: None) + + with ShouldRaise(ModuleNotFoundError("No module named 'mush.badname'")): + obj.process((o for o in [])) + + def test_register_class_missing(self): + p = Paradigms() + with ShouldRaise(AttributeError( + "module 'mush.paradigms.normal_' has no attribute 'BadName'" + )): + p.register_if_possible('mush.paradigms.normal_', 'BadName') + + def test_shifter_not_importable(self): + p1 = Mock() + p2 = Mock() + p = Paradigms() + p.register(p1) + p.add_shifter_if_possible(p1, p2, 'mush.badname', 'shifter') + + caller = p.find_caller(lambda: None, target_paradigm=p2) + with ShouldRaise(ModuleNotFoundError("No module named 'mush.badname'")): + caller() + + def test_shifter_callable_missing(self): + p = Paradigms() + with ShouldRaise(AttributeError( + "module 'mush.paradigms.normal_' has no attribute 'bad_name'" + )): + p.add_shifter_if_possible(Mock(), Mock(), 'mush.paradigms.normal_', 'bad_name') + + def test_find_paradigm(self): + p1 = Mock() + p2 = Mock() + p = Paradigms() + p.register(p1) + p.register(p2) + + assert p.find_paradigm(lambda: None) is p2 + + p2.claim.return_value = False + + assert p.find_paradigm(lambda: None) is p1 + + def test_no_paradigm_claimed(self): + p_ = Mock() + p_.claim.return_value = False + p = Paradigms() + p.register(p_) + with ShouldRaise(Exception('No paradigm')): + p.find_paradigm(lambda: None) From 4e883d99e171bac6ec58b66d80824242d7f5c024 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 31 Mar 2021 10:10:05 +0100 Subject: [PATCH 159/159] adjust coverage config --- .coveragerc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index 1afc40e..23d833b 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,5 +1,5 @@ [run] -source = mush +source = mush,tests [report] exclude_lines = @@ -9,6 +9,7 @@ exclude_lines = # stuff that we don't worry about pass + \.\.\. __name__ == '__main__' # circular references needed for type checking: