From f5383b92c0ce6fcfec1009a32ff31f47b970cfb4 Mon Sep 17 00:00:00 2001 From: Riley Murray Date: Tue, 7 May 2024 09:03:13 -0400 Subject: [PATCH] Create Torchable subclass of ModelMember --- pygsti/forwardsims/torchfwdsim.py | 2 + pygsti/modelmembers/modelmember.py | 43 --------------------- pygsti/modelmembers/operations/fulltpop.py | 19 +++++----- pygsti/modelmembers/povms/tppovm.py | 19 ++++------ pygsti/modelmembers/states/tpstate.py | 18 ++++----- pygsti/modelmembers/torchable.py | 44 ++++++++++++++++++++++ test/unit/objects/test_forwardsim.py | 1 - 7 files changed, 71 insertions(+), 75 deletions(-) create mode 100644 pygsti/modelmembers/torchable.py diff --git a/pygsti/forwardsims/torchfwdsim.py b/pygsti/forwardsims/torchfwdsim.py index b172df455..d328ae7a0 100644 --- a/pygsti/forwardsims/torchfwdsim.py +++ b/pygsti/forwardsims/torchfwdsim.py @@ -18,6 +18,7 @@ from pygsti.circuits.circuit import SeparatePOVMCircuit from pygsti.layouts.copalayout import CircuitOutcomeProbabilityArrayLayout +from pygsti.modelmembers.torchable import Torchable from collections import OrderedDict import warnings as warnings @@ -84,6 +85,7 @@ def __init__(self, model: ExplicitOpModel, layout): self.param_metadata = [] for lbl, obj in model._iter_parameterized_objs(): + assert isinstance(obj, Torchable) param_type = type(obj) param_data = (lbl, param_type) + (obj.stateless_data(),) self.param_metadata.append(param_data) diff --git a/pygsti/modelmembers/modelmember.py b/pygsti/modelmembers/modelmember.py index 5767d7983..27e36e692 100644 --- a/pygsti/modelmembers/modelmember.py +++ b/pygsti/modelmembers/modelmember.py @@ -1058,49 +1058,6 @@ def _print_gpindices(self, prefix="", member_label=None, param_labels=None, max_ def _oneline_contents(self): """ Summarizes the contents of this object in a single line. Does not summarize submembers. """ return "(contents not available)" - - def stateless_data(self): - """ - Return the data of this model that is considered considered constant for purposes - of model fitting. - - Note: the word "stateless" here is used in the sense of object-oriented programming. - """ - raise NotImplementedError() - - # TODO: verify that something like that following won't work for AD. - # def moretorch(self, vec): - # import torch - # oldvec = self.to_vector() - # self.from_vector(vec) - # numpyrep = self.base - # torchrep = torch.from_numpy(numpyrep) - # self.from_vector(oldvec) - # return torchrep - - @staticmethod - def torch_base(sd, vec, torch_handle=None): - """ - Suppose "obj" is an instance of some ModelMember subclass. If we compute - - sd = obj.stateless_data() - vec = obj.to_vector() - T = type(obj).torch_base(sd, vec, grad) - - then T will be a PyTorch Tensor that represents "obj" in a canonical numerical way. - - The meaning of "canonical" is implementation dependent. If type(obj) implements - the ``.base`` attribute, then a reasonable implementation will probably satisfy - - np.allclose(obj.base, T.numpy()). - - Optional args - ------------- - torch_handle can be None or it can be a reference to torch as a Python package - (analogous to the variable "np" after we do "import numpy as np"). If it's none - then we'll import torch as the first step of this function. - """ - raise NotImplementedError() def _compose_gpindices(parent_gpindices, child_gpindices): diff --git a/pygsti/modelmembers/operations/fulltpop.py b/pygsti/modelmembers/operations/fulltpop.py index 1c5910e50..72079249c 100644 --- a/pygsti/modelmembers/operations/fulltpop.py +++ b/pygsti/modelmembers/operations/fulltpop.py @@ -15,11 +15,12 @@ from pygsti.modelmembers.operations.denseop import DenseOperator as _DenseOperator from pygsti.modelmembers.operations.linearop import LinearOperator as _LinearOperator from pygsti.baseobjs.protectedarray import ProtectedArray as _ProtectedArray -from typing import Tuple, Optional, TypeVar -Tensor = TypeVar('Tensor') # torch.tensor. +from pygsti.modelmembers.torchable import Torchable as _Torchable +from typing import Tuple -class FullTPOp(_DenseOperator): + +class FullTPOp(_DenseOperator, _Torchable): """ A trace-preserving operation matrix. @@ -157,19 +158,17 @@ def from_vector(self, v, close=False, dirty_value=True): self._ptr_has_changed() # because _rep.base == _ptr (same memory) self.dirty = dirty_value - def stateless_data(self): + def stateless_data(self) -> Tuple[int]: return (self.dim,) @staticmethod - def torch_base(sd: Tuple[int], t_param: Tensor, torch_handle=None): - if torch_handle is None: - import torch as torch_handle - + def torch_base(sd: Tuple[int], t_param: _Torchable.Tensor) -> _Torchable.Tensor: + torch = _Torchable.torch_handle dim = sd[0] - t_const = torch_handle.zeros(size=(1, dim), dtype=torch_handle.double) + t_const = torch.zeros(size=(1, dim), dtype=torch.double) t_const[0,0] = 1.0 t_param_mat = t_param.reshape((dim - 1, dim)) - t = torch_handle.row_stack((t_const, t_param_mat)) + t = torch.row_stack((t_const, t_param_mat)) return t diff --git a/pygsti/modelmembers/povms/tppovm.py b/pygsti/modelmembers/povms/tppovm.py index eb76bd4b6..c5c34df43 100644 --- a/pygsti/modelmembers/povms/tppovm.py +++ b/pygsti/modelmembers/povms/tppovm.py @@ -11,15 +11,13 @@ #*************************************************************************************************** import numpy as _np +from pygsti.modelmembers.torchable import Torchable as _Torchable from pygsti.modelmembers.povms.basepovm import _BasePOVM -from pygsti.modelmembers.povms.effect import POVMEffect as _POVMEffect from pygsti.modelmembers.povms.fulleffect import FullPOVMEffect as _FullPOVMEffect -from pygsti.modelmembers.povms.conjugatedeffect import ConjugatedStatePOVMEffect as _ConjugatedStatePOVMEffect -from typing import Tuple, Optional, TypeVar -Tensor = TypeVar('Tensor') # torch.tensor. +from typing import Tuple -class TPPOVM(_BasePOVM): +class TPPOVM(_BasePOVM, _Torchable): """ A POVM whose sum-of-effects is constrained to what, by definition, we call the "identity". @@ -78,19 +76,18 @@ def to_vector(self): vec = _np.concatenate(effect_vecs) return vec - def stateless_data(self): + def stateless_data(self) -> Tuple[int, int]: dim1 = len(self) dim2 = self.dim return (dim1, dim2) @staticmethod - def torch_base(sd: Tuple[int, int], t_param: Tensor, torch_handle=None): - if torch_handle is None: - import torch as torch_handle + def torch_base(sd: Tuple[int, int], t_param: _Torchable.Tensor) -> _Torchable.Tensor: + torch = _Torchable.torch_handle num_effects, dim = sd - first_basis_vec = torch_handle.zeros(size=(1, dim), dtype=torch_handle.double) + first_basis_vec = torch.zeros(size=(1, dim), dtype=torch.double) first_basis_vec[0,0] = dim ** 0.25 t_param_mat = t_param.reshape((num_effects - 1, dim)) t_func = first_basis_vec - t_param_mat.sum(axis=0, keepdim=True) - t = torch_handle.row_stack((t_param_mat, t_func)) + t = torch.row_stack((t_param_mat, t_func)) return t diff --git a/pygsti/modelmembers/states/tpstate.py b/pygsti/modelmembers/states/tpstate.py index 000040913..a79a6c26f 100644 --- a/pygsti/modelmembers/states/tpstate.py +++ b/pygsti/modelmembers/states/tpstate.py @@ -15,14 +15,14 @@ from pygsti.baseobjs import Basis as _Basis from pygsti.baseobjs import statespace as _statespace +from pygsti.modelmembers.torchable import Torchable as _Torchable from pygsti.modelmembers.states.densestate import DenseState as _DenseState from pygsti.modelmembers.states.state import State as _State from pygsti.baseobjs.protectedarray import ProtectedArray as _ProtectedArray -from typing import Tuple, Optional, TypeVar -Tensor = TypeVar('Tensor') # torch.tensor. +from typing import Tuple -class TPState(_DenseState): +class TPState(_DenseState, _Torchable): """ A fixed-unit-trace state vector. @@ -160,17 +160,15 @@ def from_vector(self, v, close=False, dirty_value=True): self._ptr_has_changed() self.dirty = dirty_value - def stateless_data(self): + def stateless_data(self) -> Tuple[int]: return (self.dim,) @staticmethod - def torch_base(sd: Tuple[int], t_param: Tensor, torch_handle=None): - if torch_handle is None: - import torch as torch_handle - + def torch_base(sd: Tuple[int], t_param: _Torchable.Tensor) -> _Torchable.Tensor: + torch = _Torchable.torch_handle dim = sd[0] - t_const = (dim ** -0.25) * torch_handle.ones(1, dtype=torch_handle.double) - t = torch_handle.concat((t_const, t_param)) + t_const = (dim ** -0.25) * torch.ones(1, dtype=torch.double) + t = torch.concat((t_const, t_param)) return t def deriv_wrt_params(self, wrt_filter=None): diff --git a/pygsti/modelmembers/torchable.py b/pygsti/modelmembers/torchable.py new file mode 100644 index 000000000..07153dbc2 --- /dev/null +++ b/pygsti/modelmembers/torchable.py @@ -0,0 +1,44 @@ +from pygsti.modelmembers.modelmember import ModelMember +from typing import TypeVar, Tuple + +try: + import torch + torch_handle = torch + Tensor = torch.Tensor +except ImportError: + torch_handle = None + Tensor = TypeVar('Tensor') # we'll access this for type annotations elsewhere. + + +class Torchable(ModelMember): + + Tensor = Tensor + torch_handle = torch + + def stateless_data(self) -> Tuple: + """ + Return the data of this model that is considered considered constant for purposes + of model fitting. + + Note: the word "stateless" here is used in the sense of object-oriented programming. + """ + raise NotImplementedError() + + @staticmethod + def torch_base(sd : Tuple, t_param : Tensor) -> Tensor: + """ + Suppose "obj" is an instance of some ModelMember subclass. If we compute + + sd = obj.stateless_data() + vec = obj.to_vector() + t_param = torch.from_numpy(vec) + T = type(obj).torch_base(sd, t_param, grad) + + then T will be a PyTorch Tensor that represents "obj" in a canonical numerical way. + + The meaning of "canonical" is implementation dependent. If type(obj) implements + the ``.base`` attribute, then a reasonable implementation will probably satisfy + + np.allclose(obj.base, T.numpy()). + """ + raise NotImplementedError() diff --git a/test/unit/objects/test_forwardsim.py b/test/unit/objects/test_forwardsim.py index adc8fb06c..ea3d0ba87 100644 --- a/test/unit/objects/test_forwardsim.py +++ b/test/unit/objects/test_forwardsim.py @@ -5,7 +5,6 @@ import numpy as np import pytest -from pygsti.models import modelconstruction as _setc import pygsti.models as models from pygsti.forwardsims import ForwardSimulator, \ MapForwardSimulator, SimpleMapForwardSimulator, \