From 0971763624a9b945921d63d8f6800c566db16b33 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Tue, 10 Mar 2020 11:54:16 +0000 Subject: [PATCH] Add AnyOf requirement. --- mush/requirements.py | 17 ++++++++++++ mush/tests/test_async_context.py | 14 +++++++++- mush/tests/test_requirements.py | 44 ++++++++++++++++++++++++++++++-- 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/mush/requirements.py b/mush/requirements.py index 810fc6b..6e85b86 100644 --- a/mush/requirements.py +++ b/mush/requirements.py @@ -175,3 +175,20 @@ def resolve(self, context): if self.cache: context.add(result, provides=self.key) return result + + +class AnyOf(Requirement): + """ + A requirement that is resolved by any of the specified keys. + """ + + def __init__(self, *keys, default=missing): + super().__init__(keys, default=default) + + @nonblocking + def resolve(self, context: 'Context'): + for key in self.key: + value = context.get(key, missing) + if value is not missing: + return value + return self.default diff --git a/mush/tests/test_async_context.py b/mush/tests/test_async_context.py index 26cdcd1..0206a92 100644 --- a/mush/tests/test_async_context.py +++ b/mush/tests/test_async_context.py @@ -8,7 +8,7 @@ from mush import Context, Value, requires, returns from mush.asyncio import Context from mush.declarations import RequiresType -from mush.requirements import Requirement +from mush.requirements import Requirement, AnyOf from testfixtures import compare from mush.tests.test_context import TheType @@ -175,6 +175,18 @@ async def it(baz): compare(await context.call(it), expected='foobar') +@pytest.mark.asyncio +async def test_anyof_resolve_does_not_run_in_thread(no_threads): + with no_threads: + context = Context() + context.add(('foo', )) + + async def bob(x: str = AnyOf(tuple, Tuple[str])): + return x[0] + + compare(await context.call(bob), expected='foo') + + @pytest.mark.asyncio async def test_custom_requirement_async_resolve(): diff --git a/mush/tests/test_requirements.py b/mush/tests/test_requirements.py index 2f60268..2493521 100644 --- a/mush/tests/test_requirements.py +++ b/mush/tests/test_requirements.py @@ -1,11 +1,12 @@ +from typing import Tuple from unittest.case import TestCase import pytest from mock import Mock from testfixtures import compare, ShouldRaise -from mush import Context, Call, Value, missing, requires -from mush.requirements import Requirement, AttrOp, ItemOp +from mush import Context, Call, Value, missing, requires, ResourceError +from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf from .helpers import Type1 @@ -182,3 +183,42 @@ def bob(x: str = Call(foo)['a']): return x+'c' compare(context.call(bob), expected='bc') + + +class TestAnyOf: + + def test_first(self): + context = Context() + context.add(('foo', )) + context.add(('bar', ), provides=Tuple[str]) + + def bob(x: str = AnyOf(tuple, Tuple[str])): + return x[0] + + compare(context.call(bob), expected='foo') + + def test_second(self): + context = Context() + context.add(('bar', ), provides=Tuple[str]) + + def bob(x: str = AnyOf(tuple, Tuple[str])): + return x[0] + + compare(context.call(bob), expected='bar') + + def test_none(self): + context = Context() + + def bob(x: str = AnyOf(tuple, Tuple[str])): + pass + + with ShouldRaise(ResourceError): + context.call(bob) + + def test_default(self): + context = Context() + + def bob(x: str = AnyOf(tuple, Tuple[str], default=(42,))): + return x[0] + + compare(context.call(bob), expected=42)