Skip to content

Commit

Permalink
Add AnyOf requirement.
Browse files Browse the repository at this point in the history
  • Loading branch information
cjw296 committed Mar 10, 2020
1 parent 3284664 commit 0971763
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 3 deletions.
17 changes: 17 additions & 0 deletions mush/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,20 @@ def resolve(self, context):
if self.cache:
context.add(result, provides=self.key)
return result


class AnyOf(Requirement):
"""
A requirement that is resolved by any of the specified keys.
"""

def __init__(self, *keys, default=missing):
super().__init__(keys, default=default)

@nonblocking
def resolve(self, context: 'Context'):
for key in self.key:
value = context.get(key, missing)
if value is not missing:
return value
return self.default
14 changes: 13 additions & 1 deletion mush/tests/test_async_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mush import Context, Value, requires, returns
from mush.asyncio import Context
from mush.declarations import RequiresType
from mush.requirements import Requirement
from mush.requirements import Requirement, AnyOf
from testfixtures import compare

from mush.tests.test_context import TheType
Expand Down Expand Up @@ -175,6 +175,18 @@ async def it(baz):
compare(await context.call(it), expected='foobar')


@pytest.mark.asyncio
async def test_anyof_resolve_does_not_run_in_thread(no_threads):
with no_threads:
context = Context()
context.add(('foo', ))

async def bob(x: str = AnyOf(tuple, Tuple[str])):
return x[0]

compare(await context.call(bob), expected='foo')


@pytest.mark.asyncio
async def test_custom_requirement_async_resolve():

Expand Down
44 changes: 42 additions & 2 deletions mush/tests/test_requirements.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Tuple
from unittest.case import TestCase

import pytest
from mock import Mock
from testfixtures import compare, ShouldRaise

from mush import Context, Call, Value, missing, requires
from mush.requirements import Requirement, AttrOp, ItemOp
from mush import Context, Call, Value, missing, requires, ResourceError
from mush.requirements import Requirement, AttrOp, ItemOp, AnyOf
from .helpers import Type1


Expand Down Expand Up @@ -182,3 +183,42 @@ def bob(x: str = Call(foo)['a']):
return x+'c'

compare(context.call(bob), expected='bc')


class TestAnyOf:

def test_first(self):
context = Context()
context.add(('foo', ))
context.add(('bar', ), provides=Tuple[str])

def bob(x: str = AnyOf(tuple, Tuple[str])):
return x[0]

compare(context.call(bob), expected='foo')

def test_second(self):
context = Context()
context.add(('bar', ), provides=Tuple[str])

def bob(x: str = AnyOf(tuple, Tuple[str])):
return x[0]

compare(context.call(bob), expected='bar')

def test_none(self):
context = Context()

def bob(x: str = AnyOf(tuple, Tuple[str])):
pass

with ShouldRaise(ResourceError):
context.call(bob)

def test_default(self):
context = Context()

def bob(x: str = AnyOf(tuple, Tuple[str], default=(42,))):
return x[0]

compare(context.call(bob), expected=42)

0 comments on commit 0971763

Please sign in to comment.