@@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
6161 def transpose (a ):
6262 return np .permute_dims (a , [1 , 0 ])
6363
64- all_axes = [0 , 1 ]
6564 init (False )
6665
6766 elif backend == "numpy" :
@@ -76,7 +75,6 @@ def transpose(a):
7675 transpose = np .transpose
7776
7877 fini = sync = lambda x = None : None
79- all_axes = None
8078 else :
8179 raise ValueError (f'Unknown backend: "{ backend } "' )
8280
@@ -207,11 +205,11 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
207205 # set bathymetry
208206 h [:, :] = bathymetry (x_t_2d , y_t_2d , lx , ly )
209207 # steady state potential energy
210- pe_offset = 0.5 * g * float (np .sum (h ** 2.0 , all_axes )) / nx / ny
208+ pe_offset = 0.5 * g * float (np .sum (h ** 2.0 )) / nx / ny
211209
212210 # compute time step
213211 alpha = 0.5
214- h_max = float (np .max (h , all_axes ))
212+ h_max = float (np .max (h ))
215213 c = (g * h_max ) ** 0.5
216214 dt = alpha * dx / c
217215 dt = t_export / int (math .ceil (t_export / dt ))
@@ -344,22 +342,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
344342 t = i * dt
345343
346344 if t >= next_t_export - 1e-8 :
347- _elev_max = np .max (e , all_axes )
348- _u_max = np .max (u , all_axes )
349- _q_max = np .max (q , all_axes )
350- _total_v = np .sum (e + h , all_axes )
345+ _elev_max = np .max (e )
346+ _u_max = np .max (u )
347+ _q_max = np .max (q )
348+ _total_v = np .sum (e + h )
351349
352350 # potential energy
353351 _pe = 0.5 * g * (e + h ) * (e - h ) + pe_offset
354- _total_pe = np .sum (_pe , all_axes )
352+ _total_pe = np .sum (_pe )
355353
356354 # kinetic energy
357355 u2 = u * u
358356 v2 = v * v
359357 u2_at_t = 0.5 * (u2 [1 :, :] + u2 [:- 1 , :])
360358 v2_at_t = 0.5 * (v2 [:, 1 :] + v2 [:, :- 1 ])
361359 _ke = 0.5 * (u2_at_t + v2_at_t ) * (e + h )
362- _total_ke = np .sum (_ke , all_axes )
360+ _total_ke = np .sum (_ke )
363361
364362 total_pe = float (_total_pe ) * dx * dy
365363 total_ke = float (_total_ke ) * dx * dy
@@ -406,7 +404,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
406404 2
407405 ]
408406 err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
409- err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
407+ err_L2 = math .sqrt (float (np .sum (err2 )))
410408 info (f"L2 error: { err_L2 :7.15e} " )
411409
412410 if nx < 128 or ny < 128 :
0 commit comments