diff --git a/tests/test.py b/tests/test.py index 92e0a82..4d4368d 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,4 +1,5 @@ import sys, logging +from pathlib import Path import unittest import warnings @@ -21,7 +22,6 @@ import tempfile from tempfile import TemporaryDirectory import os, errno, shutil, glob -import json from os import listdir from os.path import isfile, join @@ -5131,6 +5131,22 @@ def test_torch_linalg(self): err = np.sum(np.abs(W - W_reconstruct)) self.assertLess(err, 0.05, f"torch svd absolute reconstruction error was {err}") + def test_torch_linalg_eig(self): + # Note that if torch is not available then this will test scipy instead. + W = np.random.random((50,50)) + W = np.matmul(W, W.T) / 2500 + L, V = RMT_Util._eig_full_fast(W) + W_reconstruct = np.matmul(V.astype("float32"), np.matmul(np.diag(L), np.linalg.inv(V.astype("float32")))) + err = np.sum(np.abs(W - W_reconstruct)) + self.assertLess(err, 0.005, f"torch eig absolute reconstruction error was {err}") + + + def test_torch_linalg_svd(self): + W = np.random.random((50,100)) + U, S, Vh = RMT_Util._svd_full_fast(W) + W_reconstruct = np.matmul(U.astype("float32"), np.matmul(np.diag(S), Vh[:50,:].astype("float32"))) + err = np.sum(np.abs(W - W_reconstruct)) + self.assertLess(err, 0.75, f"torch svd absolute reconstruction error was {err}") S_vals_only = RMT_Util._svd_vals_accurate(W) err = np.sum(np.abs(S - S_vals_only)) self.assertLess(err, 0.0005, msg=f"torch svd and svd_vals differed by {err}") @@ -5252,11 +5268,108 @@ def setUp(self): self.model = models.resnet18(weights='ResNet18_Weights.IMAGENET1K_V1') self.watcher = ww.WeightWatcher(model=self.model, log_level=logging.WARNING) - def testPlots(self): - """ Simply tests that the plot functions will not generate an exception. - Does not guarantee correctness, yet. - """ - self.watcher.analyze(layers=[67], plot=True, randomize=True) + self.plotDir = Path("./testPlots") + + def tearDown_plots(self): + if not self.plotDir.exists(): return + + for f in self.plotDir.iterdir(): + f.unlink() + self.plotDir.rmdir() + + + def check_expected_plots(self, plot_figs): + if len(plot_figs) == 0: + self.assertFalse(self.plotDir.exists()) + return + + self.assertTrue(self.plotDir.exists()) + + figs = list(self.plotDir.iterdir()) + self.assertEqual(len(figs), len(plot_figs), f"plot={plot_figs} produced {len(figs)} images") + self.tearDown_plots() + + + self.plotDir = Path("./testPlots") + + def tearDown_plots(self): + if not self.plotDir.exists(): return + + for f in self.plotDir.iterdir(): + f.unlink() + self.plotDir.rmdir() + + + def check_expected_plots(self, plot_figs): + if len(plot_figs) == 0: + self.assertFalse(self.plotDir.exists()) + return + + self.assertTrue(self.plotDir.exists()) + + figs = list(self.plotDir.iterdir()) + self.assertEqual(len(figs), len(plot_figs), f"plot={plot_figs} produced {len(figs)} images") + self.tearDown_plots() + + + def testPlots(self): + """ Simply tests that the plot functions will not generate an exception. + Does not guarantee correctness, yet. + """ + self.tearDown_plots() # Sometimes tearDown_plots() doesn't get called when a test fails previously. + + self.watcher.analyze(layers=[67], plot=True, randomize=True, savefig=str(self.plotDir)) + + self.assertTrue(self.plotDir.exists(), f"Savefig dir {self.plotDir} should exist after analyze() with plot=True") + + expected_plots = WW_FIT_PL_PLOTS + WW_RANDESD_PLOTS + [WW_PLOT_MPDENSITY] + self.check_expected_plots(expected_plots) + + + def testPlotSingletons(self): + self.tearDown_plots() # Sometimes tearDown_plots() doesn't get called when a test fails previously. + + self.watcher.analyze(layers=[67], plot=[WW_PLOT_DETX], detX=True, savefig=str(self.plotDir)) + self.check_expected_plots([WW_PLOT_DETX]) + + # MPFIT needs Q=1 \/ + self.watcher.analyze(layers=[13], plot=[WW_PLOT_MPFIT], mp_fit=True, savefig=str(self.plotDir)) + self.check_expected_plots([WW_PLOT_MPFIT]) + + self.watcher.analyze(layers=[67], plot=[WW_PLOT_MPFIT], mp_fit=True, savefig=str(self.plotDir)) + self.check_expected_plots([]) + + self.watcher.analyze(layers=[67], plot=[WW_PLOT_MPDENSITY], mp_fit=True, savefig=str(self.plotDir)) + self.check_expected_plots([WW_PLOT_MPDENSITY]) + + for plot_fig in WW_FIT_PL_PLOTS: + self.watcher.analyze(layers=[67], plot=[plot_fig], savefig=str(self.plotDir)) + self.check_expected_plots([plot_fig]) + + for plot_fig in WW_RANDESD_PLOTS: + self.watcher.analyze(layers=[67], plot=[plot_fig], randomize=True, savefig=str(self.plotDir)) + self.check_expected_plots([plot_fig]) + + # Commented for future in case support for this is re-enabled. + # for plot_fig in WW_DELTAES_PLOTS: + # self.watcher.analyze(layers=[67], plot=[plot_fig], deltas=True, savefig=str(self.plotDir)) + # self.check_expected_plots([plot_fig]) + + + def testPlotCombos(self): + self.tearDown_plots() # Sometimes tearDown_plots() doesn't get called when a test fails previously. + + self.watcher.analyze(layers=[67], plot=WW_FIT_PL_PLOTS, savefig=str(self.plotDir)) + self.check_expected_plots(WW_FIT_PL_PLOTS) + + self.watcher.analyze(layers=[67], plot=WW_RANDESD_PLOTS, randomize=True, savefig=str(self.plotDir)) + self.check_expected_plots(WW_RANDESD_PLOTS) + + # Commented for future in case support for this is re-enabled. + # self.watcher.analyze(layers=[67], plot=WW_DELTAES_PLOTS, deltas=True, savefig=str(self.plotDir)) + # self.check_expected_plots(WW_DELTAES_PLOTS) + + class Test_Pandas(Test_Base): def setUp(self): @@ -5317,7 +5430,6 @@ def test_column_names_analyze_detX(self): 'xmax', 'xmin'] - details = self.watcher.analyze(layers=[67], detX=True) self.assertTrue(isinstance(details, pd.DataFrame), "details is a pandas DataFrame") self.assertEqual(len(expected_columns), len(details.columns)) diff --git a/weightwatcher/RMT_Util.py b/weightwatcher/RMT_Util.py index e0facf3..70b2d89 100644 --- a/weightwatcher/RMT_Util.py +++ b/weightwatcher/RMT_Util.py @@ -452,6 +452,9 @@ def plot_density_and_fit(eigenvalues=None, model=None, layer_name="", layer_id=0 # if eigenvalues is None: # eigenvalues = get_eigenvalues(model, weightfile, layer) + + if plot is True: plot = [WW_PLOT_MPDENSITY] + if plot is False: plot = [] if Q == 1: to_fit = np.sqrt(eigenvalues) @@ -463,7 +466,7 @@ def plot_density_and_fit(eigenvalues=None, model=None, layer_name="", layer_id=0 label = r'$\rho_{emp}(\lambda)$' title = " W{} ESD, MP Sigma={:0.3}f" - if plot: + if WW_PLOT_MPDENSITY in plot: plt.hist(to_fit, bins=100, alpha=alpha, color=color, density=True, label=label); plt.legend() @@ -491,7 +494,7 @@ def plot_density_and_fit(eigenvalues=None, model=None, layer_name="", layer_id=0 else: x, mp = marchenko_pastur_pdf(x_min, x_max, Q, sigma) - if plot: + if WW_PLOT_MPDENSITY in plot: plt.title(title.format(layer_name, sigma)) plt.plot(x, mp, linewidth=1, color='r', label="MP fit") diff --git a/weightwatcher/constants.py b/weightwatcher/constants.py index ebfe1b1..8233e84 100755 --- a/weightwatcher/constants.py +++ b/weightwatcher/constants.py @@ -154,6 +154,48 @@ PLOT = 'plot' STACKED = 'stacked' +# constants used to indicate which plots should be generated +WW_PLOT_DETX = 'detX' + +WW_PLOT_VECTOR_METRICS = 'vectors' +WW_PLOT_VECTOR_HIST = 'vector_hist' + +WW_PLOT_LOG_DELTAES = 'log_deltaEs' +WW_PLOT_DELTAES_LEVELS = 'deltaEs_levels' + +WW_PLOT_MPFIT = 'mpfit' +WW_PLOT_MPDENSITY = 'mpdensity' + +WW_PLOT_LOGLOG_ESD = 'loglog_esd' +WW_PLOT_LINLIN_ESD = 'linlin_esd' +WW_PLOT_LOGLIN_ESD = 'loglin_esd' +WW_PLOT_DKS = 'DKS' +WW_PLOT_XMIN_ALPHA = 'xmin_alpha' + +WW_PLOT_RANDESD = 'rand_esd' +WW_PLOT_LOG_RANDESD = 'log_rand_esd' + +WW_ALL_PLOTS = [ + WW_PLOT_DETX, + WW_PLOT_VECTOR_METRICS, WW_PLOT_VECTOR_HIST, + WW_PLOT_LOG_DELTAES, WW_PLOT_DELTAES_LEVELS, + WW_PLOT_MPFIT, WW_PLOT_MPDENSITY, + WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, + WW_PLOT_RANDESD, WW_PLOT_LOG_RANDESD, +] + +WW_FIT_PL_PLOTS = [ + WW_PLOT_LOGLOG_ESD, WW_PLOT_LINLIN_ESD, WW_PLOT_LOGLIN_ESD, WW_PLOT_DKS, WW_PLOT_XMIN_ALPHA, +] + +WW_RANDESD_PLOTS = [ + WW_PLOT_RANDESD, WW_PLOT_LOG_RANDESD, +] + +WW_DELTAES_PLOTS = [ + WW_PLOT_LOG_DELTAES, WW_PLOT_DELTAES_LEVELS, +] + CHANNELS_STR = 'channels' FIRST = 'first' LAST = 'last' diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 34aa497..bad7506 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -2988,7 +2988,7 @@ def apply_random_esd(self, ww_layer, params=None): value = np.max(evals)-np.max(rand_evals) ww_layer.add_column("ww_maxdist", value) - if params[PLOT]: + if set(WW_RANDESD_PLOTS) & set(params[PLOT]): self.plot_random_esd(ww_layer, params) return ww_layer @@ -3061,21 +3061,30 @@ def apply_detX(self, ww_layer, params=None): detX_num, detX_idx = detX_constraint(evals, rescale=False) detX_val = evals[detX_idx] + alpha = ww_layer.alpha + + # Calculate the value of xmin - detX_val in normalized units. + detX_delta = None + if ww_layer.xmin: + xmin = ww_layer.xmin * np.max(evals)/ww_layer.xmax + detX_delta = xmin - detX_val + else: + xmin = None - if plot: + + if WW_PLOT_DETX in plot: name = ww_layer.name # fix rescaling to plot xmin layer_id = ww_layer.layer_id # where is the layer_id plot_id = ww_layer.plot_id - plt.title(f"DetX constraint for {name}") + plt.title(f"DetX constraint for {name}\n" + r"$\alpha$=" + f"{alpha:0.02f}") plt.xlabel("log10 eigenvalues (norm scaled)") nz_evals = evals[evals > EVALS_THRESH] plt.hist(np.log10(nz_evals), bins=100) plt.axvline(np.log10(detX_val), color='purple', label=r"detX$=1$") if ww_layer.xmin: - xmin = ww_layer.xmin * np.max(evals)/ww_layer.xmax plt.axvline(np.log10(xmin), color='red', label=r"PL $\lambda_{min}$") plt.legend() @@ -3092,6 +3101,7 @@ def apply_detX(self, ww_layer, params=None): ww_layer.add_column('detX_val', detX_val) ww_layer.add_column('detX_val_unrescaled', detX_val_unrescaled) + ww_layer.add_column('detX_delta', detX_delta) return ww_layer @@ -3127,7 +3137,6 @@ def apply_powerlaw(self, ww_layer, params=None): layer_id = ww_layer.layer_id plot_id = ww_layer.plot_id name = ww_layer.name - title = "{} {}".format(layer_id, name) xmin = None # TODO: allow other xmin settings xmax = params[XMAX]#issue 199np.max(evals) @@ -3143,7 +3152,7 @@ def apply_powerlaw(self, ww_layer, params=None): max_fingers = params[MAX_FINGERS] finger_thresh = params[FINGER_THRESH] - layer_name = "Layer {}".format(plot_id) + layer_name = f"{layer_id} {name}" fit_type = params[FIT] pl_package = params[PL_PACKAGE] @@ -3509,6 +3518,9 @@ def analyze(self, model=None, layers=[], logger.warning("WW2X option deprecated, reverting too POOL=False") ww2x=False pool=False + + if plot is True: plot = WW_ALL_PLOTS + if plot is False: plot = [] params=DEFAULT_PARAMS.copy() @@ -3597,7 +3609,7 @@ def analyze(self, model=None, layers=[], logger.debug("MP Fitting Layer: {} {} ".format(ww_layer.layer_id, ww_layer.name)) self.apply_mp_fit(ww_layer, random=False, params=params) - if params[DELTA_ES] and params[PLOT]: + if params[DELTA_ES] and set(WW_DELTAES_PLOTS) & set(params[PLOT]): logger.debug("Computing and Plotting Deltas: {} {} ".format(ww_layer.layer_id, ww_layer.name)) self.apply_plot_deltaEs(ww_layer, random=False, params=params) @@ -3612,7 +3624,7 @@ def analyze(self, model=None, layers=[], logger.debug("MP Fitting Random layer: {} {} ".format(ww_layer.layer_id, ww_layer.name)) self.apply_mp_fit(ww_layer, random=True, params=params) - if params[DELTA_ES] and params[PLOT]: + if params[DELTA_ES] and set(WW_DELTAES_PLOTS) & set(params[PLOT]): logger.debug("Computing and Plotting Deltas: {} {} ".format(ww_layer.layer_id, ww_layer.name)) self.apply_plot_deltaEs(ww_layer, random=True, params=params) @@ -3878,7 +3890,17 @@ def valid_params(params): params[SAVEDIR] = savefig logger.info("Saving all images to {}".format(savedir)) elif not isinstance(savefig,str) and not isinstance(savefig,bool): - valid = False + valid = False + + + if params[PLOT] is True: params[PLOT] = WW_ALL_PLOTS + if params[PLOT] is False: params[PLOT] = [] + plot = params[PLOT] + + invalid_plots = set(plot) - set(WW_ALL_PLOTS) + if invalid_plots: + logger.warning(f"Invalid plot types detected: {invalid_plots}. See WW_ALL_PLOTS in constants.py") + valid = False fix_fingers = params[FIX_FINGERS] @@ -4086,7 +4108,8 @@ def plot_random_esd(self, ww_layer, params=None): """Plot histogram and log histogram of ESD and randomized ESD""" if params is None: params = DEFAULT_PARAMS.copy() - + + plot = params[PLOT] savefig = params[SAVEFIG] savedir = params[SAVEDIR] @@ -4094,37 +4117,39 @@ def plot_random_esd(self, ww_layer, params=None): plot_id = ww_layer.plot_id evals = ww_layer.evals rand_evals = ww_layer.rand_evals - title = "Layer {} {}: ESD & Random ESD".format(ww_layer.layer_id,ww_layer.name) - + nonzero_evals = evals[evals > 0.0] nonzero_rand_evals = rand_evals[rand_evals > 0.0] max_rand_eval = np.max(rand_evals) - plt.hist((nonzero_evals), bins=100, density=True, color='g', label='original') - plt.hist((nonzero_rand_evals), bins=100, density=True, color='r', label='random', alpha=0.5) - plt.axvline(x=(max_rand_eval), color='orange', label='max rand') - plt.title(title) - plt.xlabel(r" Eigenvalues $(\lambda)$") - plt.legend() - if savefig: - #plt.savefig("ww.layer{}.esd.png".format(layer_id)) - save_fig(plt, "randesd1", plot_id, savedir) - plt.show(); plt.clf() + if WW_PLOT_RANDESD in plot: + plt.hist((nonzero_evals), bins=100, density=True, color='g', label='original') + plt.hist((nonzero_rand_evals), bins=100, density=True, color='r', label='random', alpha=0.5) + plt.axvline(x=(max_rand_eval), color='orange', label='max rand') + title = "Layer {} {}: ESD & Random ESD".format(ww_layer.layer_id,ww_layer.name) + plt.title(title) + plt.xlabel(r" Eigenvalues $(\lambda)$") + plt.legend() + if savefig: + #plt.savefig("ww.layer{}.esd.png".format(layer_id)) + save_fig(plt, "randesd1", plot_id, savedir) + plt.show(); plt.clf() - plt.hist(np.log10(nonzero_evals), bins=100, density=True, color='g', label='original') - plt.hist(np.log10(nonzero_rand_evals), bins=100, density=True, color='r', label='random', alpha=0.5) - plt.axvline(x=np.log10(max_rand_eval), color='orange', label='max rand') - title = "Layer {} {}: Log10 ESD & Random ESD".format(ww_layer.layer_id,ww_layer.name) - plt.title(title) - plt.xlabel(r"Log10 Eigenvalues $(log_{10}\lambda)$") - plt.legend() - if savefig: - #plt.savefig("ww.layer{}.randesd.2.png".format(layer_id)) - save_fig(plt, "randesd2", plot_id, savedir) - plt.show(); plt.clf() + if WW_PLOT_LOG_RANDESD in plot: + plt.hist(np.log10(nonzero_evals), bins=100, density=True, color='g', label='original') + plt.hist(np.log10(nonzero_rand_evals), bins=100, density=True, color='r', label='random', alpha=0.5) + plt.axvline(x=np.log10(max_rand_eval), color='orange', label='max rand') + title = "Layer {} {}: Log10 ESD & Random ESD".format(ww_layer.layer_id,ww_layer.name) + plt.title(title) + plt.xlabel(r"Log10 Eigenvalues $(log_{10}\lambda)$") + plt.legend() + if savefig: + #plt.savefig("ww.layer{}.randesd.2.png".format(layer_id)) + save_fig(plt, "randesd2", plot_id, savedir) + plt.show(); plt.clf() - def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", layer_id=0, plot_id=0, \ + def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=WW_ALL_PLOTS, layer_name="", layer_id=0, plot_id=0, \ sample=False, sample_size=None, savedir=DEF_SAVE_DIR, savefig=True, \ thresh=EVALS_THRESH,\ fix_fingers=False, finger_thresh=DEFAULT_FINGER_THRESH, xmin_max=None, max_fingers=DEFAULT_MAX_FINGERS, \ @@ -4305,26 +4330,29 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la #fit_entropy = line_entropy(fit.Ds) - if plot: - + + if set(WW_FIT_PL_PLOTS) & set(plot): if status==SUCCESS: min_evals_to_plot = (xmin/100) - - fig2 = fit.plot_pdf(color='b', linewidth=0) # invisbile - fig2 = fit.plot_pdf(color='r', linewidth=2) - if fit_type==POWER_LAW: - if pl_package == WW_POWERLAW_PACKAGE: - fit.plot_power_law_pdf(color='r', linestyle='--', ax=fig2) - else: - fit.power_law.plot_pdf(color='r', linestyle='--', ax=fig2) - - else: - fit.truncated_power_law.plot_pdf(color='r', linestyle='--', ax=fig2) else: xmin = -1 min_evals_to_plot = (0.4*np.max(evals)/100) + if WW_PLOT_LOGLOG_ESD in plot: evals_to_plot = evals[evals>min_evals_to_plot] + + if WW_PLOT_LOGLOG_ESD in plot: + fig2 = fit.plot_pdf(color='b', linewidth=0) # invisbile + fig2 = fit.plot_pdf(color='r', linewidth=2) + if fit_type==POWER_LAW: + if pl_package == WW_POWERLAW_PACKAGE: + fit.plot_power_law_pdf(color='r', linestyle='--', ax=fig2) + else: + fit.power_law.plot_pdf(color='r', linestyle='--', ax=fig2) + + else: + fit.truncated_power_law.plot_pdf(color='r', linestyle='--', ax=fig2) + plot_loghist(evals_to_plot, bins=100, xmin=xmin) title = "Log-Log ESD for {}\n".format(layer_name) @@ -4342,8 +4370,9 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la #plt.savefig("ww.layer{}.esd.png".format(layer_id)) save_fig(plt, "esd", plot_id, savedir) plt.show(); plt.clf() - - + + + if WW_PLOT_LINLIN_ESD in plot: # plot eigenvalue histogram num_bins = 100 # np.min([100,len(evals)]) plt.hist(evals_to_plot, bins=num_bins, density=True) @@ -4356,6 +4385,7 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la save_fig(plt, "esd2", plot_id, savedir) plt.show(); plt.clf() + if WW_PLOT_LOGLIN_ESD in plot: # plot log eigenvalue histogram nonzero_evals = evals_to_plot[evals_to_plot > 0.0] plt.hist(np.log10(nonzero_evals), bins=100, density=True) @@ -4370,7 +4400,8 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la plt.show(); plt.clf() # plot xmins vs D - + + if WW_PLOT_DKS in plot: plt.plot(fit.xmins, fit.Ds, label=r'$D_{KS}$') plt.axvline(x=fit.xmin, color='red', label=r'$\lambda_{xmin}$') #plt.plot(fit.xmins, fit.sigmas / fit.alphas, label=r'$\sigma /\alpha$', linestyle='--') @@ -4388,9 +4419,10 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la if savefig: save_fig(plt, "esd4", plot_id, savedir) #plt.savefig("ww.layer{}.esd4.png".format(layer_id)) - plt.show(); plt.clf() - - + plt.show(); plt.clf() + + + if WW_PLOT_XMIN_ALPHA in plot: plt.plot(fit.xmins, fit.alphas, label=r'$\alpha(xmin)$') plt.axvline(x=fit.xmin, color='red', label=r'$\lambda_{xmin}$') plt.xlabel(r'$x_{min}$') @@ -4401,9 +4433,7 @@ def fit_powerlaw(self, evals, xmin=None, xmax=None, plot=True, layer_name="", la if savefig: save_fig(plt, "esd5", plot_id, savedir) #plt.savefig("ww.layer{}.esd5.png".format(layer_id)) - - - plt.show(); plt.clf() + plt.show(); plt.clf() raw_alpha = -1 if raw_fit is not None: @@ -4574,54 +4604,58 @@ def apply_plot_deltaEs(self, ww_layer, random=False, params=None): plot_id = ww_layer.plot_id name = ww_layer.name or "" layer_name = "{} {}".format(plot_id, name) - + + plot = params[PLOT] savefig = params[SAVEFIG] savedir = params[SAVEDIR] + if not set(WW_DELTAES_PLOTS) & set(plot): return + if random: layer_name = "{} Randomized".format(layer_name) - title = "Layer {} W".format(layer_name) evals = ww_layer.rand_evals color='mediumorchid' - bulk_max = ww_layer.rand_bulk_max else: - title = "Layer {} W".format(layer_name) evals = ww_layer.evals color='blue' - # sequence of deltas deltaEs = np.diff(evals) logDeltaEs = np.log10(deltaEs) x = np.arange(len(deltaEs)) eqn = r"$\log_{10}\Delta(\lambda)$" - plt.scatter(x,logDeltaEs, color=color, marker='.') - - if not random: - idx = np.searchsorted(evals, ww_layer.xmin, side="left") - plt.axvline(x=idx, color='red', label=r'$\lambda_{xmin}$') - else: - idx = np.searchsorted(evals, bulk_max, side="left") - plt.axvline(x=idx, color='red', label=r'$\lambda_{+}$') - - plt.title("Log Delta Es for Layer {}".format(layer_name)) - plt.ylabel("Log Delta Es: "+eqn) - plt.legend() - if savefig: - #plt.savefig("ww.layer{}.deltaEs.png".format(layer_id)) - save_fig(plt, "deltaEs", plot_id, savedir) - plt.show(); plt.clf() - + # sequence of deltas + if WW_PLOT_LOG_DELTAES in plot: + plt.scatter(x,logDeltaEs, color=color, marker='.') + + if random: + bulk_max = ww_layer.rand_bulk_max + idx = np.searchsorted(evals, bulk_max, side="left") + label=r'$\lambda_{+}$' + else: + idx = np.searchsorted(evals, ww_layer.xmin, side="left") + label=r'$\lambda_{xmin}$' + plt.axvline(x=idx, color='red', label=label) + + plt.title("Log Delta Es for Layer {}".format(layer_name)) + plt.ylabel("Log Delta Es: "+eqn) + plt.legend() + if savefig: + #plt.savefig("ww.layer{}.deltaEs.png".format(layer_id)) + save_fig(plt, "deltaEs", plot_id, savedir) + plt.show(); plt.clf() + # level statistics (not mean adjusted because plotting log) - plt.hist(logDeltaEs, bins=100, color=color, density=True) - plt.title("Log Level Statisitcs for Layer {}".format(layer_name)) - plt.ylabel("density") - plt.xlabel(eqn) - plt.legend() - if savefig: - #plt.savefig("ww.layer{}.level-stats.png".format(layer_id)) - save_fig(plt, "level-stats", plot_id, savedir) - plt.show(); plt.clf() + if WW_PLOT_DELTAES_LEVELS in plot: + plt.hist(logDeltaEs, bins=100, color=color, density=True) + plt.title("Log Level Statisitcs for Layer {}".format(layer_name)) + plt.ylabel("density") + plt.xlabel(eqn) + plt.legend() + if savefig: + #plt.savefig("ww.layer{}.level-stats.png".format(layer_id)) + save_fig(plt, "level-stats", plot_id, savedir) + plt.show(); plt.clf() def apply_mp_fit(self, ww_layer, random=True, params=None): """Perform MP fit on random or actual random eigenvalues @@ -4741,7 +4775,7 @@ def mp_fit(self, evals, N, M, rf, layer_name, layer_id, plot_id, plot, savefig, #TODO: set cutoff #Even if the quarter circle applies, still plot the MP_fit - if plot: + if WW_PLOT_MPFIT in plot: plot_density(to_plot, Q=Q, sigma=s1, method="MP", color=color, cutoff=bulk_max_TW)#, scale=Wscale) plt.legend([r'$\rho_{emp}(\lambda)$', 'MP fit']) plt.title("MP ESD, sigma auto-fit for {}".format(layer_name)) @@ -4757,7 +4791,7 @@ def mp_fit(self, evals, N, M, rf, layer_name, layer_id, plot_id, plot, savefig, sigma_mp, x, mp = plot_density_and_fit(model=None, eigenvalues=to_plot, layer_name=layer_name, layer_id=0, Q=Q, num_spikes=0, sigma=s1, verbose = False, plot=plot, color=color, cutoff=bulk_max_TW)#, scale=Wscale) - if plot: + if WW_PLOT_MPDENSITY in plot: title = fit_law +" for layer "+layer_name+"\n Q={:0.3} ".format(Q) title = title + r"$\sigma_{mp}=$"+"{:0.3} ".format(sigma_mp) title = title + r"$\mathcal{R}_{mp}=$"+"{:0.3} ".format(mp_softrank) @@ -4771,6 +4805,7 @@ def mp_fit(self, evals, N, M, rf, layer_name, layer_id, plot_id, plot, savefig, # TODO: replot on log scale, along with randomized evals # we might add this back in later + # REMEMBER TO ADD a WW_PLOT_MPFIT_XXX constant # plt.hist(to_plot, bins=100, density=True) # plt.hist(to_plot, bins=100, density=True, color='red') @@ -4851,14 +4886,17 @@ def SVDSmoothing(self, model=None, percent=0.8, pool=True, layers=[], method=DET """ - self.set_model_(model) + self.set_model_(model) + + if plot is True: plot = WW_ALL_PLOTS + if plot is False: plot = [] params = DEFAULT_PARAMS.copy() params[POOL] = pool params[LAYERS] = layers params[FIT] = fit # only useful for method=LAMBDA_MINa - params[PLOT] = False + params[PLOT] = [] params[START_IDS] = start_ids params[SVD_METHOD] = svd_method @@ -5071,6 +5109,9 @@ def SVDSharpness(self, model=None, pool=True, layers=[], plot=False, start_ids= """ self.set_model_(model) + + if plot is True: plot = WW_ALL_PLOTS + if plot is False: plot = [] params=DEFAULT_PARAMS.copy() params[POOL] = pool @@ -5138,7 +5179,10 @@ def analyze_vectors(self, model=None, layers=[], min_evals=DEFAULT_MIN_EVALS, ma plot=True, savefig=DEF_SAVE_DIR, channels=None): """Seperate method to analyze the eigenvectors of each layer""" - self.set_model_(model) + self.set_model_(model) + + if plot is True: plot = WW_ALL_PLOTS + if plot is False: plot = [] params=DEFAULT_PARAMS.copy() params[SAVEFIG] = savefig @@ -5223,7 +5267,7 @@ def apply_analyze_eigenvectors(self, ww_layer, params=None): sort_ids = np.argsort(all_evals) - if params[PLOT]: + if WW_PLOT_VECTOR_METRICS in params[PLOT]: fig, axs = plt.subplots(4, figsize=(8, 12)) fig.suptitle("Vector Localization Metrics for {}".format(layer_name)) @@ -5301,7 +5345,7 @@ def apply_analyze_eigenvectors(self, ww_layer, params=None): ww_layer.add_column("tail_var_{}".format(name), tail_var) - if params[PLOT]: + if WW_PLOT_VECTOR_HIST in params[PLOT]: fig, axs = plt.subplots(3) fig.suptitle("Vector Bulk/Tail Metrics for {}".format(layer_name)) @@ -5315,8 +5359,7 @@ def apply_analyze_eigenvectors(self, ww_layer, params=None): data = np.array(arr)[sort_ids] bulk_data = data[bulk_ids] tail_data = data[tail_ids] - - + # should never happen if len(bulk_data)>0: ax.hist(bulk_data, bins=100, color='blue', alpha=0.5, label='bulk', density=True) @@ -5328,9 +5371,7 @@ def apply_analyze_eigenvectors(self, ww_layer, params=None): ax.set_ylabel(title) ax.label_outer() ax.legend() - - - + if savefig: save_fig(plt, "vector_histograms", ww_layer.plot_id, savedir) plt.show(); plt.clf()