import logging
import torch.multiprocessing as mp
from functools import partial
import tqdm
import collections
from pytorch_lightning import LightningDataModule
import torch
from torch.utils.data import DataLoader, Dataset, ConcatDataset, Subset
from torchvision import transforms
from gammalearn.datasets import VisionDomainAdaptationDataset, GlearnDomainAdaptationDataset
from gammalearn import utils as utils
from gammalearn.logging import LOGGING_CONFIG
[docs]def create_dataset_worker(file,
dataset_class,
train,
**kwargs):
torch.set_num_threads(1)
# Reload logging config (lost by spawn)
logging.config.dictConfig(LOGGING_CONFIG)
if utils.is_datafile_healthy(file):
dataset = dataset_class(file, train=train, **kwargs)
if kwargs.get('image_filter') is not None:
dataset.filter_image(kwargs.get('image_filter'))
if kwargs.get('event_filter') is not None:
dataset.filter_event(kwargs.get('event_filter'))
if len(dataset) > 0:
return dataset
[docs]def create_datasets(datafiles_list, experiment, train=True, **kwargs):
"""
Create datasets from datafiles list, data are loaded in memory.
Parameters
----------
datafiles (List) : files to load data from
experiment (Experiment): the experiment
Returns
-------
Datasets
"""
logger = logging.getLogger('gammalearn')
assert datafiles_list, 'The data file list is empty !'
logger.info('length of data file list : {}'.format(len(datafiles_list)))
# We get spawn context because fork can cause deadlock in sub-processes
# in multi-threaded programs (especially with logging)
ctx = mp.get_context('spawn')
if experiment.preprocessing_workers > 0:
num_workers = experiment.preprocessing_workers
else:
num_workers = 1
pool = ctx.Pool(processes=num_workers)
datasets = list(tqdm.tqdm(pool.imap(partial(create_dataset_worker,
dataset_class=experiment.dataset_class,
train=train,
**kwargs),
datafiles_list),
total=len(datafiles_list),
desc='Load data files'
)
)
return datasets
[docs]def split_dataset(datasets, ratio):
"""Split a list of datasets into a train and a validation set
Parameters
----------
datasets (list of Dataset): the list of datasets
ratio (float): the ratio of data for validation
Returns
-------
train set, validation set
"""
# Creation of subset train and test
assert 1 > ratio > 0, 'Validating ratio must be greater than 0 and smaller than 1.'
train_max_index = int(len(datasets) * (1 - ratio))
shuffled_indices = torch.randperm(len(datasets)).numpy()
assert isinstance(datasets, Dataset)
train_datasets = Subset(datasets, shuffled_indices[:train_max_index])
val_datasets = Subset(datasets, shuffled_indices[train_max_index:])
return train_datasets, val_datasets
[docs]def balance_datasets(source_datasets, target_datasets):
max_index = min(len(source_datasets), len(target_datasets))
shuffled_indices_source = torch.randperm(len(source_datasets)).numpy()
shuffled_indices_target = torch.randperm(len(target_datasets)).numpy()
source_datasets = Subset(source_datasets, shuffled_indices_source[:max_index])
target_datasets = Subset(target_datasets, shuffled_indices_target[:max_index])
return source_datasets, target_datasets
[docs]class BaseDataModule(LightningDataModule):
"""
Create datasets and dataloaders.
Parameters
----------
experiment (Experiment): the experiment
Returns
-------
"""
def __init__(self, experiment):
super().__init__()
self.experiment = experiment
self.logger = logging.getLogger(__name__)
self.train_set = None
self.val_set = None
self.test_sets = None # List
self.collate_fn = torch.utils.data.default_collate
[docs] def setup(self, stage=None):
"""
In the case that the train and the test data modules are different, two setup functions are defined in order to
prevent from loading data twice.
"""
self.setup_train()
self.setup_test()
[docs] def setup_train(self):
"""
This function is used if train is set to True in experiment setting file
"""
self.logger.info('Start creating datasets')
self.logger.info('look for data files')
# Creation of the global train/val dataset
datasets = self.get_dataset(train=True)
assert datasets, 'Dataset is empty !'
# Creation of subsets train and validation
train_datasets, val_datasets = split_dataset(datasets, self.experiment.validating_ratio)
self.train_set = train_datasets
self.logger.info('training set length : {}'.format(len(self.train_set)))
self.val_set = val_datasets
try:
assert len(self.val_set) > 0
except AssertionError as e:
self.logger.exception('Validating set must contain data')
raise e
self.logger.info('validating set length : {}'.format(len(self.val_set)))
[docs] def setup_test(self):
"""
This function is used if test is set to True in experiment setting file.
If no data module test is provided, test is completed on the validation set.
If neither a data module test nor a validation set is provided, an error will be raised.
"""
if self.experiment.data_module_test is not None:
# Look for specific data parameters
if self.experiment.test_dataset_parameters is not None:
self.experiment.dataset_parameters.update(self.experiment.test_dataset_parameters)
# Creation of the test datasets
self.test_sets = self.get_dataset(train=False)
else: # Test is set to False in experiment setting file
assert self.val_set is not None, 'Test is required but no test file is provided and val_set is None'
self.test_sets = [self.val_set]
self.logger.info('test set length : {}'.format(torch.tensor([len(t) for t in self.test_sets]).sum()))
[docs] def train_dataloader(self):
training_loader = DataLoader(self.train_set,
batch_size=self.experiment.batch_size,
shuffle=True,
drop_last=True,
num_workers=self.experiment.dataloader_workers,
pin_memory=self.experiment.pin_memory,
collate_fn=self.collate_fn)
self.logger.info('training loader length : {} batches'.format(len(training_loader)))
return training_loader
[docs] def val_dataloader(self):
validating_loader = DataLoader(self.val_set,
batch_size=self.experiment.batch_size,
shuffle=False,
num_workers=self.experiment.dataloader_workers,
drop_last=True,
pin_memory=self.experiment.pin_memory,
collate_fn=self.collate_fn)
self.logger.info('validating loader length : {} batches'.format(len(validating_loader)))
return validating_loader
[docs] def test_dataloaders(self):
test_loaders = [DataLoader(test_set, batch_size=self.experiment.test_batch_size, shuffle=False,
drop_last=False, num_workers=self.experiment.dataloader_workers)
for test_set in self.test_sets]
self.logger.info('test loader length : {} data loader(s)'.format(len(test_loaders)))
self.logger.info('test loader length : {} batches'.format(torch.tensor([len(t) for t in test_loaders]).sum()))
return test_loaders
[docs] def get_dataset(self, train):
"""
DataModule-specific method to be overwritten to load the dataset.
"""
return NotImplementedError
[docs] def get_collate_fn(self):
numpy_type_map = {
'float64': torch.DoubleTensor,
'float32': torch.FloatTensor,
'float16': torch.HalfTensor,
'int64': torch.LongTensor,
'int32': torch.IntTensor,
'int16': torch.ShortTensor,
'int8': torch.CharTensor,
'uint8': torch.ByteTensor,
}
def collate_fn(batch):
"""
Puts each data field into a tensor with outer dimension batch size. From:
https://github.com/hughperkins/pytorch-pytorch/blob/c902f1cf980eef27541f3660c685f7b59490e744/torch/utils/data/dataloader.py#L91
"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if torch.is_tensor(batch[0]):
out = None
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], (str, bytes)):
return batch
elif isinstance(batch[0], collections.Mapping):
# If MC and real data in target, find the common keys
common_keys = set(batch[0].keys())
for d in batch[1:]:
common_keys.intersection_update(d.keys())
return {key: collate_fn([d[key] for d in batch]) for key in common_keys}
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [collate_fn(samples) for samples in transposed]
raise TypeError((error_msg.format(type(batch[0]))))
return collate_fn
[docs]class GLearnDataModule(BaseDataModule):
def __init__(self, experiment):
super().__init__(experiment)
[docs] def get_dataset(self, train):
max_files = self.experiment.train_files_max_number if train else self.experiment.test_files_max_number
data_module = utils.fetch_data_module_settings(self.experiment, train=train, domain=None)
dataset = self.get_glearn_dataset_from_path(data_module, train, domain=None, max_files=max_files)
return dataset
[docs] def get_glearn_dataset_from_path(self, data_module, train, domain=None, max_files=None):
if max_files is None:
max_files = -1
else:
if isinstance(max_files, dict) and domain is not None:
max_files = max_files.get(domain, -1)
file_list = utils.find_datafiles(data_module['paths'], max_files)
file_list = list(file_list)
file_list.sort()
datasets = create_datasets(file_list, self.experiment, train=train, **{'domain': domain}, **data_module,
**self.experiment.dataset_parameters)
if train:
# Check the dataset list heterogeneity (e.g. simu and real data in target)
if not(all([dset.simu for dset in datasets]) or not any([dset.simu for dset in datasets])):
self.collate_fn = self.get_collate_fn()
return ConcatDataset(datasets)
else:
if self.experiment.merge_test_datasets:
particle_dict = {}
for dset in datasets:
if dset.simu:
particle_type = dset.dl1_params['mc_type'][0]
if particle_type in particle_dict:
particle_dict[particle_type].append(dset)
else:
particle_dict[particle_type] = [dset]
else:
if 'real_list' in particle_dict:
particle_dict['real_list'].append(dset)
else:
particle_dict['real_list'] = [dset]
return [ConcatDataset(dset) for dset in particle_dict.values()]
else:
return datasets
[docs]class GLearnDomainAdaptationDataModule(GLearnDataModule):
def __init__(self, experiment):
super().__init__(experiment)
self.dataset_balancing = experiment.dataset_parameters.get('dataset_balancing', False)
[docs] def get_dataset(self, train):
max_files = self.experiment.train_files_max_number if train else self.experiment.test_files_max_number
data_module_source = utils.fetch_data_module_settings(self.experiment, train=train, domain='source')
data_module_target = utils.fetch_data_module_settings(self.experiment, train=train, domain='target')
source_datasets = self.get_glearn_dataset_from_path(data_module_source,
train,
domain='source',
max_files=max_files)
target_datasets = self.get_glearn_dataset_from_path(data_module_target,
train,
domain='target',
max_files=max_files)
if self.dataset_balancing:
source_datasets, target_datasets = balance_datasets(source_datasets, target_datasets)
return GlearnDomainAdaptationDataset(source_datasets, target_datasets)
[docs]class VisionDataModule(BaseDataModule):
"""
Create datasets and dataloaders.
Parameters
----------
experiment (Experiment): the experiment
Returns
-------
"""
def __init__(self, experiment):
super().__init__(experiment)
[docs] def get_dataset(self, train):
max_files = self.experiment.train_files_max_number if train else self.experiment.test_files_max_number
data_module = utils.fetch_data_module_settings(self.experiment, train=train, domain=None)
dataset = self.get_dataset_from_path(data_module, train=train, domain=None, max_files=max_files)
return dataset
[docs] def get_dataset_from_path(self, data_module, train, domain=None, max_files=None):
datasets = self.experiment.dataset_class(
paths=data_module['paths'],
dataset_parameters=self.experiment.dataset_parameters,
transform=data_module['transform'],
target_transform=data_module['target_transform'],
train=train,
domain=domain,
max_files=max_files,
num_workers=self.experiment.preprocessing_workers,
)
return [datasets] if not train else datasets
[docs]class VisionDomainAdaptationDataModule(VisionDataModule):
"""
Create datasets and dataloaders.
Parameters
----------
experiment (Experiment): the experiment
Returns
-------
"""
def __init__(self, experiment):
super().__init__(experiment)
self.dataset_balancing = experiment.dataset_parameters.get('dataset_balancing', False)
[docs] def get_dataset(self, train):
max_files = self.experiment.train_files_max_number if train else self.experiment.test_files_max_number
data_module_source = utils.fetch_data_module_settings(self.experiment, train=train, domain='source')
data_module_target = utils.fetch_data_module_settings(self.experiment, train=train, domain='target')
dataset_src = self.get_dataset_from_path(data_module_source, train=train, domain='source', max_files=max_files)
dataset_trg = self.get_dataset_from_path(data_module_target, train=train, domain='target', max_files=max_files)
if self.dataset_balancing:
dataset_src, dataset_trg = balance_datasets(dataset_src, dataset_trg)
return VisionDomainAdaptationDataset(dataset_src, dataset_trg)