diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 69318735..99b553e4 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -136,7 +136,7 @@ jobs:
         run: python .github/workflows/set_cibw_build.py
 
       - name: Build wheels
-        uses: pypa/cibuildwheel@v2.18
+        uses: pypa/cibuildwheel@v2.19
         env:
           CIBW_BUILD: ${{ env.CIBW_BUILD }}
         with:
@@ -190,7 +190,7 @@ jobs:
         run: python .github/workflows/set_cibw_build.py
 
       - name: Build wheels
-        uses: pypa/cibuildwheel@v2.18
+        uses: pypa/cibuildwheel@v2.19
         env:
           CIBW_BUILD: ${{ env.CIBW_BUILD }}
         with:
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 24d06fe7..f156ffe3 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -80,15 +80,15 @@ jobs:
           USE_FP16: "ON"
           TORCH_CUDA_ARCH_LIST: "Common"
         run: |
-          python -m pip install -vvv -e .
+          python -m pip install -vvv --editable .
 
       - name: Test with pytest
         run: |
           make pytest
 
       - name: Upload coverage to Codecov
-        if: runner.os == 'Linux'
         uses: codecov/codecov-action@v4
+        if: ${{ matrix.os == 'ubuntu-latest' }}
         with:
           token: ${{ secrets.CODECOV_TOKEN }}
           file: ./tests/coverage.xml
@@ -127,7 +127,7 @@ jobs:
 
       - name: Install TorchOpt
         run: |
-          python -m pip install -vvv -e .
+          python -m pip install -vvv --editable .
         env:
           TORCHOPT_NO_EXTENSIONS: "true"
 
@@ -136,8 +136,8 @@ jobs:
           make pytest
 
       - name: Upload coverage to Codecov
-        if: runner.os == 'Linux'
         uses: codecov/codecov-action@v4
+        if: ${{ matrix.os == 'ubuntu-latest' }}
         with:
           token: ${{ secrets.CODECOV_TOKEN }}
           file: ./tests/coverage.xml
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f8419466..4814c681 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -26,11 +26,11 @@ repos:
       - id: debug-statements
       - id: double-quote-string-fixer
   - repo: https://github.com/pre-commit/mirrors-clang-format
-    rev: v18.1.5
+    rev: v18.1.6
     hooks:
       - id: clang-format
   - repo: https://github.com/astral-sh/ruff-pre-commit
-    rev: v0.4.7
+    rev: v0.4.9
     hooks:
       - id: ruff
         args: [--fix, --exit-non-zero-on-fix]
@@ -43,7 +43,7 @@ repos:
     hooks:
       - id: black-jupyter
   - repo: https://github.com/asottile/pyupgrade
-    rev: v3.15.2
+    rev: v3.16.0
     hooks:
       - id: pyupgrade
         args: [--py38-plus] # sync with requires-python
@@ -52,7 +52,7 @@ repos:
             ^examples/
           )
   - repo: https://github.com/pycqa/flake8
-    rev: 7.0.0
+    rev: 7.1.0
     hooks:
       - id: flake8
         additional_dependencies:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 13814dee..101ba3ec 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -17,7 +17,14 @@ cmake_minimum_required(VERSION 3.11)  # for FetchContent
 project(torchopt LANGUAGES CXX)
 
 include(FetchContent)
-set(PYBIND11_VERSION v2.12.0)
+
+set(THIRD_PARTY_DIR "${CMAKE_SOURCE_DIR}/third-party")
+if(NOT DEFINED PYBIND11_VERSION AND NOT "$ENV{PYBIND11_VERSION}" STREQUAL "")
+    set(PYBIND11_VERSION "$ENV{PYBIND11_VERSION}")
+endif()
+if(NOT PYBIND11_VERSION)
+    set(PYBIND11_VERSION stable)
+endif()
 
 if(NOT CMAKE_BUILD_TYPE)
     set(CMAKE_BUILD_TYPE Release)
@@ -172,7 +179,7 @@ endif()
 
 system(
     STRIP OUTPUT_VARIABLE PYTHON_VERSION
-    COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('platform').python_version())"
+    COMMAND "${PYTHON_EXECUTABLE}" -c "print('.'.join(map(str, __import__('sys').version_info[:3])))"
 )
 
 message(STATUS "Use Python version: ${PYTHON_VERSION}")
@@ -216,11 +223,12 @@ if("${PYBIND11_CMAKE_DIR}" STREQUAL "")
         GIT_REPOSITORY https://github.com/pybind/pybind11.git
         GIT_TAG "${PYBIND11_VERSION}"
         GIT_SHALLOW TRUE
-        SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/pybind11"
-        BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/build"
-        STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/stamp"
+        SOURCE_DIR "${THIRD_PARTY_DIR}/pybind11"
+        BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/build"
+        STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/stamp"
     )
     FetchContent_GetProperties(pybind11)
+
     if(NOT pybind11_POPULATED)
         message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third-party/pybind11...")
         FetchContent_MakeAvailable(pybind11)
diff --git a/pyproject.toml b/pyproject.toml
index ed93944a..d343e04a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -235,6 +235,7 @@ extend-exclude = ["examples"]
 select = [
     "E", "W",  # pycodestyle
     "F",       # pyflakes
+    "C90",     # mccabe
     "UP",      # pyupgrade
     "ANN",     # flake8-annotations
     "S",       # flake8-bandit
@@ -243,7 +244,10 @@ select = [
     "COM",     # flake8-commas
     "C4",      # flake8-comprehensions
     "EXE",     # flake8-executable
+    "FA",      # flake8-future-annotations
+    "LOG",     # flake8-logging
     "ISC",     # flake8-implicit-str-concat
+    "INP",     # flake8-no-pep420
     "PIE",     # flake8-pie
     "PYI",     # flake8-pyi
     "Q",       # flake8-quotes
@@ -251,6 +255,10 @@ select = [
     "RET",     # flake8-return
     "SIM",     # flake8-simplify
     "TID",     # flake8-tidy-imports
+    "TCH",     # flake8-type-checking
+    "PERF",    # perflint
+    "FURB",    # refurb
+    "TRY",     # tryceratops
     "RUF",     # ruff
 ]
 ignore = [
@@ -268,9 +276,9 @@ ignore = [
     # S101: use of `assert` detected
     # internal use and may never raise at runtime
     "S101",
-    # PLR0402: use from {module} import {name} in lieu of alias
-    # use alias for import convention (e.g., `import torch.nn as nn`)
-    "PLR0402",
+    # TRY003: avoid specifying long messages outside the exception class
+    # long messages are necessary for clarity
+    "TRY003",
 ]
 typing-modules = ["torchopt.typing"]
 
@@ -296,6 +304,9 @@ typing-modules = ["torchopt.typing"]
     "F401",  # unused-import
     "F811",  # redefined-while-unused
 ]
+"docs/source/conf.py" = [
+    "INP001",  # flake8-no-pep420
+]
 
 [tool.ruff.lint.flake8-annotations]
 allow-star-arg-any = true
diff --git a/tests/helpers.py b/tests/helpers.py
index 0dc415d4..ca5aa443 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -20,7 +20,7 @@
 import itertools
 import os
 import random
-from typing import Iterable
+from typing import TYPE_CHECKING, Iterable
 
 import numpy as np
 import pytest
@@ -30,7 +30,10 @@
 from torch.utils import data
 
 from torchopt import pytree
-from torchopt.typing import TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import TensorTree
 
 
 BATCH_SIZE = 64
diff --git a/tests/test_alias.py b/tests/test_alias.py
index 58b5a328..3c42d7c8 100644
--- a/tests/test_alias.py
+++ b/tests/test_alias.py
@@ -15,7 +15,7 @@
 
 from __future__ import annotations
 
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
 
 import functorch
 import pytest
@@ -26,7 +26,10 @@
 import torchopt
 from torchopt import pytree
 from torchopt.alias.utils import _set_use_chain_flat
-from torchopt.typing import TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import TensorTree
 
 
 @helpers.parametrize(
diff --git a/tests/test_implicit.py b/tests/test_implicit.py
index ff0ba15c..6cccb716 100644
--- a/tests/test_implicit.py
+++ b/tests/test_implicit.py
@@ -18,7 +18,7 @@
 import copy
 import re
 from collections import OrderedDict
-from types import FunctionType
+from typing import TYPE_CHECKING
 
 import functorch
 import numpy as np
@@ -47,6 +47,10 @@
     HAS_JAX = False
 
 
+if TYPE_CHECKING:
+    from types import FunctionType
+
+
 BATCH_SIZE = 8
 NUM_UPDATES = 3
 
@@ -123,7 +127,7 @@ def get_rr_dataset_torch() -> data.DataLoader:
     inner_lr=[2e-2, 2e-3],
     inner_update=[20, 50, 100],
 )
-def test_imaml_solve_normal_cg(
+def test_imaml_solve_normal_cg(  # noqa: C901
     dtype: torch.dtype,
     lr: float,
     inner_lr: float,
@@ -251,7 +255,7 @@ def outer_level(p, xs, ys):
     inner_update=[20, 50, 100],
     ns=[False, True],
 )
-def test_imaml_solve_inv(
+def test_imaml_solve_inv(  # noqa: C901
     dtype: torch.dtype,
     lr: float,
     inner_lr: float,
@@ -375,7 +379,12 @@ def outer_level(p, xs, ys):
     inner_lr=[2e-2, 2e-3],
     inner_update=[20, 50, 100],
 )
-def test_imaml_module(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None:
+def test_imaml_module(  # noqa: C901
+    dtype: torch.dtype,
+    lr: float,
+    inner_lr: float,
+    inner_update: int,
+) -> None:
     np_dtype = helpers.dtype_torch2numpy(dtype)
 
     jax_model, jax_params = get_model_jax(dtype=np_dtype)
@@ -763,7 +772,7 @@ def solve(self):
         make_optimality_from_objective(MyModule2)
 
 
-def test_module_abstract_methods() -> None:
+def test_module_abstract_methods() -> None:  # noqa: C901
     class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
         def objective(self):
             return torch.tensor(0.0)
@@ -809,7 +818,7 @@ def solve(self):
 
         class MyModule5(torchopt.nn.ImplicitMetaGradientModule):
             @classmethod
-            def optimality(self):
+            def optimality(cls):
                 return ()
 
             def solve(self):
@@ -846,7 +855,7 @@ def solve(self):
 
         class MyModule8(torchopt.nn.ImplicitMetaGradientModule):
             @classmethod
-            def objective(self):
+            def objective(cls):
                 return ()
 
             def solve(self):
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 5215e7b3..57c35e47 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -13,6 +13,8 @@
 # limitations under the License.
 # ==============================================================================
 
+import operator
+
 import torch
 
 import torchopt
@@ -80,7 +82,7 @@ def test_module_clone() -> None:
         assert y.is_cuda
 
 
-def test_extract_state_dict():
+def test_extract_state_dict():  # noqa: C901
     fc = torch.nn.Linear(1, 1)
     state_dict = torchopt.extract_state_dict(fc, by='reference', device=torch.device('meta'))
     for param_dict in state_dict.params:
@@ -121,7 +123,7 @@ def test_extract_state_dict():
     loss = fc(torch.ones(1, 1)).sum()
     optim.step(loss)
     state_dict = torchopt.extract_state_dict(optim)
-    same = pytree.tree_map(lambda x, y: x is y, state_dict, tuple(optim.state_groups))
+    same = pytree.tree_map(operator.is_, state_dict, tuple(optim.state_groups))
     assert all(pytree.tree_flatten(same)[0])
 
 
diff --git a/torchopt/__init__.py b/torchopt/__init__.py
index 5e568526..830072e3 100644
--- a/torchopt/__init__.py
+++ b/torchopt/__init__.py
@@ -81,50 +81,50 @@
 
 
 __all__ = [
-    'accelerated_op_available',
-    'adam',
-    'adamax',
-    'adadelta',
-    'radam',
-    'adamw',
-    'adagrad',
-    'rmsprop',
-    'sgd',
-    'clip_grad_norm',
-    'nan_to_num',
-    'register_hook',
-    'chain',
-    'Optimizer',
     'SGD',
-    'Adam',
-    'AdaMax',
-    'Adamax',
     'AdaDelta',
-    'Adadelta',
-    'RAdam',
-    'AdamW',
     'AdaGrad',
+    'AdaMax',
+    'Adadelta',
     'Adagrad',
-    'RMSProp',
-    'RMSprop',
-    'MetaOptimizer',
-    'MetaSGD',
-    'MetaAdam',
-    'MetaAdaMax',
-    'MetaAdamax',
+    'Adam',
+    'AdamW',
+    'Adamax',
+    'FuncOptimizer',
     'MetaAdaDelta',
-    'MetaAdadelta',
-    'MetaRAdam',
-    'MetaAdamW',
     'MetaAdaGrad',
+    'MetaAdaMax',
+    'MetaAdadelta',
     'MetaAdagrad',
+    'MetaAdam',
+    'MetaAdamW',
+    'MetaAdamax',
+    'MetaOptimizer',
+    'MetaRAdam',
     'MetaRMSProp',
     'MetaRMSprop',
-    'FuncOptimizer',
+    'MetaSGD',
+    'Optimizer',
+    'RAdam',
+    'RMSProp',
+    'RMSprop',
+    'accelerated_op_available',
+    'adadelta',
+    'adagrad',
+    'adam',
+    'adamax',
+    'adamw',
     'apply_updates',
+    'chain',
+    'clip_grad_norm',
     'extract_state_dict',
-    'recover_state_dict',
-    'stop_gradient',
     'module_clone',
     'module_detach_',
+    'nan_to_num',
+    'radam',
+    'recover_state_dict',
+    'register_hook',
+    'rmsprop',
+    'sgd',
+    'stop_gradient',
 ]
diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py
index 103b6fc0..90452046 100644
--- a/torchopt/accelerated_op/__init__.py
+++ b/torchopt/accelerated_op/__init__.py
@@ -16,12 +16,15 @@
 
 from __future__ import annotations
 
-from typing import Iterable
+from typing import TYPE_CHECKING, Iterable
 
 import torch
 
 from torchopt.accelerated_op.adam_op import AdamOp
-from torchopt.typing import Device
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import Device
 
 
 def is_available(devices: Device | Iterable[Device] | None = None) -> bool:
@@ -42,6 +45,6 @@ def is_available(devices: Device | Iterable[Device] | None = None) -> bool:
                 return False
             updates = torch.tensor(1.0, device=device)
             op(updates, updates, updates, 1)
-        return True
     except Exception:  # noqa: BLE001 # pylint: disable=broad-except
         return False
+    return True
diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py
index bc999766..d7f9796d 100644
--- a/torchopt/accelerated_op/_src/adam_op.py
+++ b/torchopt/accelerated_op/_src/adam_op.py
@@ -18,7 +18,11 @@
 
 from __future__ import annotations
 
-import torch
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+    import torch
 
 
 def forward_(
diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py
index 3cfb5b8b..5767c5d7 100644
--- a/torchopt/alias/__init__.py
+++ b/torchopt/alias/__init__.py
@@ -41,4 +41,13 @@
 from torchopt.alias.sgd import sgd
 
 
-__all__ = ['adagrad', 'radam', 'adam', 'adamax', 'adadelta', 'adamw', 'rmsprop', 'sgd']
+__all__ = [
+    'adadelta',
+    'adagrad',
+    'adam',
+    'adamax',
+    'adamw',
+    'radam',
+    'rmsprop',
+    'sgd',
+]
diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py
index fb0b551a..910cb13e 100644
--- a/torchopt/alias/adadelta.py
+++ b/torchopt/alias/adadelta.py
@@ -16,6 +16,8 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 from torchopt.alias.utils import (
     _get_use_chain_flat,
     flip_sign_and_add_weight_decay,
@@ -23,7 +25,10 @@
 )
 from torchopt.combine import chain
 from torchopt.transform import scale_by_adadelta
-from torchopt.typing import GradientTransformation, ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import GradientTransformation, ScalarOrSchedule
 
 
 __all__ = ['adadelta']
diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py
index 9419e908..0ae0eb8e 100644
--- a/torchopt/alias/adam.py
+++ b/torchopt/alias/adam.py
@@ -33,6 +33,8 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 from torchopt.alias.utils import (
     _get_use_chain_flat,
     flip_sign_and_add_weight_decay,
@@ -40,7 +42,10 @@
 )
 from torchopt.combine import chain
 from torchopt.transform import scale_by_accelerated_adam, scale_by_adam
-from torchopt.typing import GradientTransformation, ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import GradientTransformation, ScalarOrSchedule
 
 
 __all__ = ['adam']
diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py
index f80c0c2f..3da16713 100644
--- a/torchopt/alias/adamax.py
+++ b/torchopt/alias/adamax.py
@@ -16,6 +16,8 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 from torchopt.alias.utils import (
     _get_use_chain_flat,
     flip_sign_and_add_weight_decay,
@@ -23,7 +25,10 @@
 )
 from torchopt.combine import chain
 from torchopt.transform import scale_by_adamax
-from torchopt.typing import GradientTransformation, ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import GradientTransformation, ScalarOrSchedule
 
 
 __all__ = ['adamax']
diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py
index 38d4d5ac..2dc72ef1 100644
--- a/torchopt/alias/adamw.py
+++ b/torchopt/alias/adamw.py
@@ -33,7 +33,7 @@
 
 from __future__ import annotations
 
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
 
 from torchopt.alias.utils import (
     _get_use_chain_flat,
@@ -42,7 +42,10 @@
 )
 from torchopt.combine import chain
 from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam
-from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule
 
 
 __all__ = ['adamw']
diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py
index 56d3d3d5..9e2880ee 100644
--- a/torchopt/alias/radam.py
+++ b/torchopt/alias/radam.py
@@ -16,6 +16,8 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 from torchopt.alias.utils import (
     _get_use_chain_flat,
     flip_sign_and_add_weight_decay,
@@ -23,7 +25,10 @@
 )
 from torchopt.combine import chain
 from torchopt.transform import scale_by_radam
-from torchopt.typing import GradientTransformation, ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import GradientTransformation, ScalarOrSchedule
 
 
 __all__ = ['radam']
diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py
index 49f8784d..0f41e822 100644
--- a/torchopt/alias/utils.py
+++ b/torchopt/alias/utils.py
@@ -16,14 +16,18 @@
 from __future__ import annotations
 
 import threading
-
-import torch
+from typing import TYPE_CHECKING
 
 from torchopt import pytree
 from torchopt.base import EmptyState, GradientTransformation, identity
 from torchopt.transform import scale, scale_by_schedule
 from torchopt.transform.utils import tree_map_flat, tree_map_flat_
-from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates
 
 
 __all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr']
@@ -68,7 +72,7 @@ def _flip_sign_and_add_weight_decay_flat(
     )
 
 
-def _flip_sign_and_add_weight_decay(
+def _flip_sign_and_add_weight_decay(  # noqa: C901
     weight_decay: float = 0.0,
     maximize: bool = False,
     *,
diff --git a/torchopt/base.py b/torchopt/base.py
index 572708e2..81892e17 100644
--- a/torchopt/base.py
+++ b/torchopt/base.py
@@ -44,10 +44,10 @@
 
 
 __all__ = [
+    'ChainedGradientTransformation',
     'EmptyState',
-    'UninitializedState',
     'GradientTransformation',
-    'ChainedGradientTransformation',
+    'UninitializedState',
     'identity',
 ]
 
diff --git a/torchopt/clip.py b/torchopt/clip.py
index 55ae83fc..d64afc58 100644
--- a/torchopt/clip.py
+++ b/torchopt/clip.py
@@ -19,11 +19,16 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 import torch
 
 from torchopt import pytree
 from torchopt.base import EmptyState, GradientTransformation
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['clip_grad_norm']
diff --git a/torchopt/combine.py b/torchopt/combine.py
index 158ec982..15345286 100644
--- a/torchopt/combine.py
+++ b/torchopt/combine.py
@@ -33,9 +33,14 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 from torchopt import pytree
 from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['chain', 'chain_flat']
diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py
index 21737015..4cff14c6 100644
--- a/torchopt/diff/implicit/__init__.py
+++ b/torchopt/diff/implicit/__init__.py
@@ -19,4 +19,4 @@
 from torchopt.diff.implicit.nn import ImplicitMetaGradientModule
 
 
-__all__ = ['custom_root', 'ImplicitMetaGradientModule']
+__all__ = ['ImplicitMetaGradientModule', 'custom_root']
diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py
index d3efda2c..11ba0153 100644
--- a/torchopt/diff/implicit/decorator.py
+++ b/torchopt/diff/implicit/decorator.py
@@ -37,20 +37,23 @@
 
 import functools
 import inspect
-from typing import Any, Callable, Dict, Sequence, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple
 
 import functorch
 import torch
 from torch.autograd import Function
 
 from torchopt import linear_solve, pytree
-from torchopt.typing import (
-    ListOfOptionalTensors,
-    ListOfTensors,
-    TensorOrTensors,
-    TupleOfOptionalTensors,
-    TupleOfTensors,
-)
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import (
+        ListOfOptionalTensors,
+        ListOfTensors,
+        TensorOrTensors,
+        TupleOfOptionalTensors,
+        TupleOfTensors,
+    )
 
 
 __all__ = ['custom_root']
@@ -253,7 +256,7 @@ def _merge_tensor_and_others(
 
 
 # pylint: disable-next=too-many-arguments,too-many-statements
-def _custom_root(
+def _custom_root(  # noqa: C901
     solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]],
     optimality_fn: Callable[..., TensorOrTensors],
     solve: Callable[..., TensorOrTensors],
@@ -271,7 +274,7 @@ def _custom_root(
         fn = getattr(reference_signature, 'subfn', reference_signature)
         reference_signature = inspect.signature(fn)
 
-    def make_custom_vjp_solver_fn(
+    def make_custom_vjp_solver_fn(  # noqa: C901
         solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]],
         kwarg_keys: Sequence[str],
         args_signs: tuple[tuple[int, int, type[tuple | list] | None], ...],
diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py
index 8719f675..6b214cb8 100644
--- a/torchopt/diff/implicit/nn/module.py
+++ b/torchopt/diff/implicit/nn/module.py
@@ -22,15 +22,19 @@
 import functools
 import inspect
 import itertools
-from typing import Any, Iterable
+from typing import TYPE_CHECKING, Any, Iterable
 
 import functorch
-import torch
 
 from torchopt.diff.implicit.decorator import custom_root
 from torchopt.nn.module import MetaGradientModule
 from torchopt.nn.stateless import reparametrize, swap_state
-from torchopt.typing import LinearSolver, TupleOfTensors
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import LinearSolver, TupleOfTensors
 
 
 __all__ = ['ImplicitMetaGradientModule']
diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py
index f00e097a..4369f4e5 100644
--- a/torchopt/diff/zero_order/__init__.py
+++ b/torchopt/diff/zero_order/__init__.py
@@ -25,7 +25,7 @@
 from torchopt.diff.zero_order.nn import ZeroOrderGradientModule
 
 
-__all__ = ['zero_order', 'ZeroOrderGradientModule']
+__all__ = ['ZeroOrderGradientModule', 'zero_order']
 
 
 class _CallableModule(_ModuleType):  # pylint: disable=too-few-public-methods
diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py
index b1126636..e498b43c 100644
--- a/torchopt/diff/zero_order/decorator.py
+++ b/torchopt/diff/zero_order/decorator.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import functools
+import itertools
 from typing import Any, Callable, Literal, Sequence
 from typing_extensions import TypeAlias  # Python 3.10+
 
@@ -43,7 +44,7 @@ def sample(
         return self.sample_fn(sample_shape)
 
 
-def _zero_order_naive(  # pylint: disable=too-many-statements
+def _zero_order_naive(  # noqa: C901 # pylint: disable=too-many-statements
     fn: Callable[..., torch.Tensor],
     distribution: Samplable,
     argnums: tuple[int, ...],
@@ -51,7 +52,7 @@ def _zero_order_naive(  # pylint: disable=too-many-statements
     sigma: float,
 ) -> Callable[..., torch.Tensor]:
     @functools.wraps(fn)
-    def apply(*args: Any) -> torch.Tensor:  # pylint: disable=too-many-statements
+    def apply(*args: Any) -> torch.Tensor:  # noqa: C901 # pylint: disable=too-many-statements
         diff_params = [args[argnum] for argnum in argnums]
         flat_diff_params: list[Any]
         flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params)  # type: ignore[arg-type]
@@ -81,7 +82,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
 
                 output = fn(*origin_args)
                 if not isinstance(output, torch.Tensor):
-                    raise RuntimeError('`output` must be a tensor.')
+                    raise TypeError('`output` must be a tensor.')
                 if output.ndim != 0:
                     raise RuntimeError('`output` must be a scalar tensor.')
                 ctx.save_for_backward(*flat_diff_params, *tensors)
@@ -122,9 +123,9 @@ def add_perturbation(
 
                 for _ in range(num_samples):
                     noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params]
-                    flat_noisy_params = [
-                        add_perturbation(t, n) for t, n in zip(flat_diff_params, noises)  # type: ignore[arg-type]
-                    ]
+                    flat_noisy_params = list(
+                        itertools.starmap(add_perturbation, zip(flat_diff_params, noises)),
+                    )
                     noisy_params: list[Any] = pytree.tree_unflatten(  # type: ignore[assignment]
                         diff_params_treespec,
                         flat_noisy_params,
@@ -149,7 +150,7 @@ def add_perturbation(
     return apply
 
 
-def _zero_order_forward(  # pylint: disable=too-many-statements
+def _zero_order_forward(  # noqa: C901 # pylint: disable=too-many-statements
     fn: Callable[..., torch.Tensor],
     distribution: Samplable,
     argnums: tuple[int, ...],
@@ -157,7 +158,7 @@ def _zero_order_forward(  # pylint: disable=too-many-statements
     sigma: float,
 ) -> Callable[..., torch.Tensor]:
     @functools.wraps(fn)
-    def apply(*args: Any) -> torch.Tensor:  # pylint: disable=too-many-statements
+    def apply(*args: Any) -> torch.Tensor:  # noqa: C901 # pylint: disable=too-many-statements
         diff_params = [args[argnum] for argnum in argnums]
         flat_diff_params: list[Any]
         flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params)  # type: ignore[arg-type]
@@ -187,7 +188,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
 
                 output = fn(*origin_args)
                 if not isinstance(output, torch.Tensor):
-                    raise RuntimeError('`output` must be a tensor.')
+                    raise TypeError('`output` must be a tensor.')
                 if output.ndim != 0:
                     raise RuntimeError('`output` must be a scalar tensor.')
                 ctx.save_for_backward(*flat_diff_params, *tensors, output)
@@ -226,9 +227,9 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
 
                 for _ in range(num_samples):
                     noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params]
-                    flat_noisy_params = [
-                        add_perturbation(t, n) for t, n in zip(flat_diff_params, noises)  # type: ignore[arg-type]
-                    ]
+                    flat_noisy_params = list(
+                        itertools.starmap(add_perturbation, zip(flat_diff_params, noises)),
+                    )
                     noisy_params: list[Any] = pytree.tree_unflatten(  # type: ignore[assignment]
                         diff_params_treespec,
                         flat_noisy_params,
@@ -254,7 +255,7 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
     return apply
 
 
-def _zero_order_antithetic(  # pylint: disable=too-many-statements
+def _zero_order_antithetic(  # noqa: C901 # pylint: disable=too-many-statements
     fn: Callable[..., torch.Tensor],
     distribution: Samplable,
     argnums: tuple[int, ...],
@@ -262,7 +263,7 @@ def _zero_order_antithetic(  # pylint: disable=too-many-statements
     sigma: float,
 ) -> Callable[..., torch.Tensor]:
     @functools.wraps(fn)
-    def apply(*args: Any) -> torch.Tensor:  # pylint: disable=too-many-statements
+    def apply(*args: Any) -> torch.Tensor:  # noqa: C901 # pylint: disable=too-many-statements
         diff_params = [args[argnum] for argnum in argnums]
         flat_diff_params: list[Any]
         flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params)  # type: ignore[arg-type]
@@ -292,7 +293,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
 
                 output = fn(*origin_args)
                 if not isinstance(output, torch.Tensor):
-                    raise RuntimeError('`output` must be a tensor.')
+                    raise TypeError('`output` must be a tensor.')
                 if output.ndim != 0:
                     raise RuntimeError('`output` must be a scalar tensor.')
                 ctx.save_for_backward(*flat_diff_params, *tensors)
diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py
index 7ac12bb4..eeddabeb 100644
--- a/torchopt/diff/zero_order/nn/module.py
+++ b/torchopt/diff/zero_order/nn/module.py
@@ -20,14 +20,17 @@
 
 import abc
 import functools
-from typing import Any, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 import torch
 import torch.nn as nn
 
 from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order
 from torchopt.nn.stateless import reparametrize
-from torchopt.typing import Numeric, TupleOfTensors
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import Numeric, TupleOfTensors
 
 
 __all__ = ['ZeroOrderGradientModule']
diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py
index 117af9ab..97be682f 100644
--- a/torchopt/distributed/api.py
+++ b/torchopt/distributed/api.py
@@ -42,15 +42,15 @@
 
 __all__ = [
     'TensorDimensionPartitioner',
-    'dim_partitioner',
     'batch_partitioner',
+    'dim_partitioner',
     'mean_reducer',
-    'sum_reducer',
-    'remote_async_call',
-    'remote_sync_call',
     'parallelize',
     'parallelize_async',
     'parallelize_sync',
+    'remote_async_call',
+    'remote_sync_call',
+    'sum_reducer',
 ]
 
 
@@ -107,7 +107,7 @@ def __init__(
         self.workers = workers
 
     # pylint: disable-next=too-many-branches,too-many-locals
-    def __call__(
+    def __call__(  # noqa: C901
         self,
         *args: Any,
         **kwargs: Any,
@@ -310,7 +310,7 @@ def remote_async_call(
     elif callable(partitioner):
         partitions = partitioner(*args, **kwargs)  # type: ignore[assignment]
     else:
-        raise ValueError(f'Invalid partitioner: {partitioner!r}.')
+        raise TypeError(f'Invalid partitioner: {partitioner!r}.')
 
     futures = []
     for rank, worker_args, worker_kwargs in partitions:
diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py
index f7da4f46..71afdb86 100644
--- a/torchopt/distributed/autograd.py
+++ b/torchopt/distributed/autograd.py
@@ -17,15 +17,18 @@
 from __future__ import annotations
 
 from threading import Lock
+from typing import TYPE_CHECKING
 
 import torch
 import torch.distributed.autograd as autograd
 from torch.distributed.autograd import context
 
-from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors
 
+if TYPE_CHECKING:
+    from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors
 
-__all__ = ['is_available', 'context']
+
+__all__ = ['context', 'is_available']
 
 
 LOCK = Lock()
@@ -121,7 +124,7 @@ def grad(
         for p in inputs:
             try:
                 grads.append(all_local_grads[p])
-            except KeyError as ex:
+            except KeyError as ex:  # noqa: PERF203
                 if not allow_unused:
                     raise RuntimeError(
                         'One of the differentiated Tensors appears to not have been used in the '
@@ -131,4 +134,4 @@ def grad(
 
         return tuple(grads)
 
-    __all__ += ['DistAutogradContext', 'get_gradients', 'backward', 'grad']
+    __all__ += ['DistAutogradContext', 'backward', 'get_gradients', 'grad']
diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py
index a61280c5..610e52a0 100644
--- a/torchopt/distributed/world.py
+++ b/torchopt/distributed/world.py
@@ -26,19 +26,19 @@
 
 
 __all__ = [
-    'get_world_info',
-    'get_world_rank',
-    'get_rank',
-    'get_world_size',
+    'auto_init_rpc',
+    'barrier',
     'get_local_rank',
     'get_local_world_size',
+    'get_rank',
     'get_worker_id',
-    'barrier',
-    'auto_init_rpc',
-    'on_rank',
+    'get_world_info',
+    'get_world_rank',
+    'get_world_size',
     'not_on_rank',
-    'rank_zero_only',
+    'on_rank',
     'rank_non_zero_only',
+    'rank_zero_only',
 ]
 
 
diff --git a/torchopt/hook.py b/torchopt/hook.py
index b51e29eb..c11b92f6 100644
--- a/torchopt/hook.py
+++ b/torchopt/hook.py
@@ -16,16 +16,19 @@
 
 from __future__ import annotations
 
-from typing import Callable
-
-import torch
+from typing import TYPE_CHECKING, Callable
 
 from torchopt import pytree
 from torchopt.base import EmptyState, GradientTransformation
-from torchopt.typing import OptState, Params, Updates
 
 
-__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'register_hook']
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import OptState, Params, Updates
+
+
+__all__ = ['nan_to_num_hook', 'register_hook', 'zero_nan_hook']
 
 
 def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py
index a82ff877..1096a5af 100644
--- a/torchopt/linalg/cg.py
+++ b/torchopt/linalg/cg.py
@@ -36,14 +36,17 @@
 from __future__ import annotations
 
 from functools import partial
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
 
 import torch
 
 from torchopt import pytree
 from torchopt.linalg.utils import cat_shapes, normalize_matvec
 from torchopt.pytree import tree_vdot_real
-from torchopt.typing import TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import TensorTree
 
 
 __all__ = ['cg']
diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py
index b049a5ad..5fc8d478 100644
--- a/torchopt/linalg/ns.py
+++ b/torchopt/linalg/ns.py
@@ -19,13 +19,16 @@
 from __future__ import annotations
 
 import functools
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
 
 import torch
 
 from torchopt import pytree
 from torchopt.linalg.utils import normalize_matvec
-from torchopt.typing import TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import TensorTree
 
 
 __all__ = ['ns', 'ns_inv']
diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py
index a5ac765d..bbcc80aa 100644
--- a/torchopt/linalg/utils.py
+++ b/torchopt/linalg/utils.py
@@ -17,12 +17,15 @@
 from __future__ import annotations
 
 import itertools
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
 
 import torch
 
 from torchopt import pytree
-from torchopt.typing import TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import TensorTree
 
 
 def cat_shapes(tree: TensorTree) -> tuple[int, ...]:
diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py
index 2d61eb6d..43ca1da0 100644
--- a/torchopt/linear_solve/__init__.py
+++ b/torchopt/linear_solve/__init__.py
@@ -36,4 +36,4 @@
 from torchopt.linear_solve.normal_cg import solve_normal_cg
 
 
-__all__ = ['solve_cg', 'solve_normal_cg', 'solve_inv']
+__all__ = ['solve_cg', 'solve_inv', 'solve_normal_cg']
diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py
index f4127639..23814cc2 100644
--- a/torchopt/linear_solve/cg.py
+++ b/torchopt/linear_solve/cg.py
@@ -36,11 +36,14 @@
 from __future__ import annotations
 
 import functools
-from typing import Any, Callable
+from typing import TYPE_CHECKING, Any, Callable
 
 from torchopt import linalg
 from torchopt.linear_solve.utils import make_ridge_matvec
-from torchopt.typing import LinearSolver, TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import LinearSolver, TensorTree
 
 
 __all__ = ['solve_cg']
diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py
index f37be8c5..4dbe1542 100644
--- a/torchopt/linear_solve/inv.py
+++ b/torchopt/linear_solve/inv.py
@@ -36,13 +36,16 @@
 from __future__ import annotations
 
 import functools
-from typing import Any, Callable
+from typing import TYPE_CHECKING, Any, Callable
 
 import torch
 
 from torchopt import linalg, pytree
 from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec
-from torchopt.typing import LinearSolver, TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import LinearSolver, TensorTree
 
 
 __all__ = ['solve_inv']
diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py
index 405ab43c..a5af49b2 100644
--- a/torchopt/linear_solve/normal_cg.py
+++ b/torchopt/linear_solve/normal_cg.py
@@ -36,11 +36,14 @@
 from __future__ import annotations
 
 import functools
-from typing import Any, Callable
+from typing import TYPE_CHECKING, Any, Callable
 
 from torchopt import linalg
 from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec
-from torchopt.typing import LinearSolver, TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import LinearSolver, TensorTree
 
 
 __all__ = ['solve_normal_cg']
diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py
index 5e4bf7bd..9d1b8779 100644
--- a/torchopt/linear_solve/utils.py
+++ b/torchopt/linear_solve/utils.py
@@ -33,12 +33,15 @@
 
 from __future__ import annotations
 
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
 
 import functorch
 
 from torchopt import pytree
-from torchopt.typing import TensorTree
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import TensorTree
 
 
 def make_rmatvec(
diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py
index 7665f201..b55e49d7 100644
--- a/torchopt/nn/__init__.py
+++ b/torchopt/nn/__init__.py
@@ -21,10 +21,10 @@
 
 
 __all__ = [
-    'MetaGradientModule',
     'ImplicitMetaGradientModule',
+    'MetaGradientModule',
     'ZeroOrderGradientModule',
-    'reparametrize',
     'reparameterize',
+    'reparametrize',
     'swap_state',
 ]
diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py
index 419afb6a..8c40f58a 100644
--- a/torchopt/nn/module.py
+++ b/torchopt/nn/module.py
@@ -17,14 +17,17 @@
 from __future__ import annotations
 
 from collections import OrderedDict
-from typing import Any, Iterator, NamedTuple
+from typing import TYPE_CHECKING, Any, Iterator, NamedTuple
 from typing_extensions import Self  # Python 3.11+
 
 import torch
 import torch.nn as nn
 
 from torchopt import pytree
-from torchopt.typing import TensorContainer
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import TensorContainer
 
 
 class MetaInputsContainer(NamedTuple):
@@ -61,7 +64,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:  # pylint: disable=unused
         """Initialize a new module instance."""
         super().__init__()
 
-    def __getattr__(self, name: str) -> torch.Tensor | nn.Module:
+    def __getattr__(self, name: str) -> torch.Tensor | nn.Module:  # noqa: C901
         """Get an attribute of the module."""
         if '_parameters' in self.__dict__:
             _parameters = self.__dict__['_parameters']
@@ -86,7 +89,7 @@ def __getattr__(self, name: str) -> torch.Tensor | nn.Module:
         raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
 
     # pylint: disable-next=too-many-branches,too-many-statements
-    def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None:
+    def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None:  # noqa: C901
         """Set an attribute of the module."""
 
         def remove_from(*dicts_or_sets: dict[str, Any] | set[str]) -> None:
diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py
index d3437d0d..c7f92b86 100644
--- a/torchopt/nn/stateless.py
+++ b/torchopt/nn/stateless.py
@@ -17,13 +17,15 @@
 from __future__ import annotations
 
 import contextlib
-from typing import Generator, Iterable
+from typing import TYPE_CHECKING, Generator, Iterable
 
-import torch
-import torch.nn as nn
 
+if TYPE_CHECKING:
+    import torch
+    import torch.nn as nn
 
-__all__ = ['swap_state', 'reparametrize', 'reparameterize']
+
+__all__ = ['reparameterize', 'reparametrize', 'swap_state']
 
 
 MISSING: torch.Tensor = object()  # type: ignore[assignment]
diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py
index a64e00e4..600b69c5 100644
--- a/torchopt/optim/adadelta.py
+++ b/torchopt/optim/adadelta.py
@@ -16,13 +16,16 @@
 
 from __future__ import annotations
 
-from typing import Iterable
-
-import torch
+from typing import TYPE_CHECKING, Iterable
 
 from torchopt import alias
 from torchopt.optim.base import Optimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['AdaDelta', 'Adadelta']
diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py
index 277b7105..06091281 100644
--- a/torchopt/optim/adagrad.py
+++ b/torchopt/optim/adagrad.py
@@ -16,13 +16,16 @@
 
 from __future__ import annotations
 
-from typing import Iterable
-
-import torch
+from typing import TYPE_CHECKING, Iterable
 
 from torchopt import alias
 from torchopt.optim.base import Optimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['AdaGrad', 'Adagrad']
diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py
index 6ff68a69..555af22e 100644
--- a/torchopt/optim/adam.py
+++ b/torchopt/optim/adam.py
@@ -16,13 +16,16 @@
 
 from __future__ import annotations
 
-from typing import Iterable
-
-import torch
+from typing import TYPE_CHECKING, Iterable
 
 from torchopt import alias
 from torchopt.optim.base import Optimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['Adam']
diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py
index f693723c..e4996e85 100644
--- a/torchopt/optim/adamax.py
+++ b/torchopt/optim/adamax.py
@@ -16,13 +16,16 @@
 
 from __future__ import annotations
 
-from typing import Iterable
-
-import torch
+from typing import TYPE_CHECKING, Iterable
 
 from torchopt import alias
 from torchopt.optim.base import Optimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['AdaMax', 'Adamax']
diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py
index 463f245f..a60061ea 100644
--- a/torchopt/optim/adamw.py
+++ b/torchopt/optim/adamw.py
@@ -16,13 +16,16 @@
 
 from __future__ import annotations
 
-from typing import Callable, Iterable
-
-import torch
+from typing import TYPE_CHECKING, Callable, Iterable
 
 from torchopt import alias
 from torchopt.optim.base import Optimizer
-from torchopt.typing import OptState, Params, ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import OptState, Params, ScalarOrSchedule
 
 
 __all__ = ['AdamW']
diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py
index 7bb27877..fa287f04 100644
--- a/torchopt/optim/func/base.py
+++ b/torchopt/optim/func/base.py
@@ -16,13 +16,18 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 import torch
 
 from torchopt.base import GradientTransformation, UninitializedState
-from torchopt.typing import OptState, Params
 from torchopt.update import apply_updates
 
 
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params
+
+
 __all__ = ['FuncOptimizer']
 
 
diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py
index 49bdf23c..eb386ae3 100644
--- a/torchopt/optim/meta/adadelta.py
+++ b/torchopt/optim/meta/adadelta.py
@@ -16,11 +16,16 @@
 
 from __future__ import annotations
 
-import torch.nn as nn
+from typing import TYPE_CHECKING
 
 from torchopt import alias
 from torchopt.optim.meta.base import MetaOptimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch.nn as nn
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['MetaAdaDelta', 'MetaAdadelta']
diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py
index 58d913aa..129c1338 100644
--- a/torchopt/optim/meta/adagrad.py
+++ b/torchopt/optim/meta/adagrad.py
@@ -16,11 +16,16 @@
 
 from __future__ import annotations
 
-import torch.nn as nn
+from typing import TYPE_CHECKING
 
 from torchopt import alias
 from torchopt.optim.meta.base import MetaOptimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch.nn as nn
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['MetaAdaGrad', 'MetaAdagrad']
diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py
index bac71790..7a78ea7f 100644
--- a/torchopt/optim/meta/adam.py
+++ b/torchopt/optim/meta/adam.py
@@ -16,11 +16,16 @@
 
 from __future__ import annotations
 
-import torch.nn as nn
+from typing import TYPE_CHECKING
 
 from torchopt import alias
 from torchopt.optim.meta.base import MetaOptimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch.nn as nn
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['MetaAdam']
diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py
index 568a46f7..d6b40427 100644
--- a/torchopt/optim/meta/adamax.py
+++ b/torchopt/optim/meta/adamax.py
@@ -16,11 +16,16 @@
 
 from __future__ import annotations
 
-import torch.nn as nn
+from typing import TYPE_CHECKING
 
 from torchopt import alias
 from torchopt.optim.meta.base import MetaOptimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch.nn as nn
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['MetaAdaMax', 'MetaAdamax']
diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py
index 05387b77..62864582 100644
--- a/torchopt/optim/meta/adamw.py
+++ b/torchopt/optim/meta/adamw.py
@@ -16,13 +16,16 @@
 
 from __future__ import annotations
 
-from typing import Callable
-
-import torch.nn as nn
+from typing import TYPE_CHECKING, Callable
 
 from torchopt import alias
 from torchopt.optim.meta.base import MetaOptimizer
-from torchopt.typing import OptState, Params, ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch.nn as nn
+
+    from torchopt.typing import OptState, Params, ScalarOrSchedule
 
 
 __all__ = ['MetaAdamW']
diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py
index a32670d0..bb07b5ba 100644
--- a/torchopt/optim/meta/radam.py
+++ b/torchopt/optim/meta/radam.py
@@ -16,11 +16,16 @@
 
 from __future__ import annotations
 
-import torch.nn as nn
+from typing import TYPE_CHECKING
 
 from torchopt import alias
 from torchopt.optim.meta.base import MetaOptimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch.nn as nn
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['MetaRAdam']
diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py
index bba8c0d4..20e9dd22 100644
--- a/torchopt/optim/radam.py
+++ b/torchopt/optim/radam.py
@@ -16,13 +16,16 @@
 
 from __future__ import annotations
 
-from typing import Iterable
-
-import torch
+from typing import TYPE_CHECKING, Iterable
 
 from torchopt import alias
 from torchopt.optim.base import Optimizer
-from torchopt.typing import ScalarOrSchedule
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import ScalarOrSchedule
 
 
 __all__ = ['RAdam']
diff --git a/torchopt/pytree.py b/torchopt/pytree.py
index 6adea0e8..53abc2d2 100644
--- a/torchopt/pytree.py
+++ b/torchopt/pytree.py
@@ -18,7 +18,7 @@
 
 import functools
 import operator
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
 
 import optree
 import optree.typing as typing  # pylint: disable=unused-import
@@ -26,7 +26,9 @@
 import torch.distributed.rpc as rpc
 from optree import *  # pylint: disable=wildcard-import,unused-wildcard-import
 
-from torchopt.typing import Future, RRef, Scalar, T, TensorTree
+
+if TYPE_CHECKING:
+    from torchopt.typing import Future, RRef, Scalar, T, TensorTree
 
 
 __all__ = [
diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py
index 8e5545a4..d3d3eff5 100644
--- a/torchopt/schedule/__init__.py
+++ b/torchopt/schedule/__init__.py
@@ -35,4 +35,4 @@
 from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule
 
 
-__all__ = ['exponential_decay', 'polynomial_schedule', 'linear_schedule']
+__all__ = ['exponential_decay', 'linear_schedule', 'polynomial_schedule']
diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py
index 0925e164..c19c54b9 100644
--- a/torchopt/schedule/exponential_decay.py
+++ b/torchopt/schedule/exponential_decay.py
@@ -31,11 +31,15 @@
 # ==============================================================================
 """Exponential learning rate decay."""
 
+from __future__ import annotations
+
 import logging
 import math
-from typing import Optional
+from typing import TYPE_CHECKING
+
 
-from torchopt.typing import Numeric, Scalar, Schedule
+if TYPE_CHECKING:
+    from torchopt.typing import Numeric, Scalar, Schedule
 
 
 __all__ = ['exponential_decay']
@@ -48,7 +52,7 @@ def exponential_decay(
     transition_begin: int = 0,
     transition_steps: int = 1,
     staircase: bool = False,
-    end_value: Optional[float] = None,
+    end_value: float | None = None,
 ) -> Schedule:
     """Construct a schedule with either continuous or discrete exponential decay.
 
diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py
index 2482f769..d2a5160c 100644
--- a/torchopt/schedule/polynomial.py
+++ b/torchopt/schedule/polynomial.py
@@ -31,15 +31,20 @@
 # ==============================================================================
 """Polynomial learning rate schedules."""
 
+from __future__ import annotations
+
 import logging
+from typing import TYPE_CHECKING
 
 import numpy as np
 import torch
 
-from torchopt.typing import Numeric, Scalar, Schedule
+
+if TYPE_CHECKING:
+    from torchopt.typing import Numeric, Scalar, Schedule
 
 
-__all__ = ['polynomial_schedule', 'linear_schedule']
+__all__ = ['linear_schedule', 'polynomial_schedule']
 
 
 def polynomial_schedule(
diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py
index adef5596..fa59a43b 100644
--- a/torchopt/transform/__init__.py
+++ b/torchopt/transform/__init__.py
@@ -46,18 +46,18 @@
 
 
 __all__ = [
-    'trace',
-    'scale',
-    'scale_by_schedule',
     'add_decayed_weights',
     'masked',
+    'nan_to_num',
+    'scale',
+    'scale_by_accelerated_adam',
+    'scale_by_adadelta',
     'scale_by_adam',
     'scale_by_adamax',
-    'scale_by_adadelta',
     'scale_by_radam',
-    'scale_by_accelerated_adam',
-    'scale_by_rss',
     'scale_by_rms',
+    'scale_by_rss',
+    'scale_by_schedule',
     'scale_by_stddev',
-    'nan_to_num',
+    'trace',
 ]
diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py
index 950682cf..0cb67837 100644
--- a/torchopt/transform/add_decayed_weights.py
+++ b/torchopt/transform/add_decayed_weights.py
@@ -34,17 +34,20 @@
 
 from __future__ import annotations
 
-from typing import Any, Callable, NamedTuple
-
-import torch
+from typing import TYPE_CHECKING, Any, Callable, NamedTuple
 
 from torchopt import pytree
 from torchopt.base import EmptyState, GradientTransformation, identity
 from torchopt.transform.utils import tree_map_flat, tree_map_flat_
-from torchopt.typing import OptState, Params, Updates
 
 
-__all__ = ['masked', 'add_decayed_weights']
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import OptState, Params, Updates
+
+
+__all__ = ['add_decayed_weights', 'masked']
 
 
 class MaskedState(NamedTuple):
@@ -189,7 +192,7 @@ def _add_decayed_weights_flat(
     )
 
 
-def _add_decayed_weights(
+def _add_decayed_weights(  # noqa: C901
     weight_decay: float = 0.0,
     mask: OptState | Callable[[Params], OptState] | None = None,
     *,
diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py
index d3530853..740df1b0 100644
--- a/torchopt/transform/nan_to_num.py
+++ b/torchopt/transform/nan_to_num.py
@@ -16,11 +16,16 @@
 
 from __future__ import annotations
 
-import torch
+from typing import TYPE_CHECKING
 
 from torchopt import pytree
 from torchopt.base import EmptyState, GradientTransformation
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import OptState, Params, Updates
 
 
 def nan_to_num(
diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py
index 493b7196..2b492bdf 100644
--- a/torchopt/transform/scale.py
+++ b/torchopt/transform/scale.py
@@ -33,12 +33,17 @@
 
 from __future__ import annotations
 
-import torch
+from typing import TYPE_CHECKING
 
 from torchopt import pytree
 from torchopt.base import EmptyState, GradientTransformation
 from torchopt.transform.utils import tree_map_flat, tree_map_flat_
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['scale']
diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py
index f389d293..6d05e5dd 100644
--- a/torchopt/transform/scale_by_adadelta.py
+++ b/torchopt/transform/scale_by_adadelta.py
@@ -19,14 +19,17 @@
 
 from __future__ import annotations
 
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation
 from torchopt.transform.utils import tree_map_flat, update_moment
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['scale_by_adadelta']
diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py
index b08c6a14..d45d1eb2 100644
--- a/torchopt/transform/scale_by_adam.py
+++ b/torchopt/transform/scale_by_adam.py
@@ -35,7 +35,7 @@
 
 from __future__ import annotations
 
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
@@ -43,10 +43,13 @@
 from torchopt.accelerated_op import AdamOp
 from torchopt.base import GradientTransformation
 from torchopt.transform.utils import inc_count, tree_map_flat, update_moment
-from torchopt.typing import OptState, Params, Updates
 
 
-__all__ = ['scale_by_adam', 'scale_by_accelerated_adam']
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
+
+
+__all__ = ['scale_by_accelerated_adam', 'scale_by_adam']
 
 
 TRIPLE_PYTREE_SPEC = pytree.tree_structure((0, 1, 2), none_is_leaf=True)  # type: ignore[arg-type]
@@ -277,7 +280,7 @@ def _scale_by_accelerated_adam_flat(
 
 
 # pylint: disable-next=too-many-arguments
-def _scale_by_accelerated_adam(
+def _scale_by_accelerated_adam(  # noqa: C901
     b1: float = 0.9,
     b2: float = 0.999,
     eps: float = 1e-8,
diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py
index f11ed311..cfacbf35 100644
--- a/torchopt/transform/scale_by_adamax.py
+++ b/torchopt/transform/scale_by_adamax.py
@@ -19,14 +19,17 @@
 
 from __future__ import annotations
 
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation
 from torchopt.transform.utils import tree_map_flat, update_moment
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['scale_by_adamax']
diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py
index fad32b13..95f26149 100644
--- a/torchopt/transform/scale_by_radam.py
+++ b/torchopt/transform/scale_by_radam.py
@@ -20,14 +20,17 @@
 from __future__ import annotations
 
 import math
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation
 from torchopt.transform.utils import tree_map_flat, update_moment
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['scale_by_radam']
@@ -89,7 +92,7 @@ def _scale_by_radam_flat(
     )
 
 
-def _scale_by_radam(
+def _scale_by_radam(  # noqa: C901
     b1: float = 0.9,
     b2: float = 0.999,
     eps: float = 1e-6,
diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py
index 4ee67ed0..f2141388 100644
--- a/torchopt/transform/scale_by_rms.py
+++ b/torchopt/transform/scale_by_rms.py
@@ -33,14 +33,17 @@
 
 from __future__ import annotations
 
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation
 from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['scale_by_rms']
diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py
index 9bc97206..642b2e5c 100644
--- a/torchopt/transform/scale_by_rss.py
+++ b/torchopt/transform/scale_by_rss.py
@@ -33,14 +33,17 @@
 
 from __future__ import annotations
 
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation
 from torchopt.transform.utils import tree_map_flat, update_moment
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['scale_by_rss']
diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py
index 48f3f271..499e2adb 100644
--- a/torchopt/transform/scale_by_schedule.py
+++ b/torchopt/transform/scale_by_schedule.py
@@ -33,14 +33,17 @@
 
 from __future__ import annotations
 
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation
 from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_
-from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates
 
 
 __all__ = ['scale_by_schedule']
diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py
index 6b99f31a..5a3e6655 100644
--- a/torchopt/transform/scale_by_stddev.py
+++ b/torchopt/transform/scale_by_stddev.py
@@ -35,14 +35,17 @@
 
 from __future__ import annotations
 
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation
 from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['scale_by_stddev']
diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py
index 9bf37e2f..219cbbec 100644
--- a/torchopt/transform/trace.py
+++ b/torchopt/transform/trace.py
@@ -35,14 +35,17 @@
 
 from __future__ import annotations
 
-from typing import NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
 
 import torch
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation, identity
 from torchopt.transform.utils import tree_map_flat, tree_map_flat_
-from torchopt.typing import OptState, Params, Updates
+
+
+if TYPE_CHECKING:
+    from torchopt.typing import OptState, Params, Updates
 
 
 __all__ = ['trace']
@@ -101,7 +104,7 @@ def _trace_flat(
     )
 
 
-def _trace(
+def _trace(  # noqa: C901
     momentum: float = 0.9,
     dampening: float = 0.0,
     nesterov: bool = False,
@@ -136,7 +139,7 @@ def init_fn(params: Params) -> OptState:
 
     first_call = True
 
-    def update_fn(
+    def update_fn(  # noqa: C901
         updates: Updates,
         state: OptState,
         *,
diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py
index ec4e51c1..9b38d561 100644
--- a/torchopt/transform/utils.py
+++ b/torchopt/transform/utils.py
@@ -34,15 +34,18 @@
 from __future__ import annotations
 
 from collections import deque
-from typing import Any, Callable, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Sequence
 
 import torch
 
 from torchopt import pytree
-from torchopt.typing import TensorTree, Updates
 
 
-__all__ = ['tree_map_flat', 'tree_map_flat_', 'inc_count', 'update_moment']
+if TYPE_CHECKING:
+    from torchopt.typing import TensorTree, Updates
+
+
+__all__ = ['inc_count', 'tree_map_flat', 'tree_map_flat_', 'update_moment']
 
 
 INT64_MAX = torch.iinfo(torch.int64).max
@@ -161,7 +164,7 @@ def _update_moment_flat(
 
 
 # pylint: disable-next=too-many-arguments
-def _update_moment(
+def _update_moment(  # noqa: C901
     updates: Updates,
     moments: TensorTree,
     decay: float,
diff --git a/torchopt/typing.py b/torchopt/typing.py
index 60d11e0e..fcd888fb 100644
--- a/torchopt/typing.py
+++ b/torchopt/typing.py
@@ -14,6 +14,8 @@
 # ==============================================================================
 """Typing utilities."""
 
+from __future__ import annotations
+
 import abc
 from typing import (
     Callable,
@@ -45,39 +47,39 @@
 
 
 __all__ = [
-    'GradientTransformation',
     'ChainedGradientTransformation',
+    'Device',
+    'Distribution',
     'EmptyState',
-    'UninitializedState',
-    'Params',
-    'Updates',
+    'Future',
+    'GradientTransformation',
+    'LinearSolver',
+    'ListOfOptionalTensors',
+    'ListOfTensors',
+    'ModuleTensorContainers',
+    'Numeric',
     'OptState',
+    'OptionalTensor',
+    'OptionalTensorOrOptionalTensors',
+    'OptionalTensorTree',
+    'Params',
+    'PyTree',
+    'Samplable',
+    'SampleFunc',
     'Scalar',
-    'Numeric',
-    'Schedule',
     'ScalarOrSchedule',
-    'PyTree',
-    'Tensor',
-    'OptionalTensor',
-    'ListOfTensors',
-    'TupleOfTensors',
+    'Schedule',
+    'SequenceOfOptionalTensors',
     'SequenceOfTensors',
+    'Size',
+    'Tensor',
+    'TensorContainer',
     'TensorOrTensors',
     'TensorTree',
-    'ListOfOptionalTensors',
     'TupleOfOptionalTensors',
-    'SequenceOfOptionalTensors',
-    'OptionalTensorOrOptionalTensors',
-    'OptionalTensorTree',
-    'TensorContainer',
-    'ModuleTensorContainers',
-    'Future',
-    'LinearSolver',
-    'Device',
-    'Size',
-    'Distribution',
-    'SampleFunc',
-    'Samplable',
+    'TupleOfTensors',
+    'UninitializedState',
+    'Updates',
 ]
 
 T = TypeVar('T')
@@ -138,7 +140,7 @@ class Samplable(Protocol):  # pylint: disable=too-few-public-methods
     def sample(
         self,
         sample_shape: Size = Size(),  # noqa: B008 # pylint: disable=unused-argument
-    ) -> Union[Tensor, Sequence[Numeric]]:
+    ) -> Tensor | Sequence[Numeric]:
         # pylint: disable-next=line-too-long
         """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
         raise NotImplementedError
diff --git a/torchopt/update.py b/torchopt/update.py
index 8636d7a4..3f2d71fe 100644
--- a/torchopt/update.py
+++ b/torchopt/update.py
@@ -33,10 +33,15 @@
 
 from __future__ import annotations
 
-import torch
+from typing import TYPE_CHECKING
 
 from torchopt import pytree
-from torchopt.typing import Params, Updates
+
+
+if TYPE_CHECKING:
+    import torch
+
+    from torchopt.typing import Params, Updates
 
 
 __all__ = ['apply_updates']
diff --git a/torchopt/utils.py b/torchopt/utils.py
index c067d570..5f9202a3 100644
--- a/torchopt/utils.py
+++ b/torchopt/utils.py
@@ -34,11 +34,11 @@
 
 __all__ = [
     'ModuleState',
-    'stop_gradient',
     'extract_state_dict',
-    'recover_state_dict',
     'module_clone',
     'module_detach_',
+    'recover_state_dict',
+    'stop_gradient',
 ]
 
 
@@ -115,7 +115,7 @@ def extract_state_dict(
 
 
 # pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals
-def extract_state_dict(
+def extract_state_dict(  # noqa: C901
     target: nn.Module | MetaOptimizer,
     *,
     by: CopyMode = 'reference',
@@ -272,7 +272,7 @@ def get_variable(t: torch.Tensor | None) -> torch.Tensor | None:
 
         return pytree.tree_map(get_variable, state)  # type: ignore[arg-type,return-value]
 
-    raise RuntimeError(f'Unexpected class of {target}')
+    raise TypeError(f'Unexpected class of {target}')
 
 
 def extract_module_containers(
@@ -346,7 +346,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
         state = cast(Sequence[OptState], state)
         target.load_state_dict(state)
     else:
-        raise RuntimeError(f'Unexpected class of {target}')
+        raise TypeError(f'Unexpected class of {target}')
 
 
 @overload
@@ -383,7 +383,7 @@ def module_clone(
 
 
 # pylint: disable-next=too-many-locals
-def module_clone(
+def module_clone(  # noqa: C901
     target: nn.Module | MetaOptimizer | TensorTree,
     *,
     by: CopyMode = 'reference',
diff --git a/torchopt/visual.py b/torchopt/visual.py
index d7885889..7638d7ec 100644
--- a/torchopt/visual.py
+++ b/torchopt/visual.py
@@ -19,16 +19,19 @@
 
 from __future__ import annotations
 
-from typing import Any, Generator, Iterable, Mapping, cast
+from typing import TYPE_CHECKING, Any, Generator, Iterable, Mapping, cast
 
 import torch
 from graphviz import Digraph
 
 from torchopt import pytree
-from torchopt.typing import TensorTree
 from torchopt.utils import ModuleState
 
 
+if TYPE_CHECKING:
+    from torchopt.typing import TensorTree
+
+
 __all__ = ['make_dot', 'resize_graph']
 
 
@@ -69,7 +72,7 @@ def truncate(s: str) -> str:  # pylint: disable=invalid-name
 
 
 # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
-def make_dot(
+def make_dot(  # noqa: C901
     var: TensorTree,
     params: (
         Mapping[str, torch.Tensor]
@@ -153,7 +156,7 @@ def get_var_name_with_flag(var: torch.Tensor) -> str | None:
             return f'{param_map[var][0]}\n{size_to_str(param_map[var][1].size())}'
         return None
 
-    def add_nodes(fn: Any) -> None:  # pylint: disable=too-many-branches
+    def add_nodes(fn: Any) -> None:  # noqa: C901 # pylint: disable=too-many-branches
         assert not isinstance(fn, torch.Tensor)
         if fn in seen:
             return