Source code for geogenie.utils.callbacks

import logging
from pathlib import Path

import numpy as np
import torch


[docs] class EarlyStopping: """Early stopping PyTorch callback. This class defines an early stopping callback for PyTorch models. Attributes: patience (int): Number of epochs to wait for improvement before stopping. verbose (int): Verbosity mode. delta (float): Minimum change to qualify as an improvement. output_dir (str): Directory to save the outputs. prefix (str): Prefix for the saved files. counter (int): Counter for the number of epochs since last improvement. early_stop (bool): Flag to indicate if early stopping is triggered. best_score (float): Best score for the monitored quantity. val_loss_min (float): Minimum validation loss. boot (int): Boot object or identifier. trial (int): Trial number for hyperparameter optimization. logger (logging.Logger): Logger object for the class. """ def __init__( self, output_dir, prefix, patience=100, verbose=0, delta=0, trial=None, boot=None, ): """Initialize the EarlyStopping callback. Args: output_dir (str): Directory to save checkpoints. prefix (str): Prefix for the checkpoint filenames. patience (int): How long to wait after last time validation loss improved. Default: 100. verbose (int): Verbosity mode. Default: 0. delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0. trial (optuna.trial.Trial, optional): Optuna trial for hyperparameter optimization. Default: None. boot (int, optional): Bootstrap number for ensemble models. Default: None. Raises: ValueError: If both boot and trial are defined. """ self.patience = patience self.verbose = verbose self.delta = delta self.output_dir = output_dir self.prefix = prefix self.counter = 0 self.early_stop = False self.best_score = None self.val_loss_min = np.Inf self.boot = boot self.trial = trial.number if trial is not None else None self.logger = logging.getLogger(__name__) if self.boot is not None and self.trial is not None: msg = "Both boot and trial cannot both be defined." self.logger.error(msg) raise ValueError(msg) def __call__(self, val_loss, model): """Call method to check if early stopping condition is met. Args: val_loss (float): Current validation loss. model (torch.nn.Module): The model being trained. Returns: bool: True if early stopping is triggered, False otherwise. """ if self.best_score is None: self.best_score = val_loss self.save_checkpoint(val_loss, model) elif val_loss > self.best_score - self.delta: self.counter += 1 if self.verbose: self.logger.info( f"EarlyStopping counter: {self.counter}/{self.patience}" ) if self.counter >= self.patience: self.early_stop = True else: self.best_score = val_loss self.save_checkpoint(val_loss, model) self.counter = 0
[docs] def save_checkpoint(self, val_loss, model): """Save the model when validation loss decreases. Args: val_loss (float): Current validation loss. model (torch.nn.Module): The model being trained. """ if self.verbose: self.logger.info( f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..." ) chkdir = Path(self.output_dir, "models") chkdir.mkdir(parents=True, exist_ok=True) if self.boot is not None: if self.verbose: self.logger.info(f"Saving checkpoint for boot {self.boot}") fn = chkdir / f"{self.prefix}_boot{self.boot}_checkpoint.pt" else: if self.verbose: self.logger.info(f"Saving checkpoint for trial {self.trial}") fn = chkdir / f"{self.prefix}_trial{self.trial}_checkpoint.pt" torch.save(model.state_dict(), fn) self.val_loss_min = val_loss
[docs] def load_best_model(self, model): """Load the best model from the checkpoint file. Args: model (torch.nn.Module): The model to load the checkpoint into. Returns: torch.nn.Module: The model with weights loaded from the best checkpoint. Raises: FileNotFoundError: If the checkpoint file is not found. """ chkdir = Path(self.output_dir, "models") if self.boot is not None: fn = chkdir / f"{self.prefix}_boot{self.boot}_checkpoint.pt" else: fn = chkdir / f"{self.prefix}_trial{self.trial}_checkpoint.pt" if fn.exists(): model.load_state_dict(torch.load(fn)) if self.verbose: self.logger.info("Loaded the best model from checkpoint.") return model else: msg = f"Checkpoint file {fn} not found. Early stopping failed and model not loaded." self.logger.error(msg) raise FileNotFoundError(msg)
[docs] def callback_init(optimizer, args, trial=None, boot=None): """Initialize early stopping and learning rate scheduler callbacks. EarlyStopping Arguments: output_dir (str): Directory to save the outputs. prefix (str): Prefix for the saved files. patience (int): Number of epochs to wait for improvement before stopping. verbose (bool): If True, prints messages about early stopping. delta (float): Minimum change to qualify as an improvement. trial (optuna.trial.Trial): Optuna trial object. boot (any): Boot object or identifier. ReduceLROnPlateau Arguments: optimizer (torch.optim.Optimizer): Wrapped optimizer. mode (str): One of 'min', 'max'. In 'min' mode, lr will be reduced when the quantity monitored has stopped decreasing; in 'max' mode it will be reduced when the quantity monitored has stopped increasing. factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor. patience (int): Number of epochs with no improvement after which learning rate will be reduced. verbose (bool): If True, prints a message to stdout for each update. Args: optimizer (torch.optim.Optimizer): The optimizer for which the learning rate scheduler will be applied. args (argparse.Namespace): Argument namespace containing the necessary hyperparameters and settings. trial (optuna.trial.Trial, optional): Optuna trial object for hyperparameter optimization. Defaults to None. boot (any, optional): Boot object or identifier, used for early stopping callback. Defaults to None. Returns: tuple: A tuple containing the initialized EarlyStopping and ReduceLROnPlateau scheduler. """ verbose = args.verbose >= 2 or args.debug early_stopping = EarlyStopping( output_dir=args.output_dir, prefix=args.prefix, patience=args.early_stop_patience, verbose=verbose, delta=0, trial=trial, boot=boot, ) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=args.lr_scheduler_factor, patience=args.lr_scheduler_patience, verbose=verbose, ) return early_stopping, lr_scheduler