Source code for mmselfsup.datasets.data_sources.imagenet
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import numpy as np
from ..builder import DATASOURCES
from .base import BaseDataSource
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_folders(root):
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
folder_to_idx (dict): the map from folder name to class idx
"""
folders = [d for d in os.listdir(root) if osp.isdir(osp.join(root, d))]
folders.sort()
folder_to_idx = {folders[i]: i for i in range(len(folders))}
return folder_to_idx
def get_samples(root, folder_to_idx, extensions):
"""Make dataset by walking all images under a root.
Args:
root (string): root directory of folders
folder_to_idx (dict): the map from class name to class idx
extensions (tuple): allowed extensions
Returns:
samples (list): a list of tuple where each element is (image, label)
"""
samples = []
root = osp.expanduser(root)
for folder_name in sorted(list(folder_to_idx.keys())):
_dir = osp.join(root, folder_name)
for _, _, fns in sorted(os.walk(_dir)):
for fn in sorted(fns):
if has_file_allowed_extension(fn, extensions):
path = osp.join(folder_name, fn)
item = (path, folder_to_idx[folder_name])
samples.append(item)
return samples
[docs]@DATASOURCES.register_module()
class ImageNet(BaseDataSource):
"""`ImageNet <http://www.image-net.org>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py
""" # noqa: E501
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
def load_annotations(self):
if self.ann_file is None:
folder_to_idx = find_folders(self.data_prefix)
samples = get_samples(
self.data_prefix,
folder_to_idx,
extensions=self.IMG_EXTENSIONS)
if len(samples) == 0:
raise (RuntimeError('Found 0 files in subfolders of: '
f'{self.data_prefix}. '
'Supported extensions are: '
f'{",".join(self.IMG_EXTENSIONS)}'))
self.folder_to_idx = folder_to_idx
elif isinstance(self.ann_file, str):
with open(self.ann_file) as f:
samples = [x.strip().rsplit(' ', 1) for x in f.readlines()]
else:
raise TypeError('ann_file must be a str or None')
self.samples = samples
data_infos = []
for i, (filename, gt_label) in enumerate(self.samples):
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
info['idx'] = int(i)
data_infos.append(info)
return data_infos