Skip to content

Commit

Permalink
Merge pull request #21 from stealthrocket/experimental-tidy
Browse files Browse the repository at this point in the history
Consistent formatting / coverage across repo
  • Loading branch information
chriso authored Feb 3, 2024
2 parents cbe54c6 + 8cdc60b commit ebe60d9
Show file tree
Hide file tree
Showing 15 changed files with 613 additions and 141 deletions.
30 changes: 30 additions & 0 deletions src/dispatch/experimental/durable/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
""""A decorator that makes generators serializable.
This module defines a @durable decorator that can be applied to generator
functions. The resulting generators can be pickled.
Example usage:
import pickle
from dispatch.experimental.durable import durable
@durable
def my_generator():
for i in range(3):
yield i
# Run the generator to its first yield point:
g = my_generator()
print(next(g)) # 0
# Make a copy, and consume the remaining items:
b = pickle.dumps(g)
g2 = pickle.loads(b)
print(next(g2)) # 1
print(next(g2)) # 2
# The original is not affected:
print(next(g)) # 1
print(next(g)) # 2
"""

from .durable import durable

__all__ = ["durable"]
30 changes: 20 additions & 10 deletions src/dispatch/experimental/durable/durable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@
from .registry import register_function


def durable(fn):
"""A decorator for a generator that makes it pickle-able."""
return DurableFunction(fn)


class DurableFunction:
"""A durable generator function that can be pickled."""
"""A wrapper for a generator function that wraps its generator instances
with a DurableGenerator.
Attributes:
fn: A generator function.
key: A key that uniquely identifies the function.
"""

def __init__(self, fn: FunctionType):
self.fn = fn
self.key = register_function(fn)

def __call__(self, *args, **kwargs):
result = self.fn(*args, **kwargs)
if isinstance(result, GeneratorType):
return DurableGenerator(result, self.key, args, kwargs)
if not isinstance(result, GeneratorType):
raise NotImplementedError(
"only synchronous generator functions are supported"
)
return DurableGenerator(result, self.key, args, kwargs)

# TODO: support native coroutines
raise NotImplementedError

def durable(fn) -> DurableFunction:
"""Returns a "durable" function that creates serializable generators.
Args:
fn: A generator function.
"""
return DurableFunction(fn)
32 changes: 24 additions & 8 deletions src/dispatch/experimental/durable/generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import CodeType, FrameType, GeneratorType, TracebackType
from typing import Generator, TypeVar
from typing import Any, Generator, TypeVar

from . import frame as ext
from .registry import lookup_function
Expand All @@ -10,13 +10,29 @@


class DurableGenerator(Generator[_YieldT, _SendT, _ReturnT]):
"""A generator that can be pickled."""

def __init__(self, gen: GeneratorType, key, args, kwargs):
self.generator = gen

# Capture the information necessary to be able to create a
# new instance of the generator.
"""A wrapper for a generator that makes it serializable (can be pickled).
Instances behave like the generators they wrap.
Attributes:
generator: The wrapped generator.
key: A unique identifier for the function that created this generator.
args: Positional arguments to the function that created this generator.
kwargs: Keyword arguments to the function that created this generator.
"""

generator: GeneratorType
key: str
args: list[Any]
kwargs: dict[str, Any]

def __init__(
self,
generator: GeneratorType,
key: str,
args: list[Any],
kwargs: dict[str, Any],
):
self.generator = generator
self.key = key
self.args = args
self.kwargs = kwargs
Expand Down
23 changes: 22 additions & 1 deletion src/dispatch/experimental/durable/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@


def register_function(fn: FunctionType) -> str:
"""Register a generator function.
Args:
fn: The function to register.
Returns:
str: Unique identifier for the function.
Raises:
ValueError: The function conflicts with another registered function.
"""
# We need to be able to refer to the function in the serialized
# representation, and the key needs to be stable across interpreter
# invocations. Use the code object's fully-qualified name for now.
Expand All @@ -15,9 +26,19 @@ def register_function(fn: FunctionType) -> str:
raise ValueError(f"durable function already registered with key {key}")

_REGISTRY[key] = fn

return key


def lookup_function(key: str) -> FunctionType:
"""Lookup a previously registered function.
Args:
key: Unique identifier for the function.
Returns:
FunctionType: The associated function.
Raises:
KeyError: A function has not been registered with this key.
"""
return _REGISTRY[key]
10 changes: 8 additions & 2 deletions src/dispatch/experimental/multicolor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .compile import compile_function
from .compile import NoSourceError, compile_function
from .yields import CustomYield, GeneratorYield, yields

__all__ = ["compile_function", "yields", "CustomYield", "GeneratorYield"]
__all__ = [
"compile_function",
"yields",
"CustomYield",
"GeneratorYield",
"NoSourceError",
]
102 changes: 62 additions & 40 deletions src/dispatch/experimental/multicolor/compile.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import ast
import inspect
import os
import types
import textwrap
from enum import Enum
from types import FunctionType, GeneratorType, MethodType
from typing import cast

from .desugar import desugar_function
from .generator import empty_generator, is_generator
from .parse import NoSourceError, parse_function, repair_indentation
from .parse import NoSourceError, parse_function
from .template import rewrite_template
from .yields import CustomYield, GeneratorYield

Expand All @@ -19,7 +19,7 @@ def compile_function(
fn: FunctionType, decorator: FunctionType | None = None, cache_key: str = "default"
) -> FunctionType | MethodType:
"""Compile a regular function into a generator that yields data passed
to functions marked with the @multicolor.yields decorator. Decorated
to functions marked with the @multicolor.yields decorator. Decorated yield
functions can be called from anywhere in the call stack, and functions
in between do not have to be generators or async functions (coroutines).
Expand All @@ -29,7 +29,7 @@ def compile_function(
def sleep(seconds): ...
def parent():
sleep(3) # yield point!
sleep(3) # yield point
def grandparent():
parent()
Expand All @@ -42,7 +42,7 @@ def grandparent():
Two-way data flow works as expected. At a yield point, generator.send(value)
can be used to send data back to the yield point and to resume execution.
The data sent back will be the return value of the function decorated with
@multicolor.yields.
@multicolor.yields:
@multicolor.yields(type="add")
def add(a: int, b: int) -> int:
Expand Down Expand Up @@ -77,17 +77,39 @@ def adder(a: int, b: int) -> int:
The default implementation could also raise an error, to ensure that
the function is only ever called from a compiled function.
fn: FunctionType, decorator: FunctionType | None = None, cache_key: str = "default"
Args:
fn: The function to compile.
decorator: An optional decorator to apply to the compiled function.
cache_key: Cache key to use when caching compiled functions.
Returns:
FunctionType: A compiled generator function.
"""
compiled_fn, _ = compile_internal(fn, decorator, cache_key)
compiled_fn, _ = _compile_internal(fn, decorator, cache_key)
return compiled_fn


class FunctionColor(Enum):
"""Color (aka. type/flavor) of a function.
There are four colors of functions in Python:
* regular (e.g. def fn(): pass)
* generator (e.g. def fn(): yield)
* async (e.g. async def fn(): pass)
* async generator (e.g. async def fn(): yield)
Only the first two colors are supported at this time.
"""

REGULAR_FUNCTION = 0
GENERATOR_FUNCTION = 1


def compile_internal(
def _compile_internal(
fn: FunctionType, decorator: FunctionType | None, cache_key: str
) -> tuple[FunctionType | MethodType, FunctionColor]:
if hasattr(fn, "_multicolor_yield_type"):
Expand All @@ -99,7 +121,7 @@ def compile_internal(
# Check if the function has already been compiled.
cache_holder = fn
if isinstance(fn, MethodType):
cache_holder = fn.__self__
cache_holder = fn.__self__.__class__
if hasattr(cache_holder, "_multicolor_cache"):
try:
compiled_fn, color = cache_holder._multicolor_cache[fn_name]
Expand All @@ -121,6 +143,8 @@ def compile_internal(
return fn, FunctionColor.GENERATOR_FUNCTION
except TypeError:
raise e
else:
raise

# Determine what type of function we're working with.
color = FunctionColor.REGULAR_FUNCTION
Expand All @@ -130,7 +154,7 @@ def compile_internal(
if TRACE:
print("\n-------------------------------------------------")
print("[MULTICOLOR] COMPILING:")
print(repair_indentation(inspect.getsource(fn)).rstrip())
print(textwrap.dedent(inspect.getsource(fn)).rstrip())

fn_def.name = fn_name

Expand Down Expand Up @@ -171,7 +195,7 @@ def compile_internal(
namespace["_multicolor_no_source_error"] = NoSourceError
namespace["_multicolor_custom_yield"] = CustomYield
namespace["_multicolor_generator_yield"] = GeneratorYield
namespace["_multicolor_compile"] = compile_internal
namespace["_multicolor_compile"] = _compile_internal
namespace["_multicolor_generator_type"] = GeneratorType
namespace["_multicolor_decorator"] = decorator
namespace["_multicolor_cache_key"] = cache_key
Expand Down Expand Up @@ -269,38 +293,36 @@ def _build_call_gadget(

result = rewrite_template(
"""
if hasattr(__fn__, "_multicolor_yield_type"):
_multicolor_result = yield _multicolor_custom_yield(type=__fn__._multicolor_yield_type, args=__args__, kwargs=__kwargs__)
__assign_result__
else:
_multicolor_result = None
try:
__compiled_fn__, _multicolor_color = _multicolor_compile(__fn__, _multicolor_decorator, _multicolor_cache_key)
except _multicolor_no_source_error:
_multicolor_result = __fn_call__
else:
_multicolor_generator = __compiled_fn_call__
if _multicolor_color == _multicolor_generator_color:
_multicolor_result = []
for _multicolor_yield in _multicolor_generator:
if isinstance(_multicolor_yield, _multicolor_generator_yield):
_multicolor_result.append(_multicolor_yield.value)
if hasattr(__fn__, "_multicolor_yield_type"):
_multicolor_result = yield _multicolor_custom_yield(type=__fn__._multicolor_yield_type, args=__args__, kwargs=__kwargs__)
__assign_result__
else:
_multicolor_result = None
try:
__compiled_fn__, _multicolor_color = _multicolor_compile(__fn__, _multicolor_decorator, _multicolor_cache_key)
except _multicolor_no_source_error:
_multicolor_result = __fn_call__
else:
yield _multicolor_yield
else:
_multicolor_result = yield from _multicolor_generator
finally:
__assign_result__
_multicolor_generator = __compiled_fn_call__
if _multicolor_color == _multicolor_generator_color:
_multicolor_result = []
for _multicolor_yield in _multicolor_generator:
if isinstance(_multicolor_yield, _multicolor_generator_yield):
_multicolor_result.append(_multicolor_yield.value)
else:
yield _multicolor_yield
else:
_multicolor_result = yield from _multicolor_generator
finally:
__assign_result__
""",
expressions=dict(
__fn__=fn,
__fn_call__=fn_call,
__args__=args,
__kwargs__=kwargs,
__compiled_fn__=compiled_fn,
__compiled_fn_call__=compiled_fn_call,
),
statements=dict(__assign_result__=assign_result),
__fn__=fn,
__fn_call__=fn_call,
__args__=args,
__kwargs__=kwargs,
__compiled_fn__=compiled_fn,
__compiled_fn_call__=compiled_fn_call,
__assign_result__=assign_result,
)

return result[0]
Loading

0 comments on commit ebe60d9

Please sign in to comment.