vim.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Part of this code is adopted from PETL-ViT,
  2. # made publicly available under the MIT License at https://github.com/JieShibo/PETL-ViT
  3. import math
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from einops import rearrange
  8. def _agg_conv1d(weight_list, bias_list, agg, x):
  9. """
  10. weight list: list of conv1d weight ([out, in] * a)
  11. bias list: list of conv1d bias ([out] * a)
  12. agg: aggreagtion weights (a)
  13. x: input tensor (b, in, n)
  14. return output in (b, n, out)
  15. """
  16. weight_list = torch.cat([w.unsqueeze(0) for w in weight_list],
  17. dim=0) # n_ada, out, in
  18. weight = torch.sum(
  19. weight_list * rearrange(agg, 'a -> a 1 1'),
  20. dim=0).unsqueeze(2) # out, in, 1
  21. bias_list = torch.cat([w.unsqueeze(0) for w in bias_list],
  22. dim=0) # n_ada, out
  23. bias = torch.sum(bias_list * rearrange(agg, 'a -> a 1'), dim=0) # out
  24. x = F.conv1d(x, weight=weight, bias=bias)
  25. return x
  26. def _agg_conv2d(weight_list, bias_list, agg, x):
  27. """
  28. weight list: list of conv2d weight ([out, in, m, n] * a)
  29. bias list: list of conv2d bias ([out] * a)
  30. agg: aggregation weights (a)
  31. x: input tensor (b, in, p, q)
  32. return output in (b, out, p, q)
  33. """
  34. weight_list = torch.cat([w.unsqueeze(0) for w in weight_list],
  35. dim=0) # n_ada, out, in, m, n
  36. weight = torch.sum(
  37. weight_list * rearrange(agg, 'a -> a 1 1 1 1'), dim=0) # out, in, m, n
  38. bias_list = torch.cat([w.unsqueeze(0) for w in bias_list],
  39. dim=0) # n_ada, out
  40. bias = torch.sum(bias_list * rearrange(agg, 'a -> a 1'), dim=0) # out
  41. x = F.conv2d(
  42. x, weight=weight, bias=bias, stride=1, padding=1) # 1 (b out) p q
  43. return x
  44. class QuickGELU(nn.Module):
  45. def forward(self, x: torch.Tensor):
  46. return x * torch.sigmoid(1.702 * x)
  47. class ViM(nn.Module):
  48. def __init__(self):
  49. super().__init__()
  50. self.act = QuickGELU()
  51. self.adapter_conv_weight = nn.ParameterList()
  52. self.adapter_conv_bias = nn.ParameterList()
  53. self.adapter_up_weight = nn.ParameterList()
  54. self.adapter_up_bias = nn.ParameterList()
  55. self.adapter_down_weight = nn.ParameterList()
  56. self.adapter_down_bias = nn.ParameterList()
  57. # agg related
  58. self.num_modules = 0
  59. self.task_list = []
  60. self.agg_weights = {}
  61. self.agg_algos = {}
  62. def register_ViM(self, vim_list):
  63. self.num_modules = len(vim_list)
  64. for state_dict in vim_list:
  65. self.adapter_conv_weight.append(
  66. nn.Parameter(state_dict['adapter_conv.weight']))
  67. self.adapter_conv_bias.append(
  68. nn.Parameter(state_dict['adapter_conv.bias']))
  69. self.adapter_up_weight.append(
  70. nn.Parameter(state_dict['adapter_up.weight']))
  71. self.adapter_up_bias.append(
  72. nn.Parameter(state_dict['adapter_up.bias']))
  73. self.adapter_down_weight.append(
  74. nn.Parameter(state_dict['adapter_down.weight']))
  75. self.adapter_down_bias.append(
  76. nn.Parameter(state_dict['adapter_down.bias']))
  77. def register_task(self, task_name, agg_weights, agg_algo):
  78. assert agg_weights.shape[0] == self.num_modules
  79. self.task_list.append(task_name)
  80. self.agg_weights[task_name] = agg_weights
  81. self.agg_algos[task_name] = agg_algo
  82. def forward(self, x, task_name):
  83. assert task_name in self.task_list
  84. agg_algo = self.agg_algos[task_name]
  85. if agg_algo == 'Ens-MoE':
  86. return self.forward_ens_moe(x, self.agg_weights[task_name])
  87. else:
  88. raise NotImplementedError(
  89. 'Aggregation algorithm [{}] is currently not supported!'.
  90. format(agg_algo))
  91. def forward_ens_moe(self, x, agg):
  92. logits = agg
  93. k = agg.shape[0] # MoE-full (k=N)
  94. top_logits, top_indices = logits.topk(
  95. min(k + 1, logits.size(0)), dim=0)
  96. top_k_logits = top_logits[:k]
  97. top_k_indices = top_indices[:k]
  98. top_k_gates = F.softmax(top_k_logits, dim=0)
  99. zeros = torch.zeros_like(logits, requires_grad=True)
  100. gates = zeros.scatter(0, top_k_indices, top_k_gates)
  101. N, B, C = x.shape
  102. x = x.permute(1, 2, 0)
  103. output = None
  104. for i in range(self.num_modules):
  105. if gates[i] > 0:
  106. x_down = F.conv1d(
  107. x,
  108. weight=self.adapter_down_weight[i].unsqueeze(2),
  109. bias=self.adapter_down_bias[i]) # equivalent to 1 * 1 Conv
  110. x_down = self.act(x_down)
  111. num_patch_side = int(math.sqrt(x_down.size(2) - 1))
  112. x_patch = x_down[:, :,
  113. 1:].reshape(B, -1, num_patch_side,
  114. num_patch_side) # b, in, p, p
  115. x_patch = F.conv2d(
  116. x_patch,
  117. weight=self.adapter_conv_weight[i],
  118. bias=self.adapter_conv_bias[i],
  119. stride=1,
  120. padding=1)
  121. x_patch = rearrange(x_patch, 'b o p q -> b o (p q)')
  122. x_cls = x_down[:, :, :1].reshape(B, -1, 1, 1)
  123. x_cls = F.conv2d(
  124. x_cls,
  125. weight=self.adapter_conv_weight[i],
  126. bias=self.adapter_conv_bias[i],
  127. stride=1,
  128. padding=1)
  129. x_cls = rearrange(x_cls, 'b o 1 1 -> b o 1')
  130. x_down = torch.cat([x_cls, x_patch], dim=2)
  131. x_down = self.act(x_down)
  132. x_up = F.conv1d(
  133. x_down,
  134. weight=self.adapter_up_weight[i].unsqueeze(2),
  135. bias=self.adapter_up_bias[i]) # equivalent to 1 * 1 Conv
  136. if output is None:
  137. output = x_up * gates[i]
  138. else:
  139. output += x_up * gates[i]
  140. return output.permute(2, 0, 1)