From 1eae95d9eab9bb587e58967398abc66557015482 Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Thu, 7 Sep 2023 14:56:34 -0700 Subject: [PATCH 01/13] First: change the plot arg from boolean to list everywhere, and add a check in valid_params --- weightwatcher/constants.py | 7 +++++++ weightwatcher/weightwatcher.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/weightwatcher/constants.py b/weightwatcher/constants.py index ebfe1b1..2a57551 100755 --- a/weightwatcher/constants.py +++ b/weightwatcher/constants.py @@ -154,6 +154,13 @@ PLOT = 'plot' STACKED = 'stacked' +# constants used to indicate which plots should be generated +WW_PLOT_DETX = 'detX' + +WW_ALL_PLOTS = [ + WW_PLOT_DETX , +] + CHANNELS_STR = 'channels' FIRST = 'first' LAST = 'last' diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 34aa497..a27603a 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3509,6 +3509,9 @@ def analyze(self, model=None, layers=[], logger.warning("WW2X option deprecated, reverting too POOL=False") ww2x=False pool=False + + if plot is True: plot = WW_ALL_PLOTS + if plot is False: plot = [] params=DEFAULT_PARAMS.copy() @@ -3878,7 +3881,18 @@ def valid_params(params): params[SAVEDIR] = savefig logger.info("Saving all images to {}".format(savedir)) elif not isinstance(savefig,str) and not isinstance(savefig,bool): - valid = False + valid = False + + + plot = params.get(PLOT) + if (plot is True) or (plot is False): + logger.warning(f"plot param has not been converted from boolean to list") + valid = False + + invalid_plots = set(plot) - set(WW_ALL_PLOTS) + if invalid_plots: + logger.warning(f"Invalid plot types detected: {invalid_plots}. See WW_ALL_PLOTS in constants.py") + valid = False fix_fingers = params[FIX_FINGERS] @@ -4851,14 +4865,17 @@ def SVDSmoothing(self, model=None, percent=0.8, pool=True, layers=[], method=DET """ - self.set_model_(model) + self.set_model_(model) + + if plot is True: plot = WW_ALL_PLOTS + if plot is False: plot = [] params = DEFAULT_PARAMS.copy() params[POOL] = pool params[LAYERS] = layers params[FIT] = fit # only useful for method=LAMBDA_MINa - params[PLOT] = False + params[PLOT] = [] params[START_IDS] = start_ids params[SVD_METHOD] = svd_method @@ -5071,6 +5088,9 @@ def SVDSharpness(self, model=None, pool=True, layers=[], plot=False, start_ids= """ self.set_model_(model) + + if plot is True: plot = WW_ALL_PLOTS + if plot is False: plot = [] params=DEFAULT_PARAMS.copy() params[POOL] = pool @@ -5138,7 +5158,10 @@ def analyze_vectors(self, model=None, layers=[], min_evals=DEFAULT_MIN_EVALS, ma plot=True, savefig=DEF_SAVE_DIR, channels=None): """Seperate method to analyze the eigenvectors of each layer""" - self.set_model_(model) + self.set_model_(model) + + if plot is True: plot = WW_ALL_PLOTS + if plot is False: plot = [] params=DEFAULT_PARAMS.copy() params[SAVEFIG] = savefig From cc32d2a3169202dad842e6e2014d663e7e04c05d Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Thu, 7 Sep 2023 15:20:49 -0700 Subject: [PATCH 02/13] Adapted fit_powerlaw to use singleton plots --- weightwatcher/constants.py | 14 ++++++++++++-- weightwatcher/weightwatcher.py | 28 +++++++++++++++------------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/weightwatcher/constants.py b/weightwatcher/constants.py index 2a57551..ddc4606 100755 --- a/weightwatcher/constants.py +++ b/weightwatcher/constants.py @@ -155,10 +155,20 @@ STACKED = 'stacked' # constants used to indicate which plots should be generated -WW_PLOT_DETX = 'detX' +WW_PLOT_DETX = 'detX' +WW_PLOT_LOGLOG_ESD = 'loglog_esd' +WW_PLOT_LINLIN_ESD = 'linlin_esd' +WW_PLOT_LOGLIN_ESD = 'loglin_esd' +WW_PLOT_DKS = 'DKS' +WW_PLOT_XMIN_ALPHA = 'xmin_alpha' WW_ALL_PLOTS = [ - WW_PLOT_DETX , + WW_PLOT_DETX, + WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, +] + +WW_FIT_PL_PLOTS = [ + WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, ] CHANNELS_STR = 'channels' diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index a27603a..4b4cbc2 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3062,7 +3062,7 @@ def apply_detX(self, ww_layer, params=None): detX_val = evals[detX_idx] - if plot: + if WW_PLOT_DETX in plot: name = ww_layer.name # fix rescaling to plot xmin @@ -4138,7 +4138,7 @@ def plot_random_esd(self, ww_layer, params=None): plt.show(); plt.clf() - def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", layer_id=0, plot_id=0, \ + def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=WW_ALL_PLOTS, layer_name="", layer_id=0, plot_id=0, \ sample=False, sample_size=None, savedir=DEF_SAVE_DIR, savefig=True, \ thresh=EVALS_THRESH,\ fix_fingers=False, finger_thresh=DEFAULT_FINGER_THRESH, xmin_max=None, max_fingers=DEFAULT_MAX_FINGERS, \ @@ -4319,8 +4319,7 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la #fit_entropy = line_entropy(fit.Ds) - if plot: - + if set(WW_FIT_PL_PLOTS) & set(plot): if status==SUCCESS: min_evals_to_plot = (xmin/100) @@ -4338,6 +4337,7 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la xmin = -1 min_evals_to_plot = (0.4*np.max(evals)/100) + if WW_PLOT_LOGLOG_ESD in plot: evals_to_plot = evals[evals>min_evals_to_plot] plot_loghist(evals_to_plot, bins=100, xmin=xmin) title = "Log-Log ESD for {}\n".format(layer_name) @@ -4356,8 +4356,9 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la #plt.savefig("ww.layer{}.esd.png".format(layer_id)) save_fig(plt, "esd", plot_id, savedir) plt.show(); plt.clf() - - + + + if WW_PLOT_LINLIN_ESD in plot: # plot eigenvalue histogram num_bins = 100 # np.min([100,len(evals)]) plt.hist(evals_to_plot, bins=num_bins, density=True) @@ -4370,6 +4371,7 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la save_fig(plt, "esd2", plot_id, savedir) plt.show(); plt.clf() + if WW_PLOT_LOGLIN_ESD in plot: # plot log eigenvalue histogram nonzero_evals = evals_to_plot[evals_to_plot > 0.0] plt.hist(np.log10(nonzero_evals), bins=100, density=True) @@ -4384,7 +4386,8 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la plt.show(); plt.clf() # plot xmins vs D - + + if WW_PLOT_DKS in plot: plt.plot(fit.xmins, fit.Ds, label=r'$D_{KS}$') plt.axvline(x=fit.xmin, color='red', label=r'$\lambda_{xmin}$') #plt.plot(fit.xmins, fit.sigmas / fit.alphas, label=r'$\sigma /\alpha$', linestyle='--') @@ -4402,9 +4405,10 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la if savefig: save_fig(plt, "esd4", plot_id, savedir) #plt.savefig("ww.layer{}.esd4.png".format(layer_id)) - plt.show(); plt.clf() - - + plt.show(); plt.clf() + + + if WW_PLOT_XMIN_ALPHA in plot: plt.plot(fit.xmins, fit.alphas, label=r'$\alpha(xmin)$') plt.axvline(x=fit.xmin, color='red', label=r'$\lambda_{xmin}$') plt.xlabel(r'$x_{min}$') @@ -4415,9 +4419,7 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la if savefig: save_fig(plt, "esd5", plot_id, savedir) #plt.savefig("ww.layer{}.esd5.png".format(layer_id)) - - - plt.show(); plt.clf() + plt.show(); plt.clf() raw_alpha = -1 if raw_fit is not None: From 773e28479071f3fcb1b3a0b247ad639a6153da73 Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Thu, 7 Sep 2023 15:35:23 -0700 Subject: [PATCH 03/13] Adapted mp_fit to use singleton plots --- weightwatcher/RMT_Util.py | 7 +++++-- weightwatcher/constants.py | 4 ++++ weightwatcher/weightwatcher.py | 4 ++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/weightwatcher/RMT_Util.py b/weightwatcher/RMT_Util.py index e0facf3..70b2d89 100644 --- a/weightwatcher/RMT_Util.py +++ b/weightwatcher/RMT_Util.py @@ -452,6 +452,9 @@ def plot_density_and_fit(eigenvalues=None, model=None, layer_name="", layer_id=0 # if eigenvalues is None: # eigenvalues = get_eigenvalues(model, weightfile, layer) + + if plot is True: plot = [WW_PLOT_MPDENSITY] + if plot is False: plot = [] if Q == 1: to_fit = np.sqrt(eigenvalues) @@ -463,7 +466,7 @@ def plot_density_and_fit(eigenvalues=None, model=None, layer_name="", layer_id=0 label = r'$\rho_{emp}(\lambda)$' title = " W{} ESD, MP Sigma={:0.3}f" - if plot: + if WW_PLOT_MPDENSITY in plot: plt.hist(to_fit, bins=100, alpha=alpha, color=color, density=True, label=label); plt.legend() @@ -491,7 +494,7 @@ def plot_density_and_fit(eigenvalues=None, model=None, layer_name="", layer_id=0 else: x, mp = marchenko_pastur_pdf(x_min, x_max, Q, sigma) - if plot: + if WW_PLOT_MPDENSITY in plot: plt.title(title.format(layer_name, sigma)) plt.plot(x, mp, linewidth=1, color='r', label="MP fit") diff --git a/weightwatcher/constants.py b/weightwatcher/constants.py index ddc4606..471336d 100755 --- a/weightwatcher/constants.py +++ b/weightwatcher/constants.py @@ -156,6 +156,9 @@ # constants used to indicate which plots should be generated WW_PLOT_DETX = 'detX' +WW_PLOT_MPFIT = 'mpfit' +WW_PLOT_MPFIT2 = 'mpfit2' +WW_PLOT_MPDENSITY = 'mpdensity' WW_PLOT_LOGLOG_ESD = 'loglog_esd' WW_PLOT_LINLIN_ESD = 'linlin_esd' WW_PLOT_LOGLIN_ESD = 'loglin_esd' @@ -164,6 +167,7 @@ WW_ALL_PLOTS = [ WW_PLOT_DETX, + WW_PLOT_MPFIT, WW_PLOT_MPFIT2, WW_PLOT_MPDENSITY, WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, ] diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 4b4cbc2..c69ab83 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -4757,7 +4757,7 @@ def mp_fit(self, evals, N, M, rf, layer_name, layer_id, plot_id, plot, savefig, #TODO: set cutoff #Even if the quarter circle applies, still plot the MP_fit - if plot: + if WW_PLOT_MPFIT in plot: plot_density(to_plot, Q=Q, sigma=s1, method="MP", color=color, cutoff=bulk_max_TW)#, scale=Wscale) plt.legend([r'$\rho_{emp}(\lambda)$', 'MP fit']) plt.title("MP ESD, sigma auto-fit for {}".format(layer_name)) @@ -4773,7 +4773,7 @@ def mp_fit(self, evals, N, M, rf, layer_name, layer_id, plot_id, plot, savefig, sigma_mp, x, mp = plot_density_and_fit(model=None, eigenvalues=to_plot, layer_name=layer_name, layer_id=0, Q=Q, num_spikes=0, sigma=s1, verbose = False, plot=plot, color=color, cutoff=bulk_max_TW)#, scale=Wscale) - if plot: + if WW_PLOT_MPFIT2 in plot: title = fit_law +" for layer "+layer_name+"\n Q={:0.3} ".format(Q) title = title + r"$\sigma_{mp}=$"+"{:0.3} ".format(sigma_mp) title = title + r"$\mathcal{R}_{mp}=$"+"{:0.3} ".format(mp_softrank) From 1d422f69ebe5a3601c57d44ab07fe7f0641d7b8d Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Thu, 7 Sep 2023 16:07:04 -0700 Subject: [PATCH 04/13] Adapted apply_plot_randesd to use singleton plots --- weightwatcher/constants.py | 11 +++++++ weightwatcher/weightwatcher.py | 53 ++++++++++++++++++---------------- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/weightwatcher/constants.py b/weightwatcher/constants.py index 471336d..6dad823 100755 --- a/weightwatcher/constants.py +++ b/weightwatcher/constants.py @@ -156,25 +156,36 @@ # constants used to indicate which plots should be generated WW_PLOT_DETX = 'detX' +WW_PLOT_DELTAES = 'deltaEs' + WW_PLOT_MPFIT = 'mpfit' WW_PLOT_MPFIT2 = 'mpfit2' WW_PLOT_MPDENSITY = 'mpdensity' + WW_PLOT_LOGLOG_ESD = 'loglog_esd' WW_PLOT_LINLIN_ESD = 'linlin_esd' WW_PLOT_LOGLIN_ESD = 'loglin_esd' WW_PLOT_DKS = 'DKS' WW_PLOT_XMIN_ALPHA = 'xmin_alpha' +WW_PLOT_RANDESD = 'rand_esd' +WW_PLOT_LOG_RANDESD = 'log_rand_esd' + WW_ALL_PLOTS = [ WW_PLOT_DETX, WW_PLOT_MPFIT, WW_PLOT_MPFIT2, WW_PLOT_MPDENSITY, WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, + WW_PLOT_RANDESD, WW_PLOT_LOG_RANDESD, ] WW_FIT_PL_PLOTS = [ WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, ] +WW_RANDESD_PLOTS = [ + WW_PLOT_RANDESD, WW_PLOT_LOG_RANDESD, +] + CHANNELS_STR = 'channels' FIRST = 'first' LAST = 'last' diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index c69ab83..553f5f9 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -2988,7 +2988,7 @@ def apply_random_esd(self, ww_layer, params=None): value = np.max(evals)-np.max(rand_evals) ww_layer.add_column("ww_maxdist", value) - if params[PLOT]: + if set(WW_RANDESD_PLOTS) & set(params[PLOT]): self.plot_random_esd(ww_layer, params) return ww_layer @@ -4100,7 +4100,8 @@ def plot_random_esd(self, ww_layer, params=None): """Plot histogram and log histogram of ESD and randomized ESD""" if params is None: params = DEFAULT_PARAMS.copy() - + + plot = params[PLOT] savefig = params[SAVEFIG] savedir = params[SAVEDIR] @@ -4108,34 +4109,36 @@ def plot_random_esd(self, ww_layer, params=None): plot_id = ww_layer.plot_id evals = ww_layer.evals rand_evals = ww_layer.rand_evals - title = "Layer {} {}: ESD & Random ESD".format(ww_layer.layer_id,ww_layer.name) - + nonzero_evals = evals[evals > 0.0] nonzero_rand_evals = rand_evals[rand_evals > 0.0] max_rand_eval = np.max(rand_evals) - plt.hist((nonzero_evals), bins=100, density=True, color='g', label='original') - plt.hist((nonzero_rand_evals), bins=100, density=True, color='r', label='random', alpha=0.5) - plt.axvline(x=(max_rand_eval), color='orange', label='max rand') - plt.title(title) - plt.xlabel(r" Eigenvalues $(\lambda)$") - plt.legend() - if savefig: - #plt.savefig("ww.layer{}.esd.png".format(layer_id)) - save_fig(plt, "randesd1", plot_id, savedir) - plt.show(); plt.clf() + if WW_PLOT_RANDESD in plot: + plt.hist((nonzero_evals), bins=100, density=True, color='g', label='original') + plt.hist((nonzero_rand_evals), bins=100, density=True, color='r', label='random', alpha=0.5) + plt.axvline(x=(max_rand_eval), color='orange', label='max rand') + title = "Layer {} {}: ESD & Random ESD".format(ww_layer.layer_id,ww_layer.name) + plt.title(title) + plt.xlabel(r" Eigenvalues $(\lambda)$") + plt.legend() + if savefig: + #plt.savefig("ww.layer{}.esd.png".format(layer_id)) + save_fig(plt, "randesd1", plot_id, savedir) + plt.show(); plt.clf() - plt.hist(np.log10(nonzero_evals), bins=100, density=True, color='g', label='original') - plt.hist(np.log10(nonzero_rand_evals), bins=100, density=True, color='r', label='random', alpha=0.5) - plt.axvline(x=np.log10(max_rand_eval), color='orange', label='max rand') - title = "Layer {} {}: Log10 ESD & Random ESD".format(ww_layer.layer_id,ww_layer.name) - plt.title(title) - plt.xlabel(r"Log10 Eigenvalues $(log_{10}\lambda)$") - plt.legend() - if savefig: - #plt.savefig("ww.layer{}.randesd.2.png".format(layer_id)) - save_fig(plt, "randesd2", plot_id, savedir) - plt.show(); plt.clf() + if WW_PLOT_LOG_RANDESD in plot: + plt.hist(np.log10(nonzero_evals), bins=100, density=True, color='g', label='original') + plt.hist(np.log10(nonzero_rand_evals), bins=100, density=True, color='r', label='random', alpha=0.5) + plt.axvline(x=np.log10(max_rand_eval), color='orange', label='max rand') + title = "Layer {} {}: Log10 ESD & Random ESD".format(ww_layer.layer_id,ww_layer.name) + plt.title(title) + plt.xlabel(r"Log10 Eigenvalues $(log_{10}\lambda)$") + plt.legend() + if savefig: + #plt.savefig("ww.layer{}.randesd.2.png".format(layer_id)) + save_fig(plt, "randesd2", plot_id, savedir) + plt.show(); plt.clf() def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=WW_ALL_PLOTS, layer_name="", layer_id=0, plot_id=0, \ From 6ee5bc78d76a13336a1576c31d254bc6eb15b1c5 Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Thu, 7 Sep 2023 15:42:35 -0700 Subject: [PATCH 05/13] Adapted apply_deltaEs to use singleton plots --- weightwatcher/constants.py | 8 +++- weightwatcher/weightwatcher.py | 70 ++++++++++++++++++---------------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/weightwatcher/constants.py b/weightwatcher/constants.py index 6dad823..003c41d 100755 --- a/weightwatcher/constants.py +++ b/weightwatcher/constants.py @@ -156,7 +156,8 @@ # constants used to indicate which plots should be generated WW_PLOT_DETX = 'detX' -WW_PLOT_DELTAES = 'deltaEs' +WW_PLOT_LOG_DELTAES = 'log_deltaEs' +WW_PLOT_DELTAES_LEVELS = 'deltaEs_levels' WW_PLOT_MPFIT = 'mpfit' WW_PLOT_MPFIT2 = 'mpfit2' @@ -173,6 +174,7 @@ WW_ALL_PLOTS = [ WW_PLOT_DETX, + WW_PLOT_LOG_DELTAES, WW_PLOT_DELTAES_LEVELS, WW_PLOT_MPFIT, WW_PLOT_MPFIT2, WW_PLOT_MPDENSITY, WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, WW_PLOT_RANDESD, WW_PLOT_LOG_RANDESD, @@ -186,6 +188,10 @@ WW_PLOT_RANDESD, WW_PLOT_LOG_RANDESD, ] +WW_DELTAES_PLOTS = [ + WW_PLOT_LOG_DELTAES, WW_PLOT_DELTAES_LEVELS, +] + CHANNELS_STR = 'channels' FIRST = 'first' LAST = 'last' diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 553f5f9..05e35ea 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3600,7 +3600,7 @@ def analyze(self, model=None, layers=[], logger.debug("MP Fitting Layer: {} {} ".format(ww_layer.layer_id, ww_layer.name)) self.apply_mp_fit(ww_layer, random=False, params=params) - if params[DELTA_ES] and params[PLOT]: + if params[DELTA_ES] and set(WW_DELTAES_PLOTS) & set(params[PLOT]): logger.debug("Computing and Plotting Deltas: {} {} ".format(ww_layer.layer_id, ww_layer.name)) self.apply_plot_deltaEs(ww_layer, random=False, params=params) @@ -3615,7 +3615,7 @@ def analyze(self, model=None, layers=[], logger.debug("MP Fitting Random layer: {} {} ".format(ww_layer.layer_id, ww_layer.name)) self.apply_mp_fit(ww_layer, random=True, params=params) - if params[DELTA_ES] and params[PLOT]: + if params[DELTA_ES] and set(WW_DELTAES_PLOTS) & set(params[PLOT]): logger.debug("Computing and Plotting Deltas: {} {} ".format(ww_layer.layer_id, ww_layer.name)) self.apply_plot_deltaEs(ww_layer, random=True, params=params) @@ -4593,54 +4593,58 @@ def apply_plot_deltaEs(self, ww_layer, random=False, params=None): plot_id = ww_layer.plot_id name = ww_layer.name or "" layer_name = "{} {}".format(plot_id, name) - + + plot = params[PLOT] savefig = params[SAVEFIG] savedir = params[SAVEDIR] + if not set(WW_DELTAES_PLOTS) & set(plot): return + if random: layer_name = "{} Randomized".format(layer_name) - title = "Layer {} W".format(layer_name) evals = ww_layer.rand_evals color='mediumorchid' - bulk_max = ww_layer.rand_bulk_max else: - title = "Layer {} W".format(layer_name) evals = ww_layer.evals color='blue' - # sequence of deltas deltaEs = np.diff(evals) logDeltaEs = np.log10(deltaEs) x = np.arange(len(deltaEs)) eqn = r"$\log_{10}\Delta(\lambda)$" - plt.scatter(x,logDeltaEs, color=color, marker='.') - - if not random: - idx = np.searchsorted(evals, ww_layer.xmin, side="left") - plt.axvline(x=idx, color='red', label=r'$\lambda_{xmin}$') - else: - idx = np.searchsorted(evals, bulk_max, side="left") - plt.axvline(x=idx, color='red', label=r'$\lambda_{+}$') - - plt.title("Log Delta Es for Layer {}".format(layer_name)) - plt.ylabel("Log Delta Es: "+eqn) - plt.legend() - if savefig: - #plt.savefig("ww.layer{}.deltaEs.png".format(layer_id)) - save_fig(plt, "deltaEs", plot_id, savedir) - plt.show(); plt.clf() - + # sequence of deltas + if WW_PLOT_LOG_DELTAES in plot: + plt.scatter(x,logDeltaEs, color=color, marker='.') + + if random: + bulk_max = ww_layer.rand_bulk_max + idx = np.searchsorted(evals, bulk_max, side="left") + label=r'$\lambda_{+}$' + else: + idx = np.searchsorted(evals, ww_layer.xmin, side="left") + label=r'$\lambda_{xmin}$' + plt.axvline(x=idx, color='red', label=label) + + plt.title("Log Delta Es for Layer {}".format(layer_name)) + plt.ylabel("Log Delta Es: "+eqn) + plt.legend() + if savefig: + #plt.savefig("ww.layer{}.deltaEs.png".format(layer_id)) + save_fig(plt, "deltaEs", plot_id, savedir) + plt.show(); plt.clf() + # level statistics (not mean adjusted because plotting log) - plt.hist(logDeltaEs, bins=100, color=color, density=True) - plt.title("Log Level Statisitcs for Layer {}".format(layer_name)) - plt.ylabel("density") - plt.xlabel(eqn) - plt.legend() - if savefig: - #plt.savefig("ww.layer{}.level-stats.png".format(layer_id)) - save_fig(plt, "level-stats", plot_id, savedir) - plt.show(); plt.clf() + if WW_PLOT_DELTAES_LEVELS in plot: + plt.hist(logDeltaEs, bins=100, color=color, density=True) + plt.title("Log Level Statisitcs for Layer {}".format(layer_name)) + plt.ylabel("density") + plt.xlabel(eqn) + plt.legend() + if savefig: + #plt.savefig("ww.layer{}.level-stats.png".format(layer_id)) + save_fig(plt, "level-stats", plot_id, savedir) + plt.show(); plt.clf() def apply_mp_fit(self, ww_layer, random=True, params=None): """Perform MP fit on random or actual random eigenvalues From 4e5ed64a043d1955b83203a25f2d478e116fac4e Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Thu, 7 Sep 2023 16:39:11 -0700 Subject: [PATCH 06/13] Added a comment --- weightwatcher/weightwatcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 05e35ea..1352202 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -4794,6 +4794,7 @@ def mp_fit(self, evals, N, M, rf, layer_name, layer_id, plot_id, plot, savefig, # TODO: replot on log scale, along with randomized evals # we might add this back in later + # REMEMBER TO ADD a WW_PLOT_MPFIT_XXX constant # plt.hist(to_plot, bins=100, density=True) # plt.hist(to_plot, bins=100, density=True, color='red') From 6b71c995113e0717c9c1c1e1ab104908ff57847c Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Thu, 7 Sep 2023 16:41:38 -0700 Subject: [PATCH 07/13] Adapted apply_analyze_eigenvectors to use singleton plots --- weightwatcher/constants.py | 29 +++++++++++++++++------------ weightwatcher/weightwatcher.py | 11 ++++------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/weightwatcher/constants.py b/weightwatcher/constants.py index 003c41d..189235d 100755 --- a/weightwatcher/constants.py +++ b/weightwatcher/constants.py @@ -155,25 +155,30 @@ STACKED = 'stacked' # constants used to indicate which plots should be generated -WW_PLOT_DETX = 'detX' -WW_PLOT_LOG_DELTAES = 'log_deltaEs' +WW_PLOT_DETX = 'detX' + +WW_PLOT_VECTOR_METRICS = 'vectors' +WW_PLOT_VECTOR_HIST = 'vector_hist' + +WW_PLOT_LOG_DELTAES = 'log_deltaEs' WW_PLOT_DELTAES_LEVELS = 'deltaEs_levels' -WW_PLOT_MPFIT = 'mpfit' -WW_PLOT_MPFIT2 = 'mpfit2' -WW_PLOT_MPDENSITY = 'mpdensity' +WW_PLOT_MPFIT = 'mpfit' +WW_PLOT_MPFIT2 = 'mpfit2' +WW_PLOT_MPDENSITY = 'mpdensity' -WW_PLOT_LOGLOG_ESD = 'loglog_esd' -WW_PLOT_LINLIN_ESD = 'linlin_esd' -WW_PLOT_LOGLIN_ESD = 'loglin_esd' -WW_PLOT_DKS = 'DKS' -WW_PLOT_XMIN_ALPHA = 'xmin_alpha' +WW_PLOT_LOGLOG_ESD = 'loglog_esd' +WW_PLOT_LINLIN_ESD = 'linlin_esd' +WW_PLOT_LOGLIN_ESD = 'loglin_esd' +WW_PLOT_DKS = 'DKS' +WW_PLOT_XMIN_ALPHA = 'xmin_alpha' -WW_PLOT_RANDESD = 'rand_esd' -WW_PLOT_LOG_RANDESD = 'log_rand_esd' +WW_PLOT_RANDESD = 'rand_esd' +WW_PLOT_LOG_RANDESD = 'log_rand_esd' WW_ALL_PLOTS = [ WW_PLOT_DETX, + WW_PLOT_VECTOR_METRICS, WW_PLOT_VECTOR_HIST, WW_PLOT_LOG_DELTAES, WW_PLOT_DELTAES_LEVELS, WW_PLOT_MPFIT, WW_PLOT_MPFIT2, WW_PLOT_MPDENSITY, WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 1352202..82d7b22 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -5256,7 +5256,7 @@ def apply_analyze_eigenvectors(self, ww_layer, params=None): sort_ids = np.argsort(all_evals) - if params[PLOT]: + if WW_PLOT_VECTOR_METRICS in params[PLOT]: fig, axs = plt.subplots(4, figsize=(8, 12)) fig.suptitle("Vector Localization Metrics for {}".format(layer_name)) @@ -5334,7 +5334,7 @@ def apply_analyze_eigenvectors(self, ww_layer, params=None): ww_layer.add_column("tail_var_{}".format(name), tail_var) - if params[PLOT]: + if WW_PLOT_VECTOR_HIST in params[PLOT]: fig, axs = plt.subplots(3) fig.suptitle("Vector Bulk/Tail Metrics for {}".format(layer_name)) @@ -5348,8 +5348,7 @@ def apply_analyze_eigenvectors(self, ww_layer, params=None): data = np.array(arr)[sort_ids] bulk_data = data[bulk_ids] tail_data = data[tail_ids] - - + # should never happen if len(bulk_data)>0: ax.hist(bulk_data, bins=100, color='blue', alpha=0.5, label='bulk', density=True) @@ -5361,9 +5360,7 @@ def apply_analyze_eigenvectors(self, ww_layer, params=None): ax.set_ylabel(title) ax.label_outer() ax.legend() - - - + if savefig: save_fig(plt, "vector_histograms", ww_layer.plot_id, savedir) plt.show(); plt.clf() From 205f20612651e4ae4d2ca49e54accb3cc09195cc Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Thu, 7 Sep 2023 22:02:46 -0700 Subject: [PATCH 08/13] Relaxed check in valid_params so tests will work --- weightwatcher/weightwatcher.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 82d7b22..5f4cc76 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3884,10 +3884,9 @@ def valid_params(params): valid = False - plot = params.get(PLOT) - if (plot is True) or (plot is False): - logger.warning(f"plot param has not been converted from boolean to list") - valid = False + if params[PLOT] is True: params[PLOT] = WW_ALL_PLOTS + if params[PLOT] is False: params[PLOT] = [] + plot = params[PLOT] invalid_plots = set(plot) - set(WW_ALL_PLOTS) if invalid_plots: From 88c6d3872c7c783b1487a67d227920aea65f58b5 Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Fri, 8 Sep 2023 12:44:22 -0700 Subject: [PATCH 09/13] bugfix --- weightwatcher/weightwatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 5f4cc76..4635241 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -4338,9 +4338,9 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=WW_ALL_PLOTS, layer_nam else: xmin = -1 min_evals_to_plot = (0.4*np.max(evals)/100) + evals_to_plot = evals[evals>min_evals_to_plot] if WW_PLOT_LOGLOG_ESD in plot: - evals_to_plot = evals[evals>min_evals_to_plot] plot_loghist(evals_to_plot, bins=100, xmin=xmin) title = "Log-Log ESD for {}\n".format(layer_name) From 49ca1a893952fcc377128e6916c2a251f30748de Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Fri, 8 Sep 2023 13:22:59 -0700 Subject: [PATCH 10/13] Added a new set of tests to ensure plot functionality produces the expected number of plots --- tests/test.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 6 deletions(-) diff --git a/tests/test.py b/tests/test.py index 92e0a82..658bb88 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,4 +1,5 @@ import sys, logging +from pathlib import Path import unittest import warnings @@ -21,7 +22,6 @@ import tempfile from tempfile import TemporaryDirectory import os, errno, shutil, glob -import json from os import listdir from os.path import isfile, join @@ -5252,11 +5252,89 @@ def setUp(self): self.model = models.resnet18(weights='ResNet18_Weights.IMAGENET1K_V1') self.watcher = ww.WeightWatcher(model=self.model, log_level=logging.WARNING) - def testPlots(self): - """ Simply tests that the plot functions will not generate an exception. - Does not guarantee correctness, yet. - """ - self.watcher.analyze(layers=[67], plot=True, randomize=True) + self.plotDir = Path("./testPlots") + + def tearDown_plots(self): + if not self.plotDir.exists(): return + + for f in self.plotDir.iterdir(): + f.unlink() + self.plotDir.rmdir() + + + def check_expected_plots(self, plot_figs): + if len(plot_figs) == 0: + self.assertFalse(self.plotDir.exists()) + return + + self.assertTrue(self.plotDir.exists()) + + figs = list(self.plotDir.iterdir()) + self.assertEqual(len(figs), len(plot_figs), f"plot={plot_figs} produced {len(figs)} images") + self.tearDown_plots() + + + def testPlots(self): + """ Simply tests that the plot functions will not generate an exception. + Does not guarantee correctness, yet. + """ + self.tearDown_plots() # Sometimes tearDown_plots() doesn't get called when a test fails previously. + + self.watcher.analyze(layers=[67], plot=True, randomize=True, savefig=str(self.plotDir)) + + self.assertTrue(self.plotDir.exists(), f"Savefig dir {self.plotDir} should exist after analyze() with plot=True") + + expected_plots = WW_FIT_PL_PLOTS + WW_RANDESD_PLOTS + [WW_PLOT_MPFIT2, WW_PLOT_MPDENSITY] + self.check_expected_plots(expected_plots) + + + def testPlotSingletons(self): + self.tearDown_plots() # Sometimes tearDown_plots() doesn't get called when a test fails previously. + + self.watcher.analyze(layers=[67], plot=[WW_PLOT_DETX], detX=True, savefig=str(self.plotDir)) + self.check_expected_plots([WW_PLOT_DETX]) + + # MPFIT needs Q=1 \/ + self.watcher.analyze(layers=[13], plot=[WW_PLOT_MPFIT], mp_fit=True, savefig=str(self.plotDir)) + self.check_expected_plots([WW_PLOT_MPFIT]) + + self.watcher.analyze(layers=[67], plot=[WW_PLOT_MPFIT], mp_fit=True, savefig=str(self.plotDir)) + self.check_expected_plots([]) + + self.watcher.analyze(layers=[67], plot=[WW_PLOT_MPFIT2], mp_fit=True, savefig=str(self.plotDir)) + self.check_expected_plots([WW_PLOT_MPFIT2]) + + self.watcher.analyze(layers=[67], plot=[WW_PLOT_MPDENSITY], mp_fit=True, savefig=str(self.plotDir)) + self.check_expected_plots([WW_PLOT_MPDENSITY]) + + for plot_fig in WW_FIT_PL_PLOTS: + self.watcher.analyze(layers=[67], plot=[plot_fig], savefig=str(self.plotDir)) + self.check_expected_plots([plot_fig]) + + for plot_fig in WW_RANDESD_PLOTS: + self.watcher.analyze(layers=[67], plot=[plot_fig], randomize=True, savefig=str(self.plotDir)) + self.check_expected_plots([plot_fig]) + + # Commented for future in case support for this is re-enabled. + # for plot_fig in WW_DELTAES_PLOTS: + # self.watcher.analyze(layers=[67], plot=[plot_fig], deltas=True, savefig=str(self.plotDir)) + # self.check_expected_plots([plot_fig]) + + + def testPlotCombos(self): + self.tearDown_plots() # Sometimes tearDown_plots() doesn't get called when a test fails previously. + + self.watcher.analyze(layers=[67], plot=WW_FIT_PL_PLOTS, savefig=str(self.plotDir)) + self.check_expected_plots(WW_FIT_PL_PLOTS) + + self.watcher.analyze(layers=[67], plot=WW_RANDESD_PLOTS, randomize=True, savefig=str(self.plotDir)) + self.check_expected_plots(WW_RANDESD_PLOTS) + + # Commented for future in case support for this is re-enabled. + # self.watcher.analyze(layers=[67], plot=WW_DELTAES_PLOTS, deltas=True, savefig=str(self.plotDir)) + # self.check_expected_plots(WW_DELTAES_PLOTS) + + class Test_Pandas(Test_Base): def setUp(self): From f479ab887918e6b275131a1e830be229e9eab716 Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Fri, 8 Sep 2023 14:12:06 -0700 Subject: [PATCH 11/13] Fixed problems relating to half-precision in test_torch_linalg; split into two tests --- tests/test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test.py b/tests/test.py index 658bb88..6c53df2 100644 --- a/tests/test.py +++ b/tests/test.py @@ -5117,6 +5117,7 @@ def test_torch_availability(self): self.assertTrue(RMT_Util._svd_vals_accurate is RMT_Util._svd_vals_fast) + def test_torch_linalg(self): # Note that if torch is not available then this will test scipy instead. W = np.random.random((50,50)) @@ -5131,6 +5132,21 @@ def test_torch_linalg(self): err = np.sum(np.abs(W - W_reconstruct)) self.assertLess(err, 0.05, f"torch svd absolute reconstruction error was {err}") + def test_torch_linalg_eig(self): + # Note that if torch is not available then this will test scipy instead. + W = np.random.random((50,50)) + W = np.matmul(W, W.T) / 2500 + L, V = RMT_Util._eig_full_fast(W) + W_reconstruct = np.matmul(V.astype("float32"), np.matmul(np.diag(L), np.linalg.inv(V.astype("float32")))) + err = np.sum(np.abs(W - W_reconstruct)) + self.assertLess(err, 0.005, f"torch eig absolute reconstruction error was {err}") + + def test_torch_linalg_svd(self): + W = np.random.random((50,100)) + U, S, Vh = RMT_Util._svd_full_fast(W) + W_reconstruct = np.matmul(U.astype("float32"), np.matmul(np.diag(S), Vh[:50,:].astype("float32"))) + err = np.sum(np.abs(W - W_reconstruct)) + self.assertLess(err, 0.75, f"torch svd absolute reconstruction error was {err}") S_vals_only = RMT_Util._svd_vals_accurate(W) err = np.sum(np.abs(S - S_vals_only)) self.assertLess(err, 0.0005, msg=f"torch svd and svd_vals differed by {err}") From f9776600fe3563e4b56387ba26cf748021aaa1d7 Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Fri, 8 Sep 2023 15:01:25 -0700 Subject: [PATCH 12/13] Bugfix: Mistakenly thought there was an mpfit2 figure --- tests/test.py | 5 +---- weightwatcher/constants.py | 3 +-- weightwatcher/weightwatcher.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/test.py b/tests/test.py index 6c53df2..975a2cf 100644 --- a/tests/test.py +++ b/tests/test.py @@ -5300,7 +5300,7 @@ def testPlots(self): self.assertTrue(self.plotDir.exists(), f"Savefig dir {self.plotDir} should exist after analyze() with plot=True") - expected_plots = WW_FIT_PL_PLOTS + WW_RANDESD_PLOTS + [WW_PLOT_MPFIT2, WW_PLOT_MPDENSITY] + expected_plots = WW_FIT_PL_PLOTS + WW_RANDESD_PLOTS + [WW_PLOT_MPDENSITY] self.check_expected_plots(expected_plots) @@ -5317,9 +5317,6 @@ def testPlotSingletons(self): self.watcher.analyze(layers=[67], plot=[WW_PLOT_MPFIT], mp_fit=True, savefig=str(self.plotDir)) self.check_expected_plots([]) - self.watcher.analyze(layers=[67], plot=[WW_PLOT_MPFIT2], mp_fit=True, savefig=str(self.plotDir)) - self.check_expected_plots([WW_PLOT_MPFIT2]) - self.watcher.analyze(layers=[67], plot=[WW_PLOT_MPDENSITY], mp_fit=True, savefig=str(self.plotDir)) self.check_expected_plots([WW_PLOT_MPDENSITY]) diff --git a/weightwatcher/constants.py b/weightwatcher/constants.py index 189235d..8233e84 100755 --- a/weightwatcher/constants.py +++ b/weightwatcher/constants.py @@ -164,7 +164,6 @@ WW_PLOT_DELTAES_LEVELS = 'deltaEs_levels' WW_PLOT_MPFIT = 'mpfit' -WW_PLOT_MPFIT2 = 'mpfit2' WW_PLOT_MPDENSITY = 'mpdensity' WW_PLOT_LOGLOG_ESD = 'loglog_esd' @@ -180,7 +179,7 @@ WW_PLOT_DETX, WW_PLOT_VECTOR_METRICS, WW_PLOT_VECTOR_HIST, WW_PLOT_LOG_DELTAES, WW_PLOT_DELTAES_LEVELS, - WW_PLOT_MPFIT, WW_PLOT_MPFIT2, WW_PLOT_MPDENSITY, + WW_PLOT_MPFIT, WW_PLOT_MPDENSITY, WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, WW_PLOT_RANDESD, WW_PLOT_LOG_RANDESD, ] diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 4635241..55442d8 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -4779,7 +4779,7 @@ def mp_fit(self, evals, N, M, rf, layer_name, layer_id, plot_id, plot, savefig, sigma_mp, x, mp = plot_density_and_fit(model=None, eigenvalues=to_plot, layer_name=layer_name, layer_id=0, Q=Q, num_spikes=0, sigma=s1, verbose = False, plot=plot, color=color, cutoff=bulk_max_TW)#, scale=Wscale) - if WW_PLOT_MPFIT2 in plot: + if WW_PLOT_MPDENSITY in plot: title = fit_law +" for layer "+layer_name+"\n Q={:0.3} ".format(Q) title = title + r"$\sigma_{mp}=$"+"{:0.3} ".format(sigma_mp) title = title + r"$\mathcal{R}_{mp}=$"+"{:0.3} ".format(mp_softrank) From 53d8de98772f43dda1cfbac8e1b7f4018f6ba580 Mon Sep 17 00:00:00 2001 From: Chris Hinrichs Date: Tue, 19 Sep 2023 12:38:26 -0700 Subject: [PATCH 13/13] bugfix - loglog esd was being plotted on plots that did not expect it. --- weightwatcher/weightwatcher.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 55442d8..39c056b 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -4324,23 +4324,23 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=WW_ALL_PLOTS, layer_nam if set(WW_FIT_PL_PLOTS) & set(plot): if status==SUCCESS: min_evals_to_plot = (xmin/100) - - fig2 = fit.plot_pdf(color='b', linewidth=0) # invisbile - fig2 = fit.plot_pdf(color='r', linewidth=2) - if fit_type==POWER_LAW: - if pl_package == WW_POWERLAW_PACKAGE: - fit.plot_power_law_pdf(color='r', linestyle='--', ax=fig2) - else: - fit.power_law.plot_pdf(color='r', linestyle='--', ax=fig2) - - else: - fit.truncated_power_law.plot_pdf(color='r', linestyle='--', ax=fig2) else: xmin = -1 min_evals_to_plot = (0.4*np.max(evals)/100) evals_to_plot = evals[evals>min_evals_to_plot] if WW_PLOT_LOGLOG_ESD in plot: + fig2 = fit.plot_pdf(color='b', linewidth=0) # invisbile + fig2 = fit.plot_pdf(color='r', linewidth=2) + if fit_type==POWER_LAW: + if pl_package == WW_POWERLAW_PACKAGE: + fit.plot_power_law_pdf(color='r', linestyle='--', ax=fig2) + else: + fit.power_law.plot_pdf(color='r', linestyle='--', ax=fig2) + + else: + fit.truncated_power_law.plot_pdf(color='r', linestyle='--', ax=fig2) + plot_loghist(evals_to_plot, bins=100, xmin=xmin) title = "Log-Log ESD for {}\n".format(layer_name)