diff --git a/crates/erg_compiler/codegen.rs b/crates/erg_compiler/codegen.rs index 2a0bc3dfc..24384eef1 100644 --- a/crates/erg_compiler/codegen.rs +++ b/crates/erg_compiler/codegen.rs @@ -184,6 +184,7 @@ pub struct PyCodeGenerator { control_loaded: bool, convertors_loaded: bool, operators_loaded: bool, + union_loaded: bool, abc_loaded: bool, unit_size: usize, units: PyCodeGenStack, @@ -204,6 +205,7 @@ impl PyCodeGenerator { control_loaded: false, convertors_loaded: false, operators_loaded: false, + union_loaded: false, abc_loaded: false, unit_size: 0, units: PyCodeGenStack::empty(), @@ -224,6 +226,7 @@ impl PyCodeGenerator { control_loaded: false, convertors_loaded: false, operators_loaded: false, + union_loaded: false, abc_loaded: false, unit_size: 0, units: PyCodeGenStack::empty(), @@ -244,6 +247,7 @@ impl PyCodeGenerator { self.control_loaded = false; self.convertors_loaded = false; self.operators_loaded = false; + self.union_loaded = false; self.abc_loaded = false; } @@ -1566,6 +1570,16 @@ impl PyCodeGenerator { self.emit_push_null(); self.emit_load_name_instr(Identifier::public("OpenRange")); } + // From 3.10, `or` can be used for types. + // But Erg supports Python 3.7~, so we should use `typing.Union`. + TokenKind::OrOp if bin.lhs.ref_t().is_type() => { + self.load_union(); + // self.emit_push_null(); + self.emit_load_name_instr(Identifier::private("#Union")); + let args = Args::pos_only(vec![PosArg::new(*bin.lhs), PosArg::new(*bin.rhs)], None); + self.emit_index_args(args); + return; + } TokenKind::ContainsOp => { // if no-std, always `x contains y == True` if self.cfg.no_std { @@ -3517,6 +3531,16 @@ impl PyCodeGenerator { ); } + fn load_union(&mut self) { + self.emit_global_import_items( + Identifier::public("typing"), + vec![( + Identifier::public("Union"), + Some(Identifier::private("#Union")), + )], + ); + } + fn load_module_type(&mut self) { self.emit_global_import_items( Identifier::public("types"), diff --git a/crates/erg_compiler/lib/std/_erg_array.py b/crates/erg_compiler/lib/std/_erg_array.py index ee127e8c9..0284b8e1d 100644 --- a/crates/erg_compiler/lib/std/_erg_array.py +++ b/crates/erg_compiler/lib/std/_erg_array.py @@ -53,7 +53,7 @@ def __getitem__(self, index_or_slice): def type_check(self, t: type) -> bool: if isinstance(t, list): - if len(t) != len(self): + if len(t) < len(self): return False for (inner_t, elem) in zip(t, self): if not contains_operator(inner_t, elem): diff --git a/crates/erg_compiler/lib/std/_erg_contains_operator.py b/crates/erg_compiler/lib/std/_erg_contains_operator.py index 3b11c4aaf..b247d7be7 100644 --- a/crates/erg_compiler/lib/std/_erg_contains_operator.py +++ b/crates/erg_compiler/lib/std/_erg_contains_operator.py @@ -1,5 +1,6 @@ from _erg_result import is_ok from _erg_range import Range +from _erg_type import is_type, isinstance from collections import namedtuple @@ -8,7 +9,7 @@ def contains_operator(y, elem) -> bool: if hasattr(elem, "type_check"): return elem.type_check(y) # 1 in Int - elif type(y) == type: + elif is_type(y): if isinstance(elem, y): return True elif hasattr(y, "try_new") and is_ok(y.try_new(elem)): @@ -17,26 +18,33 @@ def contains_operator(y, elem) -> bool: return False # [1] in [Int] elif isinstance(y, list) and isinstance(elem, list) and ( - type(y[0]) == type or isinstance(y[0], Range) + len(y) == 0 or is_type(y[0]) or isinstance(y[0], Range) ): - # FIXME: - type_check = contains_operator(y[0], elem[0]) - len_check = len(elem) == len(y) + type_check = all(map(lambda x: contains_operator(x[0], x[1]), zip(y, elem))) + len_check = len(elem) <= len(y) return type_check and len_check # (1, 2) in (Int, Int) elif isinstance(y, tuple) and isinstance(elem, tuple) and ( - type(y[0]) == type or isinstance(y[0], Range) + len(y) == 0 or is_type(y[0]) or isinstance(y[0], Range) ): if not hasattr(elem, "__iter__"): return False type_check = all(map(lambda x: contains_operator(x[0], x[1]), zip(y, elem))) - len_check = len(elem) == len(y) + len_check = len(elem) <= len(y) return type_check and len_check # {1: 2} in {Int: Int} - elif isinstance(y, dict) and isinstance(elem, dict) and isinstance(next(iter(y.keys())), type): + elif isinstance(y, dict) and isinstance(elem, dict) and ( + len(y) == 0 or is_type(next(iter(y.keys()))) + ): + if len(y) == 1: + key = next(iter(y.keys())) + key_check = all([contains_operator(key, el) for el in elem.keys()]) + value = next(iter(y.values())) + value_check = all([contains_operator(value, el) for el in elem.values()]) + return key_check and value_check # TODO: type_check = True # contains_operator(next(iter(y.keys())), x[next(iter(x.keys()))]) - len_check = len(elem) >= len(y) + len_check = True # It can be True even if either elem or y has the larger number of elems return type_check and len_check elif isinstance(elem, list): from _erg_array import Array diff --git a/crates/erg_compiler/lib/std/_erg_type.py b/crates/erg_compiler/lib/std/_erg_type.py new file mode 100644 index 000000000..26af8622c --- /dev/null +++ b/crates/erg_compiler/lib/std/_erg_type.py @@ -0,0 +1,21 @@ +from typing import _GenericAlias, Union +try: + from types import UnionType +except ImportError: + class UnionType: + __args__: list # list[type] + def __init__(self, *args): + self.__args__ = args + +def is_type(x) -> bool: + return isinstance(x, type) or \ + isinstance(x, _GenericAlias) or \ + isinstance(x, UnionType) + +instanceof = isinstance +# The behavior of `builtins.isinstance` depends on the Python version. +def isinstance(obj, classinfo) -> bool: + if instanceof(classinfo, _GenericAlias) and classinfo.__origin__ == Union: + return any(instanceof(obj, t) for t in classinfo.__args__) + else: + return instanceof(obj, classinfo) diff --git a/crates/erg_compiler/transpile.rs b/crates/erg_compiler/transpile.rs index dcd1c1aa7..714f77b86 100644 --- a/crates/erg_compiler/transpile.rs +++ b/crates/erg_compiler/transpile.rs @@ -371,6 +371,7 @@ impl PyScriptGenerator { .replace("from _erg_result import is_ok", "") .replace("from _erg_control import then__", "") .replace("from _erg_contains_operator import contains_operator", "") + .replace("from _erg_type import is_type, isinstance", "") } fn load_namedtuple_if_not(&mut self) { diff --git a/tests/should_ok/assert_cast.er b/tests/should_ok/assert_cast.er index ee8a3f9f4..589f023e8 100644 --- a/tests/should_ok/assert_cast.er +++ b/tests/should_ok/assert_cast.er @@ -12,6 +12,10 @@ assert j["a"] in Array(Int) assert j["a"] notin Array(Str) _: Array(Int) = j["a"] +dic = {"a": "b", "c": "d"} +assert dic in {Str: {"b", "d"}} +assert dic in {Str: Str} + .f dic: {Str: Str or Array(Str)} = assert dic["key"] in Str # Required to pass the check on the next line assert dic["key"] in {"a", "b", "c"} diff --git a/tests/should_ok/dyn_type_check.er b/tests/should_ok/dyn_type_check.er new file mode 100644 index 000000000..3a15832aa --- /dev/null +++ b/tests/should_ok/dyn_type_check.er @@ -0,0 +1,21 @@ +assert 1 in (Int or Str) +assert 1.2 notin (Int or Str) + +dic = {:} +assert dic in {:} +assert dic in {Str: Int} +assert dic in {Str: Str} +dic2 = {"a": 1} +assert dic2 in {Str or Int: Int} +assert dic2 in {Str: Int or Str} +assert dic2 notin {Int: Int} + +tup = () +assert tup in () +assert tup in (Int, Int) +assert tup in (Int, Str) + +arr = [] +assert arr in [] +assert arr in [Int] +assert arr in [Str] diff --git a/tests/test.rs b/tests/test.rs index d46142dfb..8b86acf14 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -100,6 +100,11 @@ fn exec_dict_test() -> Result<(), ()> { expect_success("tests/should_ok/dict.er", 0) } +#[test] +fn exec_empty_check() -> Result<(), ()> { + expect_success("tests/should_ok/dyn_type_check.er", 0) +} + #[test] fn exec_external() -> Result<(), ()> { let py_command = opt_which_python().unwrap();