Skip to content

API Reference

This document provides detailed API references for all modules in Decent-DP.

decent_dp.ddp

The core module containing the DecentralizedDataParallel class, which is the main wrapper for enabling decentralized training on PyTorch models.

decent_dp.ddp

OPTIM_FN_TYPE = Callable[[List[Tuple[str, Tensor]]], Optimizer] module-attribute

Data type for the optimizer function

LR_SCHEDULER_FN_TYPE = Callable[[Optimizer], LRScheduler] module-attribute

Data type for the learning rate scheduler function

DecentralizedDataParallel

Bases: Module

Decentralized data parallel wrapper for PyTorch module

  1. The wrapper places hooks during the backward pass to trace the order of used parameters in the first iteration, and 2. Split the parameters into buckets and create optimizers and LR schedulers for each bucket, Add hooks on the last parameter of each bucket to perform the bucket-wise update and communication, 3. During the backward passes in the training loop, the hooks are triggered to perform the bucket-wise update and communication
Note

The wrapper currently does not support "channels_last" memory format.

Note

The wrapper assumes that the parameter will only be used once in the backward pass

Parameters:

Name Type Description Default
model Module

PyTorch module to be wrapped

required
optim_fn OPTIM_FN_TYPE

Function to create the optimizer, which takes a list of tuples of parameters and their names

required
lr_scheduler_fn Optional[LR_SCHEDULER_FN_TYPE]

Function to create the learning rate scheduler, which takes the optimizer as input. Defaults to None.

None
topology str

Topology of the decentralized communication graph. Defaults to 'complete'.

'complete'
scaler Optional[GradScaler]

Gradient scaler for mixed precision training. Defaults to None.

None
grad_clip_norm float

Gradient clipping norm, set to 0.0 if no gradient clipping is applied. Defaults to 0.0.

0.0
param_as_bucket_view bool

Whether to use the parameter as a view of part of the contiguous buffer. Defaults to True.

True
sync_buffer_in_global_avg bool

Whether to synchronize the float buffers in the global average. Defaults to False.

False
bucket_size_in_mb int

Size of the bucket in MB. Defaults to 25 MB.

25
_local_world_size Optional[int]

Provide the local world size and not using the environment variable. Defaults to None.

None
Source code in src/decent_dp/ddp.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
class DecentralizedDataParallel(Module):
    """Decentralized data parallel wrapper for PyTorch module

    1. The wrapper places hooks during the backward pass to trace the order of used parameters in the first iteration, and \
    2. Split the parameters into buckets and create optimizers and LR schedulers for each bucket, \
        Add hooks on the last parameter of each bucket to perform the bucket-wise update and communication, \
    3. During the backward passes in the training loop, the hooks are triggered to perform the bucket-wise update and communication

    Note:
        The wrapper currently does not support "channels_last" memory format.

    Note:
        The wrapper assumes that the parameter will only be used once in the backward pass

    Args:
        model (Module): PyTorch module to be wrapped
        optim_fn (OPTIM_FN_TYPE): Function to create the optimizer, which takes a list of tuples of parameters and their names
        lr_scheduler_fn (Optional[LR_SCHEDULER_FN_TYPE], optional): Function to create the learning rate scheduler, \
            which takes the optimizer as input. Defaults to None.
        topology (str, optional): Topology of the decentralized communication graph. Defaults to 'complete'.
        scaler (Optional[GradScaler], optional): Gradient scaler for mixed precision training. Defaults to None.
        grad_clip_norm (float, optional): Gradient clipping norm, set to 0.0 if no gradient clipping is applied. Defaults to 0.0.
        param_as_bucket_view (bool, optional): Whether to use the parameter as a view of part of the contiguous buffer. Defaults to True.
        sync_buffer_in_global_avg (bool, optional): Whether to synchronize the float buffers in the global average. Defaults to False.
        bucket_size_in_mb (int, optional): Size of the bucket in MB. Defaults to 25 MB.
        _local_world_size (Optional[int], optional): Provide the local world size and not using the environment variable. Defaults to None.
    """

    """Buffer data types that need to be synchronized in global average"""
    FLOAT_DTYPES = [torch.float16, torch.float32, torch.float64]

    def __init__(
        self,
        model: Module,
        optim_fn: OPTIM_FN_TYPE,
        lr_scheduler_fn: Optional[LR_SCHEDULER_FN_TYPE] = None,
        topology: str = "complete",
        scaler: Optional[GradScaler] = None,
        grad_clip_norm: float = 0.0,
        param_as_bucket_view: bool = True,
        sync_buffer_in_global_avg: bool = False,
        bucket_size_in_mb: int = 25,
        _local_world_size: Optional[int] = None,
    ):
        super(DecentralizedDataParallel, self).__init__()
        assert dist.is_available() and dist.is_initialized(), "Distributed environment is not initialized"

        self._model = model.cuda() if torch.cuda.is_available() else model
        self._optim_fn = optim_fn
        self._lr_schd_fn = lr_scheduler_fn
        self._scaler = scaler
        self._grad_clip_norm = grad_clip_norm
        self._param_as_bucket_view = param_as_bucket_view
        self._sync_buffer_in_global_avg = sync_buffer_in_global_avg
        self._bucket_size = bucket_size_in_mb * 1024 * 1024
        self._local_world_size = (
            _local_world_size if _local_world_size is not None else int(os.environ.get("LOCAL_WORLD_SIZE", 1))
        )

        # get the rank and world size
        self._rank = dist.get_rank()
        self._world_size = dist.get_world_size()

        # check if the model is with "channels_last" memory format
        if self._check_channels_last():
            if self._rank == 0:
                logger.debug('The model is with "channels_last" memory format')

        if self._rank == 0:
            logger.debug("Initializing Decentralized Data Parallel")
            logger.debug(
                f"Rank: {self._rank}, Local World Size: {self._local_world_size}, World Size: {self._world_size}, Topology: {topology}"
            )

        # model parameters
        self._params: List[Tensor] = list([x for _, x in self._model.named_parameters() if x.requires_grad])
        self._param_names: List[str] = list([n for n, x in self._model.named_parameters() if x.requires_grad])

        # trace hooks and traced parameter ids
        self._trace_hooks: List[RemovableHandle] = []
        self._traced_param_ids: List[int] = []

        self._step: int = 0
        self._comm_ops: List[Optional[Work]] = []

        self._ddp_hooks: List[RemovableHandle] = []
        self._param_buckets: List[List[Tensor]] = []
        self._param_blocks: List[Tensor] = []
        self._comm_buffers: List[List[Tensor]] = []
        self._comm_blocks: List[Tensor] = []

        # Optimizer and LR scheduler
        self._optims: List[Optimizer] = []
        self._lr_schedulers: List[Optional[LRScheduler]] = []

        # initialize the topology
        self._topo: Topology = TopologyReg.registry[topology](self._local_world_size)

        # create hooks to trace the used parameters in backward
        self._create_trace_hooks()

        # sync the parameters at the start
        self._sync_at_start()

        # flag for gradient accumulation
        self._is_grad_accum_enable: bool = False

        # flag for initializing the parameters
        self._initialized: bool = False

    def _check_channels_last(self) -> bool:
        """Check if the model is with "channels_last" memory format

        Returns:
            bool: True if the model is with "channels_last" memory format
        """
        if any(
            [
                x.is_contiguous(memory_format=torch.channels_last) and (not x.is_contiguous())
                for x in self._model.parameters()
                if len(x.shape) == 4
            ]
        ):
            return True
        return False

    def _create_trace_hooks(self):
        """Create hooks to trace the order of used parameters in backward pass"""
        for pid, param in enumerate(self._params):
            self._trace_hooks.append(
                param.register_post_accumulate_grad_hook(partial(lambda data, pid: self._trace_fn(data, pid), pid=pid))
            )

    @torch.no_grad()
    def _sync_at_start(self):
        """Broadcast the parameters of worker 0 to all other workers at the start"""
        for param in self._params:
            dist.broadcast(param, 0)

    def set_accumulate_grad(self, enable: bool = True):
        """Set the gradient accumulation mode

        Args:
            enable (bool, optional): Whether to accumulate the gradients. Defaults to True.
        """
        self._is_grad_accum_enable = enable

    """Hook functions"""

    @torch.no_grad()
    def _trace_fn(self, _: Tensor, pid: int):
        """Hook function to trace the order of used parameters in backward pass

        Args:
            _ (Tensor): corresponding tensor (not used)
            pid (int): parameter id

        Raises:
            AssertionError: The parameter is used more than once in the backward pass
        """
        if self._is_grad_accum_enable:
            return
        assert pid not in self._traced_param_ids, "The parameter is used more than once in the backward pass"
        self._traced_param_ids.append(pid)

    @torch.no_grad()
    def _ddp_fn(self, _: Tensor, bucket_id: int):
        """Hook function to perform the bucket-wise update and communication

        Args:
            _ (Tensor): corresponding tensor (not used)
            bucket_id (int): bucket id
        """

        # skip the update and communication if the model is accumulating gradients
        if self._is_grad_accum_enable:
            return

        # perform the bucket-wise update and communication when all gradients in the bucket are accumulated
        comm_op = self._comm_ops[bucket_id]
        if comm_op is not None:
            # wait for the communication from the last iteration
            comm_op.wait()
            self._comm_ops[bucket_id] = None

            # get the peers to communicate with in this iteration
            edge = self._topo.get_edge(self._step)
            weight = edge.weight

            # optionally call the pre_average_hook for optimizers using the communication information
            if hasattr(self._optims[bucket_id], "pre_average_hook"):
                self._optims[bucket_id].pre_average_hook(edge, weight)  # type: ignore

            # replace the local model with the mixed model
            if self._param_as_bucket_view:
                self._param_blocks[bucket_id].mul_(weight - (1 - weight) / (len(edge.ranks) - 1))
                self._param_blocks[bucket_id].add_(self._comm_blocks[bucket_id])
            else:
                torch._foreach_mul_(self._param_buckets[bucket_id], weight - (1 - weight) / (len(edge.ranks) - 1))
                torch._foreach_add_(self._param_buckets[bucket_id], self._comm_buffers[bucket_id])

        # perform local update
        if self._scaler:
            if self._grad_clip_norm > 0:
                self._scaler.unscale_(self._optims[bucket_id])
                torch.nn.utils.clip_grad_norm_(self._param_buckets[bucket_id], self._grad_clip_norm)
            self._scaler.step(self._optims[bucket_id])
            if bucket_id == len(self._param_buckets) - 1:
                self._scaler.update()
        else:
            if self._grad_clip_norm > 0:
                torch.nn.utils.clip_grad_norm_(self._param_buckets[bucket_id], self._grad_clip_norm)
            self._optims[bucket_id].step()
        self._optims[bucket_id].zero_grad()

        if self._lr_schedulers[bucket_id] is not None:
            scheduler = cast(LRScheduler, self._lr_schedulers[bucket_id])
            scheduler.step()

        # launch the next communication after updating the weights
        if self._param_as_bucket_view:
            self._comm_blocks[bucket_id].copy_(self._param_blocks[bucket_id])
        else:
            torch._foreach_copy_(self._comm_buffers[bucket_id], self._param_buckets[bucket_id])

        edge = self._topo.get_edge(self._step + 1)
        weight = edge.weight
        self._comm_blocks[bucket_id].mul_((1 - weight) / (len(edge.ranks) - 1))

        self._comm_ops[bucket_id] = dist.all_reduce(
            self._comm_blocks[bucket_id], op=dist.ReduceOp.SUM, group=edge.group, async_op=True
        )

    @torch.no_grad()
    def _initialize_params(self):
        """Initialize the parameter buckets and communication buffers

        Raises:
            RuntimeError: Number/Order of elements in used parameters is different on different nodes
        """

        # verify the number of elements and the order of the parameters on different nodes are the same
        verify = [[(i, self._params[i].numel()) for i in self._traced_param_ids]]
        result = [[(0, 0)]] if self._rank != 0 else verify
        dist.broadcast_object_list(result, src=0)
        if not all([x == y for x, y in zip(verify[0], result[0])]):
            raise RuntimeError("Number/Order of elements in used parameters is different on different nodes")

        # remove the trace hooks
        for hook in self._trace_hooks:
            hook.remove()
        del self._trace_hooks

        # split the parameters into roughly equal-size buckets, and register hooks on the last parameter of each bucket
        start = 0
        size = 0
        for i in range(len(self._traced_param_ids)):
            size += (
                self._align(self._params[self._traced_param_ids[i]].numel())
                * self._params[self._traced_param_ids[i]].element_size()
            )
            if (size >= self._bucket_size) or (i == len(self._traced_param_ids) - 1):
                # register hooks on the last parameter of each bucket, passing the bucket id
                self._ddp_hooks.append(
                    self._params[self._traced_param_ids[i]].register_post_accumulate_grad_hook(
                        partial(lambda data, bucket_id: self._ddp_fn(data, bucket_id), bucket_id=len(self._ddp_hooks))
                    )
                )
                self._param_buckets.append([self._params[j] for j in self._traced_param_ids[start : i + 1]])
                param_names = [self._param_names[j] for j in self._traced_param_ids[start : i + 1]]

                # create optimizer and learning rate scheduler for parameters in each bucket
                self._optims.append(self._optim_fn(list(zip(param_names, self._param_buckets[-1]))))
                self._lr_schedulers.append(self._lr_schd_fn(self._optims[-1]) if self._lr_schd_fn is not None else None)
                size = 0
                start = i + 1

        size_dict = {}

        for i in range(len(self._param_buckets)):
            total_size = sum([self._align(p.numel()) for p in self._param_buckets[i]])

            # make sure the total size is unique for each bucket \
            # (not necessary, but make sure the communication operations are unique for each bucket with negligible overhead)
            while total_size in size_dict:
                total_size += 32
            size_dict[total_size] = True

            # create the communication buffer for each bucket
            comm_block = torch.zeros(
                total_size,
                device=self._param_buckets[i][0].device,
                requires_grad=False,
                dtype=self._param_buckets[i][0].dtype,
            )

            if self._param_as_bucket_view:
                # create contiguous blocks for each bucket, and let the parameters be views of the fragments of the block
                self._param_blocks.append(
                    torch.zeros(
                        total_size,
                        device=self._param_buckets[i][0].device,
                        requires_grad=True,
                        dtype=self._param_buckets[i][0].dtype,
                    )
                )
                start = 0
                for j in range(len(self._param_buckets[i])):
                    size = self._param_buckets[i][j].numel()
                    if (
                        (len(self._param_buckets[i][j].shape) == 4)
                        and self._param_buckets[i][j].is_contiguous(memory_format=torch.channels_last)
                        and (not self._param_buckets[i][j].is_contiguous())
                    ):
                        # permute the tensor to the channels_last format
                        self._param_blocks[-1].narrow(0, start, size).copy_(
                            self._param_buckets[i][j].permute(0, 2, 3, 1).view(-1)
                        )
                        self._param_buckets[i][j].data = (
                            self._param_blocks[-1]
                            .narrow(0, start, size)
                            .view(
                                (
                                    self._param_buckets[i][j].shape[0],
                                    self._param_buckets[i][j].shape[2],
                                    self._param_buckets[i][j].shape[3],
                                    self._param_buckets[i][j].shape[1],
                                )
                            )
                            .permute(0, 3, 1, 2)
                        )
                        assert self._param_buckets[i][j].is_contiguous(memory_format=torch.channels_last)
                        assert not self._param_buckets[i][j].is_contiguous()
                    else:
                        # otherwise, copy the tensor directly
                        assert self._param_buckets[i][j].is_contiguous()
                        self._param_blocks[-1].narrow(0, start, size).copy_(self._param_buckets[i][j].view(-1))
                        self._param_buckets[i][j].data = (
                            self._param_blocks[-1].narrow(0, start, size).view_as(self._param_buckets[i][j])
                        )
                    start += self._align(size)

            self._comm_blocks.append(comm_block)
            start = 0
            self._comm_buffers.append([])
            for j in range(len(self._param_buckets[i])):
                size = self._param_buckets[i][j].numel()
                if (
                    (len(self._param_buckets[i][j].shape) == 4)
                    and self._param_buckets[i][j].is_contiguous(memory_format=torch.channels_last)
                    and (not self._param_buckets[i][j].is_contiguous())
                ):
                    # permute the tensor to the channels_last format
                    self._comm_buffers[-1].append(
                        comm_block.narrow(0, start, size)
                        .view(
                            (
                                self._param_buckets[i][j].shape[0],
                                self._param_buckets[i][j].shape[2],
                                self._param_buckets[i][j].shape[3],
                                self._param_buckets[i][j].shape[1],
                            )
                        )
                        .permute(0, 3, 1, 2)
                    )
                else:
                    self._comm_buffers[-1].append(comm_block.narrow(0, start, size).view_as(self._param_buckets[i][j]))
                start += self._align(size)

                # attach the communication buffer to the parameter for "pre_average_hook" in the optimizer
                if hasattr(self._optims[i], "pre_average_hook"):
                    setattr(self._param_buckets[i][j], "comm_buffer", self._comm_buffers[-1][-1])

            # initialize the communication buffer with the initial parameters
            torch._foreach_copy_(self._comm_buffers[-1], self._param_buckets[i])

        self._comm_ops = [None] * len(self._param_buckets)

    def _align(self, size: int):
        """Align the size to 128-byte boundary"""
        return math.ceil(size / 32) * 32

    """Delegation functions"""

    def train(self, mode: bool = True):
        """Set the module in training mode

        Args:
            mode (bool, optional): Whether to set the module in training mode. Defaults to True.
        """
        self._model.train(mode)
        return self

    def eval(self):
        """Set the module in evaluation mode"""
        self._model.eval()
        return self

    def forward(self, *args, **kwargs):
        """Forward pass of the model"""
        # lazy initialization at the second iteration
        if (self._step == 1) and (not self._initialized):
            self._initialized = True
            # initialize the parameters and communication buffers
            self._initialize_params()

            # manually trigger the communications for the first iteration only
            with torch.no_grad():
                edge = self._topo.get_edge(self._step)
                weight = edge.weight
                for i in range(len(self._param_buckets)):
                    # optionally call the pre_average_hook for optimizers using the communication information
                    if hasattr(self._optims[i], "pre_average_hook"):
                        self._optims[i].pre_average_hook(edge, weight)  # type: ignore

                    # update parameters and launch the first communication
                    if self._scaler:
                        if self._grad_clip_norm > 0:
                            self._scaler.unscale_(self._optims[i])
                            torch.nn.utils.clip_grad_norm_(self._param_buckets[i], self._grad_clip_norm)
                        self._scaler.step(self._optims[i])
                        if i == len(self._param_buckets) - 1:
                            self._scaler.update()
                            # TODO: synchronize the scaler state across all workers?
                    else:
                        if self._grad_clip_norm > 0:
                            torch.nn.utils.clip_grad_norm_(self._param_buckets[i], self._grad_clip_norm)
                        self._optims[i].step()
                    self._optims[i].zero_grad()
                    if self._lr_schedulers[i] is not None:
                        scheduler = cast(LRScheduler, self._lr_schedulers[i])
                        scheduler.step()

                    # launch the first communication
                    if self._param_as_bucket_view:
                        self._comm_blocks[i].copy_(self._param_blocks[i])
                    else:
                        torch._foreach_copy_(self._comm_buffers[i], self._param_buckets[i])

                    self._comm_blocks[i].mul_((1 - weight) / (len(edge.ranks) - 1))
                    comm_op = dist.all_reduce(
                        self._comm_blocks[i], op=dist.ReduceOp.SUM, group=edge.group, async_op=True
                    )
                    self._comm_ops[i] = comm_op
                    # wait for the communication to finish to fully synchronize the workers
                    assert comm_op is not None
                    comm_op.wait()

        if self._model.training and (not self._is_grad_accum_enable):
            self._step += 1

        with torch.autograd.profiler.record_function("DecentralizedDataParallel.forward"):
            output = self._model(*args, **kwargs)
            return output

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Get the parameters of the model

        Args:
            recurse (bool, optional): Whether to get the parameters recursively. Defaults to True.

        Yields:
            Iterator[Parameter]: The iterator of the parameters
        """
        yield from self._model.parameters(recurse)

    def named_parameters(
        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
    ) -> Iterator[Tuple[str, Parameter]]:
        """Get the named parameters of the model"""
        return super().named_parameters(prefix, recurse, remove_duplicate)

    """Utility functions"""

    @torch.no_grad()
    def global_avg(self):
        """Perform global average on the parameters (and buffers if sync_buffer_in_global_avg is True)
        The function is called at the end of the training loop to synchronize the parameters across all nodes for evaluation
        """
        for op in self._comm_ops:
            if op is not None:
                op.wait()
        self._comm_ops = [None for _ in range(len(self._param_buckets))]

        if self._param_as_bucket_view:
            torch._foreach_div_(self._param_blocks, self._world_size)
            for i in range(len(self._param_blocks)):
                dist.all_reduce(self._param_blocks[i], op=dist.ReduceOp.SUM)
        else:
            torch._foreach_div_([x.data for x in self._params], self._world_size)
            for x in self._params:
                dist.all_reduce(x.data, op=dist.ReduceOp.SUM)

        if self._sync_buffer_in_global_avg:
            # globally average the float buffers (e.g. running mean and variance in batch normalization)
            for x in self._model.buffers():
                if x.dtype in self.FLOAT_DTYPES:
                    dist.all_reduce(x.data, op=dist.ReduceOp.SUM)
                    x.data.div_(self._world_size)

set_accumulate_grad(enable=True)

Set the gradient accumulation mode

Parameters:

Name Type Description Default
enable bool

Whether to accumulate the gradients. Defaults to True.

True
Source code in src/decent_dp/ddp.py
def set_accumulate_grad(self, enable: bool = True):
    """Set the gradient accumulation mode

    Args:
        enable (bool, optional): Whether to accumulate the gradients. Defaults to True.
    """
    self._is_grad_accum_enable = enable

train(mode=True)

Set the module in training mode

Parameters:

Name Type Description Default
mode bool

Whether to set the module in training mode. Defaults to True.

True
Source code in src/decent_dp/ddp.py
def train(self, mode: bool = True):
    """Set the module in training mode

    Args:
        mode (bool, optional): Whether to set the module in training mode. Defaults to True.
    """
    self._model.train(mode)
    return self

eval()

Set the module in evaluation mode

Source code in src/decent_dp/ddp.py
def eval(self):
    """Set the module in evaluation mode"""
    self._model.eval()
    return self

forward(*args, **kwargs)

Forward pass of the model

Source code in src/decent_dp/ddp.py
def forward(self, *args, **kwargs):
    """Forward pass of the model"""
    # lazy initialization at the second iteration
    if (self._step == 1) and (not self._initialized):
        self._initialized = True
        # initialize the parameters and communication buffers
        self._initialize_params()

        # manually trigger the communications for the first iteration only
        with torch.no_grad():
            edge = self._topo.get_edge(self._step)
            weight = edge.weight
            for i in range(len(self._param_buckets)):
                # optionally call the pre_average_hook for optimizers using the communication information
                if hasattr(self._optims[i], "pre_average_hook"):
                    self._optims[i].pre_average_hook(edge, weight)  # type: ignore

                # update parameters and launch the first communication
                if self._scaler:
                    if self._grad_clip_norm > 0:
                        self._scaler.unscale_(self._optims[i])
                        torch.nn.utils.clip_grad_norm_(self._param_buckets[i], self._grad_clip_norm)
                    self._scaler.step(self._optims[i])
                    if i == len(self._param_buckets) - 1:
                        self._scaler.update()
                        # TODO: synchronize the scaler state across all workers?
                else:
                    if self._grad_clip_norm > 0:
                        torch.nn.utils.clip_grad_norm_(self._param_buckets[i], self._grad_clip_norm)
                    self._optims[i].step()
                self._optims[i].zero_grad()
                if self._lr_schedulers[i] is not None:
                    scheduler = cast(LRScheduler, self._lr_schedulers[i])
                    scheduler.step()

                # launch the first communication
                if self._param_as_bucket_view:
                    self._comm_blocks[i].copy_(self._param_blocks[i])
                else:
                    torch._foreach_copy_(self._comm_buffers[i], self._param_buckets[i])

                self._comm_blocks[i].mul_((1 - weight) / (len(edge.ranks) - 1))
                comm_op = dist.all_reduce(
                    self._comm_blocks[i], op=dist.ReduceOp.SUM, group=edge.group, async_op=True
                )
                self._comm_ops[i] = comm_op
                # wait for the communication to finish to fully synchronize the workers
                assert comm_op is not None
                comm_op.wait()

    if self._model.training and (not self._is_grad_accum_enable):
        self._step += 1

    with torch.autograd.profiler.record_function("DecentralizedDataParallel.forward"):
        output = self._model(*args, **kwargs)
        return output

parameters(recurse=True)

Get the parameters of the model

Parameters:

Name Type Description Default
recurse bool

Whether to get the parameters recursively. Defaults to True.

True

Yields:

Type Description
Parameter

Iterator[Parameter]: The iterator of the parameters

Source code in src/decent_dp/ddp.py
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
    """Get the parameters of the model

    Args:
        recurse (bool, optional): Whether to get the parameters recursively. Defaults to True.

    Yields:
        Iterator[Parameter]: The iterator of the parameters
    """
    yield from self._model.parameters(recurse)

named_parameters(prefix='', recurse=True, remove_duplicate=True)

Get the named parameters of the model

Source code in src/decent_dp/ddp.py
def named_parameters(
    self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Parameter]]:
    """Get the named parameters of the model"""
    return super().named_parameters(prefix, recurse, remove_duplicate)

global_avg()

Perform global average on the parameters (and buffers if sync_buffer_in_global_avg is True) The function is called at the end of the training loop to synchronize the parameters across all nodes for evaluation

Source code in src/decent_dp/ddp.py
@torch.no_grad()
def global_avg(self):
    """Perform global average on the parameters (and buffers if sync_buffer_in_global_avg is True)
    The function is called at the end of the training loop to synchronize the parameters across all nodes for evaluation
    """
    for op in self._comm_ops:
        if op is not None:
            op.wait()
    self._comm_ops = [None for _ in range(len(self._param_buckets))]

    if self._param_as_bucket_view:
        torch._foreach_div_(self._param_blocks, self._world_size)
        for i in range(len(self._param_blocks)):
            dist.all_reduce(self._param_blocks[i], op=dist.ReduceOp.SUM)
    else:
        torch._foreach_div_([x.data for x in self._params], self._world_size)
        for x in self._params:
            dist.all_reduce(x.data, op=dist.ReduceOp.SUM)

    if self._sync_buffer_in_global_avg:
        # globally average the float buffers (e.g. running mean and variance in batch normalization)
        for x in self._model.buffers():
            if x.dtype in self.FLOAT_DTYPES:
                dist.all_reduce(x.data, op=dist.ReduceOp.SUM)
                x.data.div_(self._world_size)

decent_dp.optim

This module provides optimizer functions and custom optimizers designed specifically for decentralized training scenarios.

decent_dp.optim

AccumAdamW

Bases: Optimizer

AccumAdamW optimizer

Parameters:

Name Type Description Default
params Any

parameters list or groups

required
lr float

base learning rate. Defaults to 1e-3.

0.001
betas Tuple[float, float]

beta1 and beta2. Defaults to (0.9, 0.999).

(0.9, 0.999)
eps float

epsilon. Defaults to 1e-8.

1e-08
weight_decay float

weight decay. Defaults to 0.

0
accum_iter int

number of accumulation steps. Defaults to 4. should be scaling up with the number of workers.

4
Source code in src/decent_dp/optim.py
class AccumAdamW(torch.optim.Optimizer):
    """AccumAdamW optimizer

    Args:
        params (Any): parameters list or groups
        lr (float, optional): base learning rate. Defaults to 1e-3.
        betas (Tuple[float, float], optional): beta1 and beta2. Defaults to (0.9, 0.999).
        eps (float, optional): epsilon. Defaults to 1e-8.
        weight_decay (float, optional): weight decay. Defaults to 0.
        accum_iter (int, optional): number of accumulation steps. Defaults to 4. should be scaling up with the number of workers.
    """

    def __init__(
        self,
        params: Any,
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0,
        accum_iter: int = 4,
    ):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, accum_iter=accum_iter)
        super().__init__(params, defaults)

    def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps):
        for p in group["params"]:
            if p.grad is not None:
                params_with_grad.append(p)
                grads.append(p.grad)
                state = self.state[p]
                if len(state) == 0:
                    state["step"] = torch.tensor(0, dtype=torch.int64)
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["accum_grad"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sq"])
                accum_grads.append(state["accum_grad"])
                state_steps.append(state["step"])

    @torch.no_grad()
    def step(self, closure=None):  # type: ignore
        self._cuda_graph_capture_health_check()
        assert closure is None

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            accum_grads = []
            state_steps = []
            beta1, beta2 = group["betas"]

            self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps)

            if len(state_steps) == 0:
                continue

            accum_adamw_foreach(
                params_with_grad,
                grads,
                exp_avgs,
                exp_avg_sqs,
                accum_grads,
                state_steps,
                beta1,
                beta2,
                group["lr"],
                group["weight_decay"],
                group["eps"],
                group["accum_iter"],
            )

AccumAdam

Bases: Optimizer

AccumAdamW optimizer

Parameters:

Name Type Description Default
params Any

parameters list or groups

required
lr float

base learning rate. Defaults to 1e-3.

0.001
betas Tuple[float, float]

beta1 and beta2. Defaults to (0.9, 0.999).

(0.9, 0.999)
eps float

epsilon. Defaults to 1e-8.

1e-08
weight_decay float

weight decay. Defaults to 0.

0.0
accum_iter int

number of accumulation steps. Defaults to 4. should be scaling up with the number of workers.

4
Source code in src/decent_dp/optim.py
class AccumAdam(torch.optim.Optimizer):
    """AccumAdamW optimizer

    Args:
        params (Any): parameters list or groups
        lr (float, optional): base learning rate. Defaults to 1e-3.
        betas (Tuple[float, float], optional): beta1 and beta2. Defaults to (0.9, 0.999).
        eps (float, optional): epsilon. Defaults to 1e-8.
        weight_decay (float, optional): weight decay. Defaults to 0.
        accum_iter (int, optional): number of accumulation steps. Defaults to 4. should be scaling up with the number of workers.
    """

    def __init__(
        self,
        params: Any,
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0.0,
        accum_iter: int = 4,
    ):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, accum_iter=accum_iter)
        super().__init__(params, defaults)

    def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps):
        for p in group["params"]:
            if p.grad is not None:
                params_with_grad.append(p)
                grads.append(p.grad)
                state = self.state[p]
                if len(state) == 0:
                    state["step"] = torch.tensor(0, dtype=torch.int64)
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["accum_grad"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sq"])
                accum_grads.append(state["accum_grad"])
                state_steps.append(state["step"])

    @torch.no_grad()
    def step(self, closure=None):  # type: ignore
        self._cuda_graph_capture_health_check()
        assert closure is None, "Closure is not supported"

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            accum_grads = []
            state_steps = []
            beta1, beta2 = group["betas"]

            self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps)

            if len(state_steps) == 0:
                continue

            accum_adam_foreach(
                params_with_grad,
                grads,
                exp_avgs,
                exp_avg_sqs,
                accum_grads,
                state_steps,
                beta1,
                beta2,
                group["lr"],
                group["weight_decay"],
                group["eps"],
                group["accum_iter"],
            )

optim_fn_adam(params, lr=0.001, beta1=0.9, beta2=0.999, weight_decay=1.0 / 32768, eps=1e-08)

An example of a function that creates an Adam optimizer with the given parameters and their names. To change the hyperparameters of the optimizer, you can wrap it with functools.partial and pass the new values.

Returns:

Name Type Description
Optimizer Optimizer

an Adam optimizer

Source code in src/decent_dp/optim.py
def optim_fn_adam(
    params: List[Tuple[str, Tensor]],
    lr: float = 1e-3,
    beta1: float = 0.9,
    beta2: float = 0.999,
    weight_decay: float = 1.0 / 32768,
    eps: float = 1e-8,
) -> Optimizer:
    """An example of a function that creates an Adam optimizer with the given parameters and their names.
        To change the hyperparameters of the optimizer, you can wrap it with `functools.partial` and pass the new values.

    Returns:
        Optimizer: an Adam optimizer
    """
    return torch.optim.Adam(_get_param_groups(params, weight_decay), lr=lr, betas=(beta1, beta2), eps=eps)

optim_fn_adamw(params, lr=0.001, beta1=0.9, beta2=0.999, weight_decay=0.1, eps=1e-08)

An example of a function that creates an AdamW optimizer with the given parameters and their names. To change the hyperparameters of the optimizer, you can wrap it with functools.partial and pass the new values.

Returns:

Name Type Description
Optimizer Optimizer

an AdamW optimizer

Source code in src/decent_dp/optim.py
def optim_fn_adamw(
    params: List[Tuple[str, Tensor]],
    lr: float = 1e-3,
    beta1: float = 0.9,
    beta2: float = 0.999,
    weight_decay: float = 0.1,
    eps: float = 1e-8,
) -> Optimizer:
    """An example of a function that creates an AdamW optimizer with the given parameters and their names.
        To change the hyperparameters of the optimizer, you can wrap it with `functools.partial` and pass the new values.

    Returns:
        Optimizer: an AdamW optimizer
    """
    return torch.optim.AdamW(_get_param_groups(params, weight_decay), lr=lr, betas=(beta1, beta2), eps=eps)

optim_fn_accum_adam(params, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=1.0 / 32768, accum_iter=4)

An example of a function that creates an AccumAdam optimizer with the given parameters and their names. To change the hyperparameters of the optimizer, you can wrap it with functools.partial and pass the new values.

Returns:

Name Type Description
Optimizer Optimizer

an AccumAdam optimizer

Source code in src/decent_dp/optim.py
def optim_fn_accum_adam(
    params: List[Tuple[str, Tensor]],
    lr: float = 1e-3,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    weight_decay: float = 1.0 / 32768,
    accum_iter: int = 4,
) -> Optimizer:
    """An example of a function that creates an AccumAdam optimizer with the given parameters and their names.
        To change the hyperparameters of the optimizer, you can wrap it with `functools.partial` and pass the new values.

    Returns:
        Optimizer: an AccumAdam optimizer
    """
    return AccumAdam(
        _get_param_groups(params, weight_decay), lr=lr, betas=(beta1, beta2), eps=eps, accum_iter=accum_iter
    )

optim_fn_accum_adamw(params, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.1, accum_iter=4)

An example of a function that creates an AccumAdamW optimizer with the given parameters and their names. To change the hyperparameters of the optimizer, you can wrap it with functools.partial and pass the new values.

Returns:

Name Type Description
Optimizer Optimizer

an AccumAdamW optimizer

Source code in src/decent_dp/optim.py
def optim_fn_accum_adamw(
    params: List[Tuple[str, Tensor]],
    lr: float = 1e-3,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    weight_decay: float = 0.1,
    accum_iter: int = 4,
) -> Optimizer:
    """An example of a function that creates an AccumAdamW optimizer with the given parameters and their names.
        To change the hyperparameters of the optimizer, you can wrap it with `functools.partial` and pass the new values.

    Returns:
        Optimizer: an AccumAdamW optimizer
    """
    return AccumAdamW(
        _get_param_groups(params, weight_decay), lr=lr, betas=(beta1, beta2), eps=eps, accum_iter=accum_iter
    )

lr_scheduler_fn_cosine_with_warmup(optimizer, t_max, t_warmup, cosine_eta_min=1e-06, warmup_decay=0.01)

An example of a function that creates a learning rate scheduler that combines a warmup and a cosine annealing schedule.

Returns:

Name Type Description
LRScheduler LRScheduler

a learning rate scheduler with the linear warmup followed by the cosine annealing

Source code in src/decent_dp/optim.py
def lr_scheduler_fn_cosine_with_warmup(
    optimizer: Optimizer, t_max: int, t_warmup: int, cosine_eta_min: float = 1e-6, warmup_decay: float = 0.01
) -> LRScheduler:
    """An example of a function that creates a learning rate scheduler that combines a warmup and a cosine annealing schedule.

    Returns:
        LRScheduler: a learning rate scheduler with the linear warmup followed by the cosine annealing
    """
    main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=cosine_eta_min)
    warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_decay, total_iters=t_warmup)
    return torch.optim.lr_scheduler.SequentialLR(
        optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[t_warmup]
    )

accum_adamw_foreach(params, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps, beta1, beta2, lr, weight_decay, eps, accum_iter)

Optimized version of AccumAdamW optimizer using torch._foreach TODO: fused kernel

Source code in src/decent_dp/optim.py
def accum_adamw_foreach(
    params: List[torch.Tensor],
    grads: List[torch.Tensor],
    exp_avgs: List[torch.Tensor],
    exp_avg_sqs: List[torch.Tensor],
    accum_grads: List[torch.Tensor],
    state_steps: List[torch.Tensor],
    beta1: float,
    beta2: float,
    lr: Union[float, torch.Tensor],
    weight_decay: float,
    eps: float,
    accum_iter: int,
):
    """Optimized version of AccumAdamW optimizer using torch._foreach
    TODO: fused kernel
    """

    torch._foreach_add_(state_steps, 1)
    if weight_decay != 0:
        torch._foreach_mul_(params, 1 - lr * weight_decay)

    step = state_steps[0].item()
    torch._foreach_add_(accum_grads, grads, alpha=1.0 / accum_iter)

    _exp_avgs = torch._foreach_add(exp_avgs, grads, alpha=1 - beta1)
    _exp_avg_sqs = torch._foreach_addcmul(exp_avg_sqs, grads, grads, value=1 - beta2)

    bias_correction1 = 1 - beta1 ** ((step + accum_iter - 1) // accum_iter)
    bias_correction2 = 1 - beta2 ** ((step + accum_iter - 1) // accum_iter)
    step_size = lr / bias_correction1
    bias_correction2_sqrt = math.sqrt(bias_correction2)

    torch._foreach_sqrt_(_exp_avg_sqs)
    torch._foreach_div_(_exp_avg_sqs, bias_correction2_sqrt)
    torch._foreach_add_(_exp_avg_sqs, eps)
    torch._foreach_addcdiv_(params, _exp_avgs, _exp_avg_sqs, value=-step_size)  # type: ignore

    if step % accum_iter == 0:
        torch._foreach_add_(exp_avgs, accum_grads, alpha=1 - beta1)
        torch._foreach_mul_(exp_avgs, beta1)
        torch._foreach_addcmul_(exp_avg_sqs, accum_grads, accum_grads, value=1 - beta2)
        torch._foreach_mul_(exp_avg_sqs, beta2)
        torch._foreach_zero_(accum_grads)

accum_adam_foreach(params, grads, exp_avgs, exp_avg_sqs, accum_grads, state_steps, beta1, beta2, lr, weight_decay, eps, accum_iter)

Optimized version of AccumAdam optimizer using torch._foreach TODO: write a fused kernel for this

Source code in src/decent_dp/optim.py
def accum_adam_foreach(
    params: List[torch.Tensor],
    grads: List[torch.Tensor],
    exp_avgs: List[torch.Tensor],
    exp_avg_sqs: List[torch.Tensor],
    accum_grads: List[torch.Tensor],
    state_steps: List[torch.Tensor],
    beta1: float,
    beta2: float,
    lr: Union[float, torch.Tensor],
    weight_decay: float,
    eps: float,
    accum_iter: int,
):
    """Optimized version of AccumAdam optimizer using torch._foreach
    TODO: write a fused kernel for this
    """
    torch._foreach_add_(state_steps, 1)
    if weight_decay != 0:
        torch._foreach_add_(grads, params, alpha=weight_decay)

    step = state_steps[0].item()
    torch._foreach_add_(accum_grads, grads, alpha=1.0 / accum_iter)

    _exp_avgs = torch._foreach_add(exp_avgs, grads, alpha=1 - beta1)
    _exp_avg_sqs = torch._foreach_addcmul(exp_avg_sqs, grads, grads, value=1 - beta2)

    bias_correction1 = 1 - beta1 ** ((step + accum_iter - 1) // accum_iter)
    bias_correction2 = 1 - beta2 ** ((step + accum_iter - 1) // accum_iter)
    step_size = lr / bias_correction1
    bias_correction2_sqrt = math.sqrt(bias_correction2)

    torch._foreach_sqrt_(_exp_avg_sqs)
    torch._foreach_div_(_exp_avg_sqs, bias_correction2_sqrt)
    torch._foreach_add_(_exp_avg_sqs, eps)
    torch._foreach_addcdiv_(params, _exp_avgs, _exp_avg_sqs, value=-step_size)  # type: ignore

    if step % accum_iter == 0:
        torch._foreach_add_(exp_avgs, accum_grads, alpha=1 - beta1)
        torch._foreach_mul_(exp_avgs, beta1)
        torch._foreach_addcmul_(exp_avg_sqs, accum_grads, accum_grads, value=1 - beta2)
        torch._foreach_mul_(exp_avg_sqs, beta2)
        torch._foreach_zero_(accum_grads)

decent_dp.topo

This module contains the topology classes that define communication patterns between workers in decentralized training.

decent_dp.topo

Edge dataclass

Edge class for defining communication patterns among workers.

Weight defines the fraction of the message that each worker keeps. For example, if the weight is 0.3, then the worker keeps 30% of its message and shares 70% with other workers. $$x_i = w \cdot x_i + \frac{1}{|\text{ranks}|}\sum_{j \in \text{ranks}} x_j * (1 - w)$$. The weight should be between 0 and 1 for convergence.

Parameters:

Name Type Description Default
ranks List[int]

List of ranks of workers that communicate in this edge

required
weight List[float]

Weight for each worker in the edge

required
group Optional[ProcessGroup]

Process group for the edge, which will be created by Topology class

None
Source code in src/decent_dp/topo.py
@dataclass
class Edge:
    """Edge class for defining communication patterns among workers.

    Weight defines the fraction of the message that each worker keeps. For example, if the weight is 0.3, \
        then the worker keeps 30% of its message and shares 70% with other workers. \
        $$x_i = w \\cdot x_i + \\frac{1}{|\\text{ranks}|}\\sum_{j \\in \\text{ranks}} x_j * (1 - w)$$. \
        The weight should be between 0 and 1 for convergence.

    Args:
        ranks (List[int]): List of ranks of workers that communicate in this edge
        weight (List[float]): Weight for each worker in the edge
        group (Optional[ProcessGroup]): Process group for the edge, which will be created by Topology class
    """

    ranks: List[int]
    weight: float
    group: Optional[ProcessGroup] = None

Topology

Source code in src/decent_dp/topo.py
class Topology:
    def __init__(self, local_world_size):
        """Topology class for defining communication patterns between workers.

        The class is responsible for creating process groups for each edge in the topology.
        The edges are defined as a list of lists of Edge objects, where each list of edges \
            corresponds to one iteration of the communication pattern.
        The topology is defined by implementing the _get_topo_edges method, which should return \
            a list of lists of Edge objects. When creating a new topology, the method should be \
            implemented to return the edges for the topology. Usable variables are `self._world_size`, \
            `self._local_world_size`, and `self._n_nodes` for the number of processes, processes per node, \
            and number of nodes, respectively.
        A valid topology is one where each node participates in exactly one communication in each iteration.

        Args:
            local_world_size (int): Number of processes in each node (added as argument for some testing \
                purposes, should be set as the environment variable LOCAL_WORLD_SIZE for normal cases)
        """

        assert dist.is_available() and dist.is_initialized(), "Distributed environment is not initialized"
        self._rank: int = dist.get_rank()
        self._world_size = dist.get_world_size()
        self._local_world_size = local_world_size
        assert self._world_size % local_world_size == 0, (
            f"World size must be divisible by local world size, \
            but {self._world_size} is not divisible by {local_world_size}"
        )
        self._n_nodes = self._world_size // local_world_size
        self._registry: Dict[str, ProcessGroup] = {}
        self._edges: List[Edge] = []
        self._create_edges()

    def _create_edges(self):
        """Create process groups for each "edge" (or group) in the topology"""
        all_edges = self._get_topo_edges()
        self._validate_edges(all_edges)

        # Create default group
        all_ranks = [i for i in range(self._world_size)]
        self._registry["all"] = cast(ProcessGroup, dist.new_group(all_ranks))

        for idx in range(len(all_edges)):
            for edge in all_edges[idx]:
                identifier = str(edge.ranks)
                if identifier not in self._registry:
                    self._registry[identifier] = cast(ProcessGroup, dist.new_group(edge.ranks))
                edge.group = self._registry[identifier]

        for idx in range(len(all_edges)):
            for edge in all_edges[idx]:
                if self._rank in edge.ranks:
                    self._edges.append(edge)
                    break

    def _validate_edges(self, edges: List[List[Edge]]):
        """Verify that the topology is valid. A valid topology is one where each \
            node participates in exactly communication in each iteration

        Args:
            edges (List[List[Edge]]): List of edges for each iteration
        """
        for idx in range(len(edges)):
            used = [False] * self._world_size
            for edge in edges[idx]:
                edge.ranks.sort()
                for rank in edge.ranks:
                    if used[rank]:
                        logger.error(f"Topology is not valid, node {rank} is used more than once in one iteration")
                        raise ValueError()
                    used[rank] = True
            if not all(used):
                logger.error("Topology is not valid, some nodes are not involved in an edge in one iteration")
                raise ValueError()

    def get_edge(self, step: int) -> Edge:
        """Get the edge for the given iteration"""
        return self._edges[step % len(self._edges)]

    def _get_topo_edges(self) -> List[List[Edge]]:
        raise NotImplementedError()

__init__(local_world_size)

Topology class for defining communication patterns between workers.

The class is responsible for creating process groups for each edge in the topology. The edges are defined as a list of lists of Edge objects, where each list of edges corresponds to one iteration of the communication pattern. The topology is defined by implementing the _get_topo_edges method, which should return a list of lists of Edge objects. When creating a new topology, the method should be implemented to return the edges for the topology. Usable variables are self._world_size, self._local_world_size, and self._n_nodes for the number of processes, processes per node, and number of nodes, respectively. A valid topology is one where each node participates in exactly one communication in each iteration.

Parameters:

Name Type Description Default
local_world_size int

Number of processes in each node (added as argument for some testing purposes, should be set as the environment variable LOCAL_WORLD_SIZE for normal cases)

required
Source code in src/decent_dp/topo.py
def __init__(self, local_world_size):
    """Topology class for defining communication patterns between workers.

    The class is responsible for creating process groups for each edge in the topology.
    The edges are defined as a list of lists of Edge objects, where each list of edges \
        corresponds to one iteration of the communication pattern.
    The topology is defined by implementing the _get_topo_edges method, which should return \
        a list of lists of Edge objects. When creating a new topology, the method should be \
        implemented to return the edges for the topology. Usable variables are `self._world_size`, \
        `self._local_world_size`, and `self._n_nodes` for the number of processes, processes per node, \
        and number of nodes, respectively.
    A valid topology is one where each node participates in exactly one communication in each iteration.

    Args:
        local_world_size (int): Number of processes in each node (added as argument for some testing \
            purposes, should be set as the environment variable LOCAL_WORLD_SIZE for normal cases)
    """

    assert dist.is_available() and dist.is_initialized(), "Distributed environment is not initialized"
    self._rank: int = dist.get_rank()
    self._world_size = dist.get_world_size()
    self._local_world_size = local_world_size
    assert self._world_size % local_world_size == 0, (
        f"World size must be divisible by local world size, \
        but {self._world_size} is not divisible by {local_world_size}"
    )
    self._n_nodes = self._world_size // local_world_size
    self._registry: Dict[str, ProcessGroup] = {}
    self._edges: List[Edge] = []
    self._create_edges()

get_edge(step)

Get the edge for the given iteration

Source code in src/decent_dp/topo.py
def get_edge(self, step: int) -> Edge:
    """Get the edge for the given iteration"""
    return self._edges[step % len(self._edges)]

CompleteTopology

Bases: Topology

Complete topology where each node communicates with all other nodes. The weights are 1/n.

Source code in src/decent_dp/topo.py
@TopologyReg.register("complete")
class CompleteTopology(Topology):
    """Complete topology where each node communicates with all other nodes. The weights are 1/n."""

    def _get_topo_edges(self) -> List[List[Edge]]:
        return [
            [
                Edge(
                    ranks=list(range(self._world_size)),
                    weight=1.0 / self._world_size,
                )
            ]
        ]

RingTopology

Bases: Topology

One-peer ring topology where each node communicates with one of its left and right neighbors (by index) in each iteration. The weights are 0.5 for each neighbor.

Source code in src/decent_dp/topo.py
@TopologyReg.register("ring")
class RingTopology(Topology):
    """One-peer ring topology where each node communicates with one of its left and right \
        neighbors (by index) in each iteration. The weights are 0.5 for each neighbor.
    """

    def _get_topo_edges(self) -> List[List[Edge]]:
        if self._world_size % 2 != 0:
            logger.error("Ring topology is not supported for odd world size")
            raise ValueError()

        edges = [[], []]
        # Odd iterations
        for i in range(0, self._world_size, 2):
            edges[0].append(Edge(ranks=sorted([i, (i + 1) % self._world_size]), weight=0.5))
        # Even iterations
        for i in range(0, self._world_size, 2):
            edges[1].append(Edge(ranks=sorted([i, (i - 1 + self._world_size) % self._world_size]), weight=0.5))
        return edges

OnePeerExpTopology

Bases: Topology

One-peer exponential topology.

Source code in src/decent_dp/topo.py
@TopologyReg.register("one-peer-exp")
class OnePeerExpTopology(Topology):
    """One-peer exponential topology."""

    def _get_topo_edges(self) -> List[List[Edge]]:
        rounds = round(math.log2(self._world_size))
        if self._world_size != 2**rounds:
            logger.error("Exponential topology is only supported for 2^x world size")
            raise ValueError()

        edges = []
        for i in range(rounds):
            edges.append([])
            used = [False] * self._world_size
            for j in range(self._world_size):
                if not used[j]:
                    used[j] = True
                    used[(j + 2**i) % self._world_size] = True
                    edges[i].append(Edge(ranks=sorted([j, (j + 2**i) % self._world_size]), weight=0.5))
        return edges

decent_dp.utils

Utility functions for setting up and managing the distributed training environment.

decent_dp.utils

initialize_dist()

A utility function to initialize the distributed environment

Returns:

Type Description
Tuple[int, int]

Tuple[int, int]: rank and world size

Source code in src/decent_dp/utils.py
def initialize_dist() -> Tuple[int, int]:
    """A utility function to initialize the distributed environment

    Returns:
        Tuple[int, int]: rank and world size
    """

    local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "")
    if gpus:
        if not (len(gpus.split(",")) == int(local_world_size)):
            logger.error(
                f"LOCAL_WORLD_SIZE and CUDA_VISIBLE_DEVICES are not consistent, \
                         {local_world_size} vs {len(gpus.split(','))}"
            )
            raise ValueError()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpus.split(",")[local_rank]
        dist.init_process_group(backend="nccl")
    else:
        dist.init_process_group(backend="gloo")
    return dist.get_rank(), dist.get_world_size()