common.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. from functools import partial
  4. from typing import Any, Callable, Dict, Iterable, List, Optional
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.utils.checkpoint
  9. from unicore.modules import LayerNorm
  10. from unicore.utils import tensor_tree_map
  11. class Linear(nn.Linear):
  12. def __init__(
  13. self,
  14. d_in: int,
  15. d_out: int,
  16. bias: bool = True,
  17. init: str = 'default',
  18. ):
  19. super(Linear, self).__init__(d_in, d_out, bias=bias)
  20. self.use_bias = bias
  21. if self.use_bias:
  22. with torch.no_grad():
  23. self.bias.fill_(0)
  24. if init == 'default':
  25. self._trunc_normal_init(1.0)
  26. elif init == 'relu':
  27. self._trunc_normal_init(2.0)
  28. elif init == 'glorot':
  29. self._glorot_uniform_init()
  30. elif init == 'gating':
  31. self._zero_init(self.use_bias)
  32. elif init == 'normal':
  33. self._normal_init()
  34. elif init == 'final':
  35. self._zero_init(False)
  36. else:
  37. raise ValueError('Invalid init method.')
  38. def _trunc_normal_init(self, scale=1.0):
  39. # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
  40. TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978
  41. _, fan_in = self.weight.shape
  42. scale = scale / max(1, fan_in)
  43. std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR
  44. nn.init.trunc_normal_(self.weight, mean=0.0, std=std)
  45. def _glorot_uniform_init(self):
  46. nn.init.xavier_uniform_(self.weight, gain=1)
  47. def _zero_init(self, use_bias=True):
  48. with torch.no_grad():
  49. self.weight.fill_(0.0)
  50. if use_bias:
  51. with torch.no_grad():
  52. self.bias.fill_(1.0)
  53. def _normal_init(self):
  54. torch.nn.init.kaiming_normal_(self.weight, nonlinearity='linear')
  55. class Transition(nn.Module):
  56. def __init__(self, d_in, n):
  57. super(Transition, self).__init__()
  58. self.d_in = d_in
  59. self.n = n
  60. self.layer_norm = LayerNorm(self.d_in)
  61. self.linear_1 = Linear(self.d_in, self.n * self.d_in, init='relu')
  62. self.act = nn.GELU()
  63. self.linear_2 = Linear(self.n * self.d_in, d_in, init='final')
  64. def _transition(self, x):
  65. x = self.layer_norm(x)
  66. x = self.linear_1(x)
  67. x = self.act(x)
  68. x = self.linear_2(x)
  69. return x
  70. @torch.jit.ignore
  71. def _chunk(
  72. self,
  73. x: torch.Tensor,
  74. chunk_size: int,
  75. ) -> torch.Tensor:
  76. return chunk_layer(
  77. self._transition,
  78. {'x': x},
  79. chunk_size=chunk_size,
  80. num_batch_dims=len(x.shape[:-2]),
  81. )
  82. def forward(
  83. self,
  84. x: torch.Tensor,
  85. chunk_size: Optional[int] = None,
  86. ) -> torch.Tensor:
  87. if chunk_size is not None:
  88. x = self._chunk(x, chunk_size)
  89. else:
  90. x = self._transition(x=x)
  91. return x
  92. class OuterProductMean(nn.Module):
  93. def __init__(self, d_msa, d_pair, d_hid, eps=1e-3):
  94. super(OuterProductMean, self).__init__()
  95. self.d_msa = d_msa
  96. self.d_pair = d_pair
  97. self.d_hid = d_hid
  98. self.eps = eps
  99. self.layer_norm = LayerNorm(d_msa)
  100. self.linear_1 = Linear(d_msa, d_hid)
  101. self.linear_2 = Linear(d_msa, d_hid)
  102. self.linear_out = Linear(d_hid**2, d_pair, init='relu')
  103. self.act = nn.GELU()
  104. self.linear_z = Linear(self.d_pair, self.d_pair, init='final')
  105. self.layer_norm_out = LayerNorm(self.d_pair)
  106. def _opm(self, a, b):
  107. outer = torch.einsum('...bac,...dae->...bdce', a, b)
  108. outer = outer.reshape(outer.shape[:-2] + (-1, ))
  109. outer = self.linear_out(outer)
  110. return outer
  111. @torch.jit.ignore
  112. def _chunk(self, a: torch.Tensor, b: torch.Tensor,
  113. chunk_size: int) -> torch.Tensor:
  114. a = a.reshape((-1, ) + a.shape[-3:])
  115. b = b.reshape((-1, ) + b.shape[-3:])
  116. out = []
  117. # TODO: optimize this
  118. for a_prime, b_prime in zip(a, b):
  119. outer = chunk_layer(
  120. partial(self._opm, b=b_prime),
  121. {'a': a_prime},
  122. chunk_size=chunk_size,
  123. num_batch_dims=1,
  124. )
  125. out.append(outer)
  126. if len(out) == 1:
  127. outer = out[0].unsqueeze(0)
  128. else:
  129. outer = torch.stack(out, dim=0)
  130. outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
  131. return outer
  132. def apply_alphafold_original_mode(self):
  133. self.linear_z = None
  134. self.layer_norm_out = None
  135. def forward(
  136. self,
  137. m: torch.Tensor,
  138. mask: Optional[torch.Tensor] = None,
  139. chunk_size: Optional[int] = None,
  140. ) -> torch.Tensor:
  141. m = self.layer_norm(m)
  142. mask = mask.unsqueeze(-1)
  143. if self.layer_norm_out is not None:
  144. # for numerical stability
  145. mask = mask * (mask.size(-2)**-0.5)
  146. a = self.linear_1(m)
  147. b = self.linear_2(m)
  148. if self.training:
  149. a = a * mask
  150. b = b * mask
  151. else:
  152. a *= mask
  153. b *= mask
  154. a = a.transpose(-2, -3)
  155. b = b.transpose(-2, -3)
  156. if chunk_size is not None:
  157. z = self._chunk(a, b, chunk_size)
  158. else:
  159. z = self._opm(a, b)
  160. norm = torch.einsum('...abc,...adc->...bdc', mask, mask)
  161. z /= self.eps + norm
  162. if self.layer_norm_out is not None:
  163. z = self.act(z)
  164. z = self.layer_norm_out(z)
  165. z = self.linear_z(z)
  166. return z
  167. def residual(residual, x, training):
  168. if training:
  169. return x + residual
  170. else:
  171. residual += x
  172. return residual
  173. @torch.jit.script
  174. def fused_bias_dropout_add(
  175. x: torch.Tensor,
  176. bias: torch.Tensor,
  177. residual: torch.Tensor,
  178. dropmask: torch.Tensor,
  179. prob: float,
  180. ) -> torch.Tensor:
  181. return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual
  182. @torch.jit.script
  183. def fused_bias_dropout_add_inference(
  184. x: torch.Tensor,
  185. bias: torch.Tensor,
  186. residual: torch.Tensor,
  187. ) -> torch.Tensor:
  188. residual += bias + x
  189. return residual
  190. def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob,
  191. training):
  192. bias = module.get_output_bias()
  193. if training:
  194. shape = list(x.shape)
  195. shape[dropout_shared_dim] = 1
  196. with torch.no_grad():
  197. mask = x.new_ones(shape)
  198. return fused_bias_dropout_add(x, bias, residual, mask, prob)
  199. else:
  200. return fused_bias_dropout_add_inference(x, bias, residual)
  201. @torch.jit.script
  202. def fused_bias_gated_dropout_add(
  203. x: torch.Tensor,
  204. bias: torch.Tensor,
  205. g: torch.Tensor,
  206. g_bias: torch.Tensor,
  207. residual: torch.Tensor,
  208. dropout_mask: torch.Tensor,
  209. prob: float,
  210. ) -> torch.Tensor:
  211. return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout(
  212. dropout_mask,
  213. p=prob,
  214. training=True,
  215. ) + residual
  216. def tri_mul_residual(
  217. module,
  218. residual,
  219. outputs,
  220. dropout_shared_dim,
  221. prob,
  222. training,
  223. block_size,
  224. ):
  225. if training:
  226. x, g = outputs
  227. bias, g_bias = module.get_output_bias()
  228. shape = list(x.shape)
  229. shape[dropout_shared_dim] = 1
  230. with torch.no_grad():
  231. mask = x.new_ones(shape)
  232. return fused_bias_gated_dropout_add(
  233. x,
  234. bias,
  235. g,
  236. g_bias,
  237. residual,
  238. mask,
  239. prob,
  240. )
  241. elif block_size is None:
  242. x, g = outputs
  243. bias, g_bias = module.get_output_bias()
  244. residual += (torch.sigmoid(g + g_bias) * (x + bias))
  245. return residual
  246. else:
  247. # gated is not used here
  248. residual += outputs
  249. return residual
  250. class SimpleModuleList(nn.ModuleList):
  251. def __repr__(self):
  252. return str(len(self)) + ' X ...\n' + self[0].__repr__()
  253. def chunk_layer(
  254. layer: Callable,
  255. inputs: Dict[str, Any],
  256. chunk_size: int,
  257. num_batch_dims: int,
  258. ) -> Any:
  259. # TODO: support inplace add to output
  260. if not (len(inputs) > 0):
  261. raise ValueError('Must provide at least one input')
  262. def _dict_get_shapes(input):
  263. shapes = []
  264. if type(input) is torch.Tensor:
  265. shapes.append(input.shape)
  266. elif type(input) is dict:
  267. for v in input.values():
  268. shapes.extend(_dict_get_shapes(v))
  269. elif isinstance(input, Iterable):
  270. for v in input:
  271. shapes.extend(_dict_get_shapes(v))
  272. else:
  273. raise ValueError('Not supported')
  274. return shapes
  275. inputs = {k: v for k, v in inputs.items() if v is not None}
  276. initial_dims = [
  277. shape[:num_batch_dims] for shape in _dict_get_shapes(inputs)
  278. ]
  279. orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
  280. flat_batch_dim = 1
  281. for d in orig_batch_dims:
  282. flat_batch_dim *= d
  283. num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size
  284. def _flat_inputs(t):
  285. t = t.view(-1, *t.shape[num_batch_dims:])
  286. assert (
  287. t.shape[0] == flat_batch_dim or t.shape[0] == 1
  288. ), 'batch dimension must be 1 or equal to the flat batch dimension'
  289. return t
  290. flat_inputs = tensor_tree_map(_flat_inputs, inputs)
  291. out = None
  292. for i in range(num_chunks):
  293. chunk_start = i * chunk_size
  294. chunk_end = min((i + 1) * chunk_size, flat_batch_dim)
  295. def select_chunk(t):
  296. if t.shape[0] == 1:
  297. return t[0:1]
  298. else:
  299. return t[chunk_start:chunk_end]
  300. chunkes = tensor_tree_map(select_chunk, flat_inputs)
  301. output_chunk = layer(**chunkes)
  302. if out is None:
  303. out = tensor_tree_map(
  304. lambda t: t.new_zeros((flat_batch_dim, ) + t.shape[1:]),
  305. output_chunk)
  306. out_type = type(output_chunk)
  307. if out_type is tuple:
  308. for x, y in zip(out, output_chunk):
  309. x[chunk_start:chunk_end] = y
  310. elif out_type is torch.Tensor:
  311. out[chunk_start:chunk_end] = output_chunk
  312. else:
  313. raise ValueError('Not supported')
  314. # reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
  315. def reshape(t):
  316. return t.view(orig_batch_dims + t.shape[1:])
  317. out = tensor_tree_map(reshape, out)
  318. return out