Shortcuts

Source code for mmselfsup.datasets.deepcluster

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .base import BaseDataset
from .builder import DATASETS
from .utils import to_numpy


[docs]@DATASETS.register_module() class DeepClusterDataset(BaseDataset): """Dataset for DC and ODC. The dataset initializes clustering labels and assigns it during training. Args: data_source (dict): Data source defined in `mmselfsup.datasets.data_sources`. pipeline (list[dict]): A list of dict, where each element represents an operation defined in `mmselfsup.datasets.pipelines`. prefetch (bool, optional): Whether to prefetch data. Defaults to False. """ def __init__(self, data_source, pipeline, prefetch=False): super(DeepClusterDataset, self).__init__(data_source, pipeline, prefetch) # init clustering labels self.clustering_labels = [-1 for _ in range(len(self.data_source))] def __getitem__(self, idx): img = self.data_source.get_img(idx) img = self.pipeline(img) clustering_label = self.clustering_labels[idx] if self.prefetch: img = torch.from_numpy(to_numpy(img)) return dict(img=img, pseudo_label=clustering_label, idx=idx) def assign_labels(self, labels): assert len(self.clustering_labels) == len(labels), ( f'Inconsistent length of assigned labels, ' f'{len(self.clustering_labels)} vs {len(labels)}') self.clustering_labels = labels[:] def evaluate(self, results, logger=None): return NotImplemented
Read the Docs v: latest
Versions
latest
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.