
Source code for mmselfsup.datasets.relative_loc

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torchvision.transforms.functional as TF
from mmcv.utils import build_from_cfg
from torchvision.transforms import Compose, RandomCrop

from .base import BaseDataset
from .builder import DATASETS, PIPELINES
from .utils import to_numpy

def image_to_patches(img):
    """Crop split_per_side x split_per_side patches from input image.

        img (PIL Image): input image.

        list[PIL Image]: A list of cropped patches.
    split_per_side = 3  # split of patches per image side
    patch_jitter = 21  # jitter of each patch from each grid
    h, w = img.size
    h_grid = h // split_per_side
    w_grid = w // split_per_side
    h_patch = h_grid - patch_jitter
    w_patch = w_grid - patch_jitter
    assert h_patch > 0 and w_patch > 0
    patches = []
    for i in range(split_per_side):
        for j in range(split_per_side):
            p = TF.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
            p = RandomCrop((h_patch, w_patch))(p)
    return patches

[docs]@DATASETS.register_module() class RelativeLocDataset(BaseDataset): """Dataset for relative patch location. The dataset crops image into several patches and concatenates every surrounding patch with center one. Finally it also outputs corresponding labels `0, 1, 2, 3, 4, 5, 6, 7`. Args: data_source (dict): Data source defined in `mmselfsup.datasets.data_sources`. pipeline (list[dict]): A list of dict, where each element represents an operation defined in `mmselfsup.datasets.pipelines`. format_pipeline (list[dict]): A list of dict, it converts input format from PIL.Image to Tensor. The operation is defined in `mmselfsup.datasets.pipelines`. prefetch (bool, optional): Whether to prefetch data. Defaults to False. """ def __init__(self, data_source, pipeline, format_pipeline, prefetch=False): super(RelativeLocDataset, self).__init__(data_source, pipeline, prefetch) format_pipeline = [ build_from_cfg(p, PIPELINES) for p in format_pipeline ] self.format_pipeline = Compose(format_pipeline) def __getitem__(self, idx): img = self.data_source.get_img(idx) img = self.pipeline(img) patches = image_to_patches(img) if self.prefetch: patches = [torch.from_numpy(to_numpy(p)) for p in patches] else: patches = [self.format_pipeline(p) for p in patches] perms = [] # create a list of patch pairs [ perms.append([i], patches[4]), dim=0)) for i in range(9) if i != 4 ] # create corresponding labels for patch pairs patch_labels = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7]) return dict( img=torch.stack(perms), patch_label=patch_labels) # 8(2C)HW, 8 def evaluate(self, results, logger=None): return NotImplemented
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.