From 7c24264e84b1e0dea1852a352ac8fef4a835678d Mon Sep 17 00:00:00 2001 From: Mathieu Doucet Date: Fri, 19 Jan 2024 13:29:07 -0500 Subject: [PATCH] Add funny kernel --- example/sphere_pytorch.py | 86 ++++++++++++++++++++++++--------------- sasmodels/kerneltorch.py | 76 ++++++++++++++++++++++++++++++++-- 2 files changed, 127 insertions(+), 35 deletions(-) diff --git a/example/sphere_pytorch.py b/example/sphere_pytorch.py index 20945cca..590219fa 100644 --- a/example/sphere_pytorch.py +++ b/example/sphere_pytorch.py @@ -1,55 +1,77 @@ -""" -Minimal example of calling a kernel for a specific set of q values. - - npts = values.pop(parameter.name+'_pd_n', 0) - width = values.pop(parameter.name+'_pd', 0.0) - nsigma = values.pop(parameter.name+'_pd_nsigma', 3.0) - distribution = values.pop(parameter.name+'_pd_type', 'gaussian') - -""" -import time - import torch - +import time from numpy import logspace, sqrt from matplotlib import pyplot as plt from sasmodels.core import load_model -from sasmodels.direct_model import call_kernel, get_mesh -from sasmodels.details import make_kernel_args, dispersion_mesh +from sasmodels.direct_model import call_kernel,get_mesh +from sasmodels.details import make_kernel_args import sasmodels.kerneltorch as kt -device = torch.device('mps') +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +#device = torch.device('mps') -def make_kernel(model, q_vectors): +print("device",device) + +def make_kernel(model, q_vectors,device, new=False): """Instantiate the python kernel with input *q_vectors*""" - q_input = kt.PyInput(q_vectors, dtype=torch.float32) - return kt.PyKernel(model.info, q_input) + q_input = kt.PyInput(q_vectors, dtype=torch.double) + if new: + return kt.FunnyKernel(model.info, q_input, device = device) + else: + return kt.PyKernel(model.info, q_input, device = device) model = load_model('_spherepy') +q = logspace(-3, -1, 200) +print("q",q[6]) +kernel = model.make_kernel([q]) -q = torch.logspace(-3, -1, 200).to(device) +pars = {'radius': 200, 'radius_pd': 0.1, 'radius_pd_n':1000, 'sld':2, 'sld_pd': 0.1, 'sld_pd_n':100, 'scale': 2, 'sld_solvent':1} +pars = {'radius': 200, 'sld':2, 'scale': 2, 'sld_solvent':3} +# Original +t_before = time.time() +Iq = call_kernel(kernel, pars) +t_after = time.time() +total_np = t_after -t_before +print("Iq",Iq[6]) +print("Tota Numpy: ",total_np) -#qq = logspace(-3, -1, 200) +# PyTorch +t_before = time.time() +q_t = torch.logspace(start=-3, end=-1, steps=200).to(device) +kernel = make_kernel(model, [q_t],device) +Iq_t = call_kernel(kernel, pars) +print("Iq_t",Iq_t[6]) -kernel = make_kernel(model, [q]) +kernel = make_kernel(model, [q_t],device, new=True) +Iq_t2 = kernel.Iq([pars['sld'], pars['sld_solvent'], pars['radius']], scale=pars['scale'], background=0) +#Iq_t2 = call_kernel(kernel, pars) +print("Iq_t",Iq_t[6]) +print("Iq_t2", Iq_t2[6]) -pars = {'radius': 200, 'radius_pd': 0.2, 'radius_pd_n':10000, 'scale': 2} -#mesh = get_mesh(kernel.info, pars, dim=kernel.dim) -#print(mesh) -#call_details, values, is_magnetic = make_kernel_args(kernel, mesh) -#print(call_details) -#print(values) -t0 = time.time() -Iq = call_kernel(kernel, pars) -elapsed = time.time() - t0 -print('Computation time:', elapsed) +# call_kernel unwrap +#calculator = kernel +#cutoff=0. +#mono=False + +#mesh = get_mesh(calculator.info, pars, dim=calculator.dim, mono=mono) +#print("in call_kernel: pars:", list(zip(*mesh))[0]) +#call_details, values, is_magnetic = make_kernel_args(calculator, mesh) +#print("in call_kernel: values:", values) +#Iq_t = calculator(call_details, values, cutoff, is_magnetic) + +t_after = time.time() +total_torch = t_after -t_before + + + +print("Total Pytorch: ",total_torch) + -print(Iq) diff --git a/sasmodels/kerneltorch.py b/sasmodels/kerneltorch.py index 4e6e0c07..c2a8d36a 100644 --- a/sasmodels/kerneltorch.py +++ b/sasmodels/kerneltorch.py @@ -97,6 +97,71 @@ def release(self): self.q = None +class FunnyKernel(Kernel): + def __init__(self, model_info, q_input, device): + self.device = device + self.dtype = np.dtype('d') + self.info = model_info + self.q_input = q_input + self.dim = '2d' if q_input.is_2d else '1d' + + + def Iq(self, pars, scale, background): + # type: (CallDetails, np.ndarray, np.ndarray, float, bool) -> np.ndarray + _, F2, _, shell_volume, _ = self.Fq(pars) + combined_scale = scale/shell_volume + background = background + return combined_scale*F2 + background + __call__ = Iq + + def Fq(self, pars): + # type: (CallDetails, np.ndarray, np.ndarray, float, bool, int) -> np.ndarray + + self._call_kernel(pars) + #print("returned",self.q_input.q, self.result) + nout = 2 if self.info.have_Fq and self.dim == '1d' else 1 + total_weight = self.result[nout*self.q_input.nq + 0] + # Note: total_weight = sum(weight > cutoff), with cutoff >= 0, so it + # is okay to test directly against zero. If weight is zero then I(q), + # etc. must also be zero. + if total_weight == 0.: + total_weight = 1. + # Note: shell_volume == form_volume for solid objects + form_volume = self.result[nout*self.q_input.nq + 1]/total_weight + shell_volume = self.result[nout*self.q_input.nq + 2]/total_weight + radius_effective = self.result[nout*self.q_input.nq + 3]/total_weight + if shell_volume == 0.: + shell_volume = 1. + F1 = (self.result[1:nout*self.q_input.nq:nout]/total_weight + if nout == 2 else None) + F2 = self.result[0:nout*self.q_input.nq:nout]/total_weight + return F1, F2, radius_effective, shell_volume, form_volume/shell_volume + + def release(self): + # type: () -> None + """ + Free resources associated with the kernel instance. + """ + #print("null release kernel") + pass + + def _call_kernel(self, kernel_args): + _form = self.info.Iqxy if self.q_input.is_2d else self.info.Iq + + #kernel_args = [parameters[i] for i in kernel_idx] + q_values = torch.DoubleTensor(self.q_input.q) + total = _form(q_values, *kernel_args) + weight_norm = 1.0 + + volume = self.info.form_volume + + weighted_shell, weighted_form = volume(1, 200), volume(1, 200) + weighted_radius = 0 + + self.result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)) + + + class PyKernel(Kernel): """ Callable SAS kernel. @@ -175,6 +240,7 @@ def __init__(self, model_info, q_input, device): else (lambda mode: 1.0)) def _call_kernel(self, call_details, values, cutoff, magnetic, radius_effective_mode): + print("VALUES", values) # type: (CallDetails, np.ndarray, np.ndarray, float, bool) -> None if magnetic: raise NotImplementedError("Magnetism not implemented for pure python models") @@ -223,13 +289,17 @@ def _loops(parameters, kernel_idx, form, form_volume, form_radius, q_input, call parameters[:] = torch.DoubleTensor(values[2:n_pars+2]).to(device) print("parameters",parameters) + print("num active", call_details.num_active) if call_details.num_active == 0: kernel_args = [parameters[i] for i in kernel_idx] + print(kernel_args) total = form(q_values, *kernel_args) weight_norm = 1.0 weighted_shell, weighted_form = form_volume() weighted_radius = form_radius() - + print(weighted_shell, weighted_form) + print(weighted_radius) + result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)) else: #transform to tensor flow pd_value = torch.DoubleTensor(values[2+n_pars:2+n_pars + call_details.num_weights]).to(device) @@ -296,8 +366,8 @@ def _loops(parameters, kernel_idx, form, form_volume, form_radius, q_input, call weighted_form += weight * unweighted_form weighted_radius += weight * form_radius() - #result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)) - result = torch.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)).to(device) + #result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)) + result = torch.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)).to(device) return result