import torch
from gammalearn.utils import BaseW
from gammalearn.criterions import DANNLoss
[docs]def get_training_step_mae(**kwargs):
def training_step_mae(module, batch):
"""
The training operations for one batch for vanilla mt learning
Parameters
----------
module: LightningModule
batch
Returns
-------
"""
# Load data
images = batch['image']
if kwargs['add_pointing']:
pointing = torch.stack((batch['dl1_params']['alt_tel'],
batch['dl1_params']['az_tel']), dim=1).to(torch.float32)
loss = module.net(images, pointing)
else:
loss = module.net(images)
if module.experiment.regularization is not None:
loss += module.experiment.regularization['function'](module.net) * module.experiment.regularization['weight']
return None, None, {'autoencoder': loss.detach().item()}, loss
return training_step_mae
[docs]def get_eval_step_mae(**kwargs):
def validation_step_mae(module, batch):
"""
The training operations for one batch for vanilla mt learning
Parameters
----------
module: LightningModule
batch
Returns
-------
"""
# Load data
images = batch['image']
if kwargs['add_pointing']:
pointing = torch.stack((batch['dl1_params']['alt_tel'],
batch['dl1_params']['az_tel']), dim=1).to(torch.float32)
loss = module.net(images, pointing)
else:
loss = module.net(images)
return None, None, {'autoencoder': loss.detach().item()}, loss
return validation_step_mae
[docs]def get_training_step_mt(**kwargs):
def training_step_mt(module, batch):
"""
The training operations for one batch for vanilla mt learning
Parameters
----------
module: LightningModule
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, False, kwargs.get('add_pointing', False))
outputs = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(outputs, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
if module.experiment.regularization is not None:
loss += module.experiment.regularization['function'](module.net) * module.experiment.regularization['weight']
return outputs, labels, loss_data, loss
return training_step_mt
[docs]def get_training_step_dann(**kwargs):
def training_step_dann(module, batch):
"""
The training operations for one batch
Parameters
----------
module: LightningModule
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, False, kwargs.get('add_pointing', False))
output = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
# Add the target domain into the output and labels
output['domain_class'] = torch.cat([output['domain_class'], data['outputs_target']['domain_class']])
labels['domain_class'] = torch.cat([labels['domain_class'], data['labels_target']['domain_class']])
# Update domain loss mask if necessary
if DANNLoss.fetch_domain_conditional_from_targets(module.experiment.targets): # If domain conditional is True
labels_class = torch.cat([data['labels_source']['class'], data['labels_target']['class']]) # Get labels
DANNLoss.set_domain_loss_mask_from_targets(module.experiment.targets, labels_class) # Set mask
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
return output, labels, loss_data, loss
return training_step_dann
[docs]def get_training_step_deepjdot(**kwargs):
def training_step_deepjdot(module, batch):
"""
The training operations for one batch
Parameters
----------
module
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, True, kwargs.get('add_pointing', False))
output = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
# The alignment task is optimized using the latent features of the source and the target.
output['deepjdot'] = data['latent_features'][0]
labels['deepjdot'] = data['latent_features'][1]
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
if module.experiment.regularization is not None:
loss += module.experiment.regularization['function'](module.net) * module.experiment.regularization['weight']
return output, labels, loss_data, loss
return training_step_deepjdot
[docs]def get_training_step_deepcoral(**kwargs):
def training_step_deepcoral(module, batch):
"""
The training operations for one batch for vanilla mt learning
Parameters
----------
module: LightningModule
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, True, kwargs.get('add_pointing', False))
output = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
output['deepcoral'] = data['latent_features'][0]
labels['deepcoral'] = data['latent_features'][1]
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
if module.experiment.regularization is not None:
loss += module.experiment.regularization['function'](module.net) * module.experiment.regularization['weight']
return output, labels, loss_data, loss
return training_step_deepcoral
[docs]def get_training_step_mkmmd(**kwargs):
def training_step_mkmmd(module, batch):
"""
The training operations for one batch for vanilla mt learning
Parameters
----------
module: LightningModule
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, True, kwargs.get('add_pointing', False))
output = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
output['mkmmd'] = data['latent_features'][0]
labels['mkmmd'] = data['latent_features'][1]
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
if module.experiment.regularization is not None:
loss += module.experiment.regularization['function'](module.net) * module.experiment.regularization['weight']
return output, labels, loss_data, loss
return training_step_mkmmd
[docs]def get_training_step_mt_gradient_penalty(**kwargs):
def training_step_mt_gradient_penalty(module, batch):
"""
The training operations for one batch for vanilla mt learning with gradient penalty
Parameters
----------
module: LightningModule
batch
Returns
-------
"""
# Load data
images = batch['image']
labels = batch['label']
images.requires_grad = True
if kwargs.get('add_pointing', False):
pointing = torch.stack((batch['dl1_params']['alt_tel'], batch['dl1_params']['az_tel']), dim=1)
output = module.net({'data': images, 'pointing': pointing})
else:
output = module.net(images)
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
if module.experiment.regularization is not None:
gradient_x = torch.autograd.grad(loss, images, retain_graph=True)[0]
penalty = torch.mean((torch.norm(gradient_x.view(gradient_x.shape[0], -1), 2, dim=1) - 1) ** 2)
loss += penalty * module.experiment.regularization['weight']
return output, labels, loss_data, loss
return training_step_mt_gradient_penalty
[docs]def get_eval_step_mt(**kwargs):
def eval_step_mt(module, batch):
"""
The validating operations for one batch
Parameters
----------
module
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, False, kwargs.get('add_pointing', False))
outputs = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
# Compute loss and quality measures
loss, loss_data = module.experiment.LossComputing.compute_loss(outputs, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
return outputs, labels, loss_data, loss
return eval_step_mt
[docs]def get_eval_step_dann(**kwargs):
def eval_step_dann(module, batch):
"""
The validating operations for one batch
Parameters
----------
module
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, False, kwargs.get('add_pointing', False))
output = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
# Add the target domain into the output and labels
output['domain_class'] = torch.cat([output['domain_class'], data['outputs_target']['domain_class']])
labels['domain_class'] = torch.cat([labels['domain_class'], data['labels_target']['domain_class']])
# Update domain loss mask if necessary
if DANNLoss.fetch_domain_conditional_from_targets(module.experiment.targets): # If domain conditional is True
labels_class = torch.cat([data['labels_source']['class'], data['labels_target']['class']]) # Get labels
DANNLoss.set_domain_loss_mask_from_targets(module.experiment.targets, labels_class) # Set mask
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
return output, labels, loss_data, loss
return eval_step_dann
[docs]def get_eval_step_deepjdot(**kwargs):
def eval_step_deepjdot(module, batch):
"""
The validating operations for one batch
Parameters
----------
module
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, True, kwargs.get('add_pointing', False))
output = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
# The alignment task is optimized using the latent features of the source and the target.
output['deepjdot'] = data['latent_features'][0]
labels['deepjdot'] = data['latent_features'][1]
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
if module.experiment.regularization is not None:
loss += module.experiment.regularization['function'](module.net) * module.experiment.regularization['weight']
return output, labels, loss_data, loss
return eval_step_deepjdot
[docs]def get_eval_step_deepcoral(**kwargs):
def eval_step_deepcoral(module, batch):
"""
The validating operations for one batch
Parameters
----------
module
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, True, kwargs.get('add_pointing', False))
output = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
output['deepcoral'] = data['latent_features'][0]
labels['deepcoral'] = data['latent_features'][1]
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
if module.experiment.regularization is not None:
loss += module.experiment.regularization['function'](module.net) * module.experiment.regularization['weight']
return output, labels, loss_data, loss
return eval_step_deepcoral
[docs]def get_eval_step_mkmmd(**kwargs):
def eval_step_mkmmd(module, batch):
"""
The validating operations for one batch
Parameters
----------
module
batch
Returns
-------
"""
# Load data
data = run_model(module, batch, True, kwargs.get('add_pointing', False))
output = data['outputs_source']
labels = data['labels_source']
dl1_params = data['dl1_params_source']
output['mkmmd'] = data['latent_features'][0]
labels['mkmmd'] = data['latent_features'][1]
# Compute loss
loss, loss_data = module.experiment.LossComputing.compute_loss(output, labels, module)
loss = module.experiment.loss_balancing(loss, module)
loss = sum(loss.values())
if module.experiment.regularization is not None:
loss += module.experiment.regularization['function'](module.net) * module.experiment.regularization['weight']
return output, labels, loss_data, loss
return eval_step_mkmmd
[docs]def get_test_step_mt(**kwargs):
def test_step_mt(module, batch):
"""
The test operations for one batch
Parameters
----------
module
batch
Returns
-------
"""
data = run_model(module, batch, True, kwargs.get('add_pointing', False), train=False)
outputs = data['outputs_source']
labels = data['labels_source']
dl1_params = batch.get('dl1_params', None)
return outputs, labels, dl1_params
return test_step_mt
[docs]def run_model(module, batch, requires_latent_features=False, requires_pointing=False, forward_params=None, train=True):
"""
If the training is not in the context of domain adaptation, the information will be stored in xxx_source.
Parameters
----------
module: () The current module.
batch: (torch.tensor) The current batch of data.
requires_latent_features: (bool) Creates a hook to get the latent space. The model must respect the following
design rule and have either a 'module.net.main_task_model.feature' or a 'module.net.feature' component.
requires_pointing: (bool) Whether to include the pointing information as the model's argument.
forward_params: (dict) Allows to pass the model's forward function extra parameters. The model must respect the
corresponding design rule of the model's forward function.
train: (bool) Whether the current step is a training or a test step.
"""
hook = None # To fetch latent features
latent_features = []
if requires_latent_features:
# Define hook to get latent features
def get_latent_features(module, input, output):
if isinstance(output, (tuple, list)):
# UNet encoder outputs a tuple, and the latent space is located in the last element of the tuple
latent_features.append(output[-1])
else:
latent_features.append(output)
if 'main_task' in module.experiment.net_parameters_dic['parameters']:
# When DANN is used, it creates a 'main_task'.
if hasattr(module.net.main_task_model, 'feature'):
hook = module.net.main_task_model.feature.register_forward_hook(get_latent_features)
else:
# When DANN is not used (for example, a single UNet encoder).
if hasattr(module.net, 'feature'):
hook = module.net.feature.register_forward_hook(get_latent_features)
# Load data
inputs_source = inputs_target = None
labels_source = labels_target = None
outputs_target = None
dl1_params_source = dl1_params_target = None
if 'image' in batch.keys(): # No domain adaptation
inputs_source = batch['image']
# Nx...xCxWxH -> N'xCxWxH
inputs_source = inputs_source.flatten(end_dim=-4) if len(inputs_source.shape) > 3 else inputs_source
if 'image_source' in batch.keys():
inputs_source = batch['image_source']
# Nx...xCxWxH -> N'xCxWxH
inputs_source = inputs_source.flatten(end_dim=-4) if len(inputs_source.shape) > 3 else inputs_source
if 'image_target' in batch.keys():
inputs_target = batch['image_target']
# Nx...xCxWxH -> N'xCxWxH
inputs_target = inputs_target.flatten(end_dim=-4) if len(inputs_target.shape) > 3 else inputs_target
if 'label' in batch.keys(): # No domain adaptation
labels_source = batch['label']
if 'label_source' in batch.keys():
labels_source = batch['label_source']
if 'label_target' in batch.keys():
labels_target = batch['label_target']
if 'dl1_params' in batch.keys(): # No domain adaptation
dl1_params_source = batch['dl1_params']
if 'dl1_params_source' in batch.keys():
dl1_params_source = batch['dl1_params_source']
if 'dl1_params_target' in batch.keys():
dl1_params_target = batch['dl1_params_target']
forward_params = {} if forward_params is None else forward_params
pointing_source = pointing_target = None
if requires_pointing:
# Include the altitude - azimuth (alt - az) information into the network.
if 'dl1_params' in batch.keys():
alt_tel = batch['dl1_params']['alt_tel']
az_tel = batch['dl1_params']['az_tel']
pointing_source = torch.stack((alt_tel, az_tel), dim=1).to(torch.float32)
if 'dl1_params_source' in batch.keys():
alt_tel = batch['dl1_params_source']['alt_tel']
az_tel = batch['dl1_params_source']['az_tel']
pointing_source = torch.stack((alt_tel, az_tel), dim=1).to(torch.float32)
if 'dl1_params_target' in batch.keys():
alt_tel = batch['dl1_params_target']['alt_tel']
az_tel = batch['dl1_params_target']['az_tel']
pointing_target = torch.stack((alt_tel, az_tel), dim=1).to(torch.float32)
# Include gradient weighting if applied
for _, v in module.experiment.targets.items():
if v.get('grad_weight', None) is not None:
if isinstance(v['grad_weight'], BaseW):
forward_params['grad_weight'] = v['grad_weight'].get_weight(module.trainer)
else:
forward_params['grad_weight'] = v['grad_weight']
forward_params['pointing'] = pointing_source
outputs_source = module.net(inputs_source, **forward_params)
if inputs_target is not None:
forward_params['pointing'] = pointing_target
outputs_target = module.net(inputs_target, **forward_params)
if hook is not None:
# Remove hook to avoid accumulation
hook.remove()
output_dict = {
'inputs_source': inputs_source,
'inputs_target': inputs_target,
'labels_source': labels_source,
'labels_target': labels_target,
'dl1_params_source': dl1_params_source,
'dl1_params_target': dl1_params_target,
'outputs_source': outputs_source,
'outputs_target': outputs_target,
'latent_features': latent_features,
}
return output_dict