Skip to content

Commit

Permalink
First attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Feb 10, 2024
1 parent 43402f7 commit 20b88a3
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 24 deletions.
101 changes: 77 additions & 24 deletions src/dispatch/experimental/multicolor/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

TRACE = os.getenv("MULTICOLOR_TRACE", False)


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -106,8 +105,8 @@ class FunctionColor(Enum):
Only the first two colors are supported at this time.
"""

REGULAR_FUNCTION = 0
GENERATOR_FUNCTION = 1
REGULAR = 0
GENERATOR = 1


def _compile_internal(
Expand All @@ -116,8 +115,6 @@ def _compile_internal(
if hasattr(fn, "_multicolor_yield_type"):
raise ValueError("cannot compile a yield point directly")

logger.debug("compiling function %s", fn.__name__)

# Give the function a unique name.
fn_name = fn.__name__ + "__multicolor_" + cache_key

Expand All @@ -143,20 +140,23 @@ def _compile_internal(
# This can occur when compiling a nested function definition
# that was created by the desugaring pass.
if inspect.getsourcefile(fn) == "<multicolor>":
return fn, FunctionColor.GENERATOR_FUNCTION
return fn, FunctionColor.GENERATOR
except TypeError:
raise e
else:
raise

logger.debug("compiling function %s", fn.__name__)

# Determine what type of function we're working with.
color = FunctionColor.REGULAR_FUNCTION
color = FunctionColor.REGULAR
if is_generator(fn_def):
color = FunctionColor.GENERATOR_FUNCTION
color = FunctionColor.GENERATOR

if TRACE:
func_or_method = "METHOD" if isinstance(fn, MethodType) else "FUNCTION"
print("\n-------------------------------------------------")
print("[MULTICOLOR] COMPILING:")
print(f"[MULTICOLOR] COMPILING {color.name} {func_or_method}:")
print(textwrap.dedent(inspect.getsource(fn)).rstrip())

fn_def.name = fn_name
Expand All @@ -172,9 +172,9 @@ def _compile_internal(
generator_transformer = GeneratorTransformer()
root = generator_transformer.visit(root)

# Replace explicit function calls with a gadget that resembles yield from.
call_transformer = CallTransformer()
root = call_transformer.visit(root)
# Rewrite function calls.
root = ClassInitTransformer().visit(root)
root = YieldFromTransformer().visit(root)

# If the function never yields it won't be considered a generator.
# Patch the function if necessary to yield from an empty generator, which
Expand Down Expand Up @@ -202,7 +202,7 @@ def _compile_internal(
namespace["_multicolor_generator_type"] = GeneratorType
namespace["_multicolor_decorator"] = decorator
namespace["_multicolor_cache_key"] = cache_key
namespace["_multicolor_generator_color"] = FunctionColor.GENERATOR_FUNCTION
namespace["_multicolor_generator_color"] = FunctionColor.GENERATOR

# Re-compile.
code = compile(root, filename="<multicolor>", mode="exec")
Expand Down Expand Up @@ -231,7 +231,7 @@ def _compile_internal(


class GeneratorTransformer(ast.NodeTransformer):
"""Wrap ast.Yield values in a GeneratorYield container."""
"""Wrap ast.Yield values with GeneratorYield."""

def visit_Yield(self, node: ast.Yield) -> ast.Yield:
value = node.value
Expand All @@ -247,9 +247,7 @@ def visit_Yield(self, node: ast.Yield) -> ast.Yield:


class CallTransformer(ast.NodeTransformer):
"""Replace explicit function calls with a gadget that recursively compiles
functions into generators and then replaces the function call with a
yield from.
"""Base class for transformations that replace ast.Call with an ast.stmt.
The transformations are only valid for ASTs that have passed through the
desugaring pass; only ast.Expr(value=ast.Call(...)) and
Expand All @@ -260,14 +258,71 @@ def visit_Assign(self, node: ast.Assign) -> ast.stmt:
if not isinstance(node.value, ast.Call):
return node
assign_stmt = ast.Assign(targets=node.targets)
return self._build_call_gadget(node.value, assign_stmt)
return self.replace_call(node.value, assign_stmt)

def visit_Expr(self, node: ast.Expr) -> ast.stmt:
if not isinstance(node.value, ast.Call):
return node
return self._build_call_gadget(node.value)
return self.replace_call(node.value)

def replace_call(
self, fn_call: ast.Call, assign: ast.Assign | None = None
) -> ast.stmt:
raise NotImplementedError


class ClassInitTransformer(CallTransformer):
"""Replace class instantiations with a lower-level form.
This allows us to avoid having to create a special case for function calls
that are class instantiations, and can instead rewrite __init__ method
calls directly."""

def replace_call(
self, fn_call: ast.Call, assign: ast.Assign | None = None
) -> ast.stmt:
fn = fn_call.func
init = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="_multicolor_instance", ctx=ast.Load()),
attr="__init__",
ctx=ast.Load(),
),
args=fn_call.args,
keywords=fn_call.keywords,
)
)

def _build_call_gadget(
if assign:
assign.value = ast.Name(id="_multicolor_instance", ctx=ast.Load())
assign_result: ast.stmt = assign
else:
assign_result = ast.Pass()

return rewrite_template(
"""
if isinstance(__fn__, type) and hasattr(__fn__, "__init__"):
_multicolor_instance = __fn__.__new__(__fn__)
__init__
__assign_result__
else:
_multicolor_instance = __fn_call__
__assign_result__
""",
__fn__=fn,
__fn_call__=fn_call,
__init__=init,
__assign_result__=assign_result,
)[0]


class YieldFromTransformer(CallTransformer):
"""Replace explicit function calls with a gadget that recursively compiles
functions into generators and then replaces the function call with a
yield from."""

def replace_call(
self, fn_call: ast.Call, assign: ast.Assign | None = None
) -> ast.stmt:
fn = fn_call.func
Expand All @@ -294,7 +349,7 @@ def _build_call_gadget(
else:
assign_result = ast.Pass()

result = rewrite_template(
return rewrite_template(
"""
if hasattr(__fn__, "_multicolor_yield_type"):
_multicolor_result = yield _multicolor_custom_yield(type=__fn__._multicolor_yield_type, args=__args__, kwargs=__kwargs__)
Expand Down Expand Up @@ -326,6 +381,4 @@ def _build_call_gadget(
__compiled_fn__=compiled_fn,
__compiled_fn_call__=compiled_fn_call,
__assign_result__=assign_result,
)

return result[0]
)[0]
21 changes: 21 additions & 0 deletions tests/dispatch/experimental/multicolor/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,27 @@ def fn():
]
self.assert_yields(fn, yields=yields, returns=3)

def test_class_init(self):
class Foo:
def __init__(self, a, b):
self.result = add(a, b)

def init_foo(a, b):
foo = Foo(a, b)
return foo.result

init_foo.__globals__["Foo"] = Foo

self.assert_yields(
init_foo,
args=[1, 2],
yields=[
CustomYield(type=YieldTypes.ADD, args=[1, 2]),
],
sends=[3],
returns=3,
)

def test_class_method(self):
class Foo:
def sleep_then_fma(self, m, a, b):
Expand Down

0 comments on commit 20b88a3

Please sign in to comment.