Source code for decent_dp.optim

import math
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from typing import Any, List, Tuple, Union


def _get_param_groups(params: List[Tuple[str, Tensor]], weight_decay: float) -> list:
    """Get the parameters grouped by weight decay and no weight decay.

    Returns:
        dict: a dictionary with two keys, 'params' and 'params_no_decay'
    """
    params_no_decay = [x for n, x in params if not (('bn' in n) or ('bias' in n))]
    params_decay = [x for n, x in params if ('bn' in n) or ('bias' in n)]

    return [
        {'params': params_no_decay, 'weight_decay': 0.},
        {'params': params_decay, 'weight_decay': weight_decay}
    ]


[docs] def optim_fn_adam(params: List[Tuple[str, Tensor]], lr: float = 1e-3, beta1: float = 0.9, beta2: float = 0.999, weight_decay: float = 1. / 32768, eps: float = 1e-8) -> Optimizer: """An example of a function that creates an Adam optimizer with the given parameters and their names. To change the hyperparameters of the optimizer, you can wrap it with `functools.partial` and pass the new values. Returns: Optimizer: an Adam optimizer """ return torch.optim.Adam(_get_param_groups(params, weight_decay), lr=lr, betas=(beta1, beta2), eps=eps)
[docs] def optim_fn_adamw(params: List[Tuple[str, Tensor]], lr: float = 1e-3, beta1: float = 0.9, beta2: float = 0.999, weight_decay: float = 0.1, eps: float = 1e-8) -> Optimizer: """An example of a function that creates an AdamW optimizer with the given parameters and their names. To change the hyperparameters of the optimizer, you can wrap it with `functools.partial` and pass the new values. Returns: Optimizer: an AdamW optimizer """ return torch.optim.AdamW(_get_param_groups(params, weight_decay), lr=lr, betas=(beta1, beta2), eps=eps)
[docs] def optim_fn_accum_adam(params: List[Tuple[str, Tensor]], lr: float = 1e-3, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, weight_decay: float = 1. / 32768, accum_iter: int = 4) -> Optimizer: """An example of a function that creates an AccumAdam optimizer with the given parameters and their names. To change the hyperparameters of the optimizer, you can wrap it with `functools.partial` and pass the new values. Returns: Optimizer: an AccumAdam optimizer """ return AccumAdam(_get_param_groups(params, weight_decay), lr=lr, betas=(beta1, beta2), eps=eps, accum_iter=accum_iter)
[docs] def optim_fn_accum_adamw(params: List[Tuple[str, Tensor]], lr: float = 1e-3, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, weight_decay: float = 0.1, accum_iter: int = 4) -> Optimizer: """An example of a function that creates an AccumAdamW optimizer with the given parameters and their names. To change the hyperparameters of the optimizer, you can wrap it with `functools.partial` and pass the new values. Returns: Optimizer: an AccumAdamW optimizer """ return AccumAdamW(_get_param_groups(params, weight_decay), lr=lr, betas=(beta1, beta2), eps=eps, accum_iter=accum_iter)
[docs] def lr_scheduler_fn_cosine_with_warmup(optimizer: Optimizer, t_max: int, t_warmup: int, cosine_eta_min: float = 1e-6, warmup_decay: float = 0.01) -> LRScheduler: """An example of a function that creates a learning rate scheduler that combines a warmup and a cosine annealing schedule. Returns: LRScheduler: a learning rate scheduler with the linear warmup followed by the cosine annealing """ main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=cosine_eta_min) warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_decay, total_iters=t_warmup) return torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[t_warmup])
def accum_adamw_foreach(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor], exp_avg_sqs: List[torch.Tensor], accum_grads: List[torch.Tensor], state_steps: List[torch.Tensor], beta1: float, beta2: float, lr: Union[float, torch.Tensor], weight_decay: float, eps: float, accum_iter: int): """Optimized version of AccumAdamW optimizer using torch._foreach TODO: fused kernel """ torch._foreach_add_(state_steps, 1) if weight_decay != 0: torch._foreach_mul_(params, 1 - lr * weight_decay) step = state_steps[0].item() torch._foreach_add_(accum_grads, grads, alpha=1./accum_iter) _exp_avgs = torch._foreach_add(exp_avgs, grads, alpha=1-beta1) _exp_avg_sqs = torch._foreach_addcmul(exp_avg_sqs, grads, grads, value=1-beta2) bias_correction1 = 1 - beta1 ** ((step + accum_iter - 1) // accum_iter) bias_correction2 = 1 - beta2 ** ((step + accum_iter - 1) // accum_iter) step_size = lr / bias_correction1 bias_correction2_sqrt = math.sqrt(bias_correction2) torch._foreach_sqrt_(_exp_avg_sqs) torch._foreach_div_(_exp_avg_sqs, bias_correction2_sqrt) torch._foreach_add_(_exp_avg_sqs, eps) torch._foreach_addcdiv_(params, _exp_avgs, _exp_avg_sqs, value=-step_size) # type: ignore if step % accum_iter == 0: torch._foreach_add_(exp_avgs, accum_grads, alpha=1-beta1) torch._foreach_mul_(exp_avgs, beta1) torch._foreach_addcmul_(exp_avg_sqs, accum_grads, accum_grads, value=1-beta2) torch._foreach_mul_(exp_avg_sqs, beta2) torch._foreach_zero_(accum_grads)
[docs] class AccumAdamW(torch.optim.Optimizer): """AccumAdamW optimizer Args: params (Any): parameters list or groups lr (float, optional): base learning rate. Defaults to 1e-3. betas (Tuple[float, float], optional): beta1 and beta2. Defaults to (0.9, 0.999). eps (float, optional): epsilon. Defaults to 1e-8. weight_decay (float, optional): weight decay. Defaults to 0. accum_iter (int, optional): number of accumulation steps. Defaults to 4. should be scaling up with the number of workers. """ def __init__(self, params: Any, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, accum_iter: int = 4): defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, accum_iter=accum_iter ) super().__init__(params, defaults) def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps): for p in group['params']: if p.grad is not None: params_with_grad.append(p) grads.append(p.grad) state = self.state[p] if len(state) == 0: state['step'] = torch.tensor(0, dtype=torch.int64) state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['accum_grad'] = torch.zeros_like(p, memory_format=torch.preserve_format) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) accum_grads.append(state['accum_grad']) state_steps.append(state['step'])
[docs] @torch.no_grad() def step(self, closure=None): # type: ignore self._cuda_graph_capture_health_check() assert closure is None for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] accum_grads = [] state_steps = [] beta1, beta2 = group["betas"] self._init_group( group, params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps ) if len(state_steps) == 0: continue accum_adamw_foreach( params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps, beta1, beta2, group['lr'], group['weight_decay'], group['eps'], group['accum_iter'] )
def accum_adam_foreach(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor], exp_avg_sqs: List[torch.Tensor], accum_grads: List[torch.Tensor], state_steps: List[torch.Tensor], beta1: float, beta2: float, lr: Union[float, torch.Tensor], weight_decay: float, eps: float, accum_iter: int): """Optimized version of AccumAdam optimizer using torch._foreach TODO: write a fused kernel for this """ torch._foreach_add_(state_steps, 1) if weight_decay != 0: torch._foreach_add_(grads, params, alpha=weight_decay) step = state_steps[0].item() torch._foreach_add_(accum_grads, grads, alpha=1./accum_iter) _exp_avgs = torch._foreach_add(exp_avgs, grads, alpha=1-beta1) _exp_avg_sqs = torch._foreach_addcmul(exp_avg_sqs, grads, grads, value=1-beta2) bias_correction1 = 1 - beta1 ** ((step + accum_iter - 1) // accum_iter) bias_correction2 = 1 - beta2 ** ((step + accum_iter - 1) // accum_iter) step_size = lr / bias_correction1 bias_correction2_sqrt = math.sqrt(bias_correction2) torch._foreach_sqrt_(_exp_avg_sqs) torch._foreach_div_(_exp_avg_sqs, bias_correction2_sqrt) torch._foreach_add_(_exp_avg_sqs, eps) torch._foreach_addcdiv_(params, _exp_avgs, _exp_avg_sqs, value=-step_size) # type: ignore if step % accum_iter == 0: torch._foreach_add_(exp_avgs, accum_grads, alpha=1-beta1) torch._foreach_mul_(exp_avgs, beta1) torch._foreach_addcmul_(exp_avg_sqs, accum_grads, accum_grads, value=1-beta2) torch._foreach_mul_(exp_avg_sqs, beta2) torch._foreach_zero_(accum_grads)
[docs] class AccumAdam(torch.optim.Optimizer): """AccumAdamW optimizer Args: params (Any): parameters list or groups lr (float, optional): base learning rate. Defaults to 1e-3. betas (Tuple[float, float], optional): beta1 and beta2. Defaults to (0.9, 0.999). eps (float, optional): epsilon. Defaults to 1e-8. weight_decay (float, optional): weight decay. Defaults to 0. accum_iter (int, optional): number of accumulation steps. Defaults to 4. should be scaling up with the number of workers. """ def __init__(self, params: Any, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1. / 32768, accum_iter: int = 4): defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, accum_iter=accum_iter ) super().__init__(params, defaults) def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps): for p in group['params']: if p.grad is not None: params_with_grad.append(p) grads.append(p.grad) state = self.state[p] if len(state) == 0: state['step'] = torch.tensor(0, dtype=torch.int64) state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['accum_grad'] = torch.zeros_like(p, memory_format=torch.preserve_format) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) accum_grads.append(state['accum_grad']) state_steps.append(state['step'])
[docs] @torch.no_grad() def step(self, closure=None): # type: ignore self._cuda_graph_capture_health_check() assert closure is None, "Closure is not supported" for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] accum_grads = [] state_steps = [] beta1, beta2 = group["betas"] self._init_group( group, params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps ) if len(state_steps) == 0: continue accum_adam_foreach( params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps, beta1, beta2, group['lr'], group['weight_decay'], group['eps'], group['accum_iter'] )
__all__ = [ 'optim_fn_adam', 'optim_fn_adamw', 'optim_fn_accum_adam', 'optim_fn_accum_adamw', 'lr_scheduler_fn_cosine_with_warmup', 'AccumAdam', 'AccumAdamW' ]