Skip to content

Commit

Permalink
Fix nested Xarray attrs accidental overwrite with deepcopy().
Browse files Browse the repository at this point in the history
May need to propagate this solution elsewhere too.
  • Loading branch information
phockett committed Sep 23, 2022
1 parent fe42f98 commit e2ad5eb
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions epsproc/sphFuncs/sphConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
import xarray as xr
import copy # For attrs deepcopy.

from epsproc.util.listFuncs import genLM, YLMtype, YLMdimList

Expand Down Expand Up @@ -60,8 +61,9 @@ def SHcoeffsFromXR(dataIn, kind = None, keyDims = None):
# else:
# kind = kind

clm = pysh.SHCoeffs.from_zeros(lmax = dataIn['lDim'].max().values, kind = kind)
clm.set_coeffs(dataIn.values, dataIn['lDim'].astype(int), dataIn['mDim'].astype(int)) # NEEDS (values, ls, ms)
# Init zeros SHtools coeffs object & populate
clm = pysh.SHCoeffs.from_zeros(lmax = dataIn[lDim].max().values, kind = kind)
clm.set_coeffs(dataIn.values, dataIn[lDim].astype(int), dataIn[mDim].astype(int)) # NEEDS (values, ls, ms)

return clm

Expand Down Expand Up @@ -267,6 +269,8 @@ def sphRealConvert(dataIn, method = 'std', keyDims = None, incConj = True, rotPh
"""

dataCalc = dataIn.copy()
# dataCalc.attrs = dataIn.attrs.copy() # XR issue with attrs copy? Works for base dict, but not nested dicts.
dataCalc.attrs = copy.deepcopy(dataIn.attrs) # THIS WORKS ALSO FOR NESTED DICT CASE

#*** Basic dim handling
# TODO: should use checkDims here
Expand Down Expand Up @@ -331,7 +335,8 @@ def sphRealConvert(dataIn, method = 'std', keyDims = None, incConj = True, rotPh
# Set by coord assignment... Should match SHtools case.
dataC = xr.zeros_like(dataCalc)
# dataC = dataC.where(dataC[mDim]>-1, 1/np.sqrt(2)*(- 1j*dataCalc)) # -m case
dataC = dataC.where(dataC[mDim]>-1,((-1)**np.abs(dataCalc[mDim]))/np.sqrt(2)*(1j*dataCalc)) # -m case
# dataC = dataC.where(dataC[mDim]>-1,((-1)**np.abs(dataCalc[mDim]))/np.sqrt(2)*(1j*dataCalc)) # -m case
dataC = dataC.where(dataC[mDim]>-1,1/np.sqrt(2)*(1j*dataCalc)) # -m case, no additional (-1)^m phase term
# dataC3 = dataC3.where(dataC3.m<1,((-1)**np.abs(dataIn.m))/np.sqrt(2)*(dataIn + 1j*dataIn)) # +m case
dataC = dataC.where(dataC[mDim]<1,1/np.sqrt(2)*(dataCalc)) # +m case
dataC = dataC.where(dataC[mDim]!=0,dataCalc) # m=0 case
Expand All @@ -358,7 +363,8 @@ def sphRealConvert(dataIn, method = 'std', keyDims = None, incConj = True, rotPh

# Propagate attrs
dataC.attrs.update(dataCalc.attrs)
# dataC = checkSphDims(dataC, keyDims) # Will also unstack!
# dataC.attrs = dataCalc.attrs.copy()
# # dataC = checkSphDims(dataC, keyDims) # Will also unstack!
dataC.attrs['harmonics'].update(YLMtype(method={'sphRealConvert':method},incConj=incConj,keyDims=keyDims))

# dataC.attrs['harmonics'] = {'dtype':'Complex harmonics',
Expand Down

0 comments on commit e2ad5eb

Please sign in to comment.