Shortcuts

Source code for mmselfsup.datasets.data_sources.imagenet

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp

import numpy as np

from ..builder import DATASOURCES
from .base import BaseDataSource


def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)


def find_folders(root):
    """Find classes by folders under a root.

    Args:
        root (string): root directory of folders

    Returns:
        folder_to_idx (dict): the map from folder name to class idx
    """
    folders = [d for d in os.listdir(root) if osp.isdir(osp.join(root, d))]
    folders.sort()
    folder_to_idx = {folders[i]: i for i in range(len(folders))}
    return folder_to_idx


def get_samples(root, folder_to_idx, extensions):
    """Make dataset by walking all images under a root.

    Args:
        root (string): root directory of folders
        folder_to_idx (dict): the map from class name to class idx
        extensions (tuple): allowed extensions

    Returns:
        samples (list): a list of tuple where each element is (image, label)
    """
    samples = []
    root = osp.expanduser(root)
    for folder_name in sorted(list(folder_to_idx.keys())):
        _dir = osp.join(root, folder_name)
        for _, _, fns in sorted(os.walk(_dir)):
            for fn in sorted(fns):
                if has_file_allowed_extension(fn, extensions):
                    path = osp.join(folder_name, fn)
                    item = (path, folder_to_idx[folder_name])
                    samples.append(item)
    return samples


[docs]@DATASOURCES.register_module() class ImageNet(BaseDataSource): """`ImageNet <http://www.image-net.org>`_ Dataset. This implementation is modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py """ # noqa: E501 IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') def load_annotations(self): if self.ann_file is None: folder_to_idx = find_folders(self.data_prefix) samples = get_samples( self.data_prefix, folder_to_idx, extensions=self.IMG_EXTENSIONS) if len(samples) == 0: raise (RuntimeError('Found 0 files in subfolders of: ' f'{self.data_prefix}. ' 'Supported extensions are: ' f'{",".join(self.IMG_EXTENSIONS)}')) self.folder_to_idx = folder_to_idx elif isinstance(self.ann_file, str): with open(self.ann_file) as f: samples = [x.strip().rsplit(' ', 1) for x in f.readlines()] else: raise TypeError('ann_file must be a str or None') self.samples = samples data_infos = [] for i, (filename, gt_label) in enumerate(self.samples): info = {'img_prefix': self.data_prefix} info['img_info'] = {'filename': filename} info['gt_label'] = np.array(gt_label, dtype=np.int64) info['idx'] = int(i) data_infos.append(info) return data_infos
Read the Docs v: latest
Versions
latest
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.