Skip to content

Commit

Permalink
_jvec and _jtvec
Browse files Browse the repository at this point in the history
- Add _jvec and _jtvec to be used within SimPEG.
- Add initial support for data_type.
  • Loading branch information
sgkang authored and prisae committed Sep 5, 2021
1 parent f9d1686 commit e4d6ac8
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 14 deletions.
49 changes: 40 additions & 9 deletions emg3d/electrodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ def __eq__(self, electrode):
# Check input.
if equal:
for name in self._serialize:
equal *= np.allclose(getattr(self, name),
getattr(electrode, name))
comp = getattr(self, name)
if isinstance(comp, np.ndarray):
equal *= np.allclose(comp, getattr(electrode, name))
else:
equal *= comp == getattr(electrode, name)

return bool(equal)

Expand Down Expand Up @@ -585,17 +588,28 @@ class Receiver(Wire):
Note that ``relative=True`` makes only sense in combination with
sources, such as is the case in a :class:`emg3d.surveys.Survey`.
data_type : str
Data type of the measured responses. Currently implemented is only
``'complex'``.
"""

# Add relative to attributes which have to be serialized.
_serialize = {'relative'} | Wire._serialize
_serialize = {'relative', 'data_type'} | Wire._serialize

def __init__(self, relative, **kwargs):
def __init__(self, relative, data_type, **kwargs):
"""Initiate a receiver."""

# Check data type is a known type.
if data_type.lower() != 'complex':
raise ValueError(f"Unknown data type '{data_type}'.")

# Store relative, add a repr-addition.
self._relative = relative
self._repr_add = f"{['absolute', 'relative'][self.relative]};"
self._data_type = data_type.lower()
self._repr_add = (
f"{['absolute', 'relative'][self.relative]}; {self.data_type};"
)

super().__init__(**kwargs)

Expand All @@ -604,6 +618,11 @@ def relative(self):
"""True if coordinates are relative to source, False if absolute."""
return self._relative

@property
def data_type(self):
"""Data type of the measured responses."""
return self._data_type

def center_abs(self, source):
"""Returns points as absolute positions."""
if self.relative:
Expand Down Expand Up @@ -636,13 +655,19 @@ class RxElectricPoint(Receiver, Point):
Note that ``relative=True`` makes only sense in combination with
sources, such as is the case in a :class:`emg3d.surveys.Survey`.
data_type : str, default: 'complex'
Data type of the measured responses. Currently implemented is only the
default value.
"""
_adjoint_source = TxElectricPoint

def __init__(self, coordinates, relative=False):
def __init__(self, coordinates, relative=False, data_type='complex'):
"""Initiate an electric point receiver."""

super().__init__(coordinates=coordinates, relative=relative)
super().__init__(
coordinates=coordinates, relative=relative, data_type=data_type
)


@utils._known_class
Expand All @@ -662,13 +687,19 @@ class RxMagneticPoint(Receiver, Point):
Note that ``relative=True`` makes only sense in combination with
sources, such as is the case in a :class:`emg3d.surveys.Survey`.
data_type : str, default: 'complex'
Data type of the measured responses. Currently implemented is only the
default value.
"""
_adjoint_source = TxMagneticPoint

def __init__(self, coordinates, relative=False):
def __init__(self, coordinates, relative=False, data_type='complex'):
"""Initiate a magnetic point receiver."""

super().__init__(coordinates=coordinates, relative=relative)
super().__init__(
coordinates=coordinates, relative=relative, data_type=data_type
)


# ROTATIONS AND CONVERSIONS
Expand Down
115 changes: 112 additions & 3 deletions emg3d/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def get_efield_info(self, source, frequency):
"""Return the solver information of the corresponding computation."""
return self._dict_efield_info[source][self._freq_inp2key(frequency)]

def _get_responses(self, source, frequency):
def _get_responses(self, source, frequency, efield=None):
"""Return electric and magnetic fields at receiver locations."""

# Get receiver types and their coordinates.
Expand All @@ -632,8 +632,9 @@ def _get_responses(self, source, frequency):
# Initiate output.
resp = np.zeros_like(self.data.synthetic.loc[source, :, frequency])

# efield of this source/frequency.
efield = self._dict_efield[source][frequency]
# efield of this source/frequency if not provided.
if efield is None:
efield = self._dict_efield[source][frequency]

if erec.size:

Expand Down Expand Up @@ -1027,6 +1028,114 @@ def _get_rfield(self, source, frequency):

return rfield

@utils._requires('discretize')
def _jvec(self, vector):
r"""Compute the sensitivity times a vector.
.. math::
:label: jvec
J v = P A^{-1} G v \ ,
where :math:`v` has size of the model.
Parameters
----------
vector : ndarray
Shape of the model.
Returns
-------
jvec : ndarray
Size of the data.
"""

# Create iterable form src/freq-list to call the process_map.
def collect_jfield_inputs(inp, vector=vector):
"""Collect inputs."""
source, frequency = inp

# Forward electric field
efield = self._dict_efield[source][frequency]

# Compute gvec = G * vector (using discretize)
gvec = efield.grid.get_edge_inner_product_deriv(
np.ones(efield.grid.n_cells))(efield.field) * vector
# Extension for tri-axial anisotropy is trivial:
# gvec = mesh.get_edge_inner_product_deriv(
# np.ones(mesh.n_cells)*3)(efield.field) * vector

gfield = fields.Field(
grid=efield.grid,
data=-efield.smu0*gvec,
frequency=efield.frequency
)

return self.model, gfield, None, self.solver_opts

# Compute and return A^-1 * G * vector
out = utils._process_map(
solver._solve,
list(map(collect_jfield_inputs, self._srcfreq)),
max_workers=self.max_workers,
**{'desc': 'Compute jvec', **self._tqdm_opts},
)

# Store gradient field and info.
if 'jvec' not in self.data.keys():
self.data['jvec'] = self.data.observed.copy(
data=np.full(self.survey.shape, np.nan+1j*np.nan))

# Loop over src-freq combinations to extract and store.
for i, (src, freq) in enumerate(self._srcfreq):

# Store responses at receivers.
resp = self._get_responses(src, freq, out[i][0])
self.data['jvec'].loc[src, :, freq] = resp

return self.data['jvec'].data

def _jtvec(self, vector):
r"""Compute the sensitivity transpose times a vector.
If `vector`=residual, `jtvec` corresponds to the `gradient`.
.. math::
:label: jtvec
J^H v = G^H A^{-H} P^H v \ ,
where :math:`v` has size of the data.
Parameters
----------
vector : ndarray
Shape of the data.
Returns
-------
jtvec : ndarray
Adjoint-state gradient (same shape as ``simulation.model``)
for the provided vector.
"""
# Note: The entire chain `gradient`->`_bcompute`->`_get_rfield` and
# also `_jtvec` could be re-factored much smarter.

# Replace residual by vector if provided.
self.survey.data['residual'][...] = vector

# Reset gradient, so it will be computed.
self._gradient = None

# Get gradient with `v` as residual.
return self.gradient

# UTILS
@property
def _dict_initiate(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_electrodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def test_receiver():
rcoo = [1000, -200, 0]
scoo = [50, 50, 50, 0, 0]

ra = electrodes.Receiver(False, coordinates=rcoo)
rr = electrodes.Receiver(True, coordinates=rcoo)
ra = electrodes.Receiver(False, coordinates=rcoo, data_type='complex')
rr = electrodes.Receiver(True, coordinates=rcoo, data_type='complex')
s1 = electrodes.TxElectricDipole(coordinates=scoo)

assert ra.relative is False
Expand Down

0 comments on commit e4d6ac8

Please sign in to comment.