Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana-s committed Feb 7, 2025
1 parent b308f91 commit 024a09e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 46 deletions.
10 changes: 6 additions & 4 deletions guppylang/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@
@dataclass(frozen=True)
class GlobalConstId:
id: int
base_name: str

_fresh_ids = itertools.count()

@staticmethod
def fresh() -> "GlobalConstId":
return GlobalConstId(next(GlobalConstId._fresh_ids))
def fresh(base_name: str) -> "GlobalConstId":
return GlobalConstId(next(GlobalConstId._fresh_ids), base_name)

def name(self, base: str) -> str:
return f"{base}.{self.id}"
@property
def name(self) -> str:
return f"{self.base_name}.{self.id}"


class CompilerContext:
Expand Down
4 changes: 1 addition & 3 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,7 @@ def visit_GlobalCall(self, node: GlobalCall) -> Wire:
assert isinstance(func, CompiledCallableDef)

args = [self.visit(arg) for arg in node.args]
rets = func.compile_call(
args, list(node.type_args), self.dfg, self.ctx, node
)
rets = func.compile_call(args, list(node.type_args), self.dfg, self.ctx, node)
if isinstance(func, CustomFunctionDef) and not func.has_signature:
func_ty = FunctionType(
[FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
Expand Down
4 changes: 1 addition & 3 deletions guppylang/definition/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ class CompiledValueDef(ValueDef, CompiledDef):
"""Abstract base class for compiled definitions that represent values."""

@abstractmethod
def load(
self, dfg: "DFContainer", ctx: "CompilerContext", node: AstNode
) -> Wire:
def load(self, dfg: "DFContainer", ctx: "CompilerContext", node: AstNode) -> Wire:
"""Loads the defined value into a local Hugr dataflow graph."""


Expand Down
71 changes: 35 additions & 36 deletions guppylang/std/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from __future__ import annotations

from typing import Final
from typing import TYPE_CHECKING, Final

import hugr.build.function as hf
from hugr import Node, Wire, ops
from hugr import tys as ht
from hugr.std.collections.array import EXTENSION
Expand All @@ -15,7 +14,7 @@
from guppylang.definition.custom import CustomCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError
from guppylang.std._internal.compiler.arithmetic import INT_T, convert_itousize
from guppylang.std._internal.compiler.arithmetic import convert_itousize
from guppylang.std._internal.compiler.prelude import (
build_expect_none,
build_unwrap,
Expand All @@ -25,6 +24,9 @@
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import int_type

if TYPE_CHECKING:
from hugr.build import function as hf

# ------------------------------------------------------
# --------------- std.array operations -----------------
# ------------------------------------------------------
Expand Down Expand Up @@ -187,23 +189,31 @@ def compile(self, args: list[Wire]) -> list[Wire]:
return [self.build_classical_array(args)]


ARRAY_GETITEM_CLASSICAL: Final[GlobalConstId] = GlobalConstId.fresh()
ARRAY_GETITEM_LINEAR: Final[GlobalConstId] = GlobalConstId.fresh()
ARRAY_SETITEM_CLASSICAL: Final[GlobalConstId] = GlobalConstId.fresh()
ARRAY_SETITEM_LINEAR: Final[GlobalConstId] = GlobalConstId.fresh()
ARRAY_GETITEM_CLASSICAL: Final[GlobalConstId] = GlobalConstId.fresh(
"array.__getitem__.classical"
)
ARRAY_GETITEM_LINEAR: Final[GlobalConstId] = GlobalConstId.fresh(
"array.__getitem__.linear"
)
ARRAY_SETITEM_CLASSICAL: Final[GlobalConstId] = GlobalConstId.fresh(
"array.__setitem__.classical"
)
ARRAY_SETITEM_LINEAR: Final[GlobalConstId] = GlobalConstId.fresh(
"array.__setitem__.linear"
)


class ArrayGetitemCompiler(ArrayCompiler):
"""Compiler for the `array.__getitem__` function."""

def _getitem_ty(self, bound: ht.TypeBound) -> ht.PolyFuncType:
def _getitem_func(self, bound: ht.TypeBound, name: str) -> hf.Function:
"""Constructs a polymorphic function type for `__getitem__`"""
# a(Option(T), N), int -> T, a(Option(T), N)
# Array element type parameter
elem_ty_param = ht.TypeTypeParam(bound)
# Array length parameter
length_param = ht.BoundedNatParam()
return ht.PolyFuncType(
func_ty = ht.PolyFuncType(
params=[elem_ty_param, length_param],
body=ht.FunctionType(
input=[
Expand All @@ -222,17 +232,17 @@ def _getitem_ty(self, bound: ht.TypeBound) -> ht.PolyFuncType:
],
),
)

def _build_classical_getitem(self, name: str) -> Wire:
"""Constructs a generic function for `__getitem__` for classical arrays."""
func_ty = self._getitem_ty(ht.TypeBound.Copyable)
func = self.ctx.module.define_function(
return self.ctx.module.define_function(
name=name,
input_types=func_ty.body.input,
output_types=func_ty.body.output,
type_params=func_ty.params,
)

def _build_classical_getitem(self, name: str) -> Wire:
"""Constructs a generic function for `__getitem__` for classical arrays."""
func = self._getitem_func(ht.TypeBound.Copyable, name)

elem_ty = ht.Variable(0, ht.TypeBound.Copyable)
length = ht.VariableArg(1, ht.BoundedNatParam())

Expand All @@ -253,13 +263,7 @@ def _build_classical_getitem(self, name: str) -> Wire:

def _build_linear_getitem(self, name: str) -> Wire:
"""Constructs function to call `array.__getitem__` for linear arrays."""
func_ty = self._getitem_ty(ht.TypeBound.Any)
func = self.ctx.module.define_function(
name=name,
input_types=func_ty.body.input,
output_types=func_ty.body.output,
type_params=func_ty.params,
)
func = self._getitem_func(ht.TypeBound.Any, name)

elem_ty = ht.Variable(0, ht.TypeBound.Any)
length = ht.VariableArg(1, ht.BoundedNatParam())
Expand Down Expand Up @@ -289,7 +293,10 @@ def _build_call_getitem(
) -> CallReturnWires:
"""Inserts a call to `array.__getitem__`."""
concrete_func_ty = ht.FunctionType(
input=[array_type(ht.Option(self.elem_ty), self.length), int_type().to_hugr()],
input=[
array_type(ht.Option(self.elem_ty), self.length),
int_type().to_hugr(),
],
output=[self.elem_ty, array_type(ht.Option(self.elem_ty), self.length)],
)
type_args = [ht.TypeTypeArg(self.elem_ty), self.length]
Expand All @@ -311,9 +318,7 @@ def compile_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
"""Lowers a call to `array.__getitem__` for classical arrays."""
if ARRAY_GETITEM_CLASSICAL not in self.ctx.global_consts:
self.ctx.global_consts[ARRAY_GETITEM_CLASSICAL] = (
self._build_classical_getitem(
name=ARRAY_GETITEM_CLASSICAL.name("array.__getitem__.classical")
)
self._build_classical_getitem(name=ARRAY_GETITEM_CLASSICAL.name)
)
return self._build_call_getitem(
func=self.ctx.global_consts[ARRAY_GETITEM_CLASSICAL],
Expand All @@ -324,10 +329,8 @@ def compile_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
def compile_linear_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
"""Lowers a call to `array.__getitem__` for classical arrays."""
if ARRAY_GETITEM_LINEAR not in self.ctx.global_consts:
self.ctx.global_consts[ARRAY_GETITEM_LINEAR] = (
self._build_linear_getitem(
name=ARRAY_GETITEM_LINEAR.name("array.__getitem__.linear")
)
self.ctx.global_consts[ARRAY_GETITEM_LINEAR] = self._build_linear_getitem(
name=ARRAY_GETITEM_LINEAR.name
)
return self._build_call_getitem(
func=self.ctx.global_consts[ARRAY_GETITEM_LINEAR],
Expand Down Expand Up @@ -468,9 +471,7 @@ def compile_classical_setitem(
"""Lowers a call to `array.__setitem__` for classical arrays."""
if ARRAY_SETITEM_CLASSICAL not in self.ctx.global_consts:
self.ctx.global_consts[ARRAY_SETITEM_CLASSICAL] = (
self._build_classical_setitem(
name=ARRAY_SETITEM_CLASSICAL.name("array.__setitem__.classical")
)
self._build_classical_setitem(name=ARRAY_SETITEM_CLASSICAL.name)
)
return self._build_call_setitem(
func=self.ctx.global_consts[ARRAY_SETITEM_CLASSICAL],
Expand All @@ -484,10 +485,8 @@ def compile_linear_setitem(
) -> CallReturnWires:
"""Lowers a call to `array.__setitem__` for linear arrays."""
if ARRAY_SETITEM_LINEAR not in self.ctx.global_consts:
self.ctx.global_consts[ARRAY_SETITEM_LINEAR] = (
self._build_linear_setitem(
name=ARRAY_SETITEM_LINEAR.name("array.__setitem__.linear")
)
self.ctx.global_consts[ARRAY_SETITEM_LINEAR] = self._build_linear_setitem(
name=ARRAY_SETITEM_LINEAR.name
)
return self._build_call_setitem(
func=self.ctx.global_consts[ARRAY_SETITEM_LINEAR],
Expand Down

0 comments on commit 024a09e

Please sign in to comment.