diff --git a/dedalus/core/solvers.py b/dedalus/core/solvers.py index 9e7b0f9b..6407ea33 100644 --- a/dedalus/core/solvers.py +++ b/dedalus/core/solvers.py @@ -406,7 +406,7 @@ class InitialValueSolver: """ - def __init__(self, problem, timestepper, matsolver=None, enforce_real_cadence=100): + def __init__(self, problem, timestepper, matsolver=None, enforce_real_cadence=100, warmup_iterations=10): logger.debug('Beginning IVP instantiation') @@ -418,11 +418,14 @@ def __init__(self, problem, timestepper, matsolver=None, enforce_real_cadence=10 self.matsolver = matsolver self.enforce_real_cadence = enforce_real_cadence self._float_array = np.zeros(1, dtype=float) - self.start_time = self.get_world_time() + self.init_time = self.get_world_time() + self._wall_time_array = np.zeros(1, dtype=float) # Build pencils and pencil matrices self.pencils = pencil.build_pencils(domain) pencil.build_matrices(self.pencils, problem, ['M', 'L']) + local_modes = sum(p.pre_right.shape[1] for p in self.pencils) + self.total_modes = self.domain.dist.comm.allreduce(local_modes, op=MPI.SUM) # Build systems namespace = problem.namespace @@ -447,7 +450,8 @@ def __init__(self, problem, timestepper, matsolver=None, enforce_real_cadence=10 # Attributes self.sim_time = self.initial_sim_time = 0. self.iteration = self.initial_iteration = 0 - + self.warmup_iterations = warmup_iterations + # Default integration parameters self.stop_sim_time = np.inf self.stop_wall_time = np.inf @@ -531,7 +535,7 @@ def proceed(self): if self.sim_time >= self.stop_sim_time: logger.info('Simulation stop time reached.') return False - elif (self.get_world_time() - self.start_time) >= self.stop_wall_time: + elif (self.get_world_time() - self.init_time) >= self.stop_wall_time: logger.info('Wall stop time reached.') return False elif self.iteration >= self.stop_iteration: @@ -567,6 +571,12 @@ def step(self, dt, trim=False): dt = min(dt, schedule - t) # (Safety gather) self.state.gather() + # Record times + wall_time = self.get_world_time() + if self.iteration == self.initial_iteration: + self.start_time = wall_time + if self.iteration == self.initial_iteration + self.warmup_iterations: + self.warmup_time = wall_time # Advance using timestepper self.timestepper.step(self, dt) # (Safety scatter) @@ -584,6 +594,8 @@ def step(self, dt, trim=False): path.increment(self.state.fields) for path in self.domain.dist.paths[::-1]: path.decrement(self.state.fields) + + return dt def evolve(self, timestep_function): @@ -613,3 +625,22 @@ def evaluate_handlers_now(self, dt, handlers=None): handlers = self.evaluator.handlers self.evaluator.evaluate_handlers(handlers, timestep=dt, sim_time=self.sim_time, world_time=end_world_time, wall_time=end_wall_time, iteration=self.iteration) + def log_stats(self, format=".4g"): + """Log timing statistics with specified string formatting (optional).""" + log_time = self.get_world_time() + logger.info(f"Final iteration: {self.iteration}") + logger.info(f"Final sim time: {self.sim_time}") + setup_time = self.start_time - self.init_time + logger.info(f"Setup time (init - iter 0): {setup_time:{format}} sec") + if self.iteration >= self.initial_iteration + self.warmup_iterations: + warmup_time = self.warmup_time - self.start_time + run_time = log_time - self.warmup_time + cpus = self.domain.dist.comm.size + modes = self.total_modes + stages = (self.iteration - self.warmup_iterations - self.initial_iteration) * self.timestepper.stages + logger.info(f"Warmup time (iter 0-{self.warmup_iterations}): {warmup_time:{format}} sec") + logger.info(f"Run time (iter {self.warmup_iterations}-end): {run_time:{format}} sec") + logger.info(f"CPU time (iter {self.warmup_iterations}-end): {run_time*cpus/3600:{format}} cpu-hr") + logger.info(f"Speed: {(modes*stages/cpus/run_time):{format}} mode-stages/cpu-sec") + else: + logger.info(f"Timings unavailable due because warmup did not complete.")