diff --git a/pyproject.toml b/pyproject.toml index 026bb864a6..1dc9ecb698 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dev = [ "nbconvert>=7.7.2,<8.0.0", "textual-dev==1.7.0", "pytest-asyncio==0.25.3", - "pyright==1.1.393", + "pyright==1.1.394", "sympy==1.13.3", ] docs = [ diff --git a/uv.lock b/uv.lock index 458da14f59..d1418e73e3 100644 --- a/uv.lock +++ b/uv.lock @@ -1874,15 +1874,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.393" +version = "1.1.394" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f4/c1/aede6c74e664ab103673e4f1b7fd3d058fef32276be5c43572f4067d4a8e/pyright-1.1.393.tar.gz", hash = "sha256:aeeb7ff4e0364775ef416a80111613f91a05c8e01e58ecfefc370ca0db7aed9c", size = 3790430 } +sdist = { url = "https://files.pythonhosted.org/packages/b1/e4/79f4d8a342eed6790fdebdb500e95062f319ee3d7d75ae27304ff995ae8c/pyright-1.1.394.tar.gz", hash = "sha256:56f2a3ab88c5214a451eb71d8f2792b7700434f841ea219119ade7f42ca93608", size = 3809348 } wheels = [ - { url = "https://files.pythonhosted.org/packages/92/47/f0dd0f8afce13d92e406421ecac6df0990daee84335fc36717678577d3e0/pyright-1.1.393-py3-none-any.whl", hash = "sha256:8320629bb7a44ca90944ba599390162bf59307f3d9fb6e27da3b7011b8c17ae5", size = 5646057 }, + { url = "https://files.pythonhosted.org/packages/d6/4c/50c74e3d589517a9712a61a26143b587dba6285434a17aebf2ce6b82d2c3/pyright-1.1.394-py3-none-any.whl", hash = "sha256:5f74cce0a795a295fb768759bbeeec62561215dea657edcaab48a932b031ddbb", size = 5679540 }, ] [[package]] @@ -2810,7 +2810,7 @@ requires-dist = [ { name = "ordered-set", specifier = "==4.1.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = "==4.1.0" }, { name = "pyclip", marker = "extra == 'gui'", specifier = "==0.7" }, - { name = "pyright", marker = "extra == 'dev'", specifier = "==1.1.393" }, + { name = "pyright", marker = "extra == 'dev'", specifier = "==1.1.394" }, { name = "pytest", marker = "extra == 'dev'", specifier = "<8.4" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==0.25.3" }, { name = "pytest-cov", marker = "extra == 'dev'" }, diff --git a/xdsl/backend/riscv/prologue_epilogue_insertion.py b/xdsl/backend/riscv/prologue_epilogue_insertion.py index 85fa00594a..9036ea2aad 100644 --- a/xdsl/backend/riscv/prologue_epilogue_insertion.py +++ b/xdsl/backend/riscv/prologue_epilogue_insertion.py @@ -6,6 +6,7 @@ from xdsl.context import MLContext from xdsl.dialects import builtin, riscv, riscv_func from xdsl.dialects.riscv import ( + FloatRegisterType, IntRegisterType, Registers, RISCVRegisterType, @@ -39,6 +40,7 @@ def _process_function(self, func: riscv_func.FuncOp) -> None: for op in func.walk() if not isinstance(op, riscv.GetRegisterOp | riscv.GetFloatRegisterOp) for res in op.results + if isinstance(res.type, IntRegisterType | FloatRegisterType) if res.type in Registers.S or res.type in Registers.FS ) diff --git a/xdsl/backend/riscv/riscv_register_queue.py b/xdsl/backend/riscv/riscv_register_queue.py index 6d48f96f69..69302f1ee2 100644 --- a/xdsl/backend/riscv/riscv_register_queue.py +++ b/xdsl/backend/riscv/riscv_register_queue.py @@ -131,7 +131,7 @@ def exclude_register(self, reg: IntRegisterType | FloatRegisterType) -> None: """ Removes register from available set, if present. """ - if reg in self.available_int_registers: + if isinstance(reg, IntRegisterType) and reg in self.available_int_registers: self.available_int_registers.remove(reg) - if reg in self.available_float_registers: + if isinstance(reg, FloatRegisterType) and reg in self.available_float_registers: self.available_float_registers.remove(reg) diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 23150d2e6b..4b0f30805b 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -302,7 +302,9 @@ def split_ops( rem.remove(use.operation) # find constants in `a` needed outside of `a` - cnst_exports = [cnst for cnst in a_exports if isinstance(cnst, arith.ConstantOp)] + cnst_exports = tuple( + cnst for cnst in a_exports if isinstance(cnst, arith.ConstantOp) + ) # `a` exports one value plus any number of constants - duplicate exported constants and return op split if len(a_exports) == 1 + len(cnst_exports): @@ -310,7 +312,7 @@ def split_ops( for op in ops: if op in a: recv_chunk_ops.append(op) - if op in cnst_exports: + if op in cnst_exports and isinstance(op, arith.ConstantOp): # create a copy of the constant in the second region done_exch_ops.append(cln := op.clone()) # rewire ops of the second region to use the copied constant diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 9384b9eb33..bf2cf02157 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -503,10 +503,9 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, rewriter.erase_op(e, safe_erase=False) # housekeeping: this strategy requires zeroing out the accumulator iff the apply is inside a loop - assert (elem_t := accumulator.type.get_element_type()) in [ - Float16Type(), - Float32Type(), - ] + assert isinstance( + (elem_t := accumulator.type.get_element_type()), Float16Type | Float32Type + ) zero = arith.ConstantOp(FloatAttr(0.0, elem_t)) mov_op = csl.FmovsOp if elem_t == Float32Type() else csl.FmovhOp rewriter.insert_op(