Skip to content

Commit

Permalink
Merge pull request #4 from feiyulu/main
Browse files Browse the repository at this point in the history
small fixes for demo
  • Loading branch information
johannag126 authored May 14, 2021
2 parents ed169e0 + 8dbc2df commit 1824404
Showing 1 changed file with 98 additions and 108 deletions.
206 changes: 98 additions & 108 deletions DA_demo_L96.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"cells": [
{
"cell_type": "code",
"execution_count": 89,
"execution_count": null,
"id": "180fa13f",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -12,15 +13,14 @@
"from L96_model import L96, L96s, L96_eq1_xdot, L96_2t_xdot_ydot, RK4\n",
"import time\n",
"import os\n",
"from numba import jit\n",
"\n",
"rng=np.random.default_rng()\n",
"\n",
"config=dict(K=40,J=10,obs_freq=10,\n",
" F_truth=10,#+np.concatenate((np.linspace(-1.8,2,20),np.linspace(1.8,-2,20))),\n",
" F_fcst=10,#+np.concatenate((np.linspace(-1.8,2,20),np.linspace(2,-1.8,20))),\n",
" GCM_param=np.array([0,0,0,0]),ns_da=20000,\n",
" ns=20000,ns_spinup=200,dt=0.005,si=0.005,B_loc=5,DA='EnKF',nens=100,\n",
" GCM_param=np.array([0,0,0,0]),ns_da=2000,\n",
" ns=2000,ns_spinup=200,dt=0.005,si=0.005,B_loc=5,DA='EnKF',nens=100,\n",
" inflate_opt=\"relaxation\",inflate_factor=0.2,hybrid_factor=0.1,\n",
" param_DA=False,param_sd=[0.01,0.02,0.1,0.5],param_inflate='multiplicative',param_inf_factor=0.02,\n",
" obs_density=0.2,DA_freq=10,obs_sigma=0.5,\n",
Expand Down Expand Up @@ -57,10 +57,9 @@
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"collapsed": true
},
"execution_count": null,
"id": "93f98ded",
"metadata": {},
"outputs": [],
"source": [
"def s(k,K):\n",
Expand Down Expand Up @@ -111,7 +110,7 @@
" return weight\n",
"\n",
"def find_obs(loc,obs,t_obs,l_obs,period):\n",
" t_period=np.where((t_obs[:,0]>period[0]) & (t_obs[:,0]<=period[1]))\n",
" t_period=np.where((t_obs[:,0]>=period[0]) & (t_obs[:,0]<period[1]))\n",
" obs_period=np.zeros(t_period[0].shape)\n",
" obs_period[:]=np.nan\n",
" for i in np.arange(len(obs_period)):\n",
Expand All @@ -133,9 +132,9 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": null,
"id": "3c1fd0a9",
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
Expand All @@ -150,20 +149,20 @@
"# Run L96 to generate the \"truth\"\n",
"M_truth.set_state(X_init, Y_init)\n",
"\n",
"# Give F a \"seasonal cycle\" in the truth model\n",
"ann_period=2000\n",
"mon_period=100\n",
"mon_per_ann=ann_period/mon_period\n",
"X_truth,Y_truth,t_truth = M_truth.run(config['si'], config['si']*mon_period)\n",
"for i in range(1,int(config['ns']/mon_period)):\n",
" M_truth.set_state(X_truth[-1,...], Y_truth[-1,...])\n",
" M_truth.set_param(F=config['F_truth']+2*np.sin(2*np.pi*i/mon_per_ann))\n",
" X_step,Y_step,t_step = M_truth.run(config['si'], config['si']*mon_period)\n",
" X_truth=np.concatenate((X_truth,X_step[1:None,...]))\n",
" Y_truth=np.concatenate((Y_truth,Y_step[1:None,...]))\n",
" t_truth=np.concatenate((t_truth,t_truth[-1]+t_step[1:None]))\n",
"\n",
"# X_truth,Y_truth,t_truth = M_truth.run(config['si'], config['si']*config['ns'])\n",
"# # Give F a \"seasonal cycle\" in the truth model\n",
"# ann_period=2000\n",
"# mon_period=100\n",
"# mon_per_ann=ann_period/mon_period\n",
"# X_truth,Y_truth,t_truth = M_truth.run(config['si'], config['si']*mon_period)\n",
"# for i in range(1,int(config['ns']/mon_period)):\n",
"# M_truth.set_state(X_truth[-1,...], Y_truth[-1,...])\n",
"# M_truth.set_param(F=config['F_truth']+2*np.sin(2*np.pi*i/mon_per_ann))\n",
"# X_step,Y_step,t_step = M_truth.run(config['si'], config['si']*mon_period)\n",
"# X_truth=np.concatenate((X_truth,X_step[1:None,...]))\n",
"# Y_truth=np.concatenate((Y_truth,Y_step[1:None,...]))\n",
"# t_truth=np.concatenate((t_truth,t_truth[-1]+t_step[1:None]))\n",
"\n",
"X_truth,Y_truth,t_truth = M_truth.run(config['si'], config['si']*config['ns'])\n",
"\n",
"# # generate climatological background covariance for 2-scale L96 model\n",
"# B_clim = np.cov(X_truth.T)\n",
Expand Down Expand Up @@ -195,10 +194,9 @@
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {
"collapsed": true
},
"execution_count": null,
"id": "3ffa2f5d",
"metadata": {},
"outputs": [],
"source": [
"# Sample the \"truth\" to generate observations at certain times (t_obs) and locations (l_obs)\n",
Expand All @@ -212,37 +210,56 @@
"# Calculated observation covariance matrix, assuming independent observations\n",
"R = config['obs_sigma']**2*np.eye(int(config['K']*config['obs_density']))\n",
"\n",
"# plt.figure(figsize=[6,4])\n",
"# plt.scatter(t_obs,X_obs)"
"plt.figure(figsize=[10,6])\n",
"plt.plot(range(1000,1500),X_truth[1000:1500,0],label='truth')\n",
"plt.scatter(t_obs[100:150,0],find_obs(0,X_obs,t_obs,l_obs,[t_obs[100,0],t_obs[150,0]]),color='k',label='obs')\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(200, 40, 100)\n",
"29.788599014282227\n"
]
}
],
"execution_count": null,
"id": "c9316b7d",
"metadata": {},
"outputs": [],
"source": [
"import DA_methods\n",
"import importlib\n",
"importlib.reload(DA_methods)\n",
"\n",
"t0 = time.time()\n",
"\n",
"# load pre-calculated climatological background covariance matrix from a long simulation\n",
"B_clim=np.load('B_clim_L96s.npy')\n",
"B_clim1=np.load('B_clim_L96s.npy')\n",
"B_loc,W_clim=cov_loc(B_clim,loc=config['B_loc'])\n",
"\n",
"B_clim2=np.load('B_clim_L96.npy')\n",
"B_corr1=np.zeros(B_clim1.shape)\n",
"B_corr2=np.zeros(B_clim2.shape)\n",
"for i in range(B_clim1.shape[0]):\n",
" for j in range(B_clim1.shape[1]):\n",
" B_corr1[i,j]=B_clim1[i,j]/np.sqrt(B_clim1[i,i]*B_clim1[j,j])\n",
" B_corr2[i,j]=B_clim2[i,j]/np.sqrt(B_clim2[i,i]*B_clim2[j,j])\n",
" \n",
"plt.figure(figsize=(16,6))\n",
"plt.subplot(121)\n",
"plt.contourf(B_corr1,cmap='bwr',extend='both',levels=np.linspace(-0.95,0.95,20))\n",
"plt.colorbar()\n",
"plt.title('Background correlation matrix 1-scale L96')\n",
"plt.subplot(122)\n",
"plt.contourf(B_corr2,cmap='bwr',extend='both',levels=np.linspace(-0.95,0.95,20))\n",
"plt.colorbar()\n",
"plt.title('Background correlation matrix 2-scale L96')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58124d8e",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"t0 = time.time()\n",
"# set up array to store DA increments\n",
"X_inc=np.zeros((int(config['ns_da']/config['DA_freq']),config['K'],config['nens']))\n",
"if config['DA']=='3DVar':\n",
Expand All @@ -254,13 +271,13 @@
"ensX=X_init[None,:,None]+rng.standard_normal((1,config['K'],config['nens']))\n",
"X_post=ensX[0,...]\n",
"\n",
"\n",
"if config['param_DA']:\n",
" mean_param=np.zeros((int(config['ns_da']/config['DA_freq']),len(config['GCM_param'])))\n",
" spread_param=np.zeros((int(config['ns_da']/config['DA_freq']),len(config['GCM_param'])))\n",
" param_scale=config['param_sd']\n",
" W=np.ones((config['K']+len(config['GCM_param']),config['K']+len(config['GCM_param'])))\n",
" W[0:config['K'],0:config['K']]=W_clim\n",
" \n",
"else: \n",
" W=W_clim\n",
" param_scale=np.zeros(config['GCM_param'].shape)\n",
Expand All @@ -276,10 +293,12 @@
" # set up array to store model forecast for each DA cycle\n",
" ensX_fcst=np.zeros((config['DA_freq']+1,config['K'],config['nens']))\n",
"\n",
" # model forecast for next DA cycle\n",
" for n in range(config['nens']):\n",
" ensX_fcst[...,n] = GCM(X_post[0:config['K'],n], config['F_fcst'], config['dt'], config['DA_freq'], ens_param[:,n])[0]\n",
" i_t=i_t+config['DA_freq']\n",
"\n",
" # get prior/background from the forecast\n",
" X_prior=ensX_fcst[-1,...] # get prior from model integration\n",
" \n",
" # call DA\n",
Expand All @@ -303,7 +322,7 @@
" ens_param=X_post[-len(config['GCM_param']):None,:]\n",
" elif config['DA']=='HyEnKF':\n",
" H=ObsOp(config['K'],l_obs,t_obs,i_t)\n",
" B_ens = np.cov(X_prior)*(1-config['hybrid_factor'])+B_clim*config['hybrid_factor']\n",
" B_ens = np.cov(X_prior)*(1-config['hybrid_factor'])+B_clim1*config['hybrid_factor']\n",
" B_ens_loc = B_ens*W\n",
" X_post=DA_methods.EnKF(X_prior,X_obs[t_obs==i_t],H,R,B_ens_loc)\n",
" X_post=DA_methods.ens_inflate(X_post,X_prior,config['inflate_opt'],config['inflate_factor'])\n",
Expand Down Expand Up @@ -339,7 +358,8 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": null,
"id": "bdc0b556",
"metadata": {
"scrolled": false
},
Expand All @@ -361,32 +381,32 @@
"axes[0,1].set_xlabel('time'); axes[0,1].set_title('RMSE (X - X_truth)');\n",
"axes[0,1].grid(which='both',linestyle='--')\n",
"\n",
"# axes[0,2].plot(M_truth.k, np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean(axis=0)),label='RMSE'); \n",
"# X_inc_ave=X_inc/config['DA_freq']/config['si']\n",
"# axes[0,2].plot(M_truth.k, X_inc_ave.mean(axis=(0,-1)),label='Inc'); \n",
"# axes[0,2].plot(M_truth.k, running_ave(X_inc_ave.mean(axis=(0,-1)),7),label='Inc Ave'); \n",
"# axes[0,2].plot(M_truth.k, np.ones(M_truth.k.shape)*(config['F_fcst']-config['F_truth']),label='F_bias'); \n",
"# axes[0,2].plot(M_truth.k, np.ones(M_truth.k.shape)*(X_inc/config['DA_freq']/config['si']).mean(),'k:',label='Ave Inc'); \n",
"# axes[0,2].legend()\n",
"# axes[0,2].set_xlabel('s'); axes[0,2].set_title('Increments');\n",
"# axes[0,2].grid(which='both',linestyle='--')\n",
"\n",
"X_inc_ave=(X_inc/config['DA_freq']/config['si']).mean(axis=(1,2)).\\\n",
" reshape(int(config['ns_da']/ann_period),int(ann_period/config['DA_freq'])).mean(axis=0)\n",
"axes[0,2].plot(np.arange(ann_period/config['DA_freq']),X_inc_ave,label='Inc')\n",
"axes[0,2].plot(np.arange(ann_period/config['DA_freq']),running_ave(X_inc_ave,10),label='Inc Ave');\n",
"axes[0,2].plot(np.arange(0,ann_period/config['DA_freq'],mon_period/config['DA_freq']),\n",
" -2*np.sin(2*np.pi*np.arange(mon_per_ann)/mon_per_ann),label='F_bias')\n",
"axes[0,2].plot(M_truth.k, np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean(axis=0)),label='RMSE'); \n",
"X_inc_ave=X_inc/config['DA_freq']/config['si']\n",
"axes[0,2].plot(M_truth.k, X_inc_ave.mean(axis=(0,-1)),label='Inc'); \n",
"axes[0,2].plot(M_truth.k, running_ave(X_inc_ave.mean(axis=(0,-1)),7),label='Inc Ave'); \n",
"axes[0,2].plot(M_truth.k, np.ones(M_truth.k.shape)*(config['F_fcst']-config['F_truth']),label='F_bias'); \n",
"axes[0,2].plot(M_truth.k, np.ones(M_truth.k.shape)*(X_inc/config['DA_freq']/config['si']).mean(),'k:',label='Ave Inc'); \n",
"axes[0,2].legend()\n",
"axes[0,2].set_xlabel('\"annual cycle\"'); axes[0,2].set_title('Increments');\n",
"axes[0,2].set_xlabel('s'); axes[0,2].set_title('Increments');\n",
"axes[0,2].grid(which='both',linestyle='--')\n",
"\n",
"plot_start,plot_end=1000, 1400\n",
"# X_inc_ave=(X_inc/config['DA_freq']/config['si']).mean(axis=(1,2)).\\\n",
"# reshape(int(config['ns_da']/ann_period),int(ann_period/config['DA_freq'])).mean(axis=0)\n",
"# axes[0,2].plot(np.arange(ann_period/config['DA_freq']),X_inc_ave,label='Inc')\n",
"# axes[0,2].plot(np.arange(ann_period/config['DA_freq']),running_ave(X_inc_ave,10),label='Inc Ave');\n",
"# axes[0,2].plot(np.arange(0,ann_period/config['DA_freq'],mon_period/config['DA_freq']),\n",
"# -2*np.sin(2*np.pi*np.arange(mon_per_ann)/mon_per_ann),label='F_bias')\n",
"# axes[0,2].legend()\n",
"# axes[0,2].set_xlabel('\"annual cycle\"'); axes[0,2].set_title('Increments');\n",
"# axes[0,2].grid(which='both',linestyle='--')\n",
"\n",
"plot_start,plot_end=1000, 1500\n",
"plot_start_DA, plot_end_DA=int(plot_start/config['DA_freq']), int(plot_end/config['DA_freq'])\n",
"plot_x=0\n",
"axes[1,0].plot(t_truth[plot_start:plot_end],X_truth[plot_start:plot_end,plot_x],label='truth')\n",
"axes[1,0].plot(t_truth[plot_start:plot_end],meanX[plot_start:plot_end,plot_x],label='forecast')\n",
"axes[1,0].scatter(t_DA[plot_start_DA:plot_end_DA],find_obs(plot_x,X_obs,t_obs,l_obs,[plot_start,plot_end]),label='obs')\n",
"axes[1,0].scatter(t_DA[plot_start_DA-1:plot_end_DA-1],find_obs(plot_x,X_obs,t_obs,l_obs,[plot_start,plot_end]),label='obs')\n",
"axes[1,0].grid(which='both',linestyle='--')\n",
"axes[1,0].set_xlabel('time'); axes[1,0].set_title('k='+str(plot_x+1)+' truth and forecast');\n",
"axes[1,0].legend()\n",
Expand All @@ -401,10 +421,10 @@
" axes[1,1].legend()\n",
" axes[1,1].grid(which='both',linestyle='--')\n",
"\n",
"axes[1,2].text(0.1,0.1,'GCM param={}\\nRMSE={:3f}\\nSpread={:3f}\\nDA={},{},{}\\nDA_freq={}\\nB_loc={}\\ninflation={},{}\\nobs_density={}\\nobs_sigma={}\\nobs_freq={}'.\\\n",
" format(config['GCM_param'],np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean()),\n",
" np.mean(np.std(ensX,axis=-1)),config['DA'],\n",
" config['nens'],config['hybrid_factor'],config['DA_freq'],config['B_loc'],\n",
"axes[1,2].text(0.1,0.1,'RMSE={:3f}\\nSpread={:3f}\\nGCM param={}\\nDA={},{}\\nDA_freq={}\\nB_loc={}\\ninflation={},{}\\nobs_density={}\\nobs_sigma={}\\nobs_freq={}'.\\\n",
" format(np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean()),\n",
" np.mean(np.std(ensX,axis=-1)),config['DA'],config['GCM_param'],\n",
" config['nens'],config['DA_freq'],config['B_loc'],\n",
" config['inflate_opt'],config['inflate_factor'],config['obs_density'],config['obs_sigma'],\n",
" config['obs_freq']),\n",
" fontsize=15)\n",
Expand All @@ -422,39 +442,9 @@
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# B_clim1=np.load('B_clim_L96s.npy')\n",
"# B_clim2=np.load('B_clim_L96.npy')\n",
"# B_corr1=np.zeros(B_clim1.shape)\n",
"# B_corr2=np.zeros(B_clim2.shape)\n",
"# for i in range(40):\n",
"# for j in range(40):\n",
"# B_corr1[i,j]=B_clim1[i,j]/np.sqrt(B_clim1[i,i]*B_clim1[j,j])\n",
"# B_corr2[i,j]=B_clim2[i,j]/np.sqrt(B_clim2[i,i]*B_clim2[j,j])\n",
" \n",
"# print(B_corr)\n",
"# plt.figure(figsize=(16,6))\n",
"# plt.subplot(121)\n",
"# plt.contourf(B_corr1,cmap='bwr',extend='both',levels=np.linspace(-0.95,0.95,20))\n",
"# plt.colorbar()\n",
"# plt.title('Background correlation matrix 1-scale L96')\n",
"# plt.subplot(122)\n",
"# plt.contourf(B_corr2,cmap='bwr',extend='both',levels=np.linspace(-0.95,0.95,20))\n",
"# plt.colorbar()\n",
"# plt.title('Background correlation matrix 2-scale L96')"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {
"collapsed": true
},
"execution_count": null,
"id": "17803f0e",
"metadata": {},
"outputs": [],
"source": [
"#save DA output for further analysis\n",
Expand Down Expand Up @@ -493,7 +483,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
"version": "3.8.6"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 1824404

Please sign in to comment.