Skip to content

Commit

Permalink
Update training_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Oct 3, 2024
1 parent 1c2a75b commit 50fdb61
Showing 1 changed file with 108 additions and 84 deletions.
192 changes: 108 additions & 84 deletions mNSF/training_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,96 +308,120 @@ def from_pickle(pth, epoch=None):
fname = "converged.pickle"
return unpickle_from_file(path.join(pth, fname))

def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=None, #Ntr=None,
def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=None, #Ntr=None,
S=3,
verbose=True,num_epochs=500,
ptic = process_time(), wtic = time(), ckpt_freq=50, test_cvdNorm=False,
kernel_hp_update_freq=10, status_freq=10,
span=100, tol=1e-4, tol_norm = 0.4, pickle_freq=None, check_convergence: bool = True):
"""
train_step
Dtrain, Dval : tensorflow Datasets produced by prepare_datasets_tf func
ckpt_mgr must store at least 2 checkpoints (max_to_keep)
Ntr: total number of training observations, needed to adjust KL term in ELBO
S: number of samples to approximate the ELBO
verbose: should status updates be printed
num_epochs: maximum passes through the data after which optimization will be stopped
ptic,wtic: process and wall time baselines
kernel_hp_update_freq: how often to update the kernel hyperparameters (eg every 10 epochs)
updating less than once per epoch improves speed but reduces numerical stability
status_freq: how often to check for convergence and print updates
ckpt_freq: how often to save tensorflow checkpoints to disk
span: when checking for convergence, how many recent observations to consider
tol: numerical (relative) change below which convergence is declared
pickle_freq: how often to save the entire object to disk as a pickle file
"""
ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic)
self.loss["train"] = rpad(self.loss["train"],num_epochs+1)
if pickle_freq is None: #only pickle at the end
pickle_freq = num_epochs
msg = '{:04d} train: {:.3e}'
if Dval:
msg += ', val: {:.3e}'
self.loss["val"] = rpad(self.loss["val"],num_epochs+1)
msg2 = "" #modified later to include rel_chg
cvg = 0 #increment each time we think it has converged
cvg_normalized=0
cc = ConvergenceChecker(span)
while (not self.converged) and (self.epoch < num_epochs):
epoch_loss = tf.keras.metrics.Mean()
chol=(self.epoch % kernel_hp_update_freq==0)
trl=0.0
nsample=len(list_Dtrain)
for ksample in range(0,nsample):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
for D in Dtrain_ksample: #iterate through each of the batches
epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k,
"""train_step
Dtrain, Dval : tensorflow Datasets produced by prepare_datasets_tf func
ckpt_mgr must store at least 2 checkpoints (max_to_keep)
Ntr: total number of training observations, needed to adjust KL term in ELBO
S: number of samples to approximate the ELBO
verbose: should status updates be printed
num_epochs: maximum passes through the data after which optimization will be stopped
ptic,wtic: process and wall time baselines
kernel_hp_update_freq: how often to update the kernel hyperparameters (eg every 10 epochs)
updating less than once per epoch improves speed but reduces numerical stability
status_freq: how often to check for convergence and print updates
ckpt_freq: how often to save tensorflow checkpoints to disk
span: when checking for convergence, how many recent observations to consider
tol: numerical (relative) change below which convergence is declared
pickle_freq: how often to save the entire object to disk as a pickle file
"""
ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic)
self.loss["train"] = rpad(self.loss["train"],num_epochs+1)
if pickle_freq is None: #only pickle at the end
pickle_freq = num_epochs
msg = '{:04d} train: {:.3e}'
if Dval:
msg += ', val: {:.3e}'
self.loss["val"] = rpad(self.loss["val"],num_epochs+1)
msg2 = "" #modified later to include rel_chg
cvg = 0 #increment each time we think it has converged
cvg_normalized=0
cc = ConvergenceChecker(span)
#epoch=0
while (not self.converged) and (self.epoch < num_epochs):
epoch_loss = tf.keras.metrics.Mean()
#epoch=epoch+1
#epoch=self.epoch
chol=(self.epoch % kernel_hp_update_freq==0)
trl=0.0
nsample=len(list_Dtrain)
for ksample in range(0,nsample):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
for D in Dtrain_ksample: #iterate through each of the batches
epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k,
Ntot=list_tro[ksample].model.delta.shape[1], chol=True))
trl = trl + epoch_loss.result().numpy()
W_updated=list_tro[ksample].model.W-list_tro[ksample].model.W
for ksample in range(0,nsample):
W_updated = W_updated+ (list_tro[ksample].model.W / nsample)
self.epoch.assign_add(1)
i = self.epoch.numpy()
self.loss["train"][i] = trl

## check for nan in any sample loadings
for fit_i in list_tro:
if np.isnan(fit_i.model.W).any():
print('NaN in sample ' + str(list_tro.index(fit_i) + 1))

if not np.isfinite(trl): ### modified
print("training loss calculated at the point of divergence: ")
print(trl)
raise NumericalDivergenceError###!!!NumericalDivergenceError
#if not np.isfinite(trl) or trl>self.loss["train"][1]: ### modified
# raise NumericalDivergenceError###!!!NumericalDivergenceError
if i%status_freq==0 or i==num_epochs:
if Dval:
val_loss = self.model.validation_step(Dval, S=S, chol=False).numpy()
self.loss["val"][i] = val_loss
if i>span and check_convergence: #checking for convergence
rel_chg = cc.relative_change(self.loss["train"],idx=i)
print("rel_chg")
print(rel_chg)
msg2 = ", chg: {:.2e}".format(-rel_chg)
if abs(rel_chg)<tol: cvg+=1
else: cvg=0
if test_cvdNorm:
rel_chg_normalized=cc.relative_chg_normalized(self.loss["train"],idx_current=i)
if(-(rel_chg_normalized)<tol_norm): cvg_normalized+=1 # positive values of rel_chg_normalized indicates increase of loss throughout the past 10 iterations
if cvg>=2 or cvg_normalized>=2: #i.e. either convergence or normalized convergence has been detected twice in a row
self.converged=True
pickle_freq = i #ensures final pickling will happen
self.loss = truncate_history(self.loss, i)
if verbose:
if Dval: print(msg.format(i,trl,val_loss)+msg2)
else: print(msg.format(i,trl)+msg2)
if i%ckpt_freq==0:
ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic)
if self.pickle_path and i%pickle_freq==0:
ptic,wtic = self.pickle(process_time()-ptic, time()-wtic)
trl = trl + epoch_loss.result().numpy()
#print("ksample")
#print(ksample)
#print(D["X"].shape)
#print(list_tro[ksample].model.delta.shape[1])
#print(tf.config.experimental.get_memory_info('GPU:0'))
W_updated=list_tro[ksample].model.W-list_tro[ksample].model.W
#print(trl)
for ksample in range(0,nsample):
W_updated = W_updated+ (list_tro[ksample].model.W / nsample)
self.epoch.assign_add(1)
i = self.epoch.numpy()
self.loss["train"][i] = trl

## check for nan in any sample loadings
for fit_i in list_tro:
if np.isnan(fit_i.model.W).any():
print('NaN in sample ' + str(list_tro.index(fit_i) + 1))

if not np.isfinite(trl): ### modified
print("training loss calculated at the point of divergence: ")
print(trl)
raise NumericalDivergenceError###!!!NumericalDivergenceError
#if not np.isfinite(trl) or trl>self.loss["train"][1]: ### modified
# raise NumericalDivergenceError###!!!NumericalDivergenceError
if i%status_freq==0 or i==num_epochs:
if Dval:
val_loss = self.model.validation_step(Dval, S=S, chol=False).numpy()
self.loss["val"][i] = val_loss
if i>span and check_convergence: #checking for convergence
rel_chg = cc.relative_change(self.loss["train"],idx=i)
print("rel_chg")
print(rel_chg)
msg2 = ", chg: {:.2e}".format(-rel_chg)
if abs(rel_chg)<tol: cvg+=1
else: cvg=0
if test_cvdNorm:
rel_chg_normalized=cc.relative_chg_normalized(self.loss["train"],idx_current=i)
print("rel_chg_normalized")
print(rel_chg_normalized)
if(-(rel_chg_normalized)<tol_norm): cvg_normalized+=1 # positive values of rel_chg_normalized indicates increase of loss throughout the past 10 iterations
if cvg>=2 or cvg_normalized>=2: #i.e. either convergence or normalized convergence has been detected twice in a row
self.converged=True
pickle_freq = i #ensures final pickling will happen
self.loss = truncate_history(self.loss, i)
if verbose:
if Dval: print(msg.format(i,trl,val_loss)+msg2)
else: print(msg.format(i,trl)+msg2)
if i%ckpt_freq==0:
ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic)
if self.pickle_path and i%pickle_freq==0:
ptic,wtic = self.pickle(process_time()-ptic, time()-wtic)
# except tf.errors.InvalidArgumentError as err: #cholesky failed
# j = i.numpy() #save the current epoch value for printing
# ptic,wtic = self.restore()
# # self.ckpt.restore(self.manager.latest_checkpoint) #resets i to last checkpt
# if ng < 1.0: ng *= 10.0
# else: raise err #nugget has gotten too big so just give up
# try: self.model.set_nugget(ng) #spatial or integrated model
# except AttributeError: raise err #nonspatial model
# if verbose:
# print("Epoch: {:04d}, numerical error, reverting to epoch {:04d}, \
# increase nugget to {:.3E}".format(j, i.numpy(), ng))
# self.loss = truncate_history(self.loss,i)
# continue


def find_checkpoint(self, ckpt_freq, back=1, epoch0=0):
"""
Expand Down

0 comments on commit 50fdb61

Please sign in to comment.