Source code for gammalearn.tests.test_nets

import unittest
import gammalearn.data.nets as nets
import torch
import torchvision.models as models
from ctapipe.instrument import CameraGeometry


[docs]class TestNets(unittest.TestCase):
[docs] def setUp(self) -> None: self.input = torch.rand(10, 2, 55, 55) self.resnet18 = { 'model': models.resnet18, 'parameters': { 'output_size': (7, 7), 'num_channels': 2 } } self.mobilenet_v2 = { 'model': models.mobilenet_v2, 'parameters': { 'output_size': (9, 9), 'num_channels': 2 } } self.mobilenet_v3 = { 'model': models.mobilenet_v3_large, 'parameters': { 'output_size': (9, 9), 'num_channels': 2 } } self.efficient_net = { 'model': models.efficientnet_b7, 'parameters': { 'output_size': (9, 9), 'num_channels': 2 } } self.mae = { 'model': nets.LSTMaskedAutoEncoder, 'parameters': { 'backbone': { 'parameters': { 'num_channels': 2, 'blocks': 24, 'embed_dim': 1024, 'mlp_ratio': 4, 'heads': 16, 'add_token_list': [], 'mask_ratio': 0.75, 'camera_geometry': CameraGeometry.from_name('LSTCam'), 'add_pointing': True, } }, 'decoder': { 'parameters': { 'blocks': 8, 'embed_dim': 512, 'mlp_ratio': 4, 'heads': 16, } }, 'norm_pixel_loss': False } }
[docs] def test_resnet18(self): net = nets.TorchConvNet(self.resnet18) output = net(self.input) assert net.num_features == 512 assert output.shape[1] == 512 assert output.shape[2:] == self.resnet18['parameters']['output_size'] assert net.n_pixels == torch.prod(torch.tensor(output.shape[2:]))
[docs] def test_mobilenet_v2(self): net = nets.TorchConvNet(self.mobilenet_v2) output = net(self.input) assert net.num_features == 1280 assert output.shape[1] == 1280 assert output.shape[2:] == self.mobilenet_v2['parameters']['output_size'] assert net.n_pixels == torch.prod(torch.tensor(output.shape[2:]))
[docs] def test_mobilenet_v3(self): net = nets.TorchConvNet(self.mobilenet_v3) output = net(self.input) assert net.num_features == 960 assert output.shape[1] == 960 assert output.shape[2:] == self.mobilenet_v3['parameters']['output_size'] assert net.n_pixels == torch.prod(torch.tensor(output.shape[2:]))
[docs] def test_efficient_net(self): net = nets.TorchConvNet(self.efficient_net) output = net(self.input) assert net.num_features == 2560 assert output.shape[1] == 2560 assert output.shape[2:] == self.efficient_net['parameters']['output_size'] assert net.n_pixels == torch.prod(torch.tensor(output.shape[2:]))
[docs] def test_mae(self): net = self.mae['model'](self.mae['parameters']) images = torch.rand(2, 2, 1855) pointing = torch.rand(2, 2) patches = net.patchify(images) assert patches.ndim == 3 assert torch.all(net.unpatchify(patches) == images) ratio = torch.tensor(0.75) tokens = net.patch_projection(images) tokens = tokens.transpose(1, 2) masked_tokens, mask, ids_restore = net.apply_random_mask(tokens, ratio) assert masked_tokens.ndim == 3 masked_patches_ratio = torch.count_nonzero(mask, dim=1) / mask.shape[1] assert torch.allclose(masked_patches_ratio, ratio, atol=0.01) enc_tokens, mask, ids_restore = net.forward_encoder(images, pointing) enc_tokens_dec = net.decoder_embedding(enc_tokens) unmasked_tokens = net._unmask_tokens(enc_tokens_dec, ids_restore) assert unmasked_tokens.shape == (2, 265, 512) predictions = net.forward_decoder(enc_tokens, ids_restore) image_patches = net.patchify(images) assert predictions.shape == image_patches.shape unpatched_pred = net.unpatchify(predictions) assert unpatched_pred.shape == images.shape loss = net(images, pointing)