Skip to content

Commit

Permalink
Update training_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Oct 3, 2024
1 parent 9131e21 commit fafe52c
Showing 1 changed file with 85 additions and 85 deletions.
170 changes: 85 additions & 85 deletions mNSF/training_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,92 +314,92 @@ def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=No
ptic = process_time(), wtic = time(), ckpt_freq=50, test_cvdNorm=False,
kernel_hp_update_freq=10, status_freq=10,
span=100, tol=1e-4, tol_norm = 0.4, pickle_freq=None, check_convergence: bool = True):
"""
train_step
Dtrain, Dval : tensorflow Datasets produced by prepare_datasets_tf func
ckpt_mgr must store at least 2 checkpoints (max_to_keep)
Ntr: total number of training observations, needed to adjust KL term in ELBO
S: number of samples to approximate the ELBO
verbose: should status updates be printed
num_epochs: maximum passes through the data after which optimization will be stopped
ptic,wtic: process and wall time baselines
kernel_hp_update_freq: how often to update the kernel hyperparameters (eg every 10 epochs)
updating less than once per epoch improves speed but reduces numerical stability
status_freq: how often to check for convergence and print updates
ckpt_freq: how often to save tensorflow checkpoints to disk
span: when checking for convergence, how many recent observations to consider
tol: numerical (relative) change below which convergence is declared
pickle_freq: how often to save the entire object to disk as a pickle file
"""
ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic)
self.loss["train"] = rpad(self.loss["train"],num_epochs+1)
if pickle_freq is None: #only pickle at the end
pickle_freq = num_epochs
msg = '{:04d} train: {:.3e}'
if Dval:
msg += ', val: {:.3e}'
self.loss["val"] = rpad(self.loss["val"],num_epochs+1)
msg2 = "" #modified later to include rel_chg
cvg = 0 #increment each time we think it has converged
cvg_normalized=0
cc = ConvergenceChecker(span)
while (not self.converged) and (self.epoch < num_epochs):
epoch_loss = tf.keras.metrics.Mean()
chol=(self.epoch % kernel_hp_update_freq==0)
trl=0.0
nsample=len(list_Dtrain)
for ksample in range(0,nsample):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
for D in Dtrain_ksample: #iterate through each of the batches
epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k,
Ntot=list_tro[ksample].model.delta.shape[1], chol=True))
trl = trl + epoch_loss.result().numpy()
W_updated=list_tro[ksample].model.W-list_tro[ksample].model.W
for ksample in range(0,nsample):
W_updated = W_updated+ (list_tro[ksample].model.W / nsample)
self.epoch.assign_add(1)
i = self.epoch.numpy()
self.loss["train"][i] = trl

## check for nan in any sample loadings
for fit_i in list_tro:
if np.isnan(fit_i.model.W).any():
print('NaN in sample ' + str(list_tro.index(fit_i) + 1))

if not np.isfinite(trl): ### modified
print("training loss calculated at the point of divergence: ")
print(trl)
raise NumericalDivergenceError###!!!NumericalDivergenceError
#if not np.isfinite(trl) or trl>self.loss["train"][1]: ### modified
# raise NumericalDivergenceError###!!!NumericalDivergenceError
if i%status_freq==0 or i==num_epochs:
if Dval:
val_loss = self.model.validation_step(Dval, S=S, chol=False).numpy()
self.loss["val"][i] = val_loss
if i>span and check_convergence: #checking for convergence
rel_chg = cc.relative_change(self.loss["train"],idx=i)
print("rel_chg")
print(rel_chg)
msg2 = ", chg: {:.2e}".format(-rel_chg)
if abs(rel_chg)<tol: cvg+=1
else: cvg=0
if test_cvdNorm:
rel_chg_normalized=cc.relative_chg_normalized(self.loss["train"],idx_current=i)
print("rel_chg_normalized")
print(rel_chg_normalized)
if(-(rel_chg_normalized)<tol_norm): cvg_normalized+=1 # positive values of rel_chg_normalized indicates increase of loss throughout the past 10 iterations
if cvg>=2 or cvg_normalized>=2: #i.e. either convergence or normalized convergence has been detected twice in a row
self.converged=True
pickle_freq = i #ensures final pickling will happen
self.loss = truncate_history(self.loss, i)
if verbose:
if Dval: print(msg.format(i,trl,val_loss)+msg2)
else: print(msg.format(i,trl)+msg2)
if i%ckpt_freq==0:
"""
train_step
Dtrain, Dval : tensorflow Datasets produced by prepare_datasets_tf func
ckpt_mgr must store at least 2 checkpoints (max_to_keep)
Ntr: total number of training observations, needed to adjust KL term in ELBO
S: number of samples to approximate the ELBO
verbose: should status updates be printed
num_epochs: maximum passes through the data after which optimization will be stopped
ptic,wtic: process and wall time baselines
kernel_hp_update_freq: how often to update the kernel hyperparameters (eg every 10 epochs)
updating less than once per epoch improves speed but reduces numerical stability
status_freq: how often to check for convergence and print updates
ckpt_freq: how often to save tensorflow checkpoints to disk
span: when checking for convergence, how many recent observations to consider
tol: numerical (relative) change below which convergence is declared
pickle_freq: how often to save the entire object to disk as a pickle file
"""
ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic)
if self.pickle_path and i%pickle_freq==0:
ptic,wtic = self.pickle(process_time()-ptic, time()-wtic)
self.loss["train"] = rpad(self.loss["train"],num_epochs+1)
if pickle_freq is None: #only pickle at the end
pickle_freq = num_epochs
msg = '{:04d} train: {:.3e}'
if Dval:
msg += ', val: {:.3e}'
self.loss["val"] = rpad(self.loss["val"],num_epochs+1)
msg2 = "" #modified later to include rel_chg
cvg = 0 #increment each time we think it has converged
cvg_normalized=0
cc = ConvergenceChecker(span)
while (not self.converged) and (self.epoch < num_epochs):
epoch_loss = tf.keras.metrics.Mean()
chol=(self.epoch % kernel_hp_update_freq==0)
trl=0.0
nsample=len(list_Dtrain)
for ksample in range(0,nsample):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
for D in Dtrain_ksample: #iterate through each of the batches
epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k,
Ntot=list_tro[ksample].model.delta.shape[1], chol=True))
trl = trl + epoch_loss.result().numpy()
W_updated=list_tro[ksample].model.W-list_tro[ksample].model.W
for ksample in range(0,nsample):
W_updated = W_updated+ (list_tro[ksample].model.W / nsample)
self.epoch.assign_add(1)
i = self.epoch.numpy()
self.loss["train"][i] = trl

## check for nan in any sample loadings
for fit_i in list_tro:
if np.isnan(fit_i.model.W).any():
print('NaN in sample ' + str(list_tro.index(fit_i) + 1))

if not np.isfinite(trl): ### modified
print("training loss calculated at the point of divergence: ")
print(trl)
raise NumericalDivergenceError###!!!NumericalDivergenceError
#if not np.isfinite(trl) or trl>self.loss["train"][1]: ### modified
# raise NumericalDivergenceError###!!!NumericalDivergenceError
if i%status_freq==0 or i==num_epochs:
if Dval:
val_loss = self.model.validation_step(Dval, S=S, chol=False).numpy()
self.loss["val"][i] = val_loss
if i>span and check_convergence: #checking for convergence
rel_chg = cc.relative_change(self.loss["train"],idx=i)
print("rel_chg")
print(rel_chg)
msg2 = ", chg: {:.2e}".format(-rel_chg)
if abs(rel_chg)<tol: cvg+=1
else: cvg=0
if test_cvdNorm:
rel_chg_normalized=cc.relative_chg_normalized(self.loss["train"],idx_current=i)
print("rel_chg_normalized")
print(rel_chg_normalized)
if(-(rel_chg_normalized)<tol_norm): cvg_normalized+=1 # positive values of rel_chg_normalized indicates increase of loss throughout the past 10 iterations
if cvg>=2 or cvg_normalized>=2: #i.e. either convergence or normalized convergence has been detected twice in a row
self.converged=True
pickle_freq = i #ensures final pickling will happen
self.loss = truncate_history(self.loss, i)
if verbose:
if Dval: print(msg.format(i,trl,val_loss)+msg2)
else: print(msg.format(i,trl)+msg2)
if i%ckpt_freq==0:
ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic)
if self.pickle_path and i%pickle_freq==0:
ptic,wtic = self.pickle(process_time()-ptic, time()-wtic)

def find_checkpoint(self, ckpt_freq, back=1, epoch0=0):
"""
Expand Down

0 comments on commit fafe52c

Please sign in to comment.