From 5477486c817fe6217adba91506d6e726f0d2fe8e Mon Sep 17 00:00:00 2001 From: Melf Date: Fri, 4 Oct 2024 14:56:45 +0100 Subject: [PATCH 1/5] add ruff config file --- ruff.toml | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 ruff.toml diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..f2146f6 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,44 @@ +target-version = "py39" + +line-length = 88 + +extend-exclude = ["examples"] + +select = [ + "E", # pycodestyle Errors + "W", # pycodestyle Warnings + + # "A", # flake8-builtins + # "B", # flake8-Bugbear + # "C4", # flake8-comprehensions + # "COM", # flake8-commas + # "EXE", # flake8-executable + "F", # pyFlakes + # "FA", # flake8-future-annotations + # "FIX", # flake8-fixme + # "FLY", # flynt + "I", # isort + # "INP", # flake8-no-pep420 + # "ISC", # flake8-implicit-str-concat + # "N", # pep8-Naming + # "NPY", # NumPy-specific + # "PERF", # Perflint + # "PGH", # pygrep-hooks + # "PIE", # flake8-pie + # "PL", # pylint + # "PT", # flake8-pytest-style + # "RSE", # flake8-raise + # "RUF", # Ruff-specific + # "S", # flake8-bandit (Security) + "SIM", # flake8-simplify + # "SLF", # flake8-self + "T20", # flake8-print + "TCH", # flake8-type-checking + # "TRY", # tryceratops + "UP", # pyupgrade + # "YTT", # flake8-2020 +] + +[per-file-ignores] +".github/workflows/docs/conf.py" = ["E402"] +"__init__.py" = ["F401"] # module imported but unused (6) From b0aa539829c50e9a66bc9854417a6e6031d60a64 Mon Sep 17 00:00:00 2001 From: Melf Date: Fri, 4 Oct 2024 14:58:46 +0100 Subject: [PATCH 2/5] add ruff to lint ci check --- .github/workflows/lint.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 418b3c3..fef6380 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,10 +22,13 @@ jobs: - name: Update pip run: pip install --upgrade pip - name: Install black and pylint - run: pip install black pylint + run: pip install black pylint ruff - name: Check files are formatted with black run: | black --check . + - name: Run ruff + run: | + ruff check . - name: Run pylint run: | pylint */ From 16bb15419a3c47b6a6f74f46c7153f4711e96297 Mon Sep 17 00:00:00 2001 From: Melf Date: Fri, 4 Oct 2024 15:01:45 +0100 Subject: [PATCH 3/5] ruff fixes I --- pytket/extensions/qujax/__init__.py | 10 +++++----- pytket/extensions/qujax/qujax_convert.py | 10 ++++++---- setup.py | 5 +++-- tests/test_tket.py | 16 +++++++++------- tests/test_tket_symbolic.py | 8 +++++--- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/pytket/extensions/qujax/__init__.py b/pytket/extensions/qujax/__init__.py index 8b51fa9..a893c0c 100644 --- a/pytket/extensions/qujax/__init__.py +++ b/pytket/extensions/qujax/__init__.py @@ -16,12 +16,12 @@ """ # _metadata.py is copied to the folder after installation. -from ._metadata import __extension_version__, __extension_name__ +from ._metadata import __extension_name__, __extension_version__ from .qujax_convert import ( - tk_to_qujax, - tk_to_qujax_args, - tk_to_param, + _tk_qubits_to_inds, print_circuit, qujax_args_to_tk, - _tk_qubits_to_inds, + tk_to_param, + tk_to_qujax, + tk_to_qujax_args, ) diff --git a/pytket/extensions/qujax/qujax_convert.py b/pytket/extensions/qujax/qujax_convert.py index 0e63f09..e478c5b 100644 --- a/pytket/extensions/qujax/qujax_convert.py +++ b/pytket/extensions/qujax/qujax_convert.py @@ -16,13 +16,15 @@ Methods to allow conversion between qujax and pytket """ -from typing import Tuple, Sequence, Optional, List, Union, Callable, Any +from collections.abc import Sequence from functools import wraps +from typing import Any, Callable, List, Optional, Tuple, Union -import qujax # type: ignore from jax import numpy as jnp -from sympy import lambdify, Symbol -from pytket import Qubit, Circuit # type: ignore +from sympy import Symbol, lambdify + +import qujax # type: ignore +from pytket import Circuit, Qubit # type: ignore from pytket._tket.circuit import Command diff --git a/setup.py b/setup.py index 9015564..b695eb4 100644 --- a/setup.py +++ b/setup.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import shutil import os -from setuptools import setup, find_namespace_packages # type: ignore +import shutil + +from setuptools import find_namespace_packages, setup # type: ignore metadata: dict = {} with open("_metadata.py") as fp: diff --git a/tests/test_tket.py b/tests/test_tket.py index 0630797..aa163cb 100644 --- a/tests/test_tket.py +++ b/tests/test_tket.py @@ -12,20 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Any -from jax import numpy as jnp, jit, grad, random -import qujax # type: ignore +from typing import Any, Union + import pytest +import qujax # type: ignore +from jax import grad, jit, random +from jax import numpy as jnp from pytket.circuit import Circuit, Qubit -from pytket.pauli import Pauli, QubitPauliString -from pytket.utils import QubitPauliOperator from pytket.extensions.qujax import ( - tk_to_qujax, - tk_to_qujax_args, qujax_args_to_tk, tk_to_param, + tk_to_qujax, + tk_to_qujax_args, ) +from pytket.pauli import Pauli, QubitPauliString +from pytket.utils import QubitPauliOperator def _test_circuit( diff --git a/tests/test_tket_symbolic.py b/tests/test_tket_symbolic.py index e66ca2d..b986cf4 100644 --- a/tests/test_tket_symbolic.py +++ b/tests/test_tket_symbolic.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from collections.abc import Sequence + import pytest +from jax import grad, jit, random +from jax import numpy as jnp from sympy import Symbol -from jax import numpy as jnp, jit, grad, random from pytket.circuit import Circuit, OpType from pytket.extensions.qujax import ( + qujax_args_to_tk, tk_to_qujax, tk_to_qujax_args, - qujax_args_to_tk, ) From 943fb4a6301cf5d2fdb56159bb35f822058ccc41 Mon Sep 17 00:00:00 2001 From: Melf Date: Fri, 4 Oct 2024 15:04:13 +0100 Subject: [PATCH 4/5] ruff fix II --- pytket/extensions/qujax/qujax_convert.py | 10 +++++----- tests/test_tket.py | 9 +++++++-- tests/test_tket_symbolic.py | 9 +++++++-- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pytket/extensions/qujax/qujax_convert.py b/pytket/extensions/qujax/qujax_convert.py index e478c5b..d535e36 100644 --- a/pytket/extensions/qujax/qujax_convert.py +++ b/pytket/extensions/qujax/qujax_convert.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from functools import wraps -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union from jax import numpy as jnp from sympy import Symbol, lambdify @@ -28,7 +28,7 @@ from pytket._tket.circuit import Command -def _tk_qubits_to_inds(tk_qubits: Sequence[Qubit]) -> Tuple[int, ...]: +def _tk_qubits_to_inds(tk_qubits: Sequence[Qubit]) -> tuple[int, ...]: """ Convert Sequence of pytket qubits objects to tuple of integers qubit indices. @@ -65,7 +65,7 @@ def g(*args: Any, **kwargs: Any) -> Any: def _symbolic_command_to_gate_and_param_inds( command: Command, symbol_map: dict -) -> Tuple[Union[str, Callable[[jnp.ndarray], jnp.ndarray]], Sequence[int]]: +) -> tuple[Union[str, Callable[[jnp.ndarray], jnp.ndarray]], Sequence[int]]: """ Convert pytket command to qujax (gate, parameter indices) tuple. @@ -110,7 +110,7 @@ def _symbolic_command_to_gate_and_param_inds( return gate, param_inds -def tk_to_qujax_args(circuit: Circuit, symbol_map: Optional[dict] = None) -> Tuple[ +def tk_to_qujax_args(circuit: Circuit, symbol_map: Optional[dict] = None) -> tuple[ Sequence[Union[str, Callable[[jnp.ndarray], jnp.ndarray]]], Sequence[Sequence[int]], Sequence[Sequence[int]], @@ -294,7 +294,7 @@ def print_circuit( gate_ind_min: int = 0, gate_ind_max: int = jnp.inf, # type: ignore sep_length: int = 1, -) -> List[str]: +) -> list[str]: """ Returns and prints basic string representation of circuit. diff --git a/tests/test_tket.py b/tests/test_tket.py index aa163cb..90f9eac 100644 --- a/tests/test_tket.py +++ b/tests/test_tket.py @@ -70,11 +70,16 @@ def _test_circuit( assert jnp.allclose(test_jit_dm_diag, true_probs) if param is not None: - cost_func = lambda p: jnp.square(apply_circuit(p)).real.sum() + + def cost_func(p): + return jnp.square(apply_circuit(p)).real.sum() + grad_cost_func = grad(cost_func) assert isinstance(grad_cost_func(param), jnp.ndarray) - cost_jit_func = lambda p: jnp.square(jit_apply_circuit(p)).real.sum() + def cost_jit_func(p): + return jnp.square(jit_apply_circuit(p)).real.sum() + grad_cost_jit_func = grad(cost_jit_func) assert isinstance(grad_cost_jit_func(param), jnp.ndarray) diff --git a/tests/test_tket_symbolic.py b/tests/test_tket_symbolic.py index b986cf4..23dc00e 100644 --- a/tests/test_tket_symbolic.py +++ b/tests/test_tket_symbolic.py @@ -71,11 +71,16 @@ def _test_circuit( assert jnp.allclose(test_jit_dm_diag, true_probs) if len(params): - cost_func = lambda p: jnp.square(apply_circuit(p)).real.sum() + + def cost_func(p): + return jnp.square(apply_circuit(p)).real.sum() + grad_cost_func = grad(cost_func) assert isinstance(grad_cost_func(params), jnp.ndarray) - cost_jit_func = lambda p: jnp.square(jit_apply_circuit(p)).real.sum() + def cost_jit_func(p): + return jnp.square(jit_apply_circuit(p)).real.sum() + grad_cost_jit_func = grad(cost_jit_func) assert isinstance(grad_cost_jit_func(params), jnp.ndarray) From 0aae00b9a82fe79bd901c94623d95cae9636f8dd Mon Sep 17 00:00:00 2001 From: Melf Date: Fri, 4 Oct 2024 15:08:42 +0100 Subject: [PATCH 5/5] fix setup path --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b695eb4..4d2c596 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ import os import shutil +from pathlib import Path from setuptools import find_namespace_packages, setup # type: ignore @@ -38,7 +39,7 @@ "Tracker": "https://github.com/CQCL/pytket-qujax/issues", }, description="Extension for pytket, providing access to qujax functions", - long_description=open("README.md").read(), + long_description=(Path(__file__).parent / "README.md").read_text(), long_description_content_type="text/markdown", license="Apache 2", packages=find_namespace_packages(include=["pytket.*"]),