From e2ad5eb0f5b51dd64c40b8c8a14fe1e2991df360 Mon Sep 17 00:00:00 2001 From: phockett Date: Fri, 23 Sep 2022 13:50:10 -0400 Subject: [PATCH] Fix nested Xarray attrs accidental overwrite with deepcopy(). May need to propagate this solution elsewhere too. --- epsproc/sphFuncs/sphConv.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/epsproc/sphFuncs/sphConv.py b/epsproc/sphFuncs/sphConv.py index eba8fe6..1fcf280 100644 --- a/epsproc/sphFuncs/sphConv.py +++ b/epsproc/sphFuncs/sphConv.py @@ -11,6 +11,7 @@ import numpy as np import xarray as xr +import copy # For attrs deepcopy. from epsproc.util.listFuncs import genLM, YLMtype, YLMdimList @@ -60,8 +61,9 @@ def SHcoeffsFromXR(dataIn, kind = None, keyDims = None): # else: # kind = kind - clm = pysh.SHCoeffs.from_zeros(lmax = dataIn['lDim'].max().values, kind = kind) - clm.set_coeffs(dataIn.values, dataIn['lDim'].astype(int), dataIn['mDim'].astype(int)) # NEEDS (values, ls, ms) + # Init zeros SHtools coeffs object & populate + clm = pysh.SHCoeffs.from_zeros(lmax = dataIn[lDim].max().values, kind = kind) + clm.set_coeffs(dataIn.values, dataIn[lDim].astype(int), dataIn[mDim].astype(int)) # NEEDS (values, ls, ms) return clm @@ -267,6 +269,8 @@ def sphRealConvert(dataIn, method = 'std', keyDims = None, incConj = True, rotPh """ dataCalc = dataIn.copy() + # dataCalc.attrs = dataIn.attrs.copy() # XR issue with attrs copy? Works for base dict, but not nested dicts. + dataCalc.attrs = copy.deepcopy(dataIn.attrs) # THIS WORKS ALSO FOR NESTED DICT CASE #*** Basic dim handling # TODO: should use checkDims here @@ -331,7 +335,8 @@ def sphRealConvert(dataIn, method = 'std', keyDims = None, incConj = True, rotPh # Set by coord assignment... Should match SHtools case. dataC = xr.zeros_like(dataCalc) # dataC = dataC.where(dataC[mDim]>-1, 1/np.sqrt(2)*(- 1j*dataCalc)) # -m case - dataC = dataC.where(dataC[mDim]>-1,((-1)**np.abs(dataCalc[mDim]))/np.sqrt(2)*(1j*dataCalc)) # -m case + # dataC = dataC.where(dataC[mDim]>-1,((-1)**np.abs(dataCalc[mDim]))/np.sqrt(2)*(1j*dataCalc)) # -m case + dataC = dataC.where(dataC[mDim]>-1,1/np.sqrt(2)*(1j*dataCalc)) # -m case, no additional (-1)^m phase term # dataC3 = dataC3.where(dataC3.m<1,((-1)**np.abs(dataIn.m))/np.sqrt(2)*(dataIn + 1j*dataIn)) # +m case dataC = dataC.where(dataC[mDim]<1,1/np.sqrt(2)*(dataCalc)) # +m case dataC = dataC.where(dataC[mDim]!=0,dataCalc) # m=0 case @@ -358,7 +363,8 @@ def sphRealConvert(dataIn, method = 'std', keyDims = None, incConj = True, rotPh # Propagate attrs dataC.attrs.update(dataCalc.attrs) - # dataC = checkSphDims(dataC, keyDims) # Will also unstack! + # dataC.attrs = dataCalc.attrs.copy() + # # dataC = checkSphDims(dataC, keyDims) # Will also unstack! dataC.attrs['harmonics'].update(YLMtype(method={'sphRealConvert':method},incConj=incConj,keyDims=keyDims)) # dataC.attrs['harmonics'] = {'dtype':'Complex harmonics',