From caacad71ce2bd01c1bb87f48ab25babe18adcfd7 Mon Sep 17 00:00:00 2001
From: Robert Hallberg <Robert.Hallberg@noaa.gov>
Date: Sat, 7 Jan 2023 06:35:30 -0500
Subject: [PATCH] +(*)Use reproducing sums to homogenize fields

  Added the ability to use reproducing sums to create spatially homogenized
tracer fields when Z_INIT_HOMOGENIZE = True, for rotational symmetry and
consistency across layouts.  The previous version had used non-reproducing
sums.  This new code is used when HOR_REGRID_ANSWER_DATE >= 20230101, and the
comments describing HOR_REGRID_ANSWER_DATE have been updated to reflect this. As
a part of this change, the new publicly visible routine homogenize_field was
added to the MOM_horizontal_regridding module, so that the homogenization
occurs in a single part of the code rather than being spread across several
files.

  By default this commit could lead to answer changes in some cases, depending
whether and how HOR_REGRID_ANSWER_DATE is set, but it turns out that the
existing single_column test cases that use this have only 4 points and happen to
give the same answers with either the older or newer version of the code.

  This commit addresses MOM6 issue #296 (github.com/NOAA-GFDL/MOM6/issues/296),
which can be closed as soon as this commit is merged in to the dev/gfdl branch
of MOM6.
---
 src/framework/MOM_horizontal_regridding.F90   | 112 +++++++++++++-----
 .../MOM_state_initialization.F90              |  28 +----
 .../MOM_tracer_initialization_from_Z.F90      |   1 +
 .../vertical/MOM_ALE_sponge.F90               |   2 +
 4 files changed, 91 insertions(+), 52 deletions(-)

diff --git a/src/framework/MOM_horizontal_regridding.F90 b/src/framework/MOM_horizontal_regridding.F90
index 7a0ace9279..c2fe772571 100644
--- a/src/framework/MOM_horizontal_regridding.F90
+++ b/src/framework/MOM_horizontal_regridding.F90
@@ -5,6 +5,7 @@ module MOM_horizontal_regridding
 
 use MOM_debugging,     only : hchksum
 use MOM_coms,          only : max_across_PEs, min_across_PEs, sum_across_PEs, broadcast
+use MOM_coms,          only : reproducing_sum
 use MOM_cpu_clock,     only : cpu_clock_id, cpu_clock_begin, cpu_clock_end, CLOCK_LOOP
 use MOM_domains,       only : pass_var
 use MOM_error_handler, only : MOM_mesg, MOM_error, FATAL, WARNING, is_root_pe
@@ -24,7 +25,7 @@ module MOM_horizontal_regridding
 
 #include <MOM_memory.h>
 
-public :: horiz_interp_and_extrap_tracer, myStats
+public :: horiz_interp_and_extrap_tracer, myStats, homogenize_field
 
 !> Extrapolate and interpolate data
 interface horiz_interp_and_extrap_tracer
@@ -321,7 +322,6 @@ subroutine horiz_interp_and_extrap_tracer_record(filename, varnam, recnum, G, tr
   logical :: found_attr
   logical :: add_np
   logical :: is_ongrid
-  character(len=8)  :: laynum
   type(horiz_interp_type) :: Interp
   type(axis_info), dimension(4) :: axes_info ! Axis information used for regridding
   integer :: is, ie, js, je     ! compute domain indices
@@ -332,8 +332,6 @@ subroutine horiz_interp_and_extrap_tracer_record(filename, varnam, recnum, G, tr
   real :: I_scale               ! The inverse of the scale factor for diagnostic output [conc CU-1 ~> 1]
   real :: dtr_iter_stop         ! The tolerance for changes in tracer concentrations between smoothing
                                 ! iterations that determines when to stop iterating [CU ~> conc]
-  real :: npoints   ! The number of points in an average [nondim]
-  real :: varAvg    ! The sum of tracer variables being averaged, then their average [CU ~> conc]
   real, dimension(SZI_(G),SZJ_(G)) :: lon_out ! The longitude of points on the model grid [radians]
   real, dimension(SZI_(G),SZJ_(G)) :: lat_out ! The latitude of points on the model grid [radians]
   real, dimension(SZI_(G),SZJ_(G)) :: tr_out  ! The tracer on the model grid [CU ~> conc]
@@ -461,13 +459,12 @@ subroutine horiz_interp_and_extrap_tracer_record(filename, varnam, recnum, G, tr
   ! Loop through each data level and interpolate to model grid.
   ! After interpolating, fill in points which will be needed to define the layers.
   do k=1,kd
-    write(laynum,'(I8)') k ; laynum = adjustl(laynum)
     mask_in(:,:)  = 0.0
     tr_out(:,:) = 0.0
 
     if (is_ongrid) then
       start(1) = is+G%HI%idg_offset ; start(2) = js+G%HI%jdg_offset ; start(3) = k
-      count(1) = ie-is+1 ; count(2) = je-js+1; count(3) = 1; start(4) = 1; count(4) = 1
+      count(1) = ie-is+1 ; count(2) = je-js+1 ; count(3) = 1 ; start(4) = 1 ; count(4) = 1
       call MOM_read_data(trim(filename), trim(varnam), tr_in, start, count, G%Domain)
       do j=js,je
         do i=is,ie
@@ -487,6 +484,7 @@ subroutine horiz_interp_and_extrap_tracer_record(filename, varnam, recnum, G, tr
       start(:) = 1 ; start(3) = k
       count(:) = 1 ; count(1) = id ; count(2) = jd
       call read_variable(trim(filename), trim(varnam), tr_in, start=start, nread=count)
+
       if (is_root_pe()) then
         if (add_np) then
           pole = 0.0 ; npole = 0.0
@@ -539,14 +537,11 @@ subroutine horiz_interp_and_extrap_tracer_record(filename, varnam, recnum, G, tr
 
     fill(:,:) = 0.0 ; good(:,:) = 0.0
 
-    nPoints = 0 ; varAvg = 0.
     do j=js,je ; do i=is,ie
       if (mask_out(i,j) < 1.0) then
         tr_out(i,j) = missing_value
       else
         good(i,j) = 1.0
-        nPoints = nPoints + 1
-        varAvg = varAvg + tr_out(i,j)
       endif
       if ((G%mask2dT(i,j) == 1.0) .and. (z_edges_in(k) <= G%bathyT(i,j) + G%Z_ref) .and. &
           (mask_out(i,j) < 1.0)) &
@@ -561,13 +556,7 @@ subroutine horiz_interp_and_extrap_tracer_record(filename, varnam, recnum, G, tr
 
     ! Horizontally homogenize data to produce perfectly "flat" initial conditions
     if (PRESENT(homogenize)) then ; if (homogenize) then
-      !### These averages will not reproduce across PE layouts or grid rotation.
-      call sum_across_PEs(nPoints)
-      call sum_across_PEs(varAvg)
-      if (nPoints>0) then
-        varAvg = varAvg / real(nPoints)
-      endif
-      tr_out(:,:) = varAvg
+      call homogenize_field(tr_out, mask_out, G, scale, answer_date)
     endif ; endif
 
     ! tr_out contains input z-space data on the model grid with missing values
@@ -663,7 +652,6 @@ subroutine horiz_interp_and_extrap_tracer_fms_id(fms_id, Time, G, tr_z, mask_z,
   real :: missing_val_in ! The missing value in the input field [conc]
   real :: roundoff  ! The magnitude of roundoff, usually ~2e-16 [nondim]
   logical :: add_np
-  character(len=8)  :: laynum
   type(horiz_interp_type) :: Interp
   type(axistype), dimension(4) :: axes_data
   integer :: is, ie, js, je     ! compute domain indices
@@ -677,8 +665,6 @@ subroutine horiz_interp_and_extrap_tracer_fms_id(fms_id, Time, G, tr_z, mask_z,
   real :: I_scale               ! The inverse of the scale factor for diagnostic output [conc CU-1 ~> 1]
   real :: dtr_iter_stop         ! The tolerance for changes in tracer concentrations between smoothing
                                 ! iterations that determines when to stop iterating [CU ~> conc]
-  real :: npoints   ! The number of points in an average [nondim]
-  real :: varAvg    ! The sum of tracer variables being averaged, then their average [CU ~> conc]
   real, dimension(SZI_(G),SZJ_(G)) :: lon_out ! The longitude of points on the model grid [radians]
   real, dimension(SZI_(G),SZJ_(G)) :: lat_out ! The latitude of points on the model grid [radians]
   real, dimension(SZI_(G),SZJ_(G)) :: tr_out  ! The tracer on the model grid [CU ~> conc]
@@ -791,10 +777,10 @@ subroutine horiz_interp_and_extrap_tracer_fms_id(fms_id, Time, G, tr_z, mask_z,
   if (.not.is_ongrid) then
     if (is_root_pe()) &
       call time_interp_external(fms_id, Time, data_in, verbose=(verbosity>5), turns=turns)
+
     ! Loop through each data level and interpolate to model grid.
     ! After interpolating, fill in points which will be needed to define the layers.
     do k=1,kd
-      write(laynum,'(I8)') k ; laynum = adjustl(laynum)
       if (is_root_pe()) then
         tr_in(1:id,1:jd) = data_in(1:id,1:jd,k)
         if (add_np) then
@@ -851,14 +837,11 @@ subroutine horiz_interp_and_extrap_tracer_fms_id(fms_id, Time, G, tr_z, mask_z,
 
       fill(:,:) = 0.0 ; good(:,:) = 0.0
 
-      nPoints = 0 ; varAvg = 0.
       do j=js,je ; do i=is,ie
         if (mask_out(i,j) < 1.0) then
           tr_out(i,j) = missing_value
         else
           good(i,j) = 1.0
-          nPoints = nPoints + 1
-          varAvg = varAvg + tr_out(i,j)
         endif
         if ((G%mask2dT(i,j) == 1.0) .and. (z_edges_in(k) <= G%bathyT(i,j) + G%Z_ref) .and. &
             (mask_out(i,j) < 1.0)) &
@@ -873,13 +856,7 @@ subroutine horiz_interp_and_extrap_tracer_fms_id(fms_id, Time, G, tr_z, mask_z,
 
       ! Horizontally homogenize data to produce perfectly "flat" initial conditions
       if (PRESENT(homogenize)) then ; if (homogenize) then
-        !### These averages will not reproduce across PE layouts or grid rotation.
-        call sum_across_PEs(nPoints)
-        call sum_across_PEs(varAvg)
-        if (nPoints>0) then
-          varAvg = varAvg / real(nPoints)
-        endif
-        tr_out(:,:) = varAvg
+        call homogenize_field(tr_out, mask_out, G, scale, answer_date)
       endif ; endif
 
       ! tr_out contains input z-space data on the model grid with missing values
@@ -920,6 +897,81 @@ subroutine horiz_interp_and_extrap_tracer_fms_id(fms_id, Time, G, tr_z, mask_z,
 
 end subroutine horiz_interp_and_extrap_tracer_fms_id
 
+!> Replace all values of a 2-d field with the weighted average over the valid points.
+subroutine homogenize_field(field, weight, G, scale, answer_date, wt_unscale)
+  type(ocean_grid_type),            intent(inout) :: G      !< Ocean grid type
+  real, dimension(SZI_(G),SZJ_(G)), intent(inout) :: field  !< The tracer on the model grid [A ~> a]
+  real, dimension(SZI_(G),SZJ_(G)), intent(in)    :: weight !< The weights for the tracer [B ~> b]
+  real,                             intent(in)    :: scale  !< A rescaling factor that has been used for the
+                                                            !! variable and has to be undone before the
+                                                            !! reproducing sums [A a-1 ~> 1]
+  integer,                optional, intent(in)    :: answer_date !< The vintage of the expressions in the code.
+                                                            !! Dates before 20230101 use non-reproducing sums
+                                                            !! in their averages, while later versions use
+                                                            !! reproducing sums for rotational symmetry and
+                                                            !! consistency across PE layouts.
+  real,                   optional, intent(in)    :: wt_unscale !< A factor that undoes any dimensional scaling
+                                                            !! of the weights so that they can be used with
+                                                            !! reproducing sums [b B-1 ~> 1]
+
+  ! Local variables
+  real, dimension(SZI_(G),SZJ_(G)) :: field_for_Sums  ! The field times the weights with the scaling undone [a b]
+  real, dimension(SZI_(G),SZJ_(G)) :: wts_for_Sums    ! A copy of the wieghts with the scaling undone [b]
+  real :: var_unscale ! The reciprocal of the scaling factor for the field and weights [a b A-1 B-1 ~> 1]
+  real :: wt_descale  ! A factor that undoes any dimensional scaling of the weights so that they
+                      ! can be used with reproducing sums [b B-1 ~> 1]
+  real :: wt_sum      ! The sum of the weights, in [b] (reproducing) or [B ~> b] (non-reproducing)
+  real :: varsum      ! The weighted sum of field being averaged [A B ~> a b]
+  real :: varAvg      ! The average of the field [A ~> a]
+  logical :: use_repro_sums  ! If true, use reproducing sums.
+  integer :: i, j, is, ie, js, je
+
+  is = G%isc ; ie = G%iec ; js = G%jsc ; je = G%jec
+
+  varAvg = 0.0  ! This value will be used if wt_sum is 0.
+
+  use_repro_sums = .false. ; if (present(answer_date)) use_repro_sums = (answer_date >= 20230101)
+
+  if (scale == 0.0) then
+    ! This seems like an unlikely case to ever be used, but dealing with it is better than having NaNs arise?
+    varAvg = 0.0
+  elseif (use_repro_sums) then
+    wt_descale = 1.0 ; if (present(wt_unscale)) wt_descale = wt_unscale
+    var_unscale = wt_descale / scale
+
+    field_for_Sums(:,:) = 0.0
+    wts_for_Sums(:,:) = 0.0
+    do j=js,je ; do i=is,ie
+      wts_for_Sums(i,j) = wt_descale * weight(i,j)
+      field_for_Sums(i,j) = var_unscale * (field(i,j) * weight(i,j))
+    enddo ; enddo
+
+    wt_sum = reproducing_sum(wts_for_Sums)
+    if (abs(wt_sum) > 0.0) &
+      varAvg = reproducing_sum(field_for_Sums) * (scale / wt_sum)
+
+  else  ! Do the averages with order-dependent sums to reproduce older answers.
+    wt_sum = 0 ; varsum = 0.
+    do j=js,je ; do i=is,ie
+      if (weight(i,j) > 0.0) then
+        wt_sum = wt_sum + weight(i,j)
+        varsum = varsum + field(i,j) * weight(i,j)
+      endif
+    enddo ; enddo
+
+    ! Note that these averages will not reproduce across PE layouts or grid rotation.
+    call sum_across_PEs(wt_sum)
+    if (wt_sum > 0.0) then
+      call sum_across_PEs(varsum)
+      varAvg = varsum / wt_sum
+    endif
+  endif
+
+  field(:,:) = varAvg
+
+end subroutine homogenize_field
+
+
 !> Create a 2d-mesh of grid coordinates from 1-d arrays.
 subroutine meshgrid(x, y, x_T, y_T)
   real, dimension(:),                   intent(in)    :: x  !< input 1-dimensional vector
diff --git a/src/initialization/MOM_state_initialization.F90 b/src/initialization/MOM_state_initialization.F90
index 14459f7d0a..c1a160f9f0 100644
--- a/src/initialization/MOM_state_initialization.F90
+++ b/src/initialization/MOM_state_initialization.F90
@@ -92,7 +92,7 @@ module MOM_state_initialization
 use MOM_regridding, only : regridding_CS, set_regrid_params, getCoordinateResolution
 use MOM_regridding, only : regridding_main, regridding_preadjust_reqs, convective_adjustment
 use MOM_remapping, only : remapping_CS, initialize_remapping, remapping_core_h
-use MOM_horizontal_regridding, only : horiz_interp_and_extrap_tracer
+use MOM_horizontal_regridding, only : horiz_interp_and_extrap_tracer, homogenize_field
 use MOM_oda_incupd, only: oda_incupd_CS, initialize_oda_incupd_fixed, initialize_oda_incupd
 use MOM_oda_incupd, only: set_up_oda_incupd_field, set_up_oda_incupd_vel_field
 use MOM_oda_incupd, only: calc_oda_increments, output_oda_incupd_inc
@@ -2671,6 +2671,7 @@ subroutine MOM_temp_salt_initialize_from_Z(h, tv, depth_tot, G, GV, US, PF, just
                  "The vintage of the order of arithmetic for horizontal regridding.  "//&
                  "Dates before 20190101 give the same answers as the code did in late 2018, "//&
                  "while later versions add parentheses for rotational symmetry.  "//&
+                 "Dates after 20230101 use reproducing sums for global averages.  "//&
                  "If both HOR_REGRID_2018_ANSWERS and HOR_REGRID_ANSWER_DATE are specified, the "//&
                  "latter takes precedence.", default=default_hor_reg_ans_date, do_not_log=just_read)
 
@@ -2930,31 +2931,14 @@ subroutine MOM_temp_salt_initialize_from_Z(h, tv, depth_tot, G, GV, US, PF, just
       endif
     endif
 
-    call tracer_z_init_array(temp_z, z_edges_in, kd, zi, temp_land_fill, G, nz, nlevs, eps_z, &
-                             tv%T)
-    call tracer_z_init_array(salt_z, z_edges_in, kd, zi, salt_land_fill, G, nz, nlevs, eps_z, &
-                             tv%S)
+    call tracer_z_init_array(temp_z, z_edges_in, kd, zi, temp_land_fill, G, nz, nlevs, eps_z, tv%T)
+    call tracer_z_init_array(salt_z, z_edges_in, kd, zi, salt_land_fill, G, nz, nlevs, eps_z, tv%S)
 
     if (homogenize) then
       ! Horizontally homogenize data to produce perfectly "flat" initial conditions
       do k=1,nz
-        nPoints = 0 ; tempAvg = 0. ; saltAvg = 0.
-        do j=js,je ; do i=is,ie ; if (G%mask2dT(i,j) > 0.0) then
-          nPoints = nPoints + 1
-          tempAvg = tempAvg + tv%T(i,j,k)
-          saltAvg = saltAvg + tv%S(i,j,k)
-        endif ; enddo ; enddo
-
-        !### These averages will not reproduce across PE layouts or grid rotation.
-        call sum_across_PEs(nPoints)
-        call sum_across_PEs(tempAvg)
-        call sum_across_PEs(saltAvg)
-        if (nPoints>0) then
-          tempAvg = tempAvg / real(nPoints)
-          saltAvg = saltAvg / real(nPoints)
-        endif
-        tv%T(:,:,k) = tempAvg
-        tv%S(:,:,k) = saltAvg
+        call homogenize_field(tv%T(:,:,k), G%mask2dT, G, scale=US%degC_to_C, answer_date=hor_regrid_answer_date)
+        call homogenize_field(tv%S(:,:,k), G%mask2dT, G, scale=US%ppt_to_S, answer_date=hor_regrid_answer_date)
       enddo
     endif
 
diff --git a/src/initialization/MOM_tracer_initialization_from_Z.F90 b/src/initialization/MOM_tracer_initialization_from_Z.F90
index 7c62ea496e..bd77ec54d5 100644
--- a/src/initialization/MOM_tracer_initialization_from_Z.F90
+++ b/src/initialization/MOM_tracer_initialization_from_Z.F90
@@ -154,6 +154,7 @@ subroutine MOM_initialize_tracer_from_Z(h, tr, G, GV, US, PF, src_file, src_var_
                  "The vintage of the order of arithmetic for horizontal regridding.  "//&
                  "Dates before 20190101 give the same answers as the code did in late 2018, "//&
                  "while later versions add parentheses for rotational symmetry.  "//&
+                 "Dates after 20230101 use reproducing sums for global averages.  "//&
                  "If both HOR_REGRID_2018_ANSWERS and HOR_REGRID_ANSWER_DATE are specified, the "//&
                  "latter takes precedence.", default=default_hor_reg_ans_date)
 
diff --git a/src/parameterizations/vertical/MOM_ALE_sponge.F90 b/src/parameterizations/vertical/MOM_ALE_sponge.F90
index 7ff3bd3701..2e2a3edf07 100644
--- a/src/parameterizations/vertical/MOM_ALE_sponge.F90
+++ b/src/parameterizations/vertical/MOM_ALE_sponge.F90
@@ -255,6 +255,7 @@ subroutine initialize_ALE_sponge_fixed(Iresttime, G, GV, param_file, CS, data_h,
                  "The vintage of the order of arithmetic for horizontal regridding.  "//&
                  "Dates before 20190101 give the same answers as the code did in late 2018, "//&
                  "while later versions add parentheses for rotational symmetry.  "//&
+                 "Dates after 20230101 use reproducing sums for global averages.  "//&
                  "If both HOR_REGRID_2018_ANSWERS and HOR_REGRID_ANSWER_DATE are specified, the "//&
                  "latter takes precedence.", default=default_hor_reg_ans_date)
 
@@ -545,6 +546,7 @@ subroutine initialize_ALE_sponge_varying(Iresttime, G, GV, param_file, CS, Irest
                  "The vintage of the order of arithmetic for horizontal regridding.  "//&
                  "Dates before 20190101 give the same answers as the code did in late 2018, "//&
                  "while later versions add parentheses for rotational symmetry.  "//&
+                 "Dates after 20230101 use reproducing sums for global averages.  "//&
                  "If both HOR_REGRID_2018_ANSWERS and HOR_REGRID_ANSWER_DATE are specified, the "//&
                  "latter takes precedence.", default=default_hor_reg_ans_date)
   call get_param(param_file, mdl, "SPONGE_DATA_ONGRID", CS%spongeDataOngrid, &