| from torchvision.datasets import ImageFolder | |
| import torch | |
| import os | |
| import collections | |
| torch.manual_seed(0) | |
| ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag')) | |
| class RobustnessDataset(ImageFolder): | |
| def __init__(self, dataset_path): | |
| self._dataset_path = dataset_path | |
| self._tag_list = [tag for tag in os.listdir(self._dataset_path)] | |
| self._all_images = [] | |
| for tag in self._tag_list: | |
| base_dir = os.path.join(self._dataset_path, tag) | |
| for i, file in enumerate(os.listdir(base_dir)): | |
| self._all_images.append(ImageItem(file, tag)) | |
| def __getitem__(self, item): | |
| image_item = self._all_images[item] | |
| image_path = os.path.join(self._dataset_path, image_item.tag, image_item.image_name) | |
| return image_path | |
| def __len__(self): | |
| return len(self._all_images) |