Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pop kwarg to avoid TypeError #58

Merged
merged 1 commit into from
Oct 12, 2023

Conversation

charlesbmi
Copy link
Contributor

Background

if the kwarg save_wavefield is set in async def before_forward, then save_wavefield is passed as a kwarg twice to _stencil(), once explicitly and once as part of **kwargs.

Changes

Pop kwarg to avoid TypeError _stencil() got multiple values for keyword argument 'save_wavefield'.

For reference

error stack from calling code in NDK:

tests/neurotechdevkit/test_scenario_with_all_parameters.py:113: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
venv/lib/python3.10/site-packages/neurotechdevkit/scenarios/_base.py:421: in simulate_steady_state
    traces = self._execute_pde(
venv/lib/python3.10/site-packages/neurotechdevkit/scenarios/_base.py:782: in _execute_pde
    return loop.run_until_complete(
venv/lib/python3.10/site-packages/nest_asyncio.py:[99](https://github.com/agencyenterprise/neurotechdevkit/actions/runs/6433635735/job/17570714414#step:4:100): in run_until_complete
    return f.result()
/usr/local/lib/python3.10/asyncio/futures.py:201: in result
    raise self._exception
/usr/local/lib/python3.10/asyncio/tasks.py:232: in __step
    result = coro.send(None)
venv/lib/python3.10/site-packages/stride/core.py:641: in __call__
    outputs = await self.forward(*args, **kwargs)
venv/lib/python3.10/site-packages/stride/physics/problem_type.py:76: in forward
    await self.before_forward(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = head:isoacousticdevito_0, wavelets = wavelets, vp = vp, rho = rho
alpha = alpha
kwargs = {'boundary_type': 'complex_frequency_shift_PML_2', 'devito_args': {}, 'platform': None, 'problem': <stride.problem.problem.SubProblem object at 0x7f7c010449a0>, ...}
problem = <stride.problem.problem.SubProblem object at 0x7f7c010449a0>
shot = <stride.problem.acquisitions.Shot object at 0x7f7c01d9b0d0>
num_sources = [100](https://github.com/agencyenterprise/neurotechdevkit/actions/runs/6433635735/job/17570714414#step:4:101)0, num_receivers = 0, save_wavefield = True, platform = None
diff_source = True

    async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):
        """
        Prepare the problem type to run the state or forward problem.
    
        Parameters
        ----------
        wavelets : Traces
            Source wavelets.
        vp : ScalarField
            Compressional speed of sound fo the medium, in [m/s].
        rho : ScalarField, optional
            Density of the medium, defaults to homogeneous, in [kg/m^3].
        alpha : ScalarField, optional
            Attenuation coefficient of the medium, defaults to 0, in [dB/cm].
        problem : Problem
            Sub-problem being solved by the PDE.
        save_wavefield : bool, optional
            Whether or not to solve the forward wavefield, defaults to True when
            a gradient is expected, and to False otherwise.
        save_bounds : tuple of int, optional
            If saving the wavefield, specify the ``(min timestep, max timestep)``
            where the wavefield should be saved
        save_undersampling : int, optional
            Amount of undersampling in time when saving the forward wavefield. If not given,
            it is calculated given the bandwidth.
        save_compression : str, optional
            Compression applied to saved wavefield, only available with DevitoPRO. Defaults to no
            compression in 2D and `bitcomp` in 3D.
        boundary_type : str, optional
            Type of boundary for the wave equation (``sponge_boundary_2`` or
            ``complex_frequency_shift_PML_2``), defaults to ``sponge_boundary_2``.
            Note that ``complex_frequency_shift_PML_2`` boundaries have lower OT4 stability
            limit than other boundaries.
        interpolation_type : str, optional
            Type of source/receiver interpolation (``linear`` for bi-/tri-linear or ``hicks`` for sinc
            interpolation), defaults to ``linear``.
        attenuation_power : int, optional
            Power of the attenuation law if attenuation is given (``0`` or ``2``),
            defaults to ``0``.
        drp : bool, optional
            Whether or not to use dispersion-relation preserving coefficients (only
            available in some versions of Stride). Defaults to False.
        kernel : str, optional
            Type of time kernel to use (``OT2`` for 2nd order in time or ``OT4`` for 4th
            order in time). If not given, it is automatically decided given the time spacing.
        diff_source : bool, optional
            Whether the source should be injected as is, or as its 1st time derivative. Defaults to
            False, leaving it unchanged.
        adaptive_boxes : bool, optional
            Whether to activate adaptive boxes (requires DevitoPRO and only
            available in some versions of Stride). Defaults to False.
        platform : str, optional
            Platform on which to run the operator, ``None`` to run on the CPU or ``nvidia-acc`` to run on
            the GPU with OpenACC. Defaults to ``None``.
        devito_config : dict, optional
            Additional keyword arguments to configure Devito before operator generation.
        devito_args : dict, optional
            Additional keyword arguments used when calling the generated operator.
    
    
        Returns
        -------
    
        """
        problem = kwargs.get('problem')
        shot = problem.shot
    
        self._check_problem(wavelets, vp, rho=rho, alpha=alpha, **kwargs)
    
        num_sources = shot.num_points_sources
        num_receivers = shot.num_points_receivers
    
        save_wavefield = kwargs.get('save_wavefield', False)
        if save_wavefield is False:
            save_wavefield = vp.needs_grad
            if rho is not None:
                save_wavefield |= rho.needs_grad
            if alpha is not None:
                save_wavefield |= alpha.needs_grad
    
        platform = kwargs.get('platform', 'cpu')
        diff_source = kwargs.pop('diff_source', False)
        save_compression = kwargs.get('save_compression',
                                      'bitcomp' if self.space.dim > 2 else None)
        save_compression = save_compression if platform and 'nvidia' in platform and devito.pro_available else None
    
        # If there's no previous operator, generate one
        if self.state_operator.devito_operator is None:
            # Define variables
            src = self.dev_grid.sparse_time_function('src', num=num_sources,
                                                     coordinates=shot.source_coordinates,
                                                     interpolation_type=self.interpolation_type)
            rec = self.dev_grid.sparse_time_function('rec', num=num_receivers,
                                                     coordinates=shot.receiver_coordinates,
                                                     interpolation_type=self.interpolation_type)
    
            p = self.dev_grid.time_function('p', coefficients='symbolic' if self.drp else 'standard')
    
            # Create stencil
>           stencil = self._stencil(p, wavelets, vp, rho=rho, alpha=alpha, direction='forward',
                                    save_wavefield=save_wavefield, **kwargs)
E           TypeError: stride.physics.iso_acoustic.devito.IsoAcousticDevito._stencil() got multiple values for keyword argument 'save_wavefield'

@ccuetom ccuetom merged commit 31d52cf into trustimaging:master Oct 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants