Skip to content

Commit

Permalink
Merge pull request #60 from trustimaging/filter-improvements
Browse files Browse the repository at this point in the history
Improve application of time trace filters and pre-conditioners
  • Loading branch information
ccuetom authored Oct 23, 2023
2 parents a77a500 + f832f1d commit 5b9c386
Show file tree
Hide file tree
Showing 9 changed files with 5,348 additions and 7,665 deletions.
17 changes: 17 additions & 0 deletions stride/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,16 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa

f_min = kwargs.pop('f_min', None)
f_max = kwargs.pop('f_max', None)
filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation', 0.25)
filter_traces_relaxation = kwargs.pop('filter_traces_relaxation', 0.75)
process_wavelets = ProcessWavelets.remote(f_min=f_min, f_max=f_max,
filter_relaxation=filter_wavelets_relaxation,
len=runtime.num_workers, **kwargs)
process_observed = ProcessObserved.remote(f_min=f_min, f_max=f_max,
filter_relaxation=filter_wavelets_relaxation,
len=runtime.num_workers, **kwargs)
process_traces = ProcessTraces.remote(f_min=f_min, f_max=f_max,
filter_relaxation=filter_traces_relaxation,
len=runtime.num_workers, **kwargs)

platform = kwargs.get('platform', 'cpu')
Expand Down Expand Up @@ -274,22 +281,32 @@ async def loop(worker, shot_id):
else:
raise ValueError('Unknown platform %s' % platform)

# pre-process wavelets and observed traces
wavelets = process_wavelets(wavelets,
problem=sub_problem, runtime=worker, **_kwargs)
await wavelets.init_future
observed = process_observed(observed,
problem=sub_problem, runtime=worker, **_kwargs)
await observed.init_future

# run PDE
modelled = pde(wavelets, *published_args,
problem=sub_problem, runtime=worker, **_kwargs)
await modelled.init_future

# post-process modelled and observed traces
traces = process_traces(modelled, observed,
problem=sub_problem, runtime=worker, **_kwargs)
await traces.init_future

# calculate loss
fun = await loss(traces.outputs[0], traces.outputs[1],
problem=sub_problem, runtime=worker, **_kwargs).result()

iteration.add_fun(fun)
logger.perf('Functional value for shot %d: %s' % (shot_id, fun))

# run adjoint
await fun.adjoint(**_kwargs)

logger.perf('Retrieved gradient for shot %d' % sub_problem.shot_id)
Expand Down
22 changes: 20 additions & 2 deletions stride/optimisation/pipelines/default_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .pipeline import Pipeline


__all__ = ['ProcessWavelets', 'ProcessTraces',
__all__ = ['ProcessWavelets', 'ProcessObserved', 'ProcessTraces',
'ProcessGlobalGradient', 'ProcessModelIteration']

# TODO Default configuration of pipelines should be better defined
Expand All @@ -17,7 +17,7 @@ class ProcessWavelets(Pipeline):
**Default steps:**
- ``filter_wavelets``
- ``filter_traces``
"""

Expand All @@ -27,6 +27,24 @@ def __init__(self, steps=None, no_grad=False, **kwargs):
if kwargs.pop('check_traces', True):
steps.append('check_traces')

steps.append('filter_traces')

super().__init__(steps, no_grad=no_grad, **kwargs)


@mosaic.tessera
class ProcessObserved(ProcessWavelets):
"""
Default pipeline to process observed data before running the forward problem.
**Default steps:**
- ``filter_traces``
"""

def __init__(self, steps=None, no_grad=False, **kwargs):
steps = steps or []
super().__init__(steps, no_grad=no_grad, **kwargs)


Expand Down
2 changes: 0 additions & 2 deletions stride/optimisation/pipelines/steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

from .filter_wavelets import FilterWavelets
from .filter_traces import FilterTraces
from .norm_per_shot import NormPerShot
from .norm_per_trace import NormPerTrace
Expand All @@ -12,7 +11,6 @@


steps_registry = {
'filter_wavelets': FilterWavelets,
'filter_traces': FilterTraces,
'norm_per_shot': NormPerShot,
'norm_per_trace': NormPerTrace,
Expand Down
9 changes: 7 additions & 2 deletions stride/optimisation/pipelines/steps/filter_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class FilterTraces(Operator):
filter_type : str, optional
Type of filter to apply, from ``butterworth`` (default for band pass and high pass),
``fir``, or ``cos`` (default for low pass).
filter_relaxation : float, optional
Relaxation factor for the filter in range (0, 1], defaults to 1 (no dilation).
"""

Expand All @@ -28,6 +30,8 @@ def __init__(self, **kwargs):

self.filter_type = kwargs.pop('filter_type', None)

self.relaxation = kwargs.pop('filter_relaxation', 1.0)

self._num_traces = None

def forward(self, *traces, **kwargs):
Expand Down Expand Up @@ -63,9 +67,10 @@ def _apply(self, traces, **kwargs):

f_min = kwargs.pop('f_min', self.f_min)
f_max = kwargs.pop('f_max', self.f_max)
relaxation = kwargs.pop('filter_relaxation', self.relaxation)

f_min_dim_less = f_min*time.step if f_min is not None else 0
f_max_dim_less = f_max*time.step if f_max is not None else 0
f_min_dim_less = relaxation*f_min*time.step if f_min is not None else 0
f_max_dim_less = 1/relaxation*f_max*time.step if f_max is not None else 0

out_traces = traces.alike(name='filtered_%s' % traces.name)

Expand Down
71 changes: 0 additions & 71 deletions stride/optimisation/pipelines/steps/filter_wavelets.py

This file was deleted.

2 changes: 1 addition & 1 deletion stride/physics/boundaries/boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def damping(self, dimensions=None, damping_coefficient=None, mask=False,
val = dimension_coefficient * pos

else:
raise ValueError('Allowed dumping type are (`sine`, `quadratic`)')
raise ValueError('Allowed dumping type are (`sine`, `power`)')

# : slices
all_ind = [slice(0, d) for d in damp.shape]
Expand Down
38 changes: 27 additions & 11 deletions stride/problem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def release_grad(self):
"""
self.grad = None

def process_grad(self, prec_scale=1e-9, **kwargs):
def process_grad(self, prec_scale=0.15, **kwargs):
"""
Process the gradient by applying the pre-conditioner to it.
Expand All @@ -355,21 +355,37 @@ def process_grad(self, prec_scale=1e-9, **kwargs):
if not self.needs_grad:
return

grad = self.grad
prec = grad.prec
self.grad.apply_prec(prec_scale=prec_scale, **kwargs)
return self.grad

def apply_prec(self, prec_scale=0.15, prec=None, **kwargs):
"""
Apply a pre-conditioner to the current field.
Parameters
----------
prec_scale : float, optional
Condition scaling for the preconditioner.
prec : StructuredData, optional
Pre-conditioner to apply. Defaults to self.prec.
Returns
-------
"""
prec = self.prec if prec is None else prec

if prec is not None:
norm_prec = np.linalg.norm(prec.data)
prec_factor = np.sum(prec.data)

if norm_prec > 1e-31:
prec += prec_scale * norm_prec + 1e-31
prec /= np.max(np.abs(prec.data))
if prec_factor > 1e-31:
num_points = np.prod(prec.shape)
prec_factor = prec_scale * num_points / prec_factor
prec.data[:] = np.sqrt(prec.data * prec_factor + 1)
non_zero = np.abs(prec.data) > 0.
grad.data[non_zero] /= prec.data[non_zero]

self.grad = grad
self.data[non_zero] /= prec.data[non_zero]

return grad
return self

def allocate(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion stride_examples/examples/breast2D/02_script_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def main(runtime):
optimisation_loop, optimiser, vp,
num_iters=num_iters,
select_shots=dict(num=16, randomly=True),
f_min=0.05e6, f_max=freq)
f_max=freq)

vp.plot()

Expand Down
Loading

0 comments on commit 5b9c386

Please sign in to comment.