DataModule представляет собой класс, который создает загрузчики данных для обучения, валидации и тестирования модели.
1.1 Параметры
config: Экземпляр класса TrainingConfig, который содержит конфигурационные параметры для обучения модели, такие как размер изображения, размер пакета для обучения, количество рабочих потоков и т. д.
Загружает данные для обучения, валидации и тестирования. Параметры для загрузки данных (например, путь к данным) берутся из конфигурации.
1.2.2setup()
Разделяет данные на обучающий, валидационный и тестовый наборы. Здесь также применяются преобразования данных, такие как изменение размера, горизонтальное отражение и нормализация.
Создают загрузчики данных для обучения, валидации и тестирования соответственно.
1.3 Свойства
Обучающий набор данных: Данные, используемые для обучения модели.
Валидационный набор данных: Данные, используемые для оценки модели в процессе обучения для настройки гиперпараметров и оценки производительности.
Тестовый набор данных: Независимый набор данных, используемый для окончательной оценки производительности модели.
Загрузчики данных: Подготавливают данные для эффективного использования в процессе обучения, валидации и тестирования модели.
Преобразования данных: Применяются к данным для улучшения обучения модели и обеспечения ее обобщающей способности.
2 Описание класса UNet2DLightning
Класс UNet2DLightning представляет собой реализацию многоканальной UNet модели в PyTorch Lightning для задачи обработки изображений.
2.1 Параметры
config: Экземпляр класса TrainingConfig, который содержит конфигурационные параметры для обучения модели, такие как размер изображения, размер пакета для обучения, количество рабочих потоков и т. д.
Инициализирует объект класса. Создает экземпляры необходимых компонентов модели, такие как слои, функции потерь и расписывает гиперпараметры.
2.2.2forward()
Прямой проход модели. Принимает на вход изображение x, временные шаги timesteps и метки класса class_labels. Если pred=True, возвращает сэмплы изображений с добавленным шумом, иначе возвращает предсказания модели.
Этот тест проверяет, что после подготовки данных методом setup(), обучающий, тестовый и валидационный наборы данных не являются пустыми, и что загрузчики данных создаются правильно.
3.2 Тесты на обучение, валидацию, тестирование и предсказание модели
Этот тест проверяет, что размер предсказанного изображения соответствует ожидаемому размеру.
3.3 Исправление ошибок при тестировании
При проверке тестов была обнаружена ошибка в тесте test_prepare_data_with_invalid_dataset, который проверяет обработку исключения ImportError при передаче недопустимого имени датасета в метод prepare_data().
Для исправления этой ошибки был использован следующий код:
Code
try: importlib.import_module(f'torchvision.datasets.{self.config.dataset_name.lower()}' )exceptImportError:raiseImportError(f'Dataset {self.config.dataset_name} is not available in torchvision.datasets' )
Этот код позволяет проверить, существует ли модуль датасета с указанным именем, и если нет, то выбрасывает исключение ImportError с информацией о недоступности датасета.
После применения этого исправления, тест test_prepare_data_with_invalid_dataset корректно отработал и проверил обработку недопустимого имени датасета.