Source code for decent_dp.ddp

import os
import copy
import math
from loguru import logger
from functools import partial
from typing import Callable, Iterator, List, Optional, Tuple, cast
import torch
from torch import Tensor
from torch.nn import Module
import torch.distributed as dist
from torch.optim import Optimizer
from torch.distributed import Work
from torch import GradScaler
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle
from torch.optim.lr_scheduler import LRScheduler
from .topo import TopologyReg, Topology


"""Data type for the optimizer function"""
OPTIM_FN_TYPE = Callable[[List[Tuple[str, Tensor]]], Optimizer]


"""Data type for the learning rate scheduler function"""
LR_SCHEDULER_FN_TYPE = Callable[[Optimizer], LRScheduler]


[docs] class DecentralizedDataParallel(Module): """Decentralized data parallel wrapper for PyTorch module 1. The wrapper places hooks during the backward pass to trace the order of used parameters in the first iteration, and \ 2. Split the parameters into buckets and create optimizers and LR schedulers for each bucket, \ Add hooks on the last parameter of each bucket to perform the bucket-wise update and communication, \ 3. During the backward passes in the training loop, the hooks are triggered to perform the bucket-wise update and communication :Warning: The wrapper currently does not support "channels_last" memory format :Warning: The wrapper assumes that the parameter will only be used once in the backward pass Args: model (Module): PyTorch module to be wrapped optim_fn (OPTIM_FN_TYPE): Function to create the optimizer, which takes a list of tuples of parameters and their names lr_scheduler_fn (Optional[LR_SCHEDULER_FN_TYPE], optional): Function to create the learning rate scheduler, \ which takes the optimizer as input. Defaults to None. topology (str, optional): Topology of the decentralized communication graph. Defaults to 'complete'. scaler (Optional[GradScaler], optional): Gradient scaler for mixed precision training. Defaults to None. grad_clip_norm (float, optional): Gradient clipping norm, set to 0.0 if no gradient clipping is applied. Defaults to 0.0. param_as_bucket_view (bool, optional): Whether to use the parameter as a view of part of the contiguous buffer. Defaults to True. sync_buffer_in_global_avg (bool, optional): Whether to synchronize the float buffers in the global average. Defaults to False. bucket_size_in_mb (int, optional): Size of the bucket in MB. Defaults to 25 MB. local_world_size (Optional[int], optional): Provide the local world size if not using the environment variable. Defaults to None. """ """Buffer data types that need to be synchronized in global average""" FLOAT_DTYPES = [torch.float16, torch.float32, torch.float64] def __init__(self, model: Module, optim_fn: OPTIM_FN_TYPE, lr_scheduler_fn: Optional[LR_SCHEDULER_FN_TYPE] = None, topology: str = 'complete', scaler: Optional[GradScaler] = None, grad_clip_norm: float = 0.0, param_as_bucket_view: bool = True, sync_buffer_in_global_avg: bool = False, bucket_size_in_mb: int = 25, _local_world_size: Optional[int] = None): super(DecentralizedDataParallel, self).__init__() assert dist.is_available() and dist.is_initialized(), 'Distributed environment is not initialized' self._model = model.cuda() if torch.cuda.is_available() else model self._optim_fn = optim_fn self._lr_schd_fn = lr_scheduler_fn self._scaler = scaler self._grad_clip_norm = grad_clip_norm self._param_as_bucket_view = param_as_bucket_view self._sync_buffer_in_global_avg = sync_buffer_in_global_avg self._bucket_size = bucket_size_in_mb * 1024 * 1024 self._local_world_size = _local_world_size if _local_world_size is not None else int(os.environ.get('LOCAL_WORLD_SIZE', 1)) # get the rank and world size self._rank = dist.get_rank() self._world_size = dist.get_world_size() # check if the model is with "channels_last" memory format if self._check_channels_last(): if self._rank == 0: logger.debug(f'The model is with "channels_last" memory format') if self._rank == 0: logger.debug(f'Initializing Decentralized Data Parallel') logger.debug(f'Rank: {self._rank}, Local World Size: {self._local_world_size}, World Size: {self._world_size}, Topology: {topology}') # model parameters self._params: List[Tensor] = list([x for _, x in self._model.named_parameters() if x.requires_grad]) self._param_names: List[str] = list([n for n, x in self._model.named_parameters() if x.requires_grad]) # trace hooks and traced parameter ids self._trace_hooks: List[RemovableHandle] = [] self._traced_param_ids: List[int] = [] self._step: int = 0 self._comm_ops: List[Optional[Work]] = [] self._ddp_hooks: List[RemovableHandle] = [] self._param_buckets: List[List[Tensor]] = [] self._param_blocks: List[Tensor] = [] self._comm_buffers: List[List[Tensor]] = [] self._comm_blocks: List[Tensor] = [] # Optimizer and LR scheduler self._optims: List[Optimizer] = [] self._lr_schedulers: List[Optional[LRScheduler]] = [] # initialize the topology self._topo: Topology = TopologyReg.registry[topology](self._local_world_size) # create hooks to trace the used parameters in backward self._create_trace_hooks() # sync the parameters at the start self._sync_at_start() # flag for gradient accumulation self._is_grad_accum_enable: bool = False # flag for initializing the parameters self._initialized: bool = False def _check_channels_last(self) -> bool: """Check if the model is with "channels_last" memory format Returns: bool: True if the model is with "channels_last" memory format """ if any([x.is_contiguous(memory_format=torch.channels_last) and (not x.is_contiguous()) for x in self._model.parameters() if len(x.shape) == 4]): return True return False def _create_trace_hooks(self): """Create hooks to trace the order of used parameters in backward pass """ for pid, param in enumerate(self._params): self._trace_hooks.append( param.register_post_accumulate_grad_hook( partial( lambda data, pid: self._trace_fn(data, pid), pid=pid ) ) ) @torch.no_grad() def _sync_at_start(self): """Broadcast the parameters of worker 0 to all other workers at the start """ for param in self._params: dist.broadcast(param, 0)
[docs] def set_accumulate_grad(self, enable: bool = True): """Set the gradient accumulation mode Args: enable (bool, optional): Whether to accumulate the gradients. Defaults to True. """ self._is_grad_accum_enable = enable
"""Hook functions""" @torch.no_grad() def _trace_fn(self, _: Tensor, pid: int): """Hook function to trace the order of used parameters in backward pass Args: _ (Tensor): corresponding tensor (not used) pid (int): parameter id Raises: AssertionError: The parameter is used more than once in the backward pass """ if self._is_grad_accum_enable: return assert not (pid in self._traced_param_ids), 'The parameter is used more than once in the backward pass' self._traced_param_ids.append(pid) @torch.no_grad() def _ddp_fn(self, _: Tensor, bucket_id: int): """Hook function to perform the bucket-wise update and communication Args: _ (Tensor): corresponding tensor (not used) bucket_id (int): bucket id """ # skip the update and communication if the model is accumulating gradients if self._is_grad_accum_enable: return # perform the bucket-wise update and communication when all gradients in the bucket are accumulated comm_op = self._comm_ops[bucket_id] if comm_op is not None: # wait for the communication from the last iteration comm_op.wait() self._comm_ops[bucket_id] = None # get the peers to communicate with in this iteration edge = self._topo.get_edge(self._step) weight = edge.weight # optionally call the pre_average_hook for optimizers using the communication information if hasattr(self._optims[bucket_id], 'pre_average_hook'): self._optims[bucket_id].pre_average_hook(edge, weight) # type: ignore # replace the local model with the mixed model if self._param_as_bucket_view: self._param_blocks[bucket_id].mul_(weight - (1 - weight) / (len(edge.ranks) - 1)) self._param_blocks[bucket_id].add_(self._comm_blocks[bucket_id]) else: torch._foreach_mul_(self._param_buckets[bucket_id], weight - (1 - weight) / (len(edge.ranks) - 1)) torch._foreach_add_(self._param_buckets[bucket_id], self._comm_buffers[bucket_id]) # perform local update if self._scaler: if self._grad_clip_norm > 0: self._scaler.unscale_(self._optims[bucket_id]) torch.nn.utils.clip_grad_norm_(self._param_buckets[bucket_id], self._grad_clip_norm) self._scaler.step(self._optims[bucket_id]) if bucket_id == len(self._param_buckets) - 1: self._scaler.update() else: if self._grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(self._param_buckets[bucket_id], self._grad_clip_norm) self._optims[bucket_id].step() self._optims[bucket_id].zero_grad() if self._lr_schedulers[bucket_id] is not None: scheduler = cast(LRScheduler, self._lr_schedulers[bucket_id]) scheduler.step() # launch the next communication after updating the weights if self._param_as_bucket_view: self._comm_blocks[bucket_id].copy_(self._param_blocks[bucket_id]) else: torch._foreach_copy_(self._comm_buffers[bucket_id], self._param_buckets[bucket_id]) edge = self._topo.get_edge(self._step + 1) weight = edge.weight self._comm_blocks[bucket_id].mul_((1 - weight) / (len(edge.ranks) - 1)) self._comm_ops[bucket_id] = dist.all_reduce( self._comm_blocks[bucket_id], op=dist.ReduceOp.SUM, group=edge.group, async_op=True ) @torch.no_grad() def _initialize_params(self): """Initialize the parameter buckets and communication buffers Raises: RuntimeError: Number/Order of elements in used parameters is different on different nodes """ # verify the number of elements and the order of the parameters on different nodes are the same verify = [[(i, self._params[i].numel()) for i in self._traced_param_ids]] result = [[(0, 0)]] if self._rank != 0 else verify dist.broadcast_object_list(result, src=0) if not all([x == y for x, y in zip(verify[0], result[0])]): raise RuntimeError('Number/Order of elements in used parameters is different on different nodes') # remove the trace hooks for hook in self._trace_hooks: hook.remove() del self._trace_hooks # split the parameters into roughly equal-size buckets, and register hooks on the last parameter of each bucket start = 0 size = 0 for i in range(len(self._traced_param_ids)): size += self._align(self._params[self._traced_param_ids[i]].numel()) * self._params[self._traced_param_ids[i]].element_size() if (size >= self._bucket_size) or (i == len(self._traced_param_ids) - 1): # register hooks on the last parameter of each bucket, passing the bucket id self._ddp_hooks.append( self._params[self._traced_param_ids[i]].register_post_accumulate_grad_hook( partial( lambda data, bucket_id: self._ddp_fn(data, bucket_id), bucket_id=len(self._ddp_hooks) ) ) ) self._param_buckets.append([self._params[j] for j in self._traced_param_ids[start:i+1]]) param_names = [self._param_names[j] for j in self._traced_param_ids[start:i+1]] # create optimizer and learning rate scheduler for parameters in each bucket self._optims.append(self._optim_fn(list(zip(param_names, self._param_buckets[-1])))) self._lr_schedulers.append(self._lr_schd_fn(self._optims[-1]) if self._lr_schd_fn is not None else None) size = 0 start = i + 1 size_dict = {} for i in range(len(self._param_buckets)): total_size = sum([self._align(p.numel()) for p in self._param_buckets[i]]) # make sure the total size is unique for each bucket \ # (not necessary, but make sure the communication operations are unique for each bucket with negligible overhead) while total_size in size_dict: total_size += 32 size_dict[total_size] = True # create the communication buffer for each bucket comm_block = torch.zeros(total_size, device=self._param_buckets[i][0].device, requires_grad=False, dtype=self._param_buckets[i][0].dtype) if self._param_as_bucket_view: # create contiguous blocks for each bucket, and let the parameters be views of the fragments of the block self._param_blocks.append(torch.zeros(total_size, device=self._param_buckets[i][0].device, requires_grad=True, dtype=self._param_buckets[i][0].dtype)) start = 0 for j in range(len(self._param_buckets[i])): size = self._param_buckets[i][j].numel() if (len(self._param_buckets[i][j].shape) == 4) and self._param_buckets[i][j].is_contiguous(memory_format=torch.channels_last) \ and (not self._param_buckets[i][j].is_contiguous()): # permute the tensor to the channels_last format self._param_blocks[-1].narrow(0, start, size).copy_(self._param_buckets[i][j].permute(0, 2, 3, 1).view(-1)) self._param_buckets[i][j].data = self._param_blocks[-1].narrow(0, start, size).view( (self._param_buckets[i][j].shape[0], self._param_buckets[i][j].shape[2], self._param_buckets[i][j].shape[3], self._param_buckets[i][j].shape[1]) ).permute(0, 3, 1, 2) assert self._param_buckets[i][j].is_contiguous(memory_format=torch.channels_last) assert not self._param_buckets[i][j].is_contiguous() else: # otherwise, copy the tensor directly assert self._param_buckets[i][j].is_contiguous() self._param_blocks[-1].narrow(0, start, size).copy_(self._param_buckets[i][j].view(-1)) self._param_buckets[i][j].data = self._param_blocks[-1].narrow(0, start, size).view_as(self._param_buckets[i][j]) start += self._align(size) self._comm_blocks.append(comm_block) start = 0 self._comm_buffers.append([]) for j in range(len(self._param_buckets[i])): size = self._param_buckets[i][j].numel() if (len(self._param_buckets[i][j].shape) == 4) and self._param_buckets[i][j].is_contiguous(memory_format=torch.channels_last) \ and (not self._param_buckets[i][j].is_contiguous()): # permute the tensor to the channels_last format self._comm_buffers[-1].append(comm_block.narrow(0, start, size).view( (self._param_buckets[i][j].shape[0], self._param_buckets[i][j].shape[2], self._param_buckets[i][j].shape[3], self._param_buckets[i][j].shape[1]) ).permute(0, 3, 1, 2)) else: self._comm_buffers[-1].append(comm_block.narrow(0, start, size).view_as(self._param_buckets[i][j])) start += self._align(size) # attach the communication buffer to the parameter for "pre_average_hook" in the optimizer if hasattr(self._optims[i], 'pre_average_hook'): setattr(self._param_buckets[i][j], 'comm_buffer', self._comm_buffers[-1][-1]) # initialize the communication buffer with the initial parameters torch._foreach_copy_(self._comm_buffers[-1], self._param_buckets[i]) self._comm_ops = [None] * len(self._param_buckets) def _align(self, size: int): """Align the size to 128-byte boundary """ return math.ceil(size / 32) * 32 """Delegation functions"""
[docs] def train(self, mode: bool = True): """Set the module in training mode Args: mode (bool, optional): Whether to set the module in training mode. Defaults to True. """ self._model.train(mode) return self
[docs] def eval(self): """Set the module in evaluation mode""" self._model.eval() return self
[docs] def forward(self, *args, **kwargs): """Forward pass of the model """ # lazy initialization at the second iteration if (self._step == 1) and (not self._initialized): self._initialized = True # initialize the parameters and communication buffers self._initialize_params() # manually trigger the communications for the first iteration only with torch.no_grad(): edge = self._topo.get_edge(self._step) weight = edge.weight for i in range(len(self._param_buckets)): # optionally call the pre_average_hook for optimizers using the communication information if hasattr(self._optims[i], 'pre_average_hook'): self._optims[i].pre_average_hook(edge, weight) # type: ignore # update parameters and launch the first communication if self._scaler: if self._grad_clip_norm > 0: self._scaler.unscale_(self._optims[i]) torch.nn.utils.clip_grad_norm_(self._param_buckets[i], self._grad_clip_norm) self._scaler.step(self._optims[i]) if i == len(self._param_buckets) - 1: self._scaler.update() # TODO: synchronize the scaler state across all workers? else: if self._grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(self._param_buckets[i], self._grad_clip_norm) self._optims[i].step() self._optims[i].zero_grad() if self._lr_schedulers[i] is not None: scheduler = cast(LRScheduler, self._lr_schedulers[i]) scheduler.step() # launch the first communication if self._param_as_bucket_view: self._comm_blocks[i].copy_(self._param_blocks[i]) else: torch._foreach_copy_(self._comm_buffers[i], self._param_buckets[i]) self._comm_blocks[i].mul_((1 - weight) / (len(edge.ranks) - 1)) comm_op = dist.all_reduce( self._comm_blocks[i], op=dist.ReduceOp.SUM, group=edge.group, async_op=True ) self._comm_ops[i] = comm_op # wait for the communication to finish to fully synchronize the workers assert comm_op is not None comm_op.wait() if self._model.training and (not self._is_grad_accum_enable): self._step += 1 with torch.autograd.profiler.record_function("DecentralizedDataParallel.forward"): output = self._model(*args, **kwargs) return output
[docs] def parameters(self, recurse: bool = True) -> Iterator[Parameter]: """Get the parameters of the model Args: recurse (bool, optional): Whether to get the parameters recursively. Defaults to True. Yields: Iterator[Parameter]: The iterator of the parameters """ yield from self._model.parameters(recurse)
[docs] def named_parameters(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Parameter]]: """Get the named parameters of the model """ return super().named_parameters(prefix, recurse, remove_duplicate)
"""Utility functions"""
[docs] @torch.no_grad() def global_avg(self): """Perform global average on the parameters (and buffers if sync_buffer_in_global_avg is True) The function is called at the end of the training loop to synchronize the parameters across all nodes for evaluation """ for op in self._comm_ops: if op is not None: op.wait() self._comm_ops = [None for _ in range(len(self._param_buckets))] if self._param_as_bucket_view: torch._foreach_div_(self._param_blocks, self._world_size) for i in range(len(self._param_blocks)): dist.all_reduce(self._param_blocks[i], op=dist.ReduceOp.SUM) else: torch._foreach_div_([x.data for x in self._params], self._world_size) for x in self._params: dist.all_reduce(x.data, op=dist.ReduceOp.SUM) if self._sync_buffer_in_global_avg: # globally average the float buffers (e.g. running mean and variance in batch normalization) for x in self._model.buffers(): if x.dtype in self.FLOAT_DTYPES: dist.all_reduce(x.data, op=dist.ReduceOp.SUM) x.data.div_(self._world_size)
__all__ = ['DecentralizedDataParallel', 'OPTIM_FN_TYPE', 'LR_SCHEDULER_FN_TYPE']