Shortcuts

Source code for mmselfsup.core.hooks.momentum_update_hook

# Copyright (c) OpenMMLab. All rights reserved.
from math import cos, pi

from mmcv.parallel import is_module_wrapper
from mmcv.runner import HOOKS, Hook


[docs]@HOOKS.register_module(name=['BYOLHook', 'MomentumUpdateHook']) class MomentumUpdateHook(Hook): """Hook for updating momentum parameter, used by BYOL, MoCoV3, etc. This hook includes momentum adjustment following: .. math:: m = 1 - (1 - m_0) * (cos(pi * k / K) + 1) / 2 where :math:`k` is the current step, :math:`K` is the total steps. Args: end_momentum (float): The final momentum coefficient for the target network. Defaults to 1. update_interval (int, optional): The momentum update interval of the weights. Defaults to 1. """ def __init__(self, end_momentum=1., update_interval=1, **kwargs): self.end_momentum = end_momentum self.update_interval = update_interval def before_train_iter(self, runner): assert hasattr(runner.model.module, 'momentum'), \ "The runner must have attribute \"momentum\" in algorithms." assert hasattr(runner.model.module, 'base_momentum'), \ "The runner must have attribute \"base_momentum\" in algorithms." if self.every_n_iters(runner, self.update_interval): cur_iter = runner.iter max_iter = runner.max_iters base_m = runner.model.module.base_momentum m = self.end_momentum - (self.end_momentum - base_m) * ( cos(pi * cur_iter / float(max_iter)) + 1) / 2 runner.model.module.momentum = m def after_train_iter(self, runner): if self.every_n_iters(runner, self.update_interval): if is_module_wrapper(runner.model): runner.model.module.momentum_update() else: runner.model.momentum_update()
Read the Docs v: latest
Versions
latest
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.