Custom Optimizers Tutorial
Decent-DP provides specialized optimizers designed for decentralized training scenarios, particularly with gradient accumulation. This tutorial explains how to use the built-in optimizers and create your own custom optimizer functions.
Built-in Optimizers
Standard Optimizers
Decent-DP includes wrapper functions for standard PyTorch optimizers:
optim_fn_adam: Adam optimizer with parameter groupingoptim_fn_adamw: AdamW optimizer with parameter grouping
These functions automatically handle parameter grouping for weight decay:
from decent_dp.optim import optim_fn_adamw
# Basic usage
ddp_model = DecentDP(model, optim_fn=optim_fn_adamw)
# Customized hyperparameters
from functools import partial
custom_adamw = partial(
optim_fn_adamw,
lr=0.001,
beta1=0.9,
beta2=0.999,
weight_decay=0.01,
eps=1e-8
)
ddp_model = DecentDP(model, optim_fn=custom_adamw)
<!-- ### Accumulated Gradient Optimizers
For scenarios where you want to accumulate gradients over multiple steps before updating:
optim_fn_accum_adam: Adam optimizer with gradient accumulationoptim_fn_accum_adamw: AdamW optimizer with gradient accumulation
These optimizers are designed to work with Decent-DP's gradient accumulation feature:
from decent_dp.optim import optim_fn_accum_adamw
# With gradient accumulation
ddp_model = DecentDP(model, optim_fn=optim_fn_accum_adamw)
# Enable gradient accumulation in training loop
ddp_model.set_accumulate_grad(True)
for i in range(accumulation_steps):
output = ddp_model(data)
loss = criterion(output, target) / accumulation_steps
loss.backward()
# Disable gradient accumulation and perform update
ddp_model.set_accumulate_grad(False)
``` -->
## Creating Custom Optimizer Functions
Optimizer functions in Decent-DP must follow a specific signature:
```python
def optimizer_function(params: List[Tuple[str, Tensor]]) -> Optimizer:
"""Create an optimizer for the given parameters.
Args:
params: List of (parameter_name, parameter_tensor) tuples
Returns:
torch.optim.Optimizer: Configured optimizer instance
"""
# Extract parameter tensors
param_tensors = [p for _, p in params]
# Create and return optimizer
return torch.optim.YourOptimizer(param_tensors, your_hyperparameters)
Example: Custom SGD Optimizer
import torch
from torch.optim import Optimizer
from decent_dp.ddp import DecentralizedDataParallel as DecentDP
def optim_fn_sgd(params, lr=0.01, momentum=0.9, weight_decay=0.0):
"""Custom SGD optimizer function."""
# Group parameters by weight decay (similar to built-in functions)
params_no_decay = [x for n, x in params if not (("bn" in n) or ("bias" in n))]
params_decay = [x for n, x in params if ("bn" in n) or ("bias" in n)]
param_groups = [
{"params": params_no_decay, "weight_decay": 0.0},
{"params": params_decay, "weight_decay": weight_decay}
]
return torch.optim.SGD(param_groups, lr=lr, momentum=momentum)
# Use the custom optimizer function
ddp_model = DecentDP(model, optim_fn=optim_fn_sgd)
Example: Custom Adam with Different Parameter Grouping
def optim_fn_custom_adam(params, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
"""Custom Adam optimizer with different parameter grouping logic."""
# Group parameters based on their names
encoder_params = [x for n, x in params if "encoder" in n]
decoder_params = [x for n, x in params if "decoder" in n]
# Create parameter groups with different learning rates
param_groups = [
{"params": encoder_params, "lr": lr * 0.1}, # Encoder with lower LR
{"params": decoder_params, "lr": lr} # Decoder with normal LR
]
return torch.optim.Adam(param_groups, betas=(beta1, beta2), eps=eps)
# Use the custom optimizer function
ddp_model = DecentDP(model, optim_fn=optim_fn_custom_adam)
Working with Learning Rate Schedulers
Decent-DP also supports learning rate schedulers through the lr_scheduler_fn parameter:
from decent_dp.optim import (
optim_fn_adamw,
lr_scheduler_fn_cosine_with_warmup
)
# Create model with optimizer and LR scheduler
ddp_model = DecentDP(
model,
optim_fn=optim_fn_adamw,
lr_scheduler_fn=lr_scheduler_fn_cosine_with_warmup
)
Custom Learning Rate Scheduler Function
def custom_lr_scheduler_fn(optimizer, step_size=10, gamma=0.1):
"""Custom learning rate scheduler function."""
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
# Use with custom scheduler
ddp_model = DecentDP(
model,
optim_fn=optim_fn_adamw,
lr_scheduler_fn=custom_lr_scheduler_fn
)
Advanced: Optimizers with Communication Awareness
Decent-DP allows optimizers to be aware of the communication topology through a special pre_average_hook method:
class TopologyAwareOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.01):
super().__init__(params, dict(lr=lr))
def pre_average_hook(self, edge, weight):
"""Called before parameter averaging in each iteration.
Args:
edge: Communication edge information
weight: Averaging weight for this worker
"""
# Adjust optimizer behavior based on communication topology
print(f"Communicating with ranks {edge.ranks} using weight {weight}")
# Example: Adjust learning rate based on communication pattern
for param_group in self.param_groups:
param_group['lr'] = self.defaults['lr'] * len(edge.ranks)
Best Practices
- Parameter Grouping: Always consider how to group parameters for weight decay
- Hyperparameter Tuning: Decentralized training may require different hyperparameters than standard training
- Gradient Accumulation: Use accumulated gradient optimizers when simulating larger batch sizes
- Consistent Signatures: Ensure your optimizer functions follow the expected signature
- Testing: Test custom optimizers with simple models before using in production
Troubleshooting
Common Issues
- Optimizer Function Signature: Ensure your function takes
params: List[Tuple[str, Tensor]]and returns anOptimizer - Parameter Grouping: Make sure all parameters are included in exactly one group
- Device Placement: Ensure parameters are on the correct device before creating optimizers
Debugging Tips
Enable logging to see optimizer creation details:
This will show information about optimizer creation for each parameter bucket.