@@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
61
61
def transpose(a):
62
62
return np.permute_dims(a, [1, 0])
63
63
64
- all_axes = [0, 1]
65
64
init(False)
66
65
67
66
elif backend == "numpy":
@@ -76,7 +75,6 @@ def transpose(a):
76
75
transpose = np.transpose
77
76
78
77
fini = sync = lambda x=None: None
79
- all_axes = None
80
78
else:
81
79
raise ValueError(f'Unknown backend: "{backend}"')
82
80
@@ -207,11 +205,11 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
207
205
# set bathymetry
208
206
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
209
207
# steady state potential energy
210
- pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes )) / nx / ny
208
+ pe_offset = 0.5 * g * float(np.sum(h**2.0)) / nx / ny
211
209
212
210
# compute time step
213
211
alpha = 0.5
214
- h_max = float(np.max(h, all_axes ))
212
+ h_max = float(np.max(h))
215
213
c = (g * h_max) ** 0.5
216
214
dt = alpha * dx / c
217
215
dt = t_export / int(math.ceil(t_export / dt))
@@ -344,22 +342,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
344
342
t = i * dt
345
343
346
344
if t >= next_t_export - 1e-8:
347
- _elev_max = np.max(e, all_axes )
348
- _u_max = np.max(u, all_axes )
349
- _q_max = np.max(q, all_axes )
350
- _total_v = np.sum(e + h, all_axes )
345
+ _elev_max = np.max(e)
346
+ _u_max = np.max(u)
347
+ _q_max = np.max(q)
348
+ _total_v = np.sum(e + h)
351
349
352
350
# potential energy
353
351
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
354
- _total_pe = np.sum(_pe, all_axes )
352
+ _total_pe = np.sum(_pe)
355
353
356
354
# kinetic energy
357
355
u2 = u * u
358
356
v2 = v * v
359
357
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
360
358
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
361
359
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
362
- _total_ke = np.sum(_ke, all_axes )
360
+ _total_ke = np.sum(_ke)
363
361
364
362
total_pe = float(_total_pe) * dx * dy
365
363
total_ke = float(_total_ke) * dx * dy
@@ -406,7 +404,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
406
404
2
407
405
]
408
406
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
409
- err_L2 = math.sqrt(float(np.sum(err2, all_axes )))
407
+ err_L2 = math.sqrt(float(np.sum(err2)))
410
408
info(f"L2 error: {err_L2:7.15e}")
411
409
412
410
if nx < 128 or ny < 128:
0 commit comments