8000 Add function to set logging level by gomezzz · Pull Request #7 · esa/NIDN · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add function to set logging level #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions nidn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os

# Set main device by default to cpu if no other choice was made before
if "TORCH_DEVICE" not in os.environ:
os.environ["TORCH_DEVICE"] = "cpu"
import torch
from loguru import logger

# Add exposed features here
from .plots.plot_model_grid import plot_model_grid
Expand All @@ -15,6 +13,19 @@
from .utils.fix_random_seeds import fix_random_seeds
from .utils.load_default_cfg import load_default_cfg
from .utils.print_cfg import print_cfg
from .utils.set_log_level import set_log_level

set_log_level("INFO")

# Set main device by default to cpu if no other choice was made before
if "TORCH_DEVICE" not in os.environ:
os.environ["TORCH_DEVICE"] = "cpu"

logger.info(f"Initialized NIDN for {os.environ['TORCH_DEVICE']}")

# Set precision (and potentially GPU)
torch.set_default_tensor_type(torch.DoubleTensor)
logger.info("Using double precision")

__all__ = [
"compute_target_frequencies",
Expand All @@ -28,5 +39,6 @@
"plot_model_grid_per_freq",
"plot_spectra",
"print_cfg",
"set_log_level",
"wl_to_phys_wl",
]
4 changes: 2 additions & 2 deletions nidn/materials/material_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def _load_materials_folder(self):
self.N_materials = len(self.material_names)

def _load_material_data(self, name):
"""Loads the passed wavelength,n,k data from the passed csv file for the closest frequencies and returns epsilon (permittivity).
"""Loads data (wavelength, n, and k) from the passed csv file for the closest frequencies and returns epsilon (permittivity).

Args:
name (str): Path to csv
name (str): Path to csv.

Returns:
torch.tensor: Epsilon for the material (permittivity)
Expand Down
4 changes: 2 additions & 2 deletions nidn/tests/material_collection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


def test_material_collection_init():
"""Tests if the material collection can be initialized successfully"""
target_frequencies = [1.0, 0.1, 0.01]
"""Tests if the material collection can be initialized successfully."""
target_frequencies = [9.5,1.0, 0.1, 0.01]
mc = MaterialCollection(target_frequencies)
assert len(mc.material_names) > 0
assert mc.target_frequencies == target_frequencies
Expand Down
2 changes: 2 additions & 0 deletions nidn/training/model/model_to_eps_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def _eval_model(model, Nx_undersampled, Ny_undersampled, N_layers, target_freque
Ny_undersampled (int): Number of grid points in y direction. Potentially unesampled if eps_oversampling > 1.
N_layers (int): Number of layers in the model.
target_frequencies (list): Target frequencies.
Returns:
[torch.tensor]: Resulting 4D [real,imag] epsilon grid
"""
# Get the grid ticks
x = torch.linspace(-1, 1, Nx_undersampled)
Expand Down
7 changes: 3 additions & 4 deletions nidn/training/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .utils.validate_config import _validate_config



def _init_training(run_cfg: DotMap, model):
"""Initializes additional parameters required for training.
Args:
Expand Down Expand Up @@ -80,7 +79,7 @@ def run_training(
model (torch.model, optional): Model to continue training. If None, a new model will be created according to the run configuration. Defaults to None.

Returns:
torch.model,DotMap: The best model achieved in the training run, and the loss results of the training run.
torch.model, DotMap: The best model achieved in the training run, and the loss results of the training run.
"""
logger.trace("Initializing training...")

Expand Down Expand Up @@ -141,7 +140,7 @@ def run_training(
if loss < best_loss:
best_loss = loss
logger.info(
f"New Best={loss.item():.4f} SpectrumLoss={spectrum_loss.detach().item():4f}"
f"### New Best={loss.item():<6.4f} with SpectrumLoss={spectrum_loss.detach().item():<6.4f} ###"
)
if not renormalized:
logger.debug("Saving model state...")
Expand All @@ -158,7 +157,7 @@ def run_training(
if it % 5 == 0:
wa_out = np.mean(weighted_average)
logger.info(
f"It={it}\t loss={loss.item():.3e}\t weighted_average={wa_out:.3e}\t SpectrumLoss={spectrum_loss.detach().item():4f}"
f"It={it:<5} Loss={loss.item():<6.4f} | weighted_avg={wa_out:<6.4f} | SpectrumLoss={spectrum_loss.detach().item():<6.4f}"
)

# Zeroes the gradient (otherwise would accumulate)
Expand Down
19 changes: 19 additions & 0 deletions nidn/utils/set_log_level.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from loguru import logger
import sys


def set_log_level(log_level: str):
"""Set the log level for the logger.

Args:
log_level (str): The log level to set. Options are 'TRACE','DEBUG', 'INFO', 'SUCCESS', 'WARNING', 'ERROR', 'CRITICAL'.
"""
logger.remove()
logger.add(
sys.stderr,
colorize=True,
level=log_level,
format="<green>{time:HH:mm:ss}</green>|NIDN-<blue>{level}</blue>| <level>{message}</level>",
filter="nidn",
)
logger.debug(f"Setting LogLevel to {log_level}")
11 changes: 1 addition & 10 deletions notebooks/Training.ipynb
1E79
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,7 @@
"import sys\n",
"sys.path.append(\"../\")\n",
"\n",
"import nidn\n",
"\n",
"# Set precision (and potentially GPU)\n",
"import torch\n",
"torch.set_default_tensor_type(torch.DoubleTensor)\n",
"\n",
"# Set up some logging\n",
"from loguru import logger\n",
"logger.remove()\n",
"logger.add(sys.stderr, format=\"{level} {message}\", level=\"INFO\");"
"import nidn"
]
},
{
Expand Down
0