Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Nov 8, 2023
1 parent 260f0ad commit 1472988
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
5 changes: 3 additions & 2 deletions e3sm_to_cmip/cmor_handlers/_formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def mrfso(ds: xr.Dataset) -> xr.DataArray:

# Reconstruct the xarray.DataArray. Make sure not include "levgrnd" since
# it has been summed over.
coords = {dim: var[dim] for dim in var.dims if dim != "levgrnd"}
da = xr.DataArray(coords=coords, data=result, attrs=var.attrs)
dims = [dim for dim in var.dims if dim != "levgrnd"]
coords = {dim: var[dim] for dim in dims}
da = xr.DataArray(dims=dims, coords=coords, data=result, attrs=var.attrs)

return da

Expand Down
32 changes: 26 additions & 6 deletions tests/cmor_handlers/test__formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,20 @@ def test_mmrso4():


def test_mrfso():
ds = xr.Dataset(data_vars={"SOILICE": _dummy_dataarray()})
ds = xr.Dataset(
data_vars={
"SOILICE": xr.DataArray(
dims=["lat", "levgrnd"],
data=np.array(
[[0, 1, 5000], [0, 1, 6000], [0, 1, 4000]], dtype="float64"
),
)
}
)

result = mrfso(ds)
expected = xr.DataArray(
dims=["lat", "lon"], data=np.array([[3, 3, 3], [3, 3, 3], [3, 3, 3]])
dims=["lat"], data=np.array([5000, 5000, 4001]), coords={"lat": [0, 1, 2]}
)
xr.testing.assert_allclose(result, expected)

Expand All @@ -199,13 +208,24 @@ def test_mrfso():

def test_mrso():
ds = xr.Dataset(
data_vars={"SOILICE": _dummy_dataarray(), "SOILLIQ": _dummy_dataarray()}
data_vars={
"SOILICE": xr.DataArray(
dims=["lat", "levgrnd"],
data=np.array(
[[0, 1, 5000], [0, 1, 6000], [0, 1, 4000]], dtype="float64"
),
),
"SOILLIQ": xr.DataArray(
dims=["lat", "levgrnd"],
data=np.array(
[[0, 1, 5000], [0, 1, 6000], [0, 1, 4000]], dtype="float64"
),
),
}
)

result = mrso(ds)
expected = xr.DataArray(
dims=["lat", "lon"], data=np.array([[6, 6, 6], [6, 6, 6], [6, 6, 6]])
)
expected = xr.DataArray(dims=["lat"], data=np.array([5000, 5000, 5000]))
xr.testing.assert_allclose(result, expected)

# Test when required variable keys are NOT in the data dictionary.
Expand Down

0 comments on commit 1472988

Please sign in to comment.