Skip to content

Commit

Permalink
feat: Turn int and float into core types (#225)
Browse files Browse the repository at this point in the history
This PR adds a new `NumericType` to the type hierarchy that now
represents `int` and `float`.
  • Loading branch information
mark-koch authored Jun 24, 2024
1 parent b2901d8 commit 99217dc
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 56 deletions.
13 changes: 13 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from guppylang.tys.builtin import (
bool_type_def,
callable_type_def,
float_type_def,
int_type_def,
linst_type_def,
list_type_def,
none_type_def,
Expand All @@ -24,6 +26,7 @@
ExistentialTypeVar,
FunctionType,
NoneType,
NumericType,
OpaqueType,
StructType,
SumType,
Expand Down Expand Up @@ -67,6 +70,8 @@ def default() -> "Globals":
tuple_type_def,
none_type_def,
bool_type_def,
int_type_def,
float_type_def,
list_type_def,
linst_type_def,
]
Expand All @@ -85,6 +90,14 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
pass
case BoundTypeVar() | ExistentialTypeVar() | SumType():
return None
case NumericType(kind):
match kind:
case NumericType.Kind.Int:
type_defn = int_type_def
case NumericType.Kind.Float:
type_defn = float_type_def
case kind:
return assert_never(kind)
case FunctionType():
type_defn = callable_type_def
case OpaqueType() as ty:
Expand Down
63 changes: 16 additions & 47 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hugr.serialization import ops, tys
from pydantic import BaseModel

from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.ast_util import AstNode, get_type, with_loc
from guppylang.checker.core import Context
from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args
from guppylang.definition.custom import (
Expand All @@ -12,36 +12,13 @@
CustomFunctionDef,
DefaultCallChecker,
)
from guppylang.definition.ty import TypeDef
from guppylang.definition.value import CallableDef
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.builtin import bool_type, list_type
from guppylang.tys.subst import Subst
from guppylang.tys.ty import FunctionType, OpaqueType, Type, unify

INT_WIDTH = 6 # 2^6 = 64 bit


hugr_int_type = tys.Type(
tys.Opaque(
extension="arithmetic.int.types",
id="int",
args=[tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))],
bound=tys.TypeBound.Eq,
)
)


hugr_float_type = tys.Type(
tys.Opaque(
extension="arithmetic.float.types",
id="float64",
args=[],
bound=tys.TypeBound.Copyable,
)
)
from guppylang.tys.ty import FunctionType, NumericType, Type, unify


class ConstInt(BaseModel):
Expand Down Expand Up @@ -76,9 +53,9 @@ def int_value(i: int) -> ops.Value:
return ops.Value(
ops.ExtensionValue(
extensions=["arithmetic.int.types"],
typ=hugr_int_type,
typ=NumericType(NumericType.Kind.Int).to_hugr(),
value=ops.CustomConst(
c="ConstInt", v=ConstInt(log_width=INT_WIDTH, value=i)
c="ConstInt", v=ConstInt(log_width=NumericType.INT_WIDTH, value=i)
),
)
)
Expand All @@ -89,7 +66,7 @@ def float_value(f: float) -> ops.Value:
return ops.Value(
ops.ExtensionValue(
extensions=["arithmetic.float.types"],
typ=hugr_float_type,
typ=NumericType(NumericType.Kind.Float).to_hugr(),
value=ops.CustomConst(c="ConstF64", v=ConstF64(value=f)),
)
)
Expand All @@ -116,16 +93,16 @@ def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType:


def int_op(
op_name: str, ext: str = "arithmetic.int", num_params: int = 1
op_name: str,
ext: str = "arithmetic.int",
args: list[tys.TypeArg] | None = None,
num_params: int = 1,
) -> ops.OpType:
"""Utility method to create Hugr integer arithmetic ops."""
if args is None:
args = num_params * [tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))]
return ops.OpType(
ops.CustomOp(
extension=ext,
op_name=op_name,
args=num_params * [tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))],
parent=UNDEFINED,
)
ops.CustomOp(extension=ext, op_name=op_name, args=args, parent=UNDEFINED)
)


Expand All @@ -140,20 +117,12 @@ class CoercingChecker(DefaultCallChecker):
"""Function call type checker that automatically coerces arguments to float."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
from .builtins import Int

for i in range(len(args)):
args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i])
if isinstance(ty, OpaqueType) and ty.defn == self.ctx.globals["int"]:
call = with_loc(
self.node,
GlobalCall(def_id=Int.__float__.id, args=[args[i]], type_args=[]),
)
float_defn = self.ctx.globals["float"]
assert isinstance(float_defn, TypeDef)
args[i] = with_type(
float_defn.check_instantiate([], self.ctx.globals), call
)
if isinstance(ty, NumericType) and ty.kind != NumericType.Kind.Float:
to_float = self.ctx.globals.get_instance_func(ty, "__float__")
assert to_float is not None
args[i], _ = to_float.synthesize_call([args[i]], self.node, self.ctx)
return super().synthesize(args)


Expand Down
16 changes: 10 additions & 6 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
ReversingChecker,
UnsupportedChecker,
float_op,
hugr_float_type,
hugr_int_type,
int_op,
logic_op,
)
from guppylang.tys.builtin import bool_type_def, linst_type_def, list_type_def
from guppylang.tys.builtin import (
bool_type_def,
float_type_def,
int_type_def,
linst_type_def,
list_type_def,
)

builtins = GuppyModule("builtins", import_builtins=False)

Expand Down Expand Up @@ -64,7 +68,7 @@ def __new__(x): ...
def __or__(self: bool, other: bool) -> bool: ...


@guppy.type(builtins, hugr_int_type, name="int")
@guppy.extend_type(builtins, int_type_def)
class Int:
@guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!)
def __abs__(self: int) -> int: ...
Expand Down Expand Up @@ -106,7 +110,7 @@ def __gt__(self: int, other: int) -> bool: ...
def __int__(self: int) -> int: ...

@guppy.hugr_op(builtins, int_op("inot"))
def __invert__(self: int) -> bool: ...
def __invert__(self: int) -> int: ...

@guppy.hugr_op(builtins, int_op("ile_s"))
def __le__(self: int, other: int) -> bool: ...
Expand Down Expand Up @@ -203,7 +207,7 @@ def __trunc__(self: int) -> int: ...
def __xor__(self: int, other: int) -> int: ...


@guppy.type(builtins, hugr_float_type, name="float", bound=tys.TypeBound.Copyable)
@guppy.extend_type(builtins, float_type_def)
class Float:
@guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker())
def __abs__(self: float) -> float: ...
Expand Down
32 changes: 31 additions & 1 deletion guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument, TypeArg
from guppylang.tys.param import TypeParam
from guppylang.tys.ty import FunctionType, NoneType, OpaqueType, TupleType, Type
from guppylang.tys.ty import (
FunctionType,
NoneType,
NumericType,
OpaqueType,
TupleType,
Type,
)

if TYPE_CHECKING:
from guppylang.checker.core import Globals
Expand Down Expand Up @@ -79,6 +86,23 @@ def check_instantiate(
return NoneType()


@dataclass(frozen=True)
class _NumericTypeDef(TypeDef):
"""Type definition associated with the builtin numeric types.
Any impls on numerics can be registered with these definitions.
"""

ty: NumericType

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> NumericType:
if args:
raise GuppyError(f"Type `{self.name}` is not parameterized", loc)
return self.ty


@dataclass(frozen=True)
class _ListTypeDef(OpaqueTypeDef):
"""Type definition associated with the builtin `list` type.
Expand Down Expand Up @@ -123,6 +147,12 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
always_linear=False,
to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))),
)
int_type_def = _NumericTypeDef(
DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
)
float_type_def = _NumericTypeDef(
DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
)
linst_type_def = OpaqueTypeDef(
id=DefId.fresh(),
name="linst",
Expand Down
5 changes: 5 additions & 0 deletions guppylang/tys/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from guppylang.tys.ty import (
FunctionType,
NoneType,
NumericType,
OpaqueType,
StructType,
SumType,
Expand Down Expand Up @@ -106,6 +107,10 @@ def _visit_SumType(self, ty: SumType, inside_row: bool) -> str:
def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str:
return "None"

@_visit.register
def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str:
return ty.kind.value

@_visit.register
def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str:
# TODO: Print linearity?
Expand Down
68 changes: 66 additions & 2 deletions guppylang/tys/ty.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, TypeAlias, cast
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast

from hugr.serialization import tys
from hugr.serialization.tys import TypeBound
Expand Down Expand Up @@ -234,6 +235,65 @@ def transform(self, transformer: Transformer) -> "Type":
return transformer.transform(self) or self


@dataclass(frozen=True)
class NumericType(TypeBase):
"""Numeric types like `int` and `float`."""

kind: "Kind"

class Kind(Enum):
"""The different kinds of numeric types."""

Int = "int"
Float = "float"

INT_WIDTH: ClassVar[int] = 6

@property
def linear(self) -> bool:
"""Whether this type should be treated linearly."""
return False

def to_hugr(self) -> tys.Type:
"""Computes the Hugr representation of the type."""
match self.kind:
case NumericType.Kind.Int:
return tys.Type(
tys.Opaque(
extension="arithmetic.int.types",
id="int",
args=[tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))],
bound=tys.TypeBound.Eq,
)
)
case NumericType.Kind.Float:
return tys.Type(
tys.Opaque(
extension="arithmetic.float.types",
id="float64",
args=[],
bound=tys.TypeBound.Copyable,
)
)

@property
def hugr_bound(self) -> tys.TypeBound:
"""The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
match self.kind:
case NumericType.Kind.Float:
return tys.TypeBound.Copyable
case _:
return tys.TypeBound.Eq

def visit(self, visitor: Visitor) -> None:
"""Accepts a visitor on this type."""
visitor.visit(self)

def transform(self, transformer: Transformer) -> "Type":
"""Accepts a transformer on this type."""
return transformer.transform(self) or self


@dataclass(frozen=True, init=False)
class FunctionType(ParametrizedTypeBase):
"""Type of (potentially generic) functions."""
Expand Down Expand Up @@ -493,7 +553,9 @@ def transform(self, transformer: Transformer) -> "Type":
#: This might become obsolete in case the @sealed decorator is added:
#: * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types
#: * https://github.com/johnthagen/sealed-typing-pep
Type: TypeAlias = BoundTypeVar | ExistentialTypeVar | NoneType | ParametrizedType
Type: TypeAlias = (
BoundTypeVar | ExistentialTypeVar | NumericType | NoneType | ParametrizedType
)

#: An immutable row of Guppy types.
TypeRow: TypeAlias = Sequence[Type]
Expand Down Expand Up @@ -545,6 +607,8 @@ def unify(s: Type, t: Type, subst: "Subst | None") -> "Subst | None":
return _unify_var(t, s, subst)
case BoundTypeVar(idx=s_idx), BoundTypeVar(idx=t_idx) if s_idx == t_idx:
return subst
case NumericType(kind=s_kind), NumericType(kind=t_kind) if s_kind == t_kind:
return subst
case NoneType(), NoneType():
return subst
case FunctionType() as s, FunctionType() as t if s.params == t.params:
Expand Down

0 comments on commit 99217dc

Please sign in to comment.