Skip to content

Commit

Permalink
[mypyc] Implement additional ircheck checks (#12191)
Browse files Browse the repository at this point in the history
* Implement more checks in ircheck

Check op/register validity and check type coercions for return
and assign.

* Add checks for call operations, control ops, and literal loads

* Add check for duplicate ops

In particular, this will catch the case where builder.add() is
called twice, which causes very weird bogus IR.
  • Loading branch information
jhance authored Feb 18, 2022
1 parent d02db50 commit c7365ef
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 47 deletions.
226 changes: 207 additions & 19 deletions mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""Utilities for checking that internal ir is valid and consistent."""
from typing import List, Union
from typing import List, Union, Set, Tuple
from mypyc.ir.pprint import format_func
from mypyc.ir.ops import (
OpVisitor, BasicBlock, Op, ControlOp, Goto, Branch, Return, Unreachable,
Assign, AssignMulti, LoadErrorValue, LoadLiteral, GetAttr, SetAttr, LoadStatic,
InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast,
Box, Unbox, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp,
LoadMem, SetMem, GetElementPtr, LoadAddress, KeepAlive
LoadMem, SetMem, GetElementPtr, LoadAddress, KeepAlive, Register, Integer,
BaseAssign
)
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.rtypes import (
RType, RPrimitive, RUnion, is_object_rprimitive, RInstance, RArray,
int_rprimitive, list_rprimitive, dict_rprimitive, set_rprimitive,
range_rprimitive, str_rprimitive, bytes_rprimitive, tuple_rprimitive
)
from mypyc.ir.func_ir import FuncIR, FUNC_STATICMETHOD


class FnError(object):
Expand All @@ -17,8 +23,11 @@ def __init__(self, source: Union[Op, BasicBlock], desc: str) -> None:
self.desc = desc

def __eq__(self, other: object) -> bool:
return isinstance(other, FnError) and self.source == other.source and \
self.desc == other.desc
return (
isinstance(other, FnError)
and self.source == other.source
and self.desc == other.desc
)

def __repr__(self) -> str:
return f"FnError(source={self.source}, desc={self.desc})"
Expand All @@ -28,19 +37,44 @@ def check_func_ir(fn: FuncIR) -> List[FnError]:
"""Applies validations to a given function ir and returns a list of errors found."""
errors = []

op_set = set()

for block in fn.blocks:
if not block.terminated:
errors.append(FnError(
source=block.ops[-1] if block.ops else block,
desc="Block not terminated",
))
errors.append(
FnError(
source=block.ops[-1] if block.ops else block,
desc="Block not terminated",
)
)
for op in block.ops[:-1]:
if isinstance(op, ControlOp):
errors.append(
FnError(
source=op,
desc="Block has operations after control op",
)
)

if op in op_set:
errors.append(
FnError(
source=op,
desc="Func has a duplicate op",
)
)
op_set.add(op)

errors.extend(check_op_sources_valid(fn))
if errors:
return errors

op_checker = OpChecker(fn)
for block in fn.blocks:
for op in block.ops:
op.accept(op_checker)

return errors + op_checker.errors
return op_checker.errors


class IrCheckException(Exception):
Expand All @@ -50,11 +84,90 @@ class IrCheckException(Exception):
def assert_func_ir_valid(fn: FuncIR) -> None:
errors = check_func_ir(fn)
if errors:
raise IrCheckException("Internal error: Generated invalid IR: \n" + "\n".join(
format_func(fn, [(e.source, e.desc) for e in errors])),
raise IrCheckException(
"Internal error: Generated invalid IR: \n"
+ "\n".join(format_func(fn, [(e.source, e.desc) for e in errors])),
)


def check_op_sources_valid(fn: FuncIR) -> List[FnError]:
errors = []
valid_ops: Set[Op] = set()
valid_registers: Set[Register] = set()

for block in fn.blocks:
valid_ops.update(block.ops)

valid_registers.update(
[op.dest for op in block.ops if isinstance(op, BaseAssign)]
)

valid_registers.update(fn.arg_regs)

for block in fn.blocks:
for op in block.ops:
for source in op.sources():
if isinstance(source, Integer):
pass
elif isinstance(source, Op):
if source not in valid_ops:
errors.append(
FnError(
source=op,
desc=f"Invalid op reference to op of type {type(source).__name__}",
)
)
elif isinstance(source, Register):
if source not in valid_registers:
errors.append(
FnError(
source=op,
desc=f"Invalid op reference to register {source.name}",
)
)

return errors


disjoint_types = set(
[
int_rprimitive.name,
bytes_rprimitive.name,
str_rprimitive.name,
dict_rprimitive.name,
list_rprimitive.name,
set_rprimitive.name,
tuple_rprimitive.name,
range_rprimitive.name,
]
)


def can_coerce_to(src: RType, dest: RType) -> bool:
"""Check if src can be assigned to dest_rtype.
Currently okay to have false positives.
"""
if isinstance(dest, RUnion):
return any(can_coerce_to(src, d) for d in dest.items)

if isinstance(dest, RPrimitive):
if isinstance(src, RPrimitive):
# If either src or dest is a disjoint type, then they must both be.
if src.name in disjoint_types and dest.name in disjoint_types:
return src.name == dest.name
return src.size == dest.size
if isinstance(src, RInstance):
return is_object_rprimitive(dest)
if isinstance(src, RUnion):
# IR doesn't have the ability to narrow unions based on
# control flow, so cannot be a strict all() here.
return any(can_coerce_to(s, dest) for s in src.items)
return False

return True


class OpChecker(OpVisitor[None]):
def __init__(self, parent_fn: FuncIR) -> None:
self.parent_fn = parent_fn
Expand All @@ -66,7 +179,16 @@ def fail(self, source: Op, desc: str) -> None:
def check_control_op_targets(self, op: ControlOp) -> None:
for target in op.targets():
if target not in self.parent_fn.blocks:
self.fail(source=op, desc=f"Invalid control operation target: {target.label}")
self.fail(
source=op, desc=f"Invalid control operation target: {target.label}"
)

def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
if not can_coerce_to(src, dest):
self.fail(
source=op,
desc=f"Cannot coerce source type {src.name} to dest type {dest.name}",
)

def visit_goto(self, op: Goto) -> None:
self.check_control_op_targets(op)
Expand All @@ -75,52 +197,118 @@ def visit_branch(self, op: Branch) -> None:
self.check_control_op_targets(op)

def visit_return(self, op: Return) -> None:
pass
self.check_type_coercion(op, op.value.type, self.parent_fn.decl.sig.ret_type)

def visit_unreachable(self, op: Unreachable) -> None:
# Unreachables are checked at a higher level since validation
# requires access to the entire basic block.
pass

def visit_assign(self, op: Assign) -> None:
pass
self.check_type_coercion(op, op.src.type, op.dest.type)

def visit_assign_multi(self, op: AssignMulti) -> None:
pass
for src in op.src:
assert isinstance(op.dest.type, RArray)
self.check_type_coercion(op, src.type, op.dest.type.item_type)

def visit_load_error_value(self, op: LoadErrorValue) -> None:
# Currently it is assumed that all types have an error value.
# Once this is fixed we can validate that the rtype here actually
# has an error value.
pass

def check_tuple_items_valid_literals(
self, op: LoadLiteral, t: Tuple[object, ...]
) -> None:
for x in t:
if x is not None and not isinstance(
x, (str, bytes, bool, int, float, complex, tuple)
):
self.fail(op, f"Invalid type for item of tuple literal: {type(x)})")
if isinstance(x, tuple):
self.check_tuple_items_valid_literals(op, x)

def visit_load_literal(self, op: LoadLiteral) -> None:
pass
expected_type = None
if op.value is None:
expected_type = "builtins.object"
elif isinstance(op.value, int):
expected_type = "builtins.int"
elif isinstance(op.value, str):
expected_type = "builtins.str"
elif isinstance(op.value, bytes):
expected_type = "builtins.bytes"
elif isinstance(op.value, bool):
expected_type = "builtins.object"
elif isinstance(op.value, float):
expected_type = "builtins.float"
elif isinstance(op.value, complex):
expected_type = "builtins.object"
elif isinstance(op.value, tuple):
expected_type = "builtins.tuple"
self.check_tuple_items_valid_literals(op, op.value)

assert expected_type is not None, "Missed a case for LoadLiteral check"

if op.type.name not in [expected_type, "builtins.object"]:
self.fail(
op,
f"Invalid literal value for type: value has "
f"type {expected_type}, but op has type {op.type.name}",
)

def visit_get_attr(self, op: GetAttr) -> None:
# Nothing to do.
pass

def visit_set_attr(self, op: SetAttr) -> None:
# Nothing to do.
pass

# Static operations cannot be checked at the function level.
def visit_load_static(self, op: LoadStatic) -> None:
pass

def visit_init_static(self, op: InitStatic) -> None:
pass

def visit_tuple_get(self, op: TupleGet) -> None:
# Nothing to do.
pass

def visit_tuple_set(self, op: TupleSet) -> None:
# Nothing to do.
pass

def visit_inc_ref(self, op: IncRef) -> None:
# Nothing to do.
pass

def visit_dec_ref(self, op: DecRef) -> None:
# Nothing to do.
pass

def visit_call(self, op: Call) -> None:
pass
# Length is checked in constructor, and return type is set
# in a way that can't be incorrect
for arg_value, arg_runtime in zip(op.args, op.fn.sig.args):
self.check_type_coercion(op, arg_value.type, arg_runtime.type)

def visit_method_call(self, op: MethodCall) -> None:
pass
# Similar to above, but we must look up method first.
method_decl = op.receiver_type.class_ir.method_decl(op.method)
if method_decl.kind == FUNC_STATICMETHOD:
decl_index = 0
else:
decl_index = 1

if len(op.args) + decl_index != len(method_decl.sig.args):
self.fail(op, "Incorrect number of args for method call.")

# Skip the receiver argument (self)
for arg_value, arg_runtime in zip(op.args, method_decl.sig.args[decl_index:]):
self.check_type_coercion(op, arg_value.type, arg_runtime.type)

def visit_cast(self, op: Cast) -> None:
pass
Expand Down
18 changes: 12 additions & 6 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,21 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:
pass


class Assign(Op):
class BaseAssign(Op):
"""Base class for ops that assign to a register."""
def __init__(self, dest: Register, line: int = -1) -> None:
super().__init__(line)
self.dest = dest


class Assign(BaseAssign):
"""Assign a value to a Register (dest = src)."""

error_kind = ERR_NEVER

def __init__(self, dest: Register, src: Value, line: int = -1) -> None:
super().__init__(line)
super().__init__(dest, line)
self.src = src
self.dest = dest

def sources(self) -> List[Value]:
return [self.src]
Expand All @@ -234,7 +240,7 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:
return visitor.visit_assign(self)


class AssignMulti(Op):
class AssignMulti(BaseAssign):
"""Assign multiple values to a Register (dest = src1, src2, ...).
This is used to initialize RArray values. It's provided to avoid
Expand All @@ -248,12 +254,11 @@ class AssignMulti(Op):
error_kind = ERR_NEVER

def __init__(self, dest: Register, src: List[Value], line: int = -1) -> None:
super().__init__(line)
super().__init__(dest, line)
assert src
assert isinstance(dest.type, RArray)
assert dest.type.length == len(src)
self.src = src
self.dest = dest

def sources(self) -> List[Value]:
return self.src[:]
Expand Down Expand Up @@ -490,6 +495,7 @@ def __init__(self, fn: 'FuncDecl', args: Sequence[Value], line: int) -> None:
super().__init__(line)
self.fn = fn
self.args = list(args)
assert len(self.args) == len(fn.sig.args)
self.type = fn.sig.ret_type

def sources(self) -> List[Value]:
Expand Down
Loading

0 comments on commit c7365ef

Please sign in to comment.