| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import collections |
| | import collections.abc |
| | import ctypes |
| | import functools |
| | import os |
| | from datetime import timedelta |
| | from typing import Any, Callable, Optional |
| |
|
| | import pynvml |
| | import torch |
| | import torch.distributed as dist |
| |
|
| | from .log import log |
| | from .device import Device |
| |
|
| |
|
| | def init() -> int | None: |
| | """Initialize distributed training.""" |
| | |
| | pynvml.nvmlInit() |
| | local_rank = int(os.getenv("LOCAL_RANK", 0)) |
| | device = Device(local_rank) |
| | os.sched_setaffinity(0, device.get_cpu_affinity()) |
| | |
| | os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" |
| | os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" |
| | if dist.is_available(): |
| | if dist.is_initialized(): |
| | return torch.cuda.current_device() |
| | torch.cuda.set_device(local_rank) |
| | |
| | timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) |
| | |
| | timeout_timedelta = timedelta(seconds=int(timeout_seconds)) |
| | dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) |
| | log.critical( |
| | f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}", |
| | rank0_only=False, |
| | ) |
| | |
| | _libcudart = ctypes.CDLL("libcudart.so") |
| | |
| | p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) |
| | _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) |
| | _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) |
| | log.info(f"Training with {get_world_size()} GPUs.") |
| |
|
| |
|
| | def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: |
| | """Get the rank (GPU device) of the worker. |
| | |
| | Returns: |
| | rank (int): The rank of the worker. |
| | """ |
| | rank = 0 |
| | if dist.is_available() and dist.is_initialized(): |
| | rank = dist.get_rank(group) |
| | return rank |
| |
|
| |
|
| | def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: |
| | """Get world size. How many GPUs are available in this job. |
| | |
| | Returns: |
| | world_size (int): The total number of GPUs available in this job. |
| | """ |
| | world_size = 1 |
| | if dist.is_available() and dist.is_initialized(): |
| | world_size = dist.get_world_size(group) |
| | return world_size |
| |
|
| |
|
| | def is_rank0() -> bool: |
| | """Check if current process is the master GPU. |
| | |
| | Returns: |
| | (bool): True if this function is called from the master GPU, else False. |
| | """ |
| | return get_rank() == 0 |
| |
|
| |
|
| | def rank0_only(func: Callable) -> Callable: |
| | """Apply this function only to the master GPU. |
| | |
| | Example usage: |
| | @rank0_only |
| | def func(x): |
| | return x + 3 |
| | |
| | Args: |
| | func (Callable): a function. |
| | |
| | Returns: |
| | (Callable): A function wrapper executing the function only on the master GPU. |
| | """ |
| |
|
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | if is_rank0(): |
| | return func(*args, **kwargs) |
| | else: |
| | return None |
| |
|
| | return wrapper |
| |
|
| |
|
| | def barrier() -> None: |
| | """Barrier for all GPUs.""" |
| | if dist.is_available() and dist.is_initialized(): |
| | dist.barrier() |
| |
|
| |
|
| | class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): |
| | """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). |
| | |
| | This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that |
| | model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling |
| | model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> |
| | training_step), allowing us to preserve the function names and signatures. |
| | """ |
| |
|
| | def __init__(self, model: torch.nn.Module, *args, **kwargs): |
| | super().__init__(model, *args, **kwargs) |
| |
|
| | def training_step(self, *args, **kwargs) -> Any: |
| | |
| | original_forward = self.module.forward |
| |
|
| | def wrapped_training_step(*_args, **_kwargs): |
| | |
| | self.module.forward = original_forward |
| | |
| | return self.module.training_step(*_args, **_kwargs) |
| |
|
| | |
| | self.module.forward = wrapped_training_step |
| | |
| | |
| | return self(*args, **kwargs) |
| |
|
| |
|
| | def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: |
| | """Aggregate the list of data batches from all devices and process the results. |
| | |
| | This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. |
| | It will return the data/output of the entire validation set in its original index order. The sizes of data_batches |
| | in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be |
| | created before calling dis.all_gather(). |
| | |
| | Args: |
| | data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where |
| | leaf entries are tensors. |
| | |
| | Returns: |
| | data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where |
| | leaf entries are concatenated tensors. |
| | """ |
| | if isinstance(data_batches[0], torch.Tensor): |
| | |
| | data_concat = torch.cat(data_batches, dim=0) |
| | |
| | max_num_local_samples = torch.tensor(len(data_concat), device="cuda") |
| | dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) |
| | if len(data_concat) < max_num_local_samples: |
| | assert len(data_concat) + 1 == max_num_local_samples |
| | dummy = torch.empty_like(data_concat[:1]) |
| | data_concat = torch.cat([data_concat, dummy], dim=0) |
| | dummy_count = torch.tensor(1, device="cuda") |
| | else: |
| | dummy_count = torch.tensor(0, device="cuda") |
| | |
| | dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) |
| | data_concat = all_gather_tensor(data_concat.contiguous()) |
| | data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) |
| | |
| | if dummy_count > 0: |
| | data_collate = data_collate[:-dummy_count] |
| | elif isinstance(data_batches[0], collections.abc.Mapping): |
| | data_collate = dict() |
| | for key in data_batches[0].keys(): |
| | data_collate[key] = collate_batches([data[key] for data in data_batches]) |
| | else: |
| | raise TypeError |
| | return data_collate |
| |
|
| |
|
| | @torch.no_grad() |
| | def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: |
| | """Gather the corresponding tensor from all GPU devices to a list. |
| | |
| | Args: |
| | tensor (torch.Tensor): Pytorch tensor. |
| | |
| | Returns: |
| | tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. |
| | """ |
| | tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] |
| | dist.all_gather(tensor_list, tensor) |
| | return tensor_list |
| |
|
| |
|
| | def broadcast(tensor, src, group=None, async_op=False): |
| | world_size = get_world_size() |
| | if world_size < 2: |
| | return tensor |
| | dist.broadcast(tensor, src=src, group=group, async_op=async_op) |
| |
|