import os
import torch
import torch.nn as nn
from torch.utils.data import Subset
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics import Accuracy, ConfusionMatrix
from torchmetrics.functional.pairwise import pairwise_cosine_similarity
from torchvision import transforms
import torchvision.utils as t_utils
import numpy as np
import pandas as pd
import tables
from ctapipe.instrument import CameraGeometry
from ctapipe.visualization import CameraDisplay
from ctapipe.io import HDF5TableWriter
from lstchain.io import write_dl2_dataframe
from lstchain.reco.utils import add_delta_t_key
from lstchain.io.io import dl1_params_lstcam_key
from PIL import Image
from indexedconv.engine import IndexedConv
from indexedconv.utils import create_index_matrix, img2mat, pool_index_matrix, build_hexagonal_position
from astropy.table import Table
from gammalearn.constants import SOURCE, TARGET
import gammalearn.utils as utils
import gammalearn.criterions as criterions
import gammalearn.datasets as dsets
import gammalearn.version as gl_version
import gammalearn.constants as csts
import matplotlib.pyplot as plt
from pathlib import Path
[docs]class LogLambda(Callback):
"""
Callback to send loss the gradient weighting from BaseW to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
log_lambda_loss_dict, log_lambda_grad_dict = {}, {}
targets = pl_module.experiment.loss_balancing.targets.copy()
trainer = pl_module.trainer
for i, task in enumerate(targets.keys()):
if targets[task].get('loss_weight', None) is not None:
if isinstance(targets[task]['loss_weight'], utils.BaseW):
log_lambda_loss_dict[task] = targets[task]['loss_weight'].get_weight(trainer)
if targets[task].get('grad_weight', None) is not None:
if isinstance(targets[task]['grad_weight'], utils.BaseW):
log_lambda_grad_dict[task] = targets[task]['grad_weight'].get_weight(trainer)
if log_lambda_loss_dict:
pl_module.log('Lambda loss', log_lambda_loss_dict, on_epoch=False, on_step=True)
if log_lambda_grad_dict:
pl_module.log('Lambda grad', log_lambda_grad_dict, on_epoch=False, on_step=True)
[docs]class LogUncertaintyTracker(Callback):
"""
Callback to send loss log vars and precisions of the Uncertainty estimation method to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if isinstance(pl_module.experiment.loss_balancing, criterions.UncertaintyWeighting):
logvar_dict = pl_module.experiment.loss_balancing.log_vars
log_logvar_dict = {}
log_precision_dict = {}
targets = pl_module.experiment.loss_balancing.targets.copy()
for i, task in enumerate(targets.keys()):
log_logvar_dict[task] = logvar_dict[i].detach().cpu()
log_precision_dict[task] = torch.exp(-logvar_dict[i].detach().cpu())
pl_module.log('Log_var', log_logvar_dict, on_epoch=False, on_step=True)
pl_module.log('Precision', log_precision_dict, on_epoch=False, on_step=True)
[docs]class LogGradNormTracker(Callback):
"""
Callback to send gradnorm parameters to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if isinstance(pl_module.experiment.loss_balancing, criterions.GradNorm):
g_dict = pl_module.experiment.loss_balancing.tracker_g
r_dict = pl_module.experiment.loss_balancing.tracker_r
k_dict = pl_module.experiment.loss_balancing.tracker_k
l_dict = pl_module.experiment.loss_balancing.tracker_l
l0_dict = pl_module.experiment.loss_balancing.tracker_l0
lgrad = pl_module.experiment.loss_balancing.tracker_lgrad
log_g_dict, log_r_dict, log_wg_dict, log_k_dict, log_l_dict, log_l0_dict = {}, {}, {}, {}, {}, {}
for task in r_dict.keys():
log_r_dict[task] = r_dict[task].detach().cpu()
log_g_dict[task] = g_dict[task].detach().cpu()
log_k_dict[task] = k_dict[task].detach().cpu()
log_l_dict[task] = l_dict[task].detach().cpu()
log_l0_dict[task] = l0_dict[task].detach().cpu()
log_lgrad = lgrad.detach().cpu()
pl_module.log('Gradient_norms', log_g_dict, on_epoch=False, on_step=True)
pl_module.log('Inverse_training_rate', log_r_dict, on_epoch=False, on_step=True)
pl_module.log('Constant', log_k_dict, on_epoch=False, on_step=True)
pl_module.log('Loss_ratio', log_l_dict, on_epoch=False, on_step=True)
pl_module.log('L0', log_l0_dict, on_epoch=False, on_step=True)
pl_module.log('Lgrad', log_lgrad, on_epoch=False, on_step=True)
[docs]class LogLossWeighting(Callback):
"""
Callback to send loss weight coefficients to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if isinstance(pl_module.experiment.loss_balancing, criterions.MultiLossBalancing):
weights_dict = pl_module.experiment.loss_balancing.weights_dict
log_weights_dict = {}
for task in weights_dict.keys():
log_weights_dict[task] = weights_dict[task].detach().cpu()
pl_module.log('Loss_weight_per_task', log_weights_dict, on_epoch=False, on_step=True)
[docs]class LogGradientNormPerTask(Callback):
"""
Callback to send the tasks gradient norm to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if isinstance(pl_module.experiment.loss_balancing, criterions.MultiLossBalancing):
if pl_module.experiment.loss_balancing.requires_gradients:
gradients_dict = pl_module.experiment.loss_balancing.gradients_dict
log_gradients_dict = {}
for task in gradients_dict.keys():
log_gradients_dict[task] = gradients_dict[task].norm(p=2).detach().cpu()
pl_module.log('Gradient_norm_per_task', log_gradients_dict, on_epoch=False, on_step=True)
[docs]class LogGradientCosineSimilarity(Callback):
"""
Callback to send the tasks gradient cosine similarity to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if isinstance(pl_module.experiment.loss_balancing, criterions.MultiLossBalancing):
if pl_module.experiment.loss_balancing.requires_gradients:
gradients = pl_module.experiment.loss_balancing.gradients
similarity = pairwise_cosine_similarity(gradients, gradients)
log_similarity_dict = {}
targets = pl_module.experiment.targets.copy()
for i, task_i in enumerate(targets):
for j, task_j in enumerate(targets):
if i < j: # Only upper triangular matrix as similarity is symmetric
log_similarity_dict[task_i+'_'+task_j] = similarity[i, j]
pl_module.log(f'Gradient_cosine_similarity', log_similarity_dict, on_epoch=False, on_step=True)
[docs]class LogModelWeightNorm(Callback):
"""
Callback to send sum of squared weigths of the network to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_epoch_end(self, trainer, pl_module):
weights = 0
for name, param in pl_module.net.named_parameters():
if 'weight' in name:
weights += torch.sum(param.data ** 2)
pl_module.log('weights', weights, on_epoch=True, on_step=False)
[docs]class LogModelParameters(Callback):
"""
Callback to send the network parameters to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_epoch_end(self, trainer, pl_module):
if isinstance(pl_module.loggers, TensorBoardLogger):
for name, param in pl_module.net.named_parameters():
pl_module.logger.experiment.add_histogram(name, param.detach().cpu(),
bins='tensorflow',
global_step=pl_module.current_epoch)
else:
# In that case, trainer_logger.watch is implemented in the experiment runner
pass
[docs]def make_activation_sender(pl_module, name):
"""
Creates the adapted activations sender to tensorboard
Parameters
----------
pl_module (LightningModule): the tensorboardX writer
name (string) : name of the layer which activation is logged
Returns
-------
An adapted function
"""
def send(m, input, output):
"""
The function to send the activation of a module to tensorboard
Parameters
----------
m (nn.Module): the module (eg nn.ReLU, ...)
input
output
Returns
-------
"""
pl_module.logger.experiment.add_histogram(name, output.detach().cpu(),
bins='tensorflow', global_step=pl_module.current_epoch)
return send
[docs]class LogReLUActivations(Callback):
"""
Callback to send activations to logger
Parameters
----------
Returns
-------
"""
[docs] def setup(self, trainer, pl_module, stage):
self.hooks = []
[docs] def on_train_epoch_start(self, trainer, pl_module):
for name, child in pl_module.net.named_children():
if isinstance(child, nn.ReLU):
sender = make_activation_sender(pl_module, name)
self.hooks.append(child.register_forward_hook(sender))
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
for hook in self.hooks:
hook.remove()
[docs]def make_linear_gradient_logger(pl_module, name):
def log_grad(m, grad_input, grad_output):
pl_module.logger.experiment.add_histogram(name + 'grad_in', grad_input[0].data.cpu(),
bins='tensorflow', global_step=pl_module.current_epoch)
return log_grad
[docs]class LogLinearGradient(Callback):
"""
Callback to send gradients to logger
Parameters
----------
Returns
-------
"""
[docs] def setup(self, trainer, pl_module, stage):
self.hooks = []
[docs] def on_train_epoch_start(self, trainer, pl_module):
for name, child in pl_module.net.named_modules():
if isinstance(child, nn.Linear):
grad_logger = make_linear_gradient_logger(pl_module, name)
self.hooks.append(child.register_full_backward_hook(grad_logger))
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
for hook in self.hooks:
hook.remove()
[docs]def make_feature_logger(pl_module, name, index_matrices):
def log_features(m, input, output):
if output.dim() == 3:
features = output.detach().cpu().clone()
images_list = []
index_matrix = index_matrices[features.shape[-1]]
pixel_pos = np.array(build_hexagonal_position(index_matrix.squeeze().squeeze()))
pix_area = np.full(features.shape[-1], 6/np.sqrt(3)*0.5**2)
# TODO load meta from datafile
geom = CameraGeometry.from_table(
Table(
{
'pix_id': np.arange(features.shape[-1]),
'pix_x': list(map(lambda x: x[0], pixel_pos)),
'pix_y': list(map(lambda x: x[1], pixel_pos)),
'pix_area': pix_area,
},
meta={
'PIX_TYPE': 'hexagonal',
'PIX_ROT': 0,
'CAM_ROT': 0,
}
)
)
for b, batch in enumerate(features):
for c, channel in enumerate(batch):
label = '{}_b{}_c{}'.format(name, b, c)
ax = plt.axes(label=label)
ax.set_aspect('equal', 'datalim')
disp = CameraDisplay(geom, ax=ax)
disp.image = channel
disp.add_colorbar()
ax.set_title(label)
canvas = plt.get_current_fig_manager().canvas
canvas.draw()
pil_img = Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb())
images_list.append(transforms.ToTensor()(pil_img))
grid = t_utils.make_grid(images_list)
pl_module.logger.experiment.add_image('Features_{}'.format(name),
grid, pl_module.current_epoch)
return log_features
[docs]class LogFeatures(Callback):
[docs] def setup(self, trainer, pl_module, stage):
self.hooks = []
self.index_matrices = {}
index_matrix = create_index_matrix(pl_module.camera_parameters['nbRow'],
pl_module.camera_parameters['nbCol'],
pl_module.camera_parameters['injTable'])
n_pixels = int(torch.sum(torch.ge(index_matrix[0, 0], 0)).data)
self.index_matrices[n_pixels] = index_matrix
idx_matx = index_matrix
while n_pixels > 1:
idx_matx = pool_index_matrix(idx_matx, kernel_type=pl_module.camera_parameters['layout'])
n_pixels = int(torch.sum(torch.ge(idx_matx[0, 0], 0)).data)
self.index_matrices[n_pixels] = idx_matx
[docs] def on_train_epoch_start(self, trainer, pl_module):
for name, child in pl_module.net.named_children():
if isinstance(child, nn.ReLU):
feature_logger = make_feature_logger(pl_module, name, self.index_matrices)
self.hooks.append(child.register_forward_hook(feature_logger))
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
for hook in self.hooks:
hook.remove()
[docs]class LogGradientNorm(Callback):
"""
Callback to send the gradient total norm to logger
Parameters
----------
Returns
-------
"""
[docs] def on_train_epoch_end(self, trainer, pl_module):
pl_module.log('Gradient_norm', pl_module.grad_norm, on_epoch=True, on_step=False)
[docs]class WriteDL2Files(Callback):
"""
Callback to produce testing result data files
Parameters
----------
trainer (Trainer)
pl_module (LightningModule)
Returns
-------
"""
[docs] def on_test_end(self, trainer, pl_module):
# Retrieve test data
merged_outputs = pd.concat([pd.DataFrame(utils.prepare_dict_of_tensors(output))
for output in pl_module.test_data['output']], ignore_index=True)
merged_dl1_params = pd.concat([pd.DataFrame(utils.prepare_dict_of_tensors(dl1))
for dl1 in pl_module.test_data['dl1_params']], ignore_index=True)
dl2_params = utils.post_process_data(merged_outputs, merged_dl1_params, pl_module.experiment.dataset_parameters)
if pl_module.experiment.data_module_test is None or pl_module.experiment.merge_test_datasets:
# Test has been done on the validation set or test dl1 have been merged in datasets by particle type
ratio = pl_module.experiment.validating_ratio if pl_module.experiment.data_module_test is None else 1.0
# Retrieve MC config information
mc_configuration = {}
def fetch_dataset_info(d):
if isinstance(d, torch.utils.data.ConcatDataset):
for d_c in d.datasets:
fetch_dataset_info(d_c)
elif isinstance(d, Subset):
fetch_dataset_info(d.dataset)
elif issubclass(pl_module.experiment.dataset_class, dsets.BaseLSTDataset):
particle_type = d.dl1_params['mc_type'][0]
if particle_type not in mc_configuration:
mc_configuration[particle_type] = {'mc_energies': [], 'run_configs': []}
if d.simu:
mc_energies = d.trig_energies
np.random.shuffle(mc_energies)
mc_energies = mc_energies[:int(len(mc_energies) * ratio)]
d.run_config['mcheader']['num_showers'] *= ratio
mc_configuration[particle_type]['mc_energies'].extend(mc_energies)
mc_configuration[particle_type]['run_configs'].append(d.run_config)
else:
pl_module.console_logger.error('Unknown dataset type, MC configuration cannot be retrieved')
raise ValueError
for dataloader in trainer.test_dataloaders:
fetch_dataset_info(dataloader.dataset)
# Write one file per particle type
for mc_type in mc_configuration:
particle_mask = merged_dl1_params['mc_type'] == mc_type
gb_file_path = pl_module.experiment.main_directory + '/' + pl_module.experiment.experiment_name + '/' + \
pl_module.experiment.experiment_name + '_' + str(mc_type) + '.h5'
if os.path.exists(gb_file_path):
os.remove(gb_file_path)
writer = HDF5TableWriter(gb_file_path)
dl1_version = []
ctapipe_version = []
runlist = []
for config in mc_configuration[mc_type]['run_configs']:
try:
dl1_version.append(config['metadata']['LSTCHAIN_VERSION'])
except Exception:
pl_module.console_logger.warning('There is no LSTCHAIN_VERSION in run config')
try:
ctapipe_version.append(config['metadata']['CTAPIPE_VERSION'])
except Exception:
pl_module.console_logger.warning('There is no CTAPIPE_VERSION in run config')
try:
runlist.extend(config['metadata']['SOURCE_FILENAMES'])
except Exception:
pl_module.console_logger.warning('There is no SOURCE_FILENAMES in run config')
try:
writer.write('simulation/run_config', config['mcheader'])
except Exception:
pl_module.console_logger.warning('Issue when writing run config')
writer.close()
try:
assert len(set(dl1_version)) == 1
except AssertionError:
warning_msg = 'There should be strictly one dl1 data handler version in dataset but there are {}'\
.format(set(dl1_version))
pl_module.console_logger.warning(warning_msg)
dl1_version = 'Unknown'
else:
dl1_version = dl1_version[0]
try:
assert len(set(ctapipe_version)) == 1
except AssertionError:
warning_msg = 'There should be strictly one ctapipe version in dataset but there are {}'\
.format(set(ctapipe_version))
pl_module.console_logger.warning(warning_msg)
ctapipe_version = 'Unknown'
else:
ctapipe_version = ctapipe_version[0]
try:
assert runlist
except AssertionError:
pl_module.console_logger.warning('Run list is empty')
metadata = {
'LSTCHAIN_VERSION': dl1_version,
'CTAPIPE_VERSION': ctapipe_version,
'mc_type': mc_type,
'GAMMALEARN_VERSION': gl_version.__version__,
}
with tables.open_file(gb_file_path, mode='a') as file:
for k, item in metadata.items():
if k in file.root._v_attrs and type(file.root._v_attrs) is list:
attribute = file.root._v_attrs[k].extend(metadata[k])
file.root._v_attrs[k] = attribute
else:
file.root._v_attrs[k] = metadata[k]
if runlist and '/simulation' in file:
file.create_array('/simulation', 'runlist', obj=runlist)
pd.DataFrame(
{
'mc_trig_energies': np.array(mc_configuration[mc_type]['mc_energies'])
}
).to_hdf(gb_file_path,
key='triggered_events')
if mc_type == csts.REAL_DATA_ID:
# Post dl2 ops for real data
dl2_params = add_delta_t_key(dl2_params)
utils.write_dataframe(merged_dl1_params[particle_mask], outfile=gb_file_path,
table_path=dl1_params_lstcam_key)
write_dl2_dataframe(dl2_params[particle_mask], gb_file_path)
else:
# Prepare output
if pl_module.experiment.dl2_path is not None:
output_dir = pl_module.experiment.dl2_path
else:
output_dir = pl_module.experiment.main_directory + '/' + pl_module.experiment.experiment_name + '/dl2/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
dataset = trainer.test_dataloaders[0].dataset
output_name = os.path.basename(dataset.hdf5_file_path)
output_name = output_name.replace('dl1', 'dl2')
output_path = os.path.join(output_dir, output_name)
if os.path.exists(output_path):
os.remove(output_path)
mc_type = merged_dl1_params['mc_type'][0]
mc_energies = dataset.trig_energies
utils.write_dl2_file(dl2_params, dataset, output_path, mc_type=mc_type, mc_energies=mc_energies)
[docs]class WriteAutoEncoderDL1(Callback):
"""
Callback to produce testing result data files
Parameters
----------
trainer (Trainer)
pl_module (LightningModule)
Returns
-------
"""
[docs] def on_test_end(self, trainer, pl_module):
# Set output dataframe
output_df = pd.DataFrame()
# Fill the output dataframe with the errors between the AE outputs and the ground truths
merged_outputs = utils.merge_list_of_dict(pl_module.test_data['output']) #TODO: output may be a dict
for k, v in merged_outputs.items():
output_df[k] = torch.cat(v).detach().to('cpu').numpy()
# Also fill with the dl1 parameters if they are available
merged_dl1_params = utils.merge_list_of_dict(pl_module.test_data['dl1_params'])
for k, v in merged_dl1_params.items():
if k in ['mc_core_x', 'mc_core_y', 'tel_pos_x', 'tel_pos_y', 'tel_pos_z', 'mc_x_max']:
output_df[k] = 1000 * torch.cat(v).detach().to('cpu').numpy()
else:
output_df[k] = torch.cat(v).detach().to('cpu').numpy()
# Get output path
if pl_module.experiment.data_module_test is None:
# Test has to be done on the validation set: Write one file
output_path = os.path.join(pl_module.experiment.main_directory, pl_module.experiment.experiment_name,
pl_module.experiment.experiment_name + '_ae_validation_results.h5')
else:
# One output file per dl1 file
output_dir = os.path.join(pl_module.experiment.main_directory,
pl_module.experiment.experiment_name,
'ae_test_results')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
dataset = trainer.test_dataloaders[0].dataset
output_name = os.path.basename(dataset.hdf5_file_path)
output_name = output_name.replace('dl1', 'ae_results')
output_path = os.path.join(output_dir, output_name)
if os.path.exists(output_path):
os.remove(output_path)
# Write output dataframe
output_df.to_hdf(output_path, key='data')
[docs]class WriteData(Callback):
def __init__(self) -> None:
super().__init__()
self.output_dir_default = None
self.output_file_default = None
[docs] def get_output_path(self, experiment) -> Path:
# Prepare output folder
if experiment.output_dir is not None:
output_dir = Path(experiment.output_dir)
else:
output_dir = Path(experiment.main_directory, experiment.experiment_name, self.output_dir_default)
output_dir.mkdir(exist_ok=True)
# Prepare output file
if experiment.output_file is not None:
output_file = Path(experiment.output_file)
else:
output_file = Path(self.output_file_default)
# Get output path
output_path = output_dir.joinpath(output_file)
if output_path.exists():
output_path.unlink()
return output_path
[docs]class WriteAutoEncoder(WriteData):
"""
Callback to produce testing result data files
Parameters
----------
trainer (Trainer)
pl_module (LightningModule)
Returns
-------
"""
def __init__(self) -> None:
super().__init__()
self.output_dir_default = 'ae_results'
self.output_file_default = 'ae.csv'
[docs] def on_test_end(self, trainer, pl_module):
# Compute error between the AE outputs and the ground truths
error = torch.empty((0, ))
for output, label in zip(pl_module.test_data['output'], pl_module.test_data['label']):
prediction = output['autoencoder']
target = label['autoencoder']
error = torch.hstack((error, torch.pow(prediction - target, 2).mean().cpu()))
# Compute the mean of the error
output_df = pd.DataFrame({'MSE': error.mean().numpy()}, index=[0])
# Get output path
output_path = self.get_output_path(pl_module.experiment)
# Write output dataframe
output_df.to_csv(output_path, index=False)
[docs]class WriteAccuracy(WriteData):
"""
Callback to produce testing result data files
Parameters
----------
trainer (Trainer)
pl_module (LightningModule)
Returns
-------
"""
def __init__(self) -> None:
super().__init__()
self.output_dir_default = 'accuracy_results'
self.output_file_default = 'accuracy.csv'
[docs] def on_test_end(self, trainer, pl_module):
# Get prediction and ground truth
predictions, targets = torch.empty((0, )), torch.empty((0, ))
for output, label in zip(pl_module.test_data['output'], pl_module.test_data['label']):
predictions = torch.hstack((predictions, torch.argmax(output['class'], dim=1).cpu()))
targets = torch.hstack((targets, label['class'].cpu()))
predictions, targets = predictions.flatten().to(torch.int64), targets.flatten().to(torch.int64)
# Compute accuracy
num_classes = pl_module.experiment.targets['class']['output_shape']
accuracy = Accuracy(num_classes=num_classes, multiclass=True, average=None)
output_df = pd.DataFrame({'Accuracy': accuracy(predictions, targets).numpy()}, index=[np.arange(num_classes)])
# Get output path
output_path = self.get_output_path(pl_module.experiment)
# Write output dataframe
output_df.to_csv(output_path, index=False)
[docs]class WriteAccuracyDomain(WriteData):
"""
Callback to produce testing result data files
Parameters
----------
trainer (Trainer)
pl_module (LightningModule)
Returns
-------
"""
def __init__(self) -> None:
super().__init__()
self.output_dir_default = 'accuracy_domain_results'
self.output_file_default = 'accuracy_domain.csv'
[docs] def on_test_end(self, trainer, pl_module):
# Get prediction and ground truth
predictions = torch.empty((0, ))
for output in pl_module.test_data['output']:
predictions = torch.hstack((predictions, torch.argmax(output['domain_class'], dim=1).cpu()))
predictions = predictions.flatten().to(torch.int64)
labels_source = (torch.ones(predictions.shape) * SOURCE).to(torch.int64)
labels_target = (torch.ones(predictions.shape) * TARGET).to(torch.int64)
# Compute accuracy
num_classes = 2
accuracy = Accuracy(num_classes=num_classes)
output_df = pd.DataFrame(
[
{'Accuracy source': accuracy(predictions, labels_source).numpy()},
{'Accuracy target': accuracy(predictions, labels_target).numpy()}
],
index=[np.arange(2)]
)
# Get output path
output_path = self.get_output_path(pl_module.experiment)
# Write output dataframe
output_df.to_csv(output_path, index=False)
[docs]class WriteConfusionMatrix(WriteData):
"""
Callback to produce testing result data files
Parameters
----------
trainer (Trainer)
pl_module (LightningModule)
Returns
-------
"""
def __init__(self) -> None:
super().__init__()
self.output_dir_default = 'confusion_matrix_results'
self.output_file_default = 'confusion_matrix.csv'
[docs] def on_test_end(self, trainer, pl_module):
# Get prediction and ground truth
predictions, targets = torch.empty((0, )), torch.empty((0, ))
for output, label in zip(pl_module.test_data['output'], pl_module.test_data['label']):
predictions = torch.hstack((predictions, torch.argmax(output['class'], dim=1).cpu()))
targets = torch.hstack((targets, label['class'].cpu()))
predictions, targets = predictions.flatten().to(torch.int64), targets.flatten().to(torch.int64)
# Compute accuracy
num_classes = pl_module.experiment.targets['class']['output_shape']
cm = ConfusionMatrix(num_classes=num_classes)
output_df = pd.DataFrame(cm(predictions, targets).numpy(), index=[np.arange(num_classes)],
columns=np.arange(num_classes))
# Get output path
output_path = self.get_output_path(pl_module.experiment)
# Write output dataframe
output_df.to_csv(output_path, index=False)
[docs]class WriteADistance(WriteData):
"""
Callback to produce testing result data files
Parameters
----------
trainer (Trainer)
pl_module (LightningModule)
Returns
-------
"""
def __init__(self) -> None:
super().__init__()
self.output_dir_default = 'a_distance_results'
self.output_file_default = 'a_distance.csv'
[docs] def on_test_end(self, trainer, pl_module):
# Get prediction
predictions, targets = torch.empty((0,)), torch.empty((0,))
for output, label in zip(pl_module.test_data['output'], pl_module.test_data['label']):
predictions = torch.hstack((predictions, torch.argmax(output['domain_class'], dim=1).cpu()))
targets = torch.hstack((targets,label['domain_class'].cpu()))
predictions, targets = predictions.flatten().to(torch.int64), targets.flatten().to(torch.int64)
# Compute accuracy
accuracy_metric = Accuracy(num_classes=2)
# Compute a-distance
accuracy = accuracy_metric(predictions, targets)
error = 1. - accuracy
a_distance = torch.abs((2. * (1. - 2. * error)).mean()) # distance is 0 when classifier converges to 0.5 accuracy
output_df = pd.DataFrame({'accuracy': [accuracy.numpy()], 'A_distance': [a_distance.numpy()]})
# Get output path
output_path = self.get_output_path(pl_module.experiment)
# Write output dataframe
output_df.to_csv(output_path, index=False)