Source code for gammalearn.optimizers

import torch.optim as optim
import torch
from gammalearn.utils import compute_dann_hparams
from torch.optim.lr_scheduler import _LRScheduler
import re


[docs]def load_sgd(net, parameters): """ Load the SGD optimizer Parameters ---------- net (nn.Module): the network of the experiment parameters (dict): a dictionary describing the parameters of the optimizer Returns ------- the optimizer """ assert 'lr' in parameters.keys( ), 'Missing learning rate for the optimizer !' assert 'weight_decay' in parameters.keys( ), 'Missing weight decay for the optimizer !' return optim.SGD(net.parameters(), **parameters)
[docs]def load_adam(net, parameters): """ Load the Adam optimizer Parameters ---------- net (nn.Module): the network of the experiment parameters (dict): a dictionary describing the parameters of the optimizer Returns ------- the optimizer """ assert 'lr' in parameters.keys( ), 'Missing learning rate for the optimizer !' return optim.Adam(net.parameters(), **parameters)
[docs]def load_adam_w(net, parameters): """ Load the Adam optimizer Parameters ---------- net (nn.Module): the network of the experiment parameters (dict): a dictionary describing the parameters of the optimizer Returns ------- the optimizer """ assert 'lr' in parameters.keys( ), 'Missing learning rate for the optimizer !' return optim.AdamW(net.parameters(), **parameters)
[docs]def load_rmsprop(net, parameters): """ Load the RMSprop optimizer Parameters ---------- net (nn.Module): the network of the experiment parameters (dict): a dictionary describing the parameters of the optimizer Returns ------- the optimizer """ assert 'lr' in parameters.keys( ), 'Missing learning rate for the optimizer !' return optim.RMSprop(net.parameters(), **parameters)
[docs]def load_per_layer_sgd(net, parameters): """ Load the SGD optimizer with a different learning rate for each layer. Parameters ---------- net (nn.Module): the network of the experiment parameters (dict): a dictionary describing the parameters of the optimizer Returns ------- the optimizer """ assert 'lr' in parameters.keys(), 'Missing learning rate for the optimizer !' assert 'weight_decay' in parameters.keys(), 'Missing weight decay for the optimizer !' assert 'alpha' in parameters.keys(), 'Missing alpha !' lr_default = parameters['lr'] alpha = parameters.pop('alpha') feature_modules = [] # The feature parameters base_modules = [] # The other parameters for name, module in net.named_modules(): if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv2d): if name.split('.')[0] == 'feature': feature_modules.append(module) else: base_modules.append(module) feature_lr = [lr_default / (alpha ** layer) for layer in range(1, len(feature_modules) + 1)] feature_lr.reverse() parameter_group = [{'params': p.parameters()} for p in base_modules] parameter_group += [{'params': p.parameters(), 'lr': lr} for p, lr in zip(feature_modules, feature_lr)] return torch.optim.SGD(parameter_group, **parameters)
[docs]def freeze(net, parameters): """ Freeze the network parameters Parameters ---------- net (nn.Module): the network or the subnetwork (e.g. feature) parameters (dict): a dictionary describing the parameters of the optimizer Returns ------- the optimizer """ for p in net.parameters(): p.requires_grad = False return None
[docs]def prime_optimizer(net: torch.nn.Module, parameters: dict) -> torch.optim.Optimizer: """ Load the optimizer for Masked AutoEncoder fine tuning (transformers) Parameters ---------- net (nn.Module): the network of the experiment parameters (dict): a dictionary describing the parameters of the optimizer Returns ------- the optimizer """ num_blocks = len(list(net.encoder.children())) + 1 layer_scales = [parameters['layer_decay'] ** (num_blocks - i) for i in range(num_blocks+1)] no_weight_decay = ['pos_embedding', 'additional_tokens'] param_groups = {} for n, p in net.named_parameters(): if p.requires_grad: # Non weight decay if p.ndim == 1 or n in no_weight_decay: this_decay = 0. else: this_decay = None layer_id = get_layer_id_for_prime(n) group_name = str(layer_id) + '_' + str(this_decay) if group_name not in param_groups: layer_scale = layer_scales[layer_id] param_groups[group_name] = { # 'lr_scale': layer_scale, 'lr': layer_scale * parameters['optimizer_parameters']['lr'], 'params': [] } if this_decay is not None: param_groups[group_name]['weight_decay'] = this_decay param_groups[group_name]['params'].append(p) return parameters['optimizer'](list(param_groups.values()), **parameters['optimizer_parameters'])
[docs]def get_layer_id_for_prime(name: str) -> int: """ Retrieve GammaPhysNetPrime layer id from parameter name """ if any(layer in name for layer in ['pos_embedding', 'patch_projection']): return 0 else: try: block = re.findall(r'enc_block_\d', name)[0] block = int(block.split('_')[-1]) + 1 return block except IndexError: return -1
############################# # Regularization strategies # #############################
[docs]def l1(net): """ Simple L1 penalty. Parameters ---------- net (nn.Module): the network. Returns ------- the penalty """ penalty = 0 for param in net.parameters(): penalty += torch.norm(param, 1) return penalty
[docs]def l2(net): """ Simple L2 penalty. Parameters ---------- net (nn.Module): the network. Returns ------- the penalty """ penalty = 0 for param in net.parameters(): penalty += torch.norm(param, 2)**2 return penalty / 2
[docs]def elastic(net): """ Elastic penalty (L1 + L2). Parameters ---------- net (nn.Module): the network. Returns ------- the penalty """ return l1(net) + l2(net)
[docs]def srip(net): """ Spectral Restricted Isometry Property (SRIP) regularization penalty. See https://arxiv.org/abs/1810.09102 Parameters ---------- net (nn.Module): the network. Returns ------- the penalty """ penalty = 0 for n, W in net.named_parameters(): if W.ndimension() >= 2: # print('{} : {}'.format(n, W.ndimension())) cols = W[0].numel() rows = W.shape[0] w1 = W.view(-1, cols) wt = torch.transpose(w1, 0, 1) if rows > cols: m = torch.matmul(wt, w1) ident = torch.eye(cols, cols, device=W.device) else: m = torch.matmul(w1, wt) ident = torch.eye(rows, rows, device=W.device) w_tmp = m - ident b_k = torch.rand(w_tmp.shape[1], 1, device=W.device) v1 = torch.matmul(w_tmp, b_k) norm1 = torch.norm(v1, 2) v2 = torch.div(v1, norm1) v3 = torch.matmul(w_tmp, v2) penalty += (torch.norm(v3, 2))**2 return penalty
[docs]class DANNLR(_LRScheduler): def __init__(self, optimizer, domain_classifier=False): self.domain_classifier = domain_classifier # Attach optimizer if not isinstance(optimizer, optim.Optimizer): raise TypeError('{} is not an Optimizer'.format( type(optimizer).__name__)) self.optimizer = optimizer
[docs] def step(self, module) -> None: lambda_p, mu_p = compute_dann_hparams(module) if self.domain_classifier: new_lr = mu_p / lambda_p else: new_lr = mu_p for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr self.last_lr = new_lr
[docs] def get_lr(self) -> float: return self.last_lr
[docs] def state_dict(self) -> dict: return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
[docs] def load_state_dict(self, state_dict) -> None: self.__dict__.update(state_dict)