diff --git a/nidn/__init__.py b/nidn/__init__.py index 1befd89..4320741 100644 --- a/nidn/__init__.py +++ b/nidn/__init__.py @@ -1,8 +1,6 @@ import os - -# Set main device by default to cpu if no other choice was made before -if "TORCH_DEVICE" not in os.environ: - os.environ["TORCH_DEVICE"] = "cpu" +import torch +from loguru import logger # Add exposed features here from .plots.plot_model_grid import plot_model_grid @@ -15,6 +13,19 @@ from .utils.fix_random_seeds import fix_random_seeds from .utils.load_default_cfg import load_default_cfg from .utils.print_cfg import print_cfg +from .utils.set_log_level import set_log_level + +set_log_level("INFO") + +# Set main device by default to cpu if no other choice was made before +if "TORCH_DEVICE" not in os.environ: + os.environ["TORCH_DEVICE"] = "cpu" + +logger.info(f"Initialized NIDN for {os.environ['TORCH_DEVICE']}") + +# Set precision (and potentially GPU) +torch.set_default_tensor_type(torch.DoubleTensor) +logger.info("Using double precision") __all__ = [ "compute_target_frequencies", @@ -28,5 +39,6 @@ "plot_model_grid_per_freq", "plot_spectra", "print_cfg", + "set_log_level", "wl_to_phys_wl", ] diff --git a/nidn/materials/material_collection.py b/nidn/materials/material_collection.py index 59b58d8..2a0e34e 100644 --- a/nidn/materials/material_collection.py +++ b/nidn/materials/material_collection.py @@ -52,10 +52,10 @@ def _load_materials_folder(self): self.N_materials = len(self.material_names) def _load_material_data(self, name): - """Loads the passed wavelength,n,k data from the passed csv file for the closest frequencies and returns epsilon (permittivity). + """Loads data (wavelength, n, and k) from the passed csv file for the closest frequencies and returns epsilon (permittivity). Args: - name (str): Path to csv + name (str): Path to csv. Returns: torch.tensor: Epsilon for the material (permittivity) diff --git a/nidn/tests/material_collection_test.py b/nidn/tests/material_collection_test.py index a1cba87..a59787a 100644 --- a/nidn/tests/material_collection_test.py +++ b/nidn/tests/material_collection_test.py @@ -2,8 +2,8 @@ def test_material_collection_init(): - """Tests if the material collection can be initialized successfully""" - target_frequencies = [1.0, 0.1, 0.01] + """Tests if the material collection can be initialized successfully.""" + target_frequencies = [9.5,1.0, 0.1, 0.01] mc = MaterialCollection(target_frequencies) assert len(mc.material_names) > 0 assert mc.target_frequencies == target_frequencies diff --git a/nidn/training/model/model_to_eps_grid.py b/nidn/training/model/model_to_eps_grid.py index e23a3a8..91a65ae 100644 --- a/nidn/training/model/model_to_eps_grid.py +++ b/nidn/training/model/model_to_eps_grid.py @@ -11,6 +11,8 @@ def _eval_model(model, Nx_undersampled, Ny_undersampled, N_layers, target_freque Ny_undersampled (int): Number of grid points in y direction. Potentially unesampled if eps_oversampling > 1. N_layers (int): Number of layers in the model. target_frequencies (list): Target frequencies. + Returns: + [torch.tensor]: Resulting 4D [real,imag] epsilon grid """ # Get the grid ticks x = torch.linspace(-1, 1, Nx_undersampled) diff --git a/nidn/training/run_training.py b/nidn/training/run_training.py index 2e92ccc..c6f0b19 100644 --- a/nidn/training/run_training.py +++ b/nidn/training/run_training.py @@ -17,7 +17,6 @@ from .utils.validate_config import _validate_config - def _init_training(run_cfg: DotMap, model): """Initializes additional parameters required for training. Args: @@ -80,7 +79,7 @@ def run_training( model (torch.model, optional): Model to continue training. If None, a new model will be created according to the run configuration. Defaults to None. Returns: - torch.model,DotMap: The best model achieved in the training run, and the loss results of the training run. + torch.model, DotMap: The best model achieved in the training run, and the loss results of the training run. """ logger.trace("Initializing training...") @@ -141,7 +140,7 @@ def run_training( if loss < best_loss: best_loss = loss logger.info( - f"New Best={loss.item():.4f} SpectrumLoss={spectrum_loss.detach().item():4f}" + f"### New Best={loss.item():<6.4f} with SpectrumLoss={spectrum_loss.detach().item():<6.4f} ###" ) if not renormalized: logger.debug("Saving model state...") @@ -158,7 +157,7 @@ def run_training( if it % 5 == 0: wa_out = np.mean(weighted_average) logger.info( - f"It={it}\t loss={loss.item():.3e}\t weighted_average={wa_out:.3e}\t SpectrumLoss={spectrum_loss.detach().item():4f}" + f"It={it:<5} Loss={loss.item():<6.4f} | weighted_avg={wa_out:<6.4f} | SpectrumLoss={spectrum_loss.detach().item():<6.4f}" ) # Zeroes the gradient (otherwise would accumulate) diff --git a/nidn/utils/set_log_level.py b/nidn/utils/set_log_level.py new file mode 100644 index 0000000..0aa4e73 --- /dev/null +++ b/nidn/utils/set_log_level.py @@ -0,0 +1,19 @@ +from loguru import logger +import sys + + +def set_log_level(log_level: str): + """Set the log level for the logger. + + Args: + log_level (str): The log level to set. Options are 'TRACE','DEBUG', 'INFO', 'SUCCESS', 'WARNING', 'ERROR', 'CRITICAL'. + """ + logger.remove() + logger.add( + sys.stderr, + colorize=True, + level=log_level, + format="{time:HH:mm:ss}|NIDN-{level}| {message}", + filter="nidn", + ) + logger.debug(f"Setting LogLevel to {log_level}") diff --git a/notebooks/Training.ipynb b/notebooks/Training.ipynb index 7edde8f..ea25fb0 100644 --- a/notebooks/Training.ipynb +++ b/notebooks/Training.ipynb @@ -15,16 +15,7 @@ "import sys\n", "sys.path.append(\"../\")\n", "\n", - "import nidn\n", - "\n", - "# Set precision (and potentially GPU)\n", - "import torch\n", - "torch.set_default_tensor_type(torch.DoubleTensor)\n", - "\n", - "# Set up some logging\n", - "from loguru import logger\n", - "logger.remove()\n", - "logger.add(sys.stderr, format=\"{level} {message}\", level=\"INFO\");" + "import nidn" ] }, {