Skip to content

Commit

Permalink
feat: Add a nat type and make int/float core types
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed May 28, 2024
1 parent b25b51c commit 8c04951
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 52 deletions.
19 changes: 19 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from guppylang.tys.builtin import (
bool_type_def,
callable_type_def,
float_type_def,
int_type_def,
linst_type_def,
list_type_def,
nat_type_def,
none_type_def,
tuple_type_def,
)
Expand All @@ -24,6 +27,7 @@
ExistentialTypeVar,
FunctionType,
NoneType,
NumericType,
OpaqueType,
StructType,
SumType,
Expand Down Expand Up @@ -67,6 +71,9 @@ def default() -> "Globals":
tuple_type_def,
none_type_def,
bool_type_def,
nat_type_def,
int_type_def,
float_type_def,
list_type_def,
linst_type_def,
]
Expand All @@ -85,6 +92,18 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
pass
case BoundTypeVar() | ExistentialTypeVar() | SumType():
return None
case NumericType(kind):
match kind:
case NumericType.Kind.Bool:
type_defn = bool_type_def
case NumericType.Kind.Nat:
type_defn = nat_type_def
case NumericType.Kind.Int:
type_defn = int_type_def
case NumericType.Kind.Float:
type_defn = float_type_def
case kind:
return assert_never(kind)
case FunctionType():
type_defn = callable_type_def
case OpaqueType() as ty:
Expand Down
41 changes: 7 additions & 34 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,13 @@
CustomFunctionDef,
DefaultCallChecker,
)
from guppylang.definition.ty import TypeDef
from guppylang.definition.value import CallableDef
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.builtin import bool_type, list_type
from guppylang.tys.subst import Subst
from guppylang.tys.ty import FunctionType, OpaqueType, Type, unify

INT_WIDTH = 6 # 2^6 = 64 bit


hugr_int_type = tys.Type(
tys.Opaque(
extension="arithmetic.int.types",
id="int",
args=[tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))],
bound=tys.TypeBound.Eq,
)
)


hugr_float_type = tys.Type(
tys.Opaque(
extension="arithmetic.float.types",
id="float64",
args=[],
bound=tys.TypeBound.Copyable,
)
)
from guppylang.tys.ty import FunctionType, NumericType, Type, unify


class ConstInt(BaseModel):
Expand Down Expand Up @@ -77,9 +54,9 @@ def int_value(i: int) -> ops.Value:
return ops.Value(
ops.ExtensionValue(
extensions=["arithmetic.int.types"],
typ=hugr_int_type,
typ=NumericType(NumericType.Kind.Nat).to_hugr(),
value=ops.CustomConst(
c="ConstInt", v=ConstInt(log_width=INT_WIDTH, value=i)
c="ConstInt", v=ConstInt(log_width=NumericType.INT_WIDTH, value=i)
),
)
)
Expand All @@ -90,7 +67,7 @@ def float_value(f: float) -> ops.Value:
return ops.Value(
ops.ExtensionValue(
extensions=["arithmetic.float.types"],
typ=hugr_float_type,
typ=NumericType(NumericType.Kind.Float).to_hugr(),
value=ops.CustomConst(c="ConstF64", v=ConstF64(value=f)),
)
)
Expand Down Expand Up @@ -124,7 +101,7 @@ def int_op(
ops.CustomOp(
extension=ext,
op_name=op_name,
args=num_params * [tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))],
args=num_params * [tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))],
parent=UNDEFINED,
)
)
Expand All @@ -145,16 +122,12 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:

for i in range(len(args)):
args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i])
if isinstance(ty, OpaqueType) and ty.defn == self.ctx.globals["int"]:
if isinstance(ty, NumericType) and ty.kind == NumericType.Kind.Int:
call = with_loc(
self.node,
GlobalCall(def_id=Int.__float__.id, args=[args[i]], type_args=[]),
)
float_defn = self.ctx.globals["float"]
assert isinstance(float_defn, TypeDef)
args[i] = with_type(
float_defn.check_instantiate([], self.ctx.globals), call
)
args[i] = with_type(NumericType(NumericType.Kind.Float), call)
return super().synthesize(args)


Expand Down
18 changes: 13 additions & 5 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,27 @@
ReversingChecker,
UnsupportedChecker,
float_op,
hugr_float_type,
hugr_int_type,
int_op,
logic_op,
)
from guppylang.tys.builtin import bool_type_def, linst_type_def, list_type_def
from guppylang.tys.builtin import (
bool_type_def,
float_type_def,
int_type_def,
linst_type_def,
list_type_def, nat_type_def,
)

builtins = GuppyModule("builtins", import_builtins=False)

T = guppy.type_var(builtins, "T")
L = guppy.type_var(builtins, "L", linear=True)


# Define the nat type so scripts can import it
nat = nat_type_def


@guppy.extend_type(builtins, bool_type_def)
class Bool:
@guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))]))
Expand All @@ -52,7 +60,7 @@ def __new__(x): ...
def __or__(self: bool, other: bool) -> bool: ...


@guppy.type(builtins, hugr_int_type, name="int")
@guppy.extend_type(builtins, int_type_def)
class Int:
@guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!)
def __abs__(self: int) -> int: ...
Expand Down Expand Up @@ -191,7 +199,7 @@ def __trunc__(self: int) -> int: ...
def __xor__(self: int, other: int) -> int: ...


@guppy.type(builtins, hugr_float_type, name="float", bound=tys.TypeBound.Copyable)
@guppy.extend_type(builtins, float_type_def)
class Float:
@guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker())
def __abs__(self: float) -> float: ...
Expand Down
50 changes: 39 additions & 11 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument, TypeArg
from guppylang.tys.param import TypeParam
from guppylang.tys.ty import FunctionType, NoneType, OpaqueType, TupleType, Type
from guppylang.tys.ty import (
FunctionType,
NoneType,
NumericType,
OpaqueType,
TupleType,
Type,
)

if TYPE_CHECKING:
from guppylang.checker.core import Globals
Expand Down Expand Up @@ -79,6 +86,23 @@ def check_instantiate(
return NoneType()


@dataclass(frozen=True)
class _NumericTypeDef(TypeDef):
"""Type definition associated with the builtin `None` type.
Any impls on None can be registered with this definition.
"""

ty: NumericType

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> NumericType:
if args:
raise GuppyError(f"Type `{self.name}` is not parameterized", loc)
return self.ty


@dataclass(frozen=True)
class _ListTypeDef(OpaqueTypeDef):
"""Type definition associated with the builtin `list` type.
Expand Down Expand Up @@ -115,13 +139,17 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
callable_type_def = _CallableTypeDef(DefId.fresh(), None)
tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
none_type_def = _NoneTypeDef(DefId.fresh(), None)
bool_type_def = OpaqueTypeDef(
id=DefId.fresh(),
name="bool",
defined_at=None,
params=[],
always_linear=False,
to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))),
bool_type_def = _NumericTypeDef(
DefId.fresh(), "bool", None, NumericType(NumericType.Kind.Bool)
)
nat_type_def = _NumericTypeDef(
DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat)
)
int_type_def = _NumericTypeDef(
DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
)
float_type_def = _NumericTypeDef(
DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
)
linst_type_def = OpaqueTypeDef(
id=DefId.fresh(),
Expand All @@ -141,8 +169,8 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
)


def bool_type() -> OpaqueType:
return OpaqueType([], bool_type_def)
def bool_type() -> NumericType:
return NumericType(NumericType.Kind.Bool)


def list_type(element_ty: Type) -> OpaqueType:
Expand All @@ -154,7 +182,7 @@ def linst_type(element_ty: Type) -> OpaqueType:


def is_bool_type(ty: Type) -> bool:
return isinstance(ty, OpaqueType) and ty.defn == bool_type_def
return isinstance(ty, NumericType) and ty.kind == NumericType.Kind.Bool


def is_list_type(ty: Type) -> bool:
Expand Down
5 changes: 5 additions & 0 deletions guppylang/tys/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from guppylang.tys.ty import (
FunctionType,
NoneType,
NumericType,
OpaqueType,
StructType,
SumType,
Expand Down Expand Up @@ -106,6 +107,10 @@ def _visit_SumType(self, ty: SumType, inside_row: bool) -> str:
def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str:
return "None"

@_visit.register
def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str:
return ty.kind.value

@_visit.register
def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str:
# TODO: Print linearity?
Expand Down
72 changes: 70 additions & 2 deletions guppylang/tys/ty.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, TypeAlias, cast
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast

from hugr.serialization import tys
from hugr.serialization.tys import TypeBound
Expand Down Expand Up @@ -234,6 +235,69 @@ def transform(self, transformer: Transformer) -> "Type":
return transformer.transform(self) or self


@dataclass(frozen=True)
class NumericType(TypeBase):
"""Numeric types like `int` and `float`."""

kind: "Kind"

class Kind(Enum):
"""The different kinds of numeric types."""

Bool = "bool"
Nat = "nat"
Int = "int"
Float = "float"

INT_WIDTH: ClassVar[int] = 6

@property
def linear(self) -> bool:
"""Whether this type should be treated linearly."""
return False

def to_hugr(self) -> tys.Type:
"""Computes the Hugr representation of the type."""
match self.kind:
case NumericType.Kind.Bool:
return SumType([NoneType(), NoneType()]).to_hugr()
case NumericType.Kind.Nat | NumericType.Kind.Int:
return tys.Type(
tys.Opaque(
extension="arithmetic.int.types",
id="int",
args=[tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))],
bound=tys.TypeBound.Eq,
)
)
case NumericType.Kind.Float:
return tys.Type(
tys.Opaque(
extension="arithmetic.float.types",
id="float64",
args=[],
bound=tys.TypeBound.Copyable,
)
)

@property
def hugr_bound(self) -> tys.TypeBound:
"""The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
match self.kind:
case NumericType.Kind.Float:
return tys.TypeBound.Copyable
case _:
return tys.TypeBound.Eq

def visit(self, visitor: Visitor) -> None:
"""Accepts a visitor on this type."""
visitor.visit(self)

def transform(self, transformer: Transformer) -> "Type":
"""Accepts a transformer on this type."""
return transformer.transform(self) or self


@dataclass(frozen=True, init=False)
class FunctionType(ParametrizedTypeBase):
"""Type of (potentially generic) functions."""
Expand Down Expand Up @@ -493,7 +557,9 @@ def transform(self, transformer: Transformer) -> "Type":
#: This might become obsolete in case the @sealed decorator is added:
#: * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types
#: * https://github.com/johnthagen/sealed-typing-pep
Type: TypeAlias = BoundTypeVar | ExistentialTypeVar | NoneType | ParametrizedType
Type: TypeAlias = (
BoundTypeVar | ExistentialTypeVar | NumericType | NoneType | ParametrizedType
)

#: An immutable row of Guppy types.
TypeRow: TypeAlias = Sequence[Type]
Expand Down Expand Up @@ -545,6 +611,8 @@ def unify(s: Type, t: Type, subst: "Subst | None") -> "Subst | None":
return _unify_var(t, s, subst)
case BoundTypeVar(idx=s_idx), BoundTypeVar(idx=t_idx) if s_idx == t_idx:
return subst
case NumericType(kind=s_kind), NumericType(kind=t_kind) if s_kind == t_kind:
return subst
case NoneType(), NoneType():
return subst
case FunctionType() as s, FunctionType() as t if s.params == t.params:
Expand Down
Loading

0 comments on commit 8c04951

Please sign in to comment.