diff --git a/src/pyhf/constraints.py b/src/pyhf/constraints.py index 882a221041..59d6b56469 100644 --- a/src/pyhf/constraints.py +++ b/src/pyhf/constraints.py @@ -195,10 +195,13 @@ def __init__(self, pdfconfig, batch_size=None): self._batched_factors = default_backend.tile( factors, (self.batch_size or 1, 1) ) - - access_field = default_backend.concatenate( - self.param_viewer.index_selection, axis=1 - ) + print('ok') + try: + selection = [x.cpu().numpy() for x in self.param_viewer.index_selection] + except AttributeError: + selection = [x for x in self.param_viewer.index_selection] + print('selectioin', selection) + access_field = default_backend.concatenate(selection, axis=1) self._access_field = access_field self._precompute() diff --git a/src/pyhf/modifiers/shapesys.py b/src/pyhf/modifiers/shapesys.py index 34e762e4f7..e89da07f86 100644 --- a/src/pyhf/modifiers/shapesys.py +++ b/src/pyhf/modifiers/shapesys.py @@ -101,6 +101,11 @@ def _reindex_access_field(self, pdfconfig): ) sample_mask = self._shapesys_mask[syst_index][singular_sample_index][0] + try: + selection = selection.cpu().numpy() + except AttributeError: + pass + print(sample_mask, access_field_for_syst_and_batch, selection) access_field_for_syst_and_batch[sample_mask] = selection self._access_field[ syst_index, batch_index diff --git a/src/pyhf/tensor/pytorch_backend.py b/src/pyhf/tensor/pytorch_backend.py index cded58bd87..7ab4dd5842 100644 --- a/src/pyhf/tensor/pytorch_backend.py +++ b/src/pyhf/tensor/pytorch_backend.py @@ -4,6 +4,7 @@ from torch.distributions.utils import broadcast_all import logging import math +import torch_xla.core.xla_model as xm log = logging.getLogger(__name__) @@ -177,7 +178,7 @@ def astensor(self, tensor_in, dtype='float'): ) raise - return torch.as_tensor(tensor_in, dtype=dtype) + return torch.as_tensor(tensor_in, dtype=dtype, device=xm.xla_device()) def gather(self, tensor, indices): return tensor[indices.type(torch.LongTensor)] @@ -225,10 +226,10 @@ def abs(self, tensor): return torch.abs(tensor) def ones(self, shape): - return torch.ones(shape, dtype=self.dtypemap['float']) + return torch.ones(shape, dtype=self.dtypemap['float'], device=xm.xla_device()) def zeros(self, shape): - return torch.zeros(shape, dtype=self.dtypemap['float']) + return torch.zeros(shape, dtype=self.dtypemap['float'], device=xm.xla_device()) def power(self, tensor_in_1, tensor_in_2): return torch.pow(tensor_in_1, tensor_in_2)