Skip to content

Commit

Permalink
Undo turning bool into a numeric type
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jun 24, 2024
1 parent 13154dd commit cf8a529
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 194 deletions.
2 changes: 0 additions & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
return None
case NumericType(kind):
match kind:
case NumericType.Kind.Bool:
type_defn = bool_type_def
case NumericType.Kind.Int:
type_defn = int_type_def
case NumericType.Kind.Float:
Expand Down
41 changes: 2 additions & 39 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@

from guppylang.ast_util import AstNode, get_type, with_loc
from guppylang.checker.core import Context
from guppylang.checker.expr_checker import (
ExprSynthesizer,
check_call,
check_num_args,
synthesize_call,
)
from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args
from guppylang.definition.custom import (
CustomCallChecker,
CustomCallCompiler,
Expand All @@ -21,7 +16,7 @@
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.builtin import bool_type, int_type, list_type
from guppylang.tys.builtin import bool_type, list_type
from guppylang.tys.subst import Subst
from guppylang.tys.ty import FunctionType, NumericType, Type, unify

Expand Down Expand Up @@ -246,38 +241,6 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
return args, subst


class BoolArithChecker(DefaultCallChecker):
"""Function call checker for arithmetic operations on bools.
Converts all bools into ints and calls the corresponding int arithmetic method with
the same name.
"""

def _prepare_args(self, args: list[ast.expr]) -> list[ast.expr]:
# Cast all inputs to int
to_int = self.ctx.globals.get_instance_func(bool_type(), "__int__")
assert to_int is not None
return [to_int.synthesize_call([arg], arg, self.ctx)[0] for arg in args]

def _get_func(self) -> CallableDef:
# Get the int function with the same name
func = self.ctx.globals.get_instance_func(int_type(), self.func.name)
assert func is not None
return func

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
args, _, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
assert not inst # `self.func.ty` is not generic
args = self._prepare_args(args)
return self._get_func().synthesize_call(args, self.node, self.ctx)

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
args, _, inst = check_call(self.func.ty, args, ty, self.node, self.ctx)
assert not inst # `self.func.ty` is not generic
args = self._prepare_args(args)
return self._get_func().check_call(args, ty, self.node, self.ctx)


class IntTruedivCompiler(CustomCallCompiler):
"""Compiler for the `int.__truediv__` method."""

Expand Down
128 changes: 1 addition & 127 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from guppylang.hugr_builder.hugr import DummyOp
from guppylang.module import GuppyModule
from guppylang.prelude._internal import (
BoolArithChecker,
CallableChecker,
CoercingChecker,
DunderChecker,
Expand Down Expand Up @@ -53,146 +52,21 @@ def py(*_args: Any) -> Any:

@guppy.extend_type(builtins, bool_type_def)
class Bool:
@guppy.custom(builtins, NoopCompiler())
def __abs__(self: int) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __add__(self: bool, other: bool) -> int: ...

@guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))]))
def __and__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, NoopCompiler())
def __bool__(self: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __ceil__(self: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __divmod__(self: bool, other: bool) -> tuple[int, int]: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __eq__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __float__(self: bool) -> float: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __floor__(self: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __floordiv__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __ge__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __gt__(self: bool, other: bool) -> bool: ...

@guppy.hugr_op(builtins, DummyOp("ifrombool")) # TODO: Widen to INT_WIDTH
@guppy.hugr_op(builtins, int_op("ifrombool"))
def __int__(self: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __invert__(self: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __le__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __lshift__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __lt__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __mod__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __mul__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __ne__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __neg__(self: bool) -> int: ...

@guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False)
def __new__(x): ...

@guppy.hugr_op(builtins, logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))]))
def __or__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __pos__(self: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __pow__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __radd__(self: bool, other: bool) -> int: ...

@guppy.hugr_op(
builtins,
logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))]),
ReversingChecker(),
)
def __rand__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rdivmod__(self: bool, other: bool) -> tuple[int, int]: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rfloordiv__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rlshift__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rmod__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rmul__(self: bool, other: bool) -> int: ...

@guppy.hugr_op(
builtins,
logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))]),
ReversingChecker(),
)
def __ror__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __round__(self: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rpow__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rrshift__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rshift__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rsub__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __rtruediv__(self: bool, other: bool) -> float: ...

@guppy.hugr_op(builtins, DummyOp("Xor"), ReversingChecker()) # TODO
def __rxor__(self: bool, other: bool) -> bool: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __sub__(self: bool, other: bool) -> int: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __truediv__(self: bool, other: bool) -> float: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __trunc__(self: bool) -> int: ...

@guppy.hugr_op(builtins, DummyOp("Xor")) # TODO
def __xor__(self: bool, other: bool) -> bool: ...


@guppy.extend_type(builtins, int_type_def)
class Int:
Expand Down
15 changes: 10 additions & 5 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,13 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
callable_type_def = _CallableTypeDef(DefId.fresh(), None)
tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
none_type_def = _NoneTypeDef(DefId.fresh(), None)
bool_type_def = _NumericTypeDef(
DefId.fresh(), "bool", None, NumericType(NumericType.Kind.Bool)
bool_type_def = OpaqueTypeDef(
id=DefId.fresh(),
name="bool",
defined_at=None,
params=[],
always_linear=False,
to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))),
)
int_type_def = _NumericTypeDef(
DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
Expand All @@ -166,8 +171,8 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
)


def bool_type() -> NumericType:
return NumericType(NumericType.Kind.Bool)
def bool_type() -> OpaqueType:
return OpaqueType([], bool_type_def)


def int_type() -> NumericType:
Expand All @@ -183,7 +188,7 @@ def linst_type(element_ty: Type) -> OpaqueType:


def is_bool_type(ty: Type) -> bool:
return isinstance(ty, NumericType) and ty.kind == NumericType.Kind.Bool
return isinstance(ty, OpaqueType) and ty.defn == bool_type_def


def is_list_type(ty: Type) -> bool:
Expand Down
3 changes: 0 additions & 3 deletions guppylang/tys/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ class NumericType(TypeBase):
class Kind(Enum):
"""The different kinds of numeric types."""

Bool = "bool"
Int = "int"
Float = "float"

Expand All @@ -258,8 +257,6 @@ def linear(self) -> bool:
def to_hugr(self) -> tys.Type:
"""Computes the Hugr representation of the type."""
match self.kind:
case NumericType.Kind.Bool:
return SumType([NoneType(), NoneType()]).to_hugr()
case NumericType.Kind.Int:
return tys.Type(
tys.Opaque(
Expand Down
6 changes: 3 additions & 3 deletions tests/error/type_errors/invert_not_int.err
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6

4: @compile_guppy
5: def foo() -> int:
6: return ~()
^^
GuppyTypeError: Unary operator `~` not defined for argument of type `()`
6: return ~True
^^^^
GuppyTypeError: Unary operator `~` not defined for argument of type `bool`
2 changes: 1 addition & 1 deletion tests/error/type_errors/invert_not_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@compile_guppy
def foo() -> int:
return ~()
return ~True
6 changes: 3 additions & 3 deletions tests/error/type_errors/unary_not_arith.err
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6

4: @compile_guppy
5: def foo() -> int:
6: return -()
^^
GuppyTypeError: Unary operator `-` not defined for argument of type `()`
6: return -True
^^^^
GuppyTypeError: Unary operator `-` not defined for argument of type `bool`
2 changes: 1 addition & 1 deletion tests/error/type_errors/unary_not_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@compile_guppy
def foo() -> int:
return -()
return -True
12 changes: 2 additions & 10 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,10 @@ def add(x: int) -> int:
validate(add)


def test_bool(validate):
@compile_guppy
def add(x: bool, y: bool) -> int:
return x + y

validate(add)


def test_float_coercion(validate):
@compile_guppy
def coerce(x: int, y: float, z: bool) -> float:
return x * y + z
def coerce(x: int, y: float) -> float:
return x * y

validate(coerce)

Expand Down

0 comments on commit cf8a529

Please sign in to comment.