diff --git a/nidn/__init__.py b/nidn/__init__.py
index 1befd89..4320741 100644
--- a/nidn/__init__.py
+++ b/nidn/__init__.py
@@ -1,8 +1,6 @@
import os
-
-# Set main device by default to cpu if no other choice was made before
-if "TORCH_DEVICE" not in os.environ:
- os.environ["TORCH_DEVICE"] = "cpu"
+import torch
+from loguru import logger
# Add exposed features here
from .plots.plot_model_grid import plot_model_grid
@@ -15,6 +13,19 @@
from .utils.fix_random_seeds import fix_random_seeds
from .utils.load_default_cfg import load_default_cfg
from .utils.print_cfg import print_cfg
+from .utils.set_log_level import set_log_level
+
+set_log_level("INFO")
+
+# Set main device by default to cpu if no other choice was made before
+if "TORCH_DEVICE" not in os.environ:
+ os.environ["TORCH_DEVICE"] = "cpu"
+
+logger.info(f"Initialized NIDN for {os.environ['TORCH_DEVICE']}")
+
+# Set precision (and potentially GPU)
+torch.set_default_tensor_type(torch.DoubleTensor)
+logger.info("Using double precision")
__all__ = [
"compute_target_frequencies",
@@ -28,5 +39,6 @@
"plot_model_grid_per_freq",
"plot_spectra",
"print_cfg",
+ "set_log_level",
"wl_to_phys_wl",
]
diff --git a/nidn/materials/material_collection.py b/nidn/materials/material_collection.py
index 59b58d8..2a0e34e 100644
--- a/nidn/materials/material_collection.py
+++ b/nidn/materials/material_collection.py
@@ -52,10 +52,10 @@ def _load_materials_folder(self):
self.N_materials = len(self.material_names)
def _load_material_data(self, name):
- """Loads the passed wavelength,n,k data from the passed csv file for the closest frequencies and returns epsilon (permittivity).
+ """Loads data (wavelength, n, and k) from the passed csv file for the closest frequencies and returns epsilon (permittivity).
Args:
- name (str): Path to csv
+ name (str): Path to csv.
Returns:
torch.tensor: Epsilon for the material (permittivity)
diff --git a/nidn/tests/material_collection_test.py b/nidn/tests/material_collection_test.py
index a1cba87..a59787a 100644
--- a/nidn/tests/material_collection_test.py
+++ b/nidn/tests/material_collection_test.py
@@ -2,8 +2,8 @@
def test_material_collection_init():
- """Tests if the material collection can be initialized successfully"""
- target_frequencies = [1.0, 0.1, 0.01]
+ """Tests if the material collection can be initialized successfully."""
+ target_frequencies = [9.5,1.0, 0.1, 0.01]
mc = MaterialCollection(target_frequencies)
assert len(mc.material_names) > 0
assert mc.target_frequencies == target_frequencies
diff --git a/nidn/training/model/model_to_eps_grid.py b/nidn/training/model/model_to_eps_grid.py
index e23a3a8..91a65ae 100644
--- a/nidn/training/model/model_to_eps_grid.py
+++ b/nidn/training/model/model_to_eps_grid.py
@@ -11,6 +11,8 @@ def _eval_model(model, Nx_undersampled, Ny_undersampled, N_layers, target_freque
Ny_undersampled (int): Number of grid points in y direction. Potentially unesampled if eps_oversampling > 1.
N_layers (int): Number of layers in the model.
target_frequencies (list): Target frequencies.
+ Returns:
+ [torch.tensor]: Resulting 4D [real,imag] epsilon grid
"""
# Get the grid ticks
x = torch.linspace(-1, 1, Nx_undersampled)
diff --git a/nidn/training/run_training.py b/nidn/training/run_training.py
index 2e92ccc..c6f0b19 100644
--- a/nidn/training/run_training.py
+++ b/nidn/training/run_training.py
@@ -17,7 +17,6 @@
from .utils.validate_config import _validate_config
-
def _init_training(run_cfg: DotMap, model):
"""Initializes additional parameters required for training.
Args:
@@ -80,7 +79,7 @@ def run_training(
model (torch.model, optional): Model to continue training. If None, a new model will be created according to the run configuration. Defaults to None.
Returns:
- torch.model,DotMap: The best model achieved in the training run, and the loss results of the training run.
+ torch.model, DotMap: The best model achieved in the training run, and the loss results of the training run.
"""
logger.trace("Initializing training...")
@@ -141,7 +140,7 @@ def run_training(
if loss < best_loss:
best_loss = loss
logger.info(
- f"New Best={loss.item():.4f} SpectrumLoss={spectrum_loss.detach().item():4f}"
+ f"### New Best={loss.item():<6.4f} with SpectrumLoss={spectrum_loss.detach().item():<6.4f} ###"
)
if not renormalized:
logger.debug("Saving model state...")
@@ -158,7 +157,7 @@ def run_training(
if it % 5 == 0:
wa_out = np.mean(weighted_average)
logger.info(
- f"It={it}\t loss={loss.item():.3e}\t weighted_average={wa_out:.3e}\t SpectrumLoss={spectrum_loss.detach().item():4f}"
+ f"It={it:<5} Loss={loss.item():<6.4f} | weighted_avg={wa_out:<6.4f} | SpectrumLoss={spectrum_loss.detach().item():<6.4f}"
)
# Zeroes the gradient (otherwise would accumulate)
diff --git a/nidn/utils/set_log_level.py b/nidn/utils/set_log_level.py
new file mode 100644
index 0000000..0aa4e73
--- /dev/null
+++ b/nidn/utils/set_log_level.py
@@ -0,0 +1,19 @@
+from loguru import logger
+import sys
+
+
+def set_log_level(log_level: str):
+ """Set the log level for the logger.
+
+ Args:
+ log_level (str): The log level to set. Options are 'TRACE','DEBUG', 'INFO', 'SUCCESS', 'WARNING', 'ERROR', 'CRITICAL'.
+ """
+ logger.remove()
+ logger.add(
+ sys.stderr,
+ colorize=True,
+ level=log_level,
+ format="{time:HH:mm:ss}|NIDN-{level}| {message}",
+ filter="nidn",
+ )
+ logger.debug(f"Setting LogLevel to {log_level}")
diff --git a/notebooks/Training.ipynb b/notebooks/Training.ipynb
index 7edde8f..ea25fb0 100644
--- a/notebooks/Training.ipynb
+++ b/notebooks/Training.ipynb
@@ -15,16 +15,7 @@
"import sys\n",
"sys.path.append(\"../\")\n",
"\n",
- "import nidn\n",
- "\n",
- "# Set precision (and potentially GPU)\n",
- "import torch\n",
- "torch.set_default_tensor_type(torch.DoubleTensor)\n",
- "\n",
- "# Set up some logging\n",
- "from loguru import logger\n",
- "logger.remove()\n",
- "logger.add(sys.stderr, format=\"{level} {message}\", level=\"INFO\");"
+ "import nidn"
]
},
{