Skip to content

Commit

Permalink
Add funny kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
mdoucet committed Jan 19, 2024
1 parent 535165b commit 7c24264
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 35 deletions.
86 changes: 54 additions & 32 deletions example/sphere_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,77 @@
"""
Minimal example of calling a kernel for a specific set of q values.
npts = values.pop(parameter.name+'_pd_n', 0)
width = values.pop(parameter.name+'_pd', 0.0)
nsigma = values.pop(parameter.name+'_pd_nsigma', 3.0)
distribution = values.pop(parameter.name+'_pd_type', 'gaussian')
"""
import time

import torch

import time
from numpy import logspace, sqrt
from matplotlib import pyplot as plt
from sasmodels.core import load_model
from sasmodels.direct_model import call_kernel, get_mesh
from sasmodels.details import make_kernel_args, dispersion_mesh
from sasmodels.direct_model import call_kernel,get_mesh
from sasmodels.details import make_kernel_args

import sasmodels.kerneltorch as kt

device = torch.device('mps')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('mps')

def make_kernel(model, q_vectors):
print("device",device)

def make_kernel(model, q_vectors,device, new=False):
"""Instantiate the python kernel with input *q_vectors*"""
q_input = kt.PyInput(q_vectors, dtype=torch.float32)
return kt.PyKernel(model.info, q_input)
q_input = kt.PyInput(q_vectors, dtype=torch.double)
if new:
return kt.FunnyKernel(model.info, q_input, device = device)
else:
return kt.PyKernel(model.info, q_input, device = device)


model = load_model('_spherepy')
q = logspace(-3, -1, 200)
print("q",q[6])
kernel = model.make_kernel([q])

q = torch.logspace(-3, -1, 200).to(device)
pars = {'radius': 200, 'radius_pd': 0.1, 'radius_pd_n':1000, 'sld':2, 'sld_pd': 0.1, 'sld_pd_n':100, 'scale': 2, 'sld_solvent':1}
pars = {'radius': 200, 'sld':2, 'scale': 2, 'sld_solvent':3}

# Original
t_before = time.time()
Iq = call_kernel(kernel, pars)
t_after = time.time()
total_np = t_after -t_before
print("Iq",Iq[6])
print("Tota Numpy: ",total_np)

#qq = logspace(-3, -1, 200)
# PyTorch
t_before = time.time()
q_t = torch.logspace(start=-3, end=-1, steps=200).to(device)
kernel = make_kernel(model, [q_t],device)
Iq_t = call_kernel(kernel, pars)
print("Iq_t",Iq_t[6])

kernel = make_kernel(model, [q])


kernel = make_kernel(model, [q_t],device, new=True)
Iq_t2 = kernel.Iq([pars['sld'], pars['sld_solvent'], pars['radius']], scale=pars['scale'], background=0)
#Iq_t2 = call_kernel(kernel, pars)
print("Iq_t",Iq_t[6])
print("Iq_t2", Iq_t2[6])

pars = {'radius': 200, 'radius_pd': 0.2, 'radius_pd_n':10000, 'scale': 2}

#mesh = get_mesh(kernel.info, pars, dim=kernel.dim)
#print(mesh)

#call_details, values, is_magnetic = make_kernel_args(kernel, mesh)
#print(call_details)
#print(values)

t0 = time.time()
Iq = call_kernel(kernel, pars)
elapsed = time.time() - t0
print('Computation time:', elapsed)
# call_kernel unwrap
#calculator = kernel
#cutoff=0.
#mono=False

#mesh = get_mesh(calculator.info, pars, dim=calculator.dim, mono=mono)
#print("in call_kernel: pars:", list(zip(*mesh))[0])
#call_details, values, is_magnetic = make_kernel_args(calculator, mesh)
#print("in call_kernel: values:", values)
#Iq_t = calculator(call_details, values, cutoff, is_magnetic)

t_after = time.time()
total_torch = t_after -t_before



print("Total Pytorch: ",total_torch)


print(Iq)
76 changes: 73 additions & 3 deletions sasmodels/kerneltorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,71 @@ def release(self):
self.q = None


class FunnyKernel(Kernel):
def __init__(self, model_info, q_input, device):
self.device = device
self.dtype = np.dtype('d')
self.info = model_info
self.q_input = q_input
self.dim = '2d' if q_input.is_2d else '1d'


def Iq(self, pars, scale, background):
# type: (CallDetails, np.ndarray, np.ndarray, float, bool) -> np.ndarray
_, F2, _, shell_volume, _ = self.Fq(pars)
combined_scale = scale/shell_volume
background = background
return combined_scale*F2 + background
__call__ = Iq

def Fq(self, pars):
# type: (CallDetails, np.ndarray, np.ndarray, float, bool, int) -> np.ndarray

self._call_kernel(pars)
#print("returned",self.q_input.q, self.result)
nout = 2 if self.info.have_Fq and self.dim == '1d' else 1
total_weight = self.result[nout*self.q_input.nq + 0]
# Note: total_weight = sum(weight > cutoff), with cutoff >= 0, so it
# is okay to test directly against zero. If weight is zero then I(q),
# etc. must also be zero.
if total_weight == 0.:
total_weight = 1.
# Note: shell_volume == form_volume for solid objects
form_volume = self.result[nout*self.q_input.nq + 1]/total_weight
shell_volume = self.result[nout*self.q_input.nq + 2]/total_weight
radius_effective = self.result[nout*self.q_input.nq + 3]/total_weight
if shell_volume == 0.:
shell_volume = 1.
F1 = (self.result[1:nout*self.q_input.nq:nout]/total_weight
if nout == 2 else None)
F2 = self.result[0:nout*self.q_input.nq:nout]/total_weight
return F1, F2, radius_effective, shell_volume, form_volume/shell_volume

def release(self):
# type: () -> None
"""
Free resources associated with the kernel instance.
"""
#print("null release kernel")
pass

def _call_kernel(self, kernel_args):
_form = self.info.Iqxy if self.q_input.is_2d else self.info.Iq

#kernel_args = [parameters[i] for i in kernel_idx]
q_values = torch.DoubleTensor(self.q_input.q)
total = _form(q_values, *kernel_args)
weight_norm = 1.0

volume = self.info.form_volume

weighted_shell, weighted_form = volume(1, 200), volume(1, 200)
weighted_radius = 0

self.result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius))



class PyKernel(Kernel):
"""
Callable SAS kernel.
Expand Down Expand Up @@ -175,6 +240,7 @@ def __init__(self, model_info, q_input, device):
else (lambda mode: 1.0))

def _call_kernel(self, call_details, values, cutoff, magnetic, radius_effective_mode):
print("VALUES", values)
# type: (CallDetails, np.ndarray, np.ndarray, float, bool) -> None
if magnetic:
raise NotImplementedError("Magnetism not implemented for pure python models")
Expand Down Expand Up @@ -223,13 +289,17 @@ def _loops(parameters, kernel_idx, form, form_volume, form_radius, q_input, call
parameters[:] = torch.DoubleTensor(values[2:n_pars+2]).to(device)

print("parameters",parameters)
print("num active", call_details.num_active)
if call_details.num_active == 0:
kernel_args = [parameters[i] for i in kernel_idx]
print(kernel_args)
total = form(q_values, *kernel_args)
weight_norm = 1.0
weighted_shell, weighted_form = form_volume()
weighted_radius = form_radius()

print(weighted_shell, weighted_form)
print(weighted_radius)
result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius))
else:
#transform to tensor flow
pd_value = torch.DoubleTensor(values[2+n_pars:2+n_pars + call_details.num_weights]).to(device)
Expand Down Expand Up @@ -296,8 +366,8 @@ def _loops(parameters, kernel_idx, form, form_volume, form_radius, q_input, call
weighted_form += weight * unweighted_form
weighted_radius += weight * form_radius()

#result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius))
result = torch.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)).to(device)
#result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius))
result = torch.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)).to(device)

return result

Expand Down

0 comments on commit 7c24264

Please sign in to comment.