Skip to content

Commit

Permalink
remove static constant from TorchForwardSimulator class
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed May 7, 2024
1 parent f5383b9 commit ac2e8e7
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pygsti/forwardsims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .forwardsim import ForwardSimulator
from .mapforwardsim import SimpleMapForwardSimulator, MapForwardSimulator
from .torchfwdsim import TorchForwardSimulator
from .torchfwdsim import TorchForwardSimulator, TORCH_ENABLED
from .matrixforwardsim import SimpleMatrixForwardSimulator, MatrixForwardSimulator
from .termforwardsim import TermForwardSimulator
from .weakforwardsim import WeakForwardSimulator
4 changes: 1 addition & 3 deletions pygsti/forwardsims/torchfwdsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,11 @@ def jac_friendly_circuit_probs(self, *free_params: Tuple[torch.Tensor]):

class TorchForwardSimulator(ForwardSimulator):

ENABLED = TORCH_ENABLED

"""
A forward simulator that leverages automatic differentiation in PyTorch.
"""
def __init__(self, model : Optional[ExplicitOpModel] = None):
if not TorchForwardSimulator.ENABLED:
if not TORCH_ENABLED:
raise RuntimeError('PyTorch could not be imported.')
self.model = model
super(ForwardSimulator, self).__init__(model)
Expand Down
4 changes: 2 additions & 2 deletions test/unit/objects/test_forwardsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pygsti.forwardsims import ForwardSimulator, \
MapForwardSimulator, SimpleMapForwardSimulator, \
MatrixForwardSimulator, SimpleMatrixForwardSimulator, \
TorchForwardSimulator
TorchForwardSimulator, TORCH_ENABLED
from pygsti.models import ExplicitOpModel
from pygsti.circuits import Circuit
from pygsti.baseobjs import Label as L
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_simple_matrix_fwdsim(self):
def test_simple_map_fwdsim(self):
self._run(SimpleMapForwardSimulator)

@pytest.mark.skipif(not TorchForwardSimulator.ENABLED, reason="PyTorch is not installed.")
@pytest.mark.skipif(not TORCH_ENABLED, reason="PyTorch is not installed.")
def test_torch_fwdsim(self):
self._run(TorchForwardSimulator)

Expand Down

0 comments on commit ac2e8e7

Please sign in to comment.