Skip to content

Commit

Permalink
Create Torchable subclass of ModelMember
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed May 7, 2024
1 parent a3ffa68 commit f5383b9
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 75 deletions.
2 changes: 2 additions & 0 deletions pygsti/forwardsims/torchfwdsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pygsti.circuits.circuit import SeparatePOVMCircuit
from pygsti.layouts.copalayout import CircuitOutcomeProbabilityArrayLayout

from pygsti.modelmembers.torchable import Torchable
from collections import OrderedDict
import warnings as warnings

Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(self, model: ExplicitOpModel, layout):

self.param_metadata = []
for lbl, obj in model._iter_parameterized_objs():
assert isinstance(obj, Torchable)
param_type = type(obj)
param_data = (lbl, param_type) + (obj.stateless_data(),)
self.param_metadata.append(param_data)
Expand Down
43 changes: 0 additions & 43 deletions pygsti/modelmembers/modelmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,49 +1058,6 @@ def _print_gpindices(self, prefix="", member_label=None, param_labels=None, max_
def _oneline_contents(self):
""" Summarizes the contents of this object in a single line. Does not summarize submembers. """
return "(contents not available)"

def stateless_data(self):
"""
Return the data of this model that is considered considered constant for purposes
of model fitting.
Note: the word "stateless" here is used in the sense of object-oriented programming.
"""
raise NotImplementedError()

# TODO: verify that something like that following won't work for AD.
# def moretorch(self, vec):
# import torch
# oldvec = self.to_vector()
# self.from_vector(vec)
# numpyrep = self.base
# torchrep = torch.from_numpy(numpyrep)
# self.from_vector(oldvec)
# return torchrep

@staticmethod
def torch_base(sd, vec, torch_handle=None):
"""
Suppose "obj" is an instance of some ModelMember subclass. If we compute
sd = obj.stateless_data()
vec = obj.to_vector()
T = type(obj).torch_base(sd, vec, grad)
then T will be a PyTorch Tensor that represents "obj" in a canonical numerical way.
The meaning of "canonical" is implementation dependent. If type(obj) implements
the ``.base`` attribute, then a reasonable implementation will probably satisfy
np.allclose(obj.base, T.numpy()).
Optional args
-------------
torch_handle can be None or it can be a reference to torch as a Python package
(analogous to the variable "np" after we do "import numpy as np"). If it's none
then we'll import torch as the first step of this function.
"""
raise NotImplementedError()


def _compose_gpindices(parent_gpindices, child_gpindices):
Expand Down
19 changes: 9 additions & 10 deletions pygsti/modelmembers/operations/fulltpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from pygsti.modelmembers.operations.denseop import DenseOperator as _DenseOperator
from pygsti.modelmembers.operations.linearop import LinearOperator as _LinearOperator
from pygsti.baseobjs.protectedarray import ProtectedArray as _ProtectedArray
from typing import Tuple, Optional, TypeVar
Tensor = TypeVar('Tensor') # torch.tensor.
from pygsti.modelmembers.torchable import Torchable as _Torchable
from typing import Tuple


class FullTPOp(_DenseOperator):

class FullTPOp(_DenseOperator, _Torchable):
"""
A trace-preserving operation matrix.
Expand Down Expand Up @@ -157,19 +158,17 @@ def from_vector(self, v, close=False, dirty_value=True):
self._ptr_has_changed() # because _rep.base == _ptr (same memory)
self.dirty = dirty_value

def stateless_data(self):
def stateless_data(self) -> Tuple[int]:
return (self.dim,)

@staticmethod
def torch_base(sd: Tuple[int], t_param: Tensor, torch_handle=None):
if torch_handle is None:
import torch as torch_handle

def torch_base(sd: Tuple[int], t_param: _Torchable.Tensor) -> _Torchable.Tensor:
torch = _Torchable.torch_handle
dim = sd[0]
t_const = torch_handle.zeros(size=(1, dim), dtype=torch_handle.double)
t_const = torch.zeros(size=(1, dim), dtype=torch.double)
t_const[0,0] = 1.0
t_param_mat = t_param.reshape((dim - 1, dim))
t = torch_handle.row_stack((t_const, t_param_mat))
t = torch.row_stack((t_const, t_param_mat))
return t


Expand Down
19 changes: 8 additions & 11 deletions pygsti/modelmembers/povms/tppovm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
#***************************************************************************************************

import numpy as _np
from pygsti.modelmembers.torchable import Torchable as _Torchable
from pygsti.modelmembers.povms.basepovm import _BasePOVM
from pygsti.modelmembers.povms.effect import POVMEffect as _POVMEffect
from pygsti.modelmembers.povms.fulleffect import FullPOVMEffect as _FullPOVMEffect
from pygsti.modelmembers.povms.conjugatedeffect import ConjugatedStatePOVMEffect as _ConjugatedStatePOVMEffect
from typing import Tuple, Optional, TypeVar
Tensor = TypeVar('Tensor') # torch.tensor.
from typing import Tuple


class TPPOVM(_BasePOVM):
class TPPOVM(_BasePOVM, _Torchable):
"""
A POVM whose sum-of-effects is constrained to what, by definition, we call the "identity".
Expand Down Expand Up @@ -78,19 +76,18 @@ def to_vector(self):
vec = _np.concatenate(effect_vecs)
return vec

def stateless_data(self):
def stateless_data(self) -> Tuple[int, int]:
dim1 = len(self)
dim2 = self.dim
return (dim1, dim2)

@staticmethod
def torch_base(sd: Tuple[int, int], t_param: Tensor, torch_handle=None):
if torch_handle is None:
import torch as torch_handle
def torch_base(sd: Tuple[int, int], t_param: _Torchable.Tensor) -> _Torchable.Tensor:
torch = _Torchable.torch_handle
num_effects, dim = sd
first_basis_vec = torch_handle.zeros(size=(1, dim), dtype=torch_handle.double)
first_basis_vec = torch.zeros(size=(1, dim), dtype=torch.double)
first_basis_vec[0,0] = dim ** 0.25
t_param_mat = t_param.reshape((num_effects - 1, dim))
t_func = first_basis_vec - t_param_mat.sum(axis=0, keepdim=True)
t = torch_handle.row_stack((t_param_mat, t_func))
t = torch.row_stack((t_param_mat, t_func))
return t
18 changes: 8 additions & 10 deletions pygsti/modelmembers/states/tpstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

from pygsti.baseobjs import Basis as _Basis
from pygsti.baseobjs import statespace as _statespace
from pygsti.modelmembers.torchable import Torchable as _Torchable
from pygsti.modelmembers.states.densestate import DenseState as _DenseState
from pygsti.modelmembers.states.state import State as _State
from pygsti.baseobjs.protectedarray import ProtectedArray as _ProtectedArray
from typing import Tuple, Optional, TypeVar
Tensor = TypeVar('Tensor') # torch.tensor.
from typing import Tuple


class TPState(_DenseState):
class TPState(_DenseState, _Torchable):
"""
A fixed-unit-trace state vector.
Expand Down Expand Up @@ -160,17 +160,15 @@ def from_vector(self, v, close=False, dirty_value=True):
self._ptr_has_changed()
self.dirty = dirty_value

def stateless_data(self):
def stateless_data(self) -> Tuple[int]:
return (self.dim,)

@staticmethod
def torch_base(sd: Tuple[int], t_param: Tensor, torch_handle=None):
if torch_handle is None:
import torch as torch_handle

def torch_base(sd: Tuple[int], t_param: _Torchable.Tensor) -> _Torchable.Tensor:
torch = _Torchable.torch_handle
dim = sd[0]
t_const = (dim ** -0.25) * torch_handle.ones(1, dtype=torch_handle.double)
t = torch_handle.concat((t_const, t_param))
t_const = (dim ** -0.25) * torch.ones(1, dtype=torch.double)
t = torch.concat((t_const, t_param))
return t

def deriv_wrt_params(self, wrt_filter=None):
Expand Down
44 changes: 44 additions & 0 deletions pygsti/modelmembers/torchable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pygsti.modelmembers.modelmember import ModelMember
from typing import TypeVar, Tuple

try:
import torch
torch_handle = torch
Tensor = torch.Tensor
except ImportError:
torch_handle = None
Tensor = TypeVar('Tensor') # we'll access this for type annotations elsewhere.


class Torchable(ModelMember):

Tensor = Tensor
torch_handle = torch

def stateless_data(self) -> Tuple:
"""
Return the data of this model that is considered considered constant for purposes
of model fitting.
Note: the word "stateless" here is used in the sense of object-oriented programming.
"""
raise NotImplementedError()

@staticmethod
def torch_base(sd : Tuple, t_param : Tensor) -> Tensor:
"""
Suppose "obj" is an instance of some ModelMember subclass. If we compute
sd = obj.stateless_data()
vec = obj.to_vector()
t_param = torch.from_numpy(vec)
T = type(obj).torch_base(sd, t_param, grad)
then T will be a PyTorch Tensor that represents "obj" in a canonical numerical way.
The meaning of "canonical" is implementation dependent. If type(obj) implements
the ``.base`` attribute, then a reasonable implementation will probably satisfy
np.allclose(obj.base, T.numpy()).
"""
raise NotImplementedError()
1 change: 0 additions & 1 deletion test/unit/objects/test_forwardsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import pytest

from pygsti.models import modelconstruction as _setc
import pygsti.models as models
from pygsti.forwardsims import ForwardSimulator, \
MapForwardSimulator, SimpleMapForwardSimulator, \
Expand Down

0 comments on commit f5383b9

Please sign in to comment.