Skip to content

Commit

Permalink
float32
Browse files Browse the repository at this point in the history
  • Loading branch information
mdoucet committed Jan 18, 2024
1 parent ce90a04 commit 71481f2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
10 changes: 7 additions & 3 deletions example/sphere_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@

import sasmodels.kerneltorch as kt

device = torch.device('mps')

def make_kernel(model, q_vectors):
"""Instantiate the python kernel with input *q_vectors*"""
q_input = kt.PyInput(q_vectors, dtype=torch.float64)
q_input = kt.PyInput(q_vectors, dtype=torch.float32)
return kt.PyKernel(model.info, q_input)


model = load_model('_spherepy')
print(model.info)

q = torch.logspace(-3, -1, 200)
q = torch.logspace(-3, -1, 200).to(device)


#qq = logspace(-3, -1, 200)

Expand All @@ -49,3 +51,5 @@ def make_kernel(model, q_vectors):
Iq = call_kernel(kernel, pars)
elapsed = time.time() - t0
print('Computation time:', elapsed)

print(Iq)
4 changes: 3 additions & 1 deletion example/sphere_pytorch_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import sasmodels.kerneltorch as kt


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps')

print("device",device)

Expand All @@ -24,7 +26,7 @@ def make_kernel(model, q_vectors,device):
print("q",q[6])
kernel = model.make_kernel([q])

pars = {'radius': 200, 'radius_pd': 0.1, 'radius_pd_n':100, 'scale': 2}
pars = {'radius': 200, 'radius_pd': 0.1, 'radius_pd_n':10000, 'scale': 2}

t_before = time.time()
Iq = call_kernel(kernel, pars)
Expand Down
14 changes: 7 additions & 7 deletions sasmodels/kerneltorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, q_vectors, dtype):
self.q[:, 1] = q_vectors[1]
else:
# Create empty tensor
self.q = torch.tensor(np.empty(self.nq, dtype=np.float64))
self.q = torch.tensor(np.empty(self.nq, dtype=np.float32))
self.q[:self.nq] = q_vectors[0]

def release(self):
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(self, model_info, q_input, device):
self.dtype = np.dtype('d')
self.info = model_info
self.q_input = q_input
self.res = np.empty(q_input.nq, np.float64)
self.res = np.empty(q_input.nq, np.float32)
self.dim = '2d' if q_input.is_2d else '1d'

partable = model_info.parameters
Expand Down Expand Up @@ -224,10 +224,10 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
# mesh, we update the components with the polydispersity values before
# calling the respective functions.
n_pars = len(parameters)
parameters = torch.DoubleTensor(parameters).to(device)
parameters = torch.tensor(parameters, dtype=torch.float32).to(device)

#parameters[:] = values[2:n_pars+2]
parameters[:] = torch.DoubleTensor(values[2:n_pars+2])
parameters[:] = torch.tensor(values[2:n_pars+2], dtype=torch.float32)

print("parameters",parameters)
if call_details.num_active == 0:
Expand All @@ -238,8 +238,8 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,

else:
#transform to tensor flow
pd_value = torch.DoubleTensor(values[2+n_pars:2+n_pars + call_details.num_weights])
pd_weight = torch.DoubleTensor(values[2+n_pars + call_details.num_weights:])
pd_value = torch.tensor(values[2+n_pars:2+n_pars + call_details.num_weights], dtype=torch.float32)
pd_weight = torch.tensor(values[2+n_pars + call_details.num_weights:], dtype=torch.float32)

#print("pd_value",pd_value)
#print("pd_weight",pd_weight)
Expand All @@ -262,7 +262,7 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
pd_length = call_details.pd_length[:call_details.num_active]

#total = np.zeros(nq, np.float64)
total = torch.zeros(nq, dtype= torch.float64).to(device)
total = torch.zeros(nq, dtype= torch.float32).to(device)

#print("ll", range(call_details.num_eval))
#parallel for loop
Expand Down

0 comments on commit 71481f2

Please sign in to comment.