Tests in ML project

1 Описание класса DataModule

DataModule представляет собой класс, который создает загрузчики данных для обучения, валидации и тестирования модели.

1.1 Параметры

  • config: Экземпляр класса TrainingConfig, который содержит конфигурационные параметры для обучения модели, такие как размер изображения, размер пакета для обучения, количество рабочих потоков и т. д.

Пример данных

Code
dataset_name = draw(st.just('MNIST'))
train_batch_size = draw(st.integers(min_value=1, max_value=16))
test_batch_size = draw(st.integers(min_value=1, max_value=16))
val_batch_size = draw(st.integers(min_value=1, max_value=16))
num_workers = draw(st.integers(min_value=1, max_value=8))

1.2 Методы

1.2.1 prepare_data()

Загружает данные для обучения, валидации и тестирования. Параметры для загрузки данных (например, путь к данным) берутся из конфигурации.

1.2.2 setup()

Разделяет данные на обучающий, валидационный и тестовый наборы. Здесь также применяются преобразования данных, такие как изменение размера, горизонтальное отражение и нормализация.

1.2.3 train_dataloader(), val_dataloader(), test_dataloader()

Создают загрузчики данных для обучения, валидации и тестирования соответственно.

1.3 Свойства

  • Обучающий набор данных: Данные, используемые для обучения модели.
  • Валидационный набор данных: Данные, используемые для оценки модели в процессе обучения для настройки гиперпараметров и оценки производительности.
  • Тестовый набор данных: Независимый набор данных, используемый для окончательной оценки производительности модели.
  • Загрузчики данных: Подготавливают данные для эффективного использования в процессе обучения, валидации и тестирования модели.
  • Преобразования данных: Применяются к данным для улучшения обучения модели и обеспечения ее обобщающей способности.

2 Описание класса UNet2DLightning

Класс UNet2DLightning представляет собой реализацию многоканальной UNet модели в PyTorch Lightning для задачи обработки изображений.

2.1 Параметры

  • config: Экземпляр класса TrainingConfig, который содержит конфигурационные параметры для обучения модели, такие как размер изображения, размер пакета для обучения, количество рабочих потоков и т. д.

Пример данных

Code
dataset_name = draw(st.just('MNIST'))
train_batch_size = draw(st.integers(min_value=1, max_value=16))
test_batch_size = draw(st.integers(min_value=1, max_value=16))
val_batch_size = draw(st.integers(min_value=1, max_value=16))
num_workers = draw(st.integers(min_value=1, max_value=8))

2.2 Методы

2.2.1 __init__()

Инициализирует объект класса. Создает экземпляры необходимых компонентов модели, такие как слои, функции потерь и расписывает гиперпараметры.

2.2.2 forward()

Прямой проход модели. Принимает на вход изображение x, временные шаги timesteps и метки класса class_labels. Если pred=True, возвращает сэмплы изображений с добавленным шумом, иначе возвращает предсказания модели.

2.2.3 on_train_start(), on_training_epoch_end(), validation_step(), on_validation_epoch_end(), test_step()

Методы, используемые для обучения, валидации и тестирования модели. Реализуют соответствующие этапы обучения и вычисления функции потерь.

2.2.4 predict_step()

Метод для предсказания модели. Принимает на вход метки класса class_labels и возвращает сгенерированные изображения.

2.2.5 configure_optimizers()

Конфигурирует оптимизатор и планировщик шага обучения для модели.

3 Тесты

3.1 Тесты на инициализацию, подготовку данных и настройку

3.1.1 Тест на инициализацию DataModule с конфигурацией

Code
@given(valid_training_config())
def test_data_module_init(config):
    data_module = DataModule(config)
    assert data_module.config == config

Этот тест проверяет, что при инициализации объекта DataModule с заданной конфигурацией объект конфигурации сохраняется внутри модуля.

3.1.2 Тест на подготовку данных

Code
@given(valid_training_config())
def test_prepare_data(config):
    data_module = DataModule(config)
    data_module.prepare_data()
    assert data_module.dataset is not None

Данный тест проверяет, что после вызова метода prepare_data() объект dataset в DataModule не является пустым.

3.1.3 Тест на подготовку данных с недопустимым именем датасета

Code
@given(valid_training_config())
def test_prepare_data_with_invalid_dataset(config):
    data_module = DataModule(config)
    config.dataset_name = "invalid_dataset_name"
    data_module = DataModule(config)
    with pytest.raises(ImportError):
        data_module.prepare_data()

Этот тест проверяет, что при передаче недопустимого имени датасета в метод prepare_data(), вызывается исключение ImportError.

3.1.4 Тест на настройку DataModule

Code
@given(valid_training_config())
def test_setup(config):
    data_module = DataModule(config)
    data_module.prepare_data()
    data_module.setup()
    assert data_module.train_dataset is not None
    assert data_module.test_dataset is not None
    assert data_module.val_dataset is not None
    assert isinstance(
        data_module.train_dataloader(), torch.utils.data.dataloader.DataLoader
    )
    assert isinstance(
        data_module.val_dataloader(), torch.utils.data.dataloader.DataLoader
    )
    assert isinstance(
        data_module.test_dataloader(), torch.utils.data.dataloader.DataLoader
    )

Этот тест проверяет, что после подготовки данных методом setup(), обучающий, тестовый и валидационный наборы данных не являются пустыми, и что загрузчики данных создаются правильно.

3.2 Тесты на обучение, валидацию, тестирование и предсказание модели

3.2.1 Тест на обучение модели

@given(model_config())
def test_model_training(config):
    config.num_timesteps = 1
    model = UNet2DLightning(config=config)
    image = torch.randn(
        (
            config.pred_batch_size,
            1,
            config.image_size,
            config.image_size,
        ),
        device=model.device,
    )
    rand_label = torch.tensor(0)
    batch = [image, rand_label]
    train_loss = model.training_step(batch=batch, batch_idx=0)
    assert train_loss.item() >= 0

Этот тест проверяет, что функция потерь на этапе обучения неотрицательна.

3.2.2 Тест на валидацию модели

@given(model_config())
def test_model_validation(config):
    config.num_timesteps = 1
    model = UNet2DLightning(config=config)
    image = torch.randn(
        (
            config.pred_batch_size,
            1,
            config.image_size,
            config.image_size,
        ),
        device=model.device,
    )
    rand_label = torch.tensor(0)
    batch = [image, rand_label]
    val_loss = model.validation_step(batch=batch, batch_idx=0)
    assert val_loss.item() >= 0

Этот тест проверяет, что функция потерь на этапе валидации неотрицательна.

3.2.3 Тест на тестирование модели

@given(model_config())
def test_model_testing(config):
    config.num_timesteps = 1
    model = UNet2DLightning(config=config)
    image = torch.randn(
        (
            config.pred_batch_size,
            1,
            config.image_size,
            config.image_size,
        ),
        device=model.device,
    )
    rand_label = torch.tensor(0)
    batch = [image, rand_label]
    test_loss = model.test_step(batch=batch, batch_idx=0)
    assert test_loss.item() >= 0

Этот тест проверяет, что функция потерь на этапе тестирования неотрицательна.

3.2.4 Тест на предсказание модели

@given(model_config())
def test_model_prediction(config):
    config.num_timesteps = 1
    model = UNet2DLightning(config=config)
    rand_label = torch.tensor(0)
    with torch.no_grad():
        pred_image = model.predict_step(batch=rand_label, batch_idx=0)
    assert pred_image.shape == (
        config.pred_batch_size,
        config.image_size,
        config.image_size,
        1,
    )

Этот тест проверяет, что размер предсказанного изображения соответствует ожидаемому размеру.

3.3 Исправление ошибок при тестировании

При проверке тестов была обнаружена ошибка в тесте test_prepare_data_with_invalid_dataset, который проверяет обработку исключения ImportError при передаче недопустимого имени датасета в метод prepare_data().

Для исправления этой ошибки был использован следующий код:

Code
try:
    importlib.import_module(
        f'torchvision.datasets.{self.config.dataset_name.lower()}'
    )
except ImportError:
    raise ImportError(
        f'Dataset {self.config.dataset_name} is not available in torchvision.datasets'
    )

Этот код позволяет проверить, существует ли модуль датасета с указанным именем, и если нет, то выбрасывает исключение ImportError с информацией о недоступности датасета.

После применения этого исправления, тест test_prepare_data_with_invalid_dataset корректно отработал и проверил обработку недопустимого имени датасета.