| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- # Following code is partially borrowed from openmmlab/mmcv
- import functools
- import inspect
- import os
- import pickle
- import random
- import socket
- import subprocess
- import tempfile
- from typing import Callable, List, Optional, Tuple
- import numpy as np
- import torch
- import torch.multiprocessing as mp
- from packaging import version
- from torch import distributed as dist
- def _find_free_port() -> str:
- # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- # Binding to port 0 will cause the OS to find an available port for us
- sock.bind(('', 0))
- port = sock.getsockname()[1]
- sock.close()
- # NOTE: there is still a chance the port could be taken by other processes.
- return port
- def _is_free_port(port: int) -> bool:
- ips = socket.gethostbyname_ex(socket.gethostname())[-1]
- ips.append('localhost')
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- return all(s.connect_ex((ip, port)) != 0 for ip in ips)
- def compile_model(model, **compile_options):
- # Compile the model with torch 2.0
- if hasattr(model, 'compile'):
- model = model.compile(**compile_options)
- elif version.parse(torch.__version__) >= version.parse('2.0.0.dev'):
- model = torch.compile(model, **compile_options)
- else:
- print(
- 'Compiling model needs torch version > 2.0.0, '
- f'your torch version is: {torch.__version__}, origin model will be returned.'
- )
- return model
- def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
- if mp.get_start_method(allow_none=True) is None:
- mp.set_start_method('spawn')
- if launcher == 'pytorch':
- _init_dist_pytorch(backend, **kwargs)
- elif launcher == 'mpi':
- _init_dist_mpi(backend, **kwargs)
- elif launcher == 'slurm':
- _init_dist_slurm(backend, **kwargs)
- else:
- raise ValueError(f'Invalid launcher type: {launcher}')
- def _init_dist_pytorch(backend: str, **kwargs) -> None:
- # rank = int(os.environ['RANK'])
- local_rank = int(os.environ['LOCAL_RANK'])
- torch.cuda.set_device(local_rank)
- dist.init_process_group(backend=backend, **kwargs)
- def _init_dist_mpi(backend: str, **kwargs) -> None:
- local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
- torch.cuda.set_device(local_rank)
- if 'MASTER_PORT' not in os.environ:
- # 29500 is torch.distributed default port
- os.environ['MASTER_PORT'] = '29500'
- if 'MASTER_ADDR' not in os.environ:
- raise KeyError('The environment variable MASTER_ADDR is not set')
- os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
- os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
- dist.init_process_group(backend=backend, **kwargs)
- def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
- """Initialize slurm distributed training environment.
- If argument ``port`` is not specified, then the master port will be system
- environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
- environment variable, then a default port ``29500`` will be used.
- Args:
- backend (str): Backend of torch.distributed.
- port (int, optional): Master port. Defaults to None.
- """
- proc_id = int(os.environ['SLURM_PROCID'])
- ntasks = int(os.environ['SLURM_NTASKS'])
- node_list = os.environ['SLURM_NODELIST']
- num_gpus = torch.cuda.device_count()
- torch.cuda.set_device(proc_id % num_gpus)
- addr = subprocess.getoutput(
- f'scontrol show hostname {node_list} | head -n1')
- # specify master port
- if port is not None:
- os.environ['MASTER_PORT'] = str(port)
- elif 'MASTER_PORT' in os.environ:
- pass # use MASTER_PORT in the environment variable
- else:
- # if torch.distributed default port(29500) is available
- # then use it, else find a free port
- if _is_free_port(29500):
- os.environ['MASTER_PORT'] = '29500'
- else:
- os.environ['MASTER_PORT'] = str(_find_free_port())
- # use MASTER_ADDR in the environment variable if it already exists
- if 'MASTER_ADDR' not in os.environ:
- os.environ['MASTER_ADDR'] = addr
- os.environ['WORLD_SIZE'] = str(ntasks)
- os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
- os.environ['RANK'] = str(proc_id)
- dist.init_process_group(backend=backend)
- def get_dist_info(group=None) -> Tuple[int, int]:
- """Get dist info of a specified group
- Args:
- group: The parallel group, default None, for the global group
- Returns:
- A tuple of the current rank and world_size of the group
- """
- if is_dist():
- from modelscope.utils.megatron_utils import is_megatron_initialized
- if group is None and is_megatron_initialized():
- from megatron_util import mpu
- group = mpu.get_data_parallel_group()
- rank = dist.get_rank(group)
- world_size = dist.get_world_size(group)
- else:
- rank = 0
- world_size = 1
- return rank, world_size
- def get_local_rank():
- return int(os.environ.get('LOCAL_RANK', 0))
- def get_rank():
- if not dist.is_available():
- return 0
- if not dist.is_initialized():
- return 0
- return dist.get_rank()
- def get_world_size():
- if not dist.is_available():
- return 1
- if not dist.is_initialized():
- return 1
- return dist.get_world_size()
- def synchronize():
- """
- Helper function to synchronize (barrier)
- among all processes when using distributed training
- """
- if not dist.is_available():
- return
- if not dist.is_initialized():
- return
- world_size = dist.get_world_size()
- if world_size == 1:
- return
- dist.barrier()
- def is_dist():
- return dist.is_available() and dist.is_initialized()
- def is_master(group=None):
- return dist.get_rank(group) == 0 if is_dist() else True
- def master_only(group=None):
- def decorate(func: Callable) -> Callable:
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- if is_master(group):
- return func(*args, **kwargs)
- return wrapper
- return decorate
- def make_tmp_dir():
- """Make sure each rank has the same temporary directory on the distributed mode.
- """
- if not is_dist():
- return tempfile.mkdtemp()
- tmpdir = None
- if is_master():
- tmpdir = tempfile.mkdtemp()
- dist.barrier()
- tmpdir = broadcast(tmpdir, 0)
- return tmpdir
- def broadcast(inputs, src):
- """
- Broadcasts the inputs to all ranks.
- Arguments:
- inputs : Any objects that can be serialized by pickle.
- src (int): Source rank.
- Returns:
- Each rank returns the same value as src.
- """
- rank = dist.get_rank()
- shape_tensor = torch.tensor([0], device='cuda')
- if rank == src:
- inputs_tensor = torch.tensor(
- bytearray(pickle.dumps(inputs)), dtype=torch.uint8, device='cuda')
- shape_tensor = torch.tensor(inputs_tensor.shape, device='cuda')
- dist.barrier()
- dist.broadcast(shape_tensor, src)
- if rank != src:
- inputs_tensor = torch.full((shape_tensor.item(), ),
- 0,
- dtype=torch.uint8,
- device='cuda')
- dist.barrier()
- dist.broadcast(inputs_tensor, src)
- return pickle.loads(inputs_tensor.cpu().numpy().tobytes())
- def set_random_seed(seed):
- if seed is not None and seed >= 0:
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- else:
- raise ValueError(
- f'Random seed should be positive, current seed is {seed}')
- @functools.lru_cache()
- def _get_global_gloo_group():
- """
- Return a process group based on gloo backend, containing all the ranks
- The result is cached.
- """
- if dist.get_backend() == 'nccl':
- return dist.new_group(backend='gloo')
- else:
- return dist.group.WORLD
- def _serialize_to_tensor(data, group):
- backend = dist.get_backend(group)
- assert backend in ['gloo', 'nccl']
- device = torch.device('cpu' if backend == 'gloo' else 'cuda')
- buffer = pickle.dumps(data)
- if len(buffer) > 1024**3:
- logger.warning(
- 'Rank {} trying to all-gather {:.2f} GB of data on device {}'.
- format(get_rank(),
- len(buffer) / (1024**3), device))
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to(device=device)
- return tensor
- def _pad_to_largest_tensor(tensor, group):
- """
- Returns:
- list[int]: size of the tensor, on each rank
- Tensor: padded tensor that has the max size
- """
- world_size = dist.get_world_size(group=group)
- assert (
- world_size >= 1
- ), 'comm.gather/all_gather must be called from ranks within the group!'
- local_size = torch.tensor([tensor.numel()],
- dtype=torch.int64,
- device=tensor.device)
- size_list = [
- torch.zeros([1], dtype=torch.int64, device=tensor.device)
- for _ in range(world_size)
- ]
- dist.all_gather(size_list, local_size, group=group)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
- # we pad the tensor because torch all_gather does not support
- # gathering tensors of different shapes
- if local_size != max_size:
- padding = torch.zeros((max_size - local_size, ),
- dtype=torch.uint8,
- device=tensor.device)
- tensor = torch.cat((tensor, padding), dim=0)
- return size_list, tensor
- def all_gather(data, group=None):
- """
- Run all_gather on arbitrary picklable data (not necessarily tensors).
- Args:
- data: any picklable object
- group: a torch process group. By default, will use a group which
- contains all ranks on gloo backend.
- Returns:
- list[data]: list of data gathered from each rank
- """
- if get_world_size() == 1:
- return [data]
- if group is None:
- group = _get_global_gloo_group()
- if dist.get_world_size(group) == 1:
- return [data]
- tensor = _serialize_to_tensor(data, group)
- size_list, tensor = _pad_to_largest_tensor(tensor, group)
- max_size = max(size_list)
- # receiving Tensor from all ranks
- tensor_list = [
- torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
- for _ in size_list
- ]
- dist.all_gather(tensor_list, tensor, group=group)
- data_list = []
- for size, tensor in zip(size_list, tensor_list):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
- return data_list
- def is_on_same_device(model: torch.nn.Module) -> bool:
- device_set = set(str(p.device) for p in model.parameters()) - {'cpu'}
- return len(device_set) <= 1
- def apply_chunking_to_forward(
- forward_fn: Callable[..., torch.Tensor],
- chunk_size: int,
- chunk_dim: int,
- *input_tensors,
- ) -> torch.Tensor:
- # Copied from transformers, the latest version of transformers deletes this function
- assert len(input_tensors
- ) > 0, f'{input_tensors} has to be a tuple/list of tensors'
- # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
- num_args_in_forward_chunk_fn = len(
- inspect.signature(forward_fn).parameters)
- if num_args_in_forward_chunk_fn != len(input_tensors):
- raise ValueError(
- f'forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input '
- 'tensors are given')
- if chunk_size > 0:
- tensor_shape = input_tensors[0].shape[chunk_dim]
- for input_tensor in input_tensors:
- if input_tensor.shape[chunk_dim] != tensor_shape:
- raise ValueError(
- f'All input tenors have to be of the same shape: {tensor_shape}, '
- f'found shape {input_tensor.shape[chunk_dim]}')
- if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
- raise ValueError(
- f'The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk '
- f'size {chunk_size}')
- num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
- # chunk input tensor into tuples
- input_tensors_chunks = tuple(
- input_tensor.chunk(num_chunks, dim=chunk_dim)
- for input_tensor in input_tensors)
- # apply forward fn to every tuple
- output_chunks = tuple(
- forward_fn(*input_tensors_chunk)
- for input_tensors_chunk in zip(*input_tensors_chunks))
- # concatenate output at same dimension
- return torch.cat(output_chunks, dim=chunk_dim)
- return forward_fn(*input_tensors)
- def find_pruneable_heads_and_indices(
- heads: list[int], n_heads: int, head_size: int,
- already_pruned_heads: set[int]) -> tuple[set[int], torch.Tensor]:
- # Copied from transformers, the latest version of transformers deletes this function
- mask = torch.ones(n_heads, head_size)
- heads = set(
- heads
- ) - already_pruned_heads # Convert to set and remove already pruned heads
- for head in heads:
- # Compute how many pruned heads are before the head and move the index accordingly
- head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
- mask[head] = 0
- mask = mask.view(-1).contiguous().eq(1)
- index: torch.LongTensor = torch.arange(len(mask))[mask].long()
- return heads, index
- def prune_linear_layer(layer: torch.nn.Linear,
- index: torch.LongTensor,
- dim: int = 0) -> torch.nn.Linear:
- # Copied from transformers, the latest version of transformers deletes this function
- index = index.to(layer.weight.device)
- W = layer.weight.index_select(dim, index).detach().clone()
- if layer.bias is not None:
- if dim == 1:
- b = layer.bias.detach().clone()
- else:
- b = layer.bias[index].detach().clone()
- new_size = list(layer.weight.size())
- new_size[dim] = len(index)
- new_layer = torch.nn.Linear(
- new_size[1], new_size[0], bias=layer.bias
- is not None).to(layer.weight.device)
- new_layer.weight.requires_grad = False
- new_layer.weight.copy_(W.contiguous())
- new_layer.weight.requires_grad = True
- if layer.bias is not None:
- new_layer.bias.requires_grad = False
- new_layer.bias.copy_(b.contiguous())
- new_layer.bias.requires_grad = True
- return new_layer
|