From ee7a3348eee1e811c24e8dc6baf82a86ec62aea5 Mon Sep 17 00:00:00 2001 From: Charles Guan Date: Mon, 9 Oct 2023 13:41:53 -0700 Subject: [PATCH] Update based on PR feedback --- stride/physics/iso_acoustic/devito.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index 191ac55e..68f30aea 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -269,10 +269,9 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): platform = kwargs.get('platform', 'cpu') diff_source = kwargs.pop('diff_source', False) - if platform and ('nvidia' in platform) and devito.pro_available and (self.space.dim > 2): - save_compression = kwargs.get('save_compression', 'bitcomp') - else: - save_compression = None + save_compression = kwargs.get('save_compression', + 'bitcomp' if self.space.dim > 2 else None) + save_compression = save_compression if platform and 'nvidia' in platform and devito.pro_available else None # If there's no previous operator, generate one if self.state_operator.devito_operator is None: @@ -305,7 +304,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): # Define the saving of the wavefield if save_wavefield is True: - layers = devito.HostDevice if (platform and ('nvidia' in platform)) else devito.NoLayers + layers = devito.HostDevice if platform and 'nvidia' in platform else devito.NoLayers p_saved = self.dev_grid.undersampled_time_function('p_saved', bounds=kwargs.pop('save_bounds', None), factor=self.undersampling_factor, @@ -337,7 +336,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): else: # If the wavefield is lazily streamed, re-create every time - if platform and ('nvidia' in platform) and devito.pro_available: + if platform and 'nvidia' in platform and devito.pro_available: self.dev_grid.undersampled_time_function('p_saved', bounds=kwargs.pop('save_bounds', None), factor=self.undersampling_factor,