torch_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Following code is partially borrowed from openmmlab/mmcv
  3. import functools
  4. import inspect
  5. import os
  6. import pickle
  7. import random
  8. import socket
  9. import subprocess
  10. import tempfile
  11. from typing import Callable, List, Optional, Tuple
  12. import numpy as np
  13. import torch
  14. import torch.multiprocessing as mp
  15. from packaging import version
  16. from torch import distributed as dist
  17. def _find_free_port() -> str:
  18. # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
  19. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  20. # Binding to port 0 will cause the OS to find an available port for us
  21. sock.bind(('', 0))
  22. port = sock.getsockname()[1]
  23. sock.close()
  24. # NOTE: there is still a chance the port could be taken by other processes.
  25. return port
  26. def _is_free_port(port: int) -> bool:
  27. ips = socket.gethostbyname_ex(socket.gethostname())[-1]
  28. ips.append('localhost')
  29. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  30. return all(s.connect_ex((ip, port)) != 0 for ip in ips)
  31. def compile_model(model, **compile_options):
  32. # Compile the model with torch 2.0
  33. if hasattr(model, 'compile'):
  34. model = model.compile(**compile_options)
  35. elif version.parse(torch.__version__) >= version.parse('2.0.0.dev'):
  36. model = torch.compile(model, **compile_options)
  37. else:
  38. print(
  39. 'Compiling model needs torch version > 2.0.0, '
  40. f'your torch version is: {torch.__version__}, origin model will be returned.'
  41. )
  42. return model
  43. def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
  44. if mp.get_start_method(allow_none=True) is None:
  45. mp.set_start_method('spawn')
  46. if launcher == 'pytorch':
  47. _init_dist_pytorch(backend, **kwargs)
  48. elif launcher == 'mpi':
  49. _init_dist_mpi(backend, **kwargs)
  50. elif launcher == 'slurm':
  51. _init_dist_slurm(backend, **kwargs)
  52. else:
  53. raise ValueError(f'Invalid launcher type: {launcher}')
  54. def _init_dist_pytorch(backend: str, **kwargs) -> None:
  55. # rank = int(os.environ['RANK'])
  56. local_rank = int(os.environ['LOCAL_RANK'])
  57. torch.cuda.set_device(local_rank)
  58. dist.init_process_group(backend=backend, **kwargs)
  59. def _init_dist_mpi(backend: str, **kwargs) -> None:
  60. local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
  61. torch.cuda.set_device(local_rank)
  62. if 'MASTER_PORT' not in os.environ:
  63. # 29500 is torch.distributed default port
  64. os.environ['MASTER_PORT'] = '29500'
  65. if 'MASTER_ADDR' not in os.environ:
  66. raise KeyError('The environment variable MASTER_ADDR is not set')
  67. os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
  68. os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
  69. dist.init_process_group(backend=backend, **kwargs)
  70. def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
  71. """Initialize slurm distributed training environment.
  72. If argument ``port`` is not specified, then the master port will be system
  73. environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
  74. environment variable, then a default port ``29500`` will be used.
  75. Args:
  76. backend (str): Backend of torch.distributed.
  77. port (int, optional): Master port. Defaults to None.
  78. """
  79. proc_id = int(os.environ['SLURM_PROCID'])
  80. ntasks = int(os.environ['SLURM_NTASKS'])
  81. node_list = os.environ['SLURM_NODELIST']
  82. num_gpus = torch.cuda.device_count()
  83. torch.cuda.set_device(proc_id % num_gpus)
  84. addr = subprocess.getoutput(
  85. f'scontrol show hostname {node_list} | head -n1')
  86. # specify master port
  87. if port is not None:
  88. os.environ['MASTER_PORT'] = str(port)
  89. elif 'MASTER_PORT' in os.environ:
  90. pass # use MASTER_PORT in the environment variable
  91. else:
  92. # if torch.distributed default port(29500) is available
  93. # then use it, else find a free port
  94. if _is_free_port(29500):
  95. os.environ['MASTER_PORT'] = '29500'
  96. else:
  97. os.environ['MASTER_PORT'] = str(_find_free_port())
  98. # use MASTER_ADDR in the environment variable if it already exists
  99. if 'MASTER_ADDR' not in os.environ:
  100. os.environ['MASTER_ADDR'] = addr
  101. os.environ['WORLD_SIZE'] = str(ntasks)
  102. os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
  103. os.environ['RANK'] = str(proc_id)
  104. dist.init_process_group(backend=backend)
  105. def get_dist_info(group=None) -> Tuple[int, int]:
  106. """Get dist info of a specified group
  107. Args:
  108. group: The parallel group, default None, for the global group
  109. Returns:
  110. A tuple of the current rank and world_size of the group
  111. """
  112. if is_dist():
  113. from modelscope.utils.megatron_utils import is_megatron_initialized
  114. if group is None and is_megatron_initialized():
  115. from megatron_util import mpu
  116. group = mpu.get_data_parallel_group()
  117. rank = dist.get_rank(group)
  118. world_size = dist.get_world_size(group)
  119. else:
  120. rank = 0
  121. world_size = 1
  122. return rank, world_size
  123. def get_local_rank():
  124. return int(os.environ.get('LOCAL_RANK', 0))
  125. def get_rank():
  126. if not dist.is_available():
  127. return 0
  128. if not dist.is_initialized():
  129. return 0
  130. return dist.get_rank()
  131. def get_world_size():
  132. if not dist.is_available():
  133. return 1
  134. if not dist.is_initialized():
  135. return 1
  136. return dist.get_world_size()
  137. def synchronize():
  138. """
  139. Helper function to synchronize (barrier)
  140. among all processes when using distributed training
  141. """
  142. if not dist.is_available():
  143. return
  144. if not dist.is_initialized():
  145. return
  146. world_size = dist.get_world_size()
  147. if world_size == 1:
  148. return
  149. dist.barrier()
  150. def is_dist():
  151. return dist.is_available() and dist.is_initialized()
  152. def is_master(group=None):
  153. return dist.get_rank(group) == 0 if is_dist() else True
  154. def master_only(group=None):
  155. def decorate(func: Callable) -> Callable:
  156. @functools.wraps(func)
  157. def wrapper(*args, **kwargs):
  158. if is_master(group):
  159. return func(*args, **kwargs)
  160. return wrapper
  161. return decorate
  162. def make_tmp_dir():
  163. """Make sure each rank has the same temporary directory on the distributed mode.
  164. """
  165. if not is_dist():
  166. return tempfile.mkdtemp()
  167. tmpdir = None
  168. if is_master():
  169. tmpdir = tempfile.mkdtemp()
  170. dist.barrier()
  171. tmpdir = broadcast(tmpdir, 0)
  172. return tmpdir
  173. def broadcast(inputs, src):
  174. """
  175. Broadcasts the inputs to all ranks.
  176. Arguments:
  177. inputs : Any objects that can be serialized by pickle.
  178. src (int): Source rank.
  179. Returns:
  180. Each rank returns the same value as src.
  181. """
  182. rank = dist.get_rank()
  183. shape_tensor = torch.tensor([0], device='cuda')
  184. if rank == src:
  185. inputs_tensor = torch.tensor(
  186. bytearray(pickle.dumps(inputs)), dtype=torch.uint8, device='cuda')
  187. shape_tensor = torch.tensor(inputs_tensor.shape, device='cuda')
  188. dist.barrier()
  189. dist.broadcast(shape_tensor, src)
  190. if rank != src:
  191. inputs_tensor = torch.full((shape_tensor.item(), ),
  192. 0,
  193. dtype=torch.uint8,
  194. device='cuda')
  195. dist.barrier()
  196. dist.broadcast(inputs_tensor, src)
  197. return pickle.loads(inputs_tensor.cpu().numpy().tobytes())
  198. def set_random_seed(seed):
  199. if seed is not None and seed >= 0:
  200. random.seed(seed)
  201. np.random.seed(seed)
  202. torch.manual_seed(seed)
  203. torch.cuda.manual_seed_all(seed)
  204. else:
  205. raise ValueError(
  206. f'Random seed should be positive, current seed is {seed}')
  207. @functools.lru_cache()
  208. def _get_global_gloo_group():
  209. """
  210. Return a process group based on gloo backend, containing all the ranks
  211. The result is cached.
  212. """
  213. if dist.get_backend() == 'nccl':
  214. return dist.new_group(backend='gloo')
  215. else:
  216. return dist.group.WORLD
  217. def _serialize_to_tensor(data, group):
  218. backend = dist.get_backend(group)
  219. assert backend in ['gloo', 'nccl']
  220. device = torch.device('cpu' if backend == 'gloo' else 'cuda')
  221. buffer = pickle.dumps(data)
  222. if len(buffer) > 1024**3:
  223. logger.warning(
  224. 'Rank {} trying to all-gather {:.2f} GB of data on device {}'.
  225. format(get_rank(),
  226. len(buffer) / (1024**3), device))
  227. storage = torch.ByteStorage.from_buffer(buffer)
  228. tensor = torch.ByteTensor(storage).to(device=device)
  229. return tensor
  230. def _pad_to_largest_tensor(tensor, group):
  231. """
  232. Returns:
  233. list[int]: size of the tensor, on each rank
  234. Tensor: padded tensor that has the max size
  235. """
  236. world_size = dist.get_world_size(group=group)
  237. assert (
  238. world_size >= 1
  239. ), 'comm.gather/all_gather must be called from ranks within the group!'
  240. local_size = torch.tensor([tensor.numel()],
  241. dtype=torch.int64,
  242. device=tensor.device)
  243. size_list = [
  244. torch.zeros([1], dtype=torch.int64, device=tensor.device)
  245. for _ in range(world_size)
  246. ]
  247. dist.all_gather(size_list, local_size, group=group)
  248. size_list = [int(size.item()) for size in size_list]
  249. max_size = max(size_list)
  250. # we pad the tensor because torch all_gather does not support
  251. # gathering tensors of different shapes
  252. if local_size != max_size:
  253. padding = torch.zeros((max_size - local_size, ),
  254. dtype=torch.uint8,
  255. device=tensor.device)
  256. tensor = torch.cat((tensor, padding), dim=0)
  257. return size_list, tensor
  258. def all_gather(data, group=None):
  259. """
  260. Run all_gather on arbitrary picklable data (not necessarily tensors).
  261. Args:
  262. data: any picklable object
  263. group: a torch process group. By default, will use a group which
  264. contains all ranks on gloo backend.
  265. Returns:
  266. list[data]: list of data gathered from each rank
  267. """
  268. if get_world_size() == 1:
  269. return [data]
  270. if group is None:
  271. group = _get_global_gloo_group()
  272. if dist.get_world_size(group) == 1:
  273. return [data]
  274. tensor = _serialize_to_tensor(data, group)
  275. size_list, tensor = _pad_to_largest_tensor(tensor, group)
  276. max_size = max(size_list)
  277. # receiving Tensor from all ranks
  278. tensor_list = [
  279. torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
  280. for _ in size_list
  281. ]
  282. dist.all_gather(tensor_list, tensor, group=group)
  283. data_list = []
  284. for size, tensor in zip(size_list, tensor_list):
  285. buffer = tensor.cpu().numpy().tobytes()[:size]
  286. data_list.append(pickle.loads(buffer))
  287. return data_list
  288. def is_on_same_device(model: torch.nn.Module) -> bool:
  289. device_set = set(str(p.device) for p in model.parameters()) - {'cpu'}
  290. return len(device_set) <= 1
  291. def apply_chunking_to_forward(
  292. forward_fn: Callable[..., torch.Tensor],
  293. chunk_size: int,
  294. chunk_dim: int,
  295. *input_tensors,
  296. ) -> torch.Tensor:
  297. # Copied from transformers, the latest version of transformers deletes this function
  298. assert len(input_tensors
  299. ) > 0, f'{input_tensors} has to be a tuple/list of tensors'
  300. # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
  301. num_args_in_forward_chunk_fn = len(
  302. inspect.signature(forward_fn).parameters)
  303. if num_args_in_forward_chunk_fn != len(input_tensors):
  304. raise ValueError(
  305. f'forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input '
  306. 'tensors are given')
  307. if chunk_size > 0:
  308. tensor_shape = input_tensors[0].shape[chunk_dim]
  309. for input_tensor in input_tensors:
  310. if input_tensor.shape[chunk_dim] != tensor_shape:
  311. raise ValueError(
  312. f'All input tenors have to be of the same shape: {tensor_shape}, '
  313. f'found shape {input_tensor.shape[chunk_dim]}')
  314. if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
  315. raise ValueError(
  316. f'The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk '
  317. f'size {chunk_size}')
  318. num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
  319. # chunk input tensor into tuples
  320. input_tensors_chunks = tuple(
  321. input_tensor.chunk(num_chunks, dim=chunk_dim)
  322. for input_tensor in input_tensors)
  323. # apply forward fn to every tuple
  324. output_chunks = tuple(
  325. forward_fn(*input_tensors_chunk)
  326. for input_tensors_chunk in zip(*input_tensors_chunks))
  327. # concatenate output at same dimension
  328. return torch.cat(output_chunks, dim=chunk_dim)
  329. return forward_fn(*input_tensors)
  330. def find_pruneable_heads_and_indices(
  331. heads: list[int], n_heads: int, head_size: int,
  332. already_pruned_heads: set[int]) -> tuple[set[int], torch.Tensor]:
  333. # Copied from transformers, the latest version of transformers deletes this function
  334. mask = torch.ones(n_heads, head_size)
  335. heads = set(
  336. heads
  337. ) - already_pruned_heads # Convert to set and remove already pruned heads
  338. for head in heads:
  339. # Compute how many pruned heads are before the head and move the index accordingly
  340. head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
  341. mask[head] = 0
  342. mask = mask.view(-1).contiguous().eq(1)
  343. index: torch.LongTensor = torch.arange(len(mask))[mask].long()
  344. return heads, index
  345. def prune_linear_layer(layer: torch.nn.Linear,
  346. index: torch.LongTensor,
  347. dim: int = 0) -> torch.nn.Linear:
  348. # Copied from transformers, the latest version of transformers deletes this function
  349. index = index.to(layer.weight.device)
  350. W = layer.weight.index_select(dim, index).detach().clone()
  351. if layer.bias is not None:
  352. if dim == 1:
  353. b = layer.bias.detach().clone()
  354. else:
  355. b = layer.bias[index].detach().clone()
  356. new_size = list(layer.weight.size())
  357. new_size[dim] = len(index)
  358. new_layer = torch.nn.Linear(
  359. new_size[1], new_size[0], bias=layer.bias
  360. is not None).to(layer.weight.device)
  361. new_layer.weight.requires_grad = False
  362. new_layer.weight.copy_(W.contiguous())
  363. new_layer.weight.requires_grad = True
  364. if layer.bias is not None:
  365. new_layer.bias.requires_grad = False
  366. new_layer.bias.copy_(b.contiguous())
  367. new_layer.bias.requires_grad = True
  368. return new_layer