Skip to content

Commit 66cd96c

Browse files
feat: overwriting self.np.zeros + fix cast in PSR
1 parent 5cf909c commit 66cd96c

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

src/qiboml/backends/pytorch.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,18 @@ def __init__(self):
7171
self.np.sign = self.np.sgn
7272
self.np.flatnonzero = lambda x: self.np.nonzero(x).flatten()
7373

74+
# These functions are device dependent
75+
torch_zeros = self.np.zeros
76+
77+
def zeros(shape, dtype=None, device=None):
78+
if dtype is None:
79+
dtype = self.dtype
80+
if device is None:
81+
device = self.device
82+
return torch_zeros(shape, dtype=dtype, device=device)
83+
84+
setattr(self.np, "zeros", zeros)
85+
7486
def _torch_dtype(self, dtype):
7587
if dtype == "float":
7688
dtype += "32"
@@ -178,16 +190,6 @@ def _cast_parameter(self, x, trainable):
178190
x, dtype=self.dtype, requires_grad=trainable, device=self.device
179191
)
180192

181-
def zero_state(self, nqubits):
182-
state = self.np.zeros(2**nqubits, dtype=self.dtype, device=self.device)
183-
state[0] = 1
184-
return state
185-
186-
def zero_density_matrix(self, nqubits):
187-
state = self.np.zeros(2 * (2**nqubits,), dtype=self.dtype, device=self.device)
188-
state[0, 0] = 1
189-
return state
190-
191193
def is_sparse(self, x):
192194
if isinstance(x, self.np.Tensor):
193195
return x.is_sparse

src/qiboml/interfaces/pytorch.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,17 @@ def forward(
150150
ctx.backend = backend
151151
ctx.differentiation = differentiation
152152
x_clone = x.clone().detach().cpu().numpy()
153-
x_clone = backend.cast(x_clone)
153+
x_clone = backend.cast(x_clone, dtype=x_clone.dtype)
154154
params = [
155-
backend.cast(par.clone().detach().cpu().numpy()) for par in parameters
155+
backend.cast(par.clone().detach().cpu().numpy(), dtype=x_clone.dtype)
156+
for par in parameters
156157
]
157158
x_clone = encoding(x_clone) + circuit
158159
x_clone.set_parameters(params)
159160
x_clone = decoding(x_clone)
160-
x_clone = torch.as_tensor(backend.to_numpy(x_clone).tolist()).to(x.device)
161+
x_clone = torch.as_tensor(
162+
backend.to_numpy(x_clone).tolist(), dtype=x.dtype, device=x.device
163+
)
161164
return x_clone
162165

163166
@staticmethod
@@ -166,11 +169,14 @@ def backward(ctx, grad_output: torch.Tensor):
166169
x_clone = x.clone().detach().cpu().numpy()
167170
x_clone = ctx.backend.cast(x_clone, dtype=x_clone.dtype)
168171
params = [
169-
ctx.backend.cast(par.clone().detach().cpu().numpy()) for par in parameters
172+
ctx.backend.cast(par.clone().detach().cpu().numpy(), dtype=x_clone.dtype)
173+
for par in parameters
170174
]
171175
wrt_inputs = not x.is_leaf and ctx.encoding.differentiable
172176
grad_input, *gradients = (
173-
torch.as_tensor(ctx.backend.to_numpy(grad).tolist()).to(x.device)
177+
torch.as_tensor(
178+
ctx.backend.to_numpy(grad).tolist(), dtype=x.dtype, device=x.device
179+
)
174180
for grad in ctx.differentiation.evaluate(
175181
x_clone,
176182
ctx.encoding,

src/qiboml/operations/differentiation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def one_parameter_shift(
131131
generator_eigenval = gate.generator_eigenvalue()
132132
s = np.pi / (4 * generator_eigenval)
133133

134-
tmp_params = backend.cast(parameters, copy=True)
134+
tmp_params = backend.cast(parameters, copy=True, dtype=parameters[0].dtype)
135135
tmp_params = self.shift_parameter(tmp_params, parameter_index, s, backend)
136136

137137
circuit.set_parameters(tmp_params)

0 commit comments

Comments
 (0)