Shortcuts

Source code for mmselfsup.datasets.rotation_pred

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .base import BaseDataset
from .builder import DATASETS
from .utils import to_numpy


def rotate(img):
    """Rotate input image with 0, 90, 180, and 270 degrees.

    Args:
        img (Tensor): input image of shape (C, H, W).

    Returns:
        list[Tensor]: A list of four rotated images.
    """
    return [
        img,
        torch.flip(img.transpose(1, 2), [1]),
        torch.flip(img, [1, 2]),
        torch.flip(img, [1]).transpose(1, 2)
    ]


[docs]@DATASETS.register_module() class RotationPredDataset(BaseDataset): """Dataset for rotation prediction. The dataset rotates the image with 0, 90, 180, and 270 degrees and outputs labels `0, 1, 2, 3` correspodingly. Args: data_source (dict): Data source defined in `mmselfsup.datasets.data_sources`. pipeline (list[dict]): A list of dict, where each element represents an operation defined in `mmselfsup.datasets.pipelines`. prefetch (bool, optional): Whether to prefetch data. Defaults to False. """ def __init__(self, data_source, pipeline, prefetch=False): super(RotationPredDataset, self).__init__(data_source, pipeline, prefetch) def __getitem__(self, idx): img = self.data_source.get_img(idx) img = self.pipeline(img) if self.prefetch: img = torch.from_numpy(to_numpy(img)) img = torch.stack(rotate(img), dim=0) rotation_labels = torch.LongTensor([0, 1, 2, 3]) return dict(img=img, rot_label=rotation_labels) def evaluate(self, results, logger=None): return NotImplemented
Read the Docs v: latest
Versions
latest
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.