masactrl.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # ------------------------------------------------------------------------
  2. # Modified from https://github.com/TencentARC/MasaCtrl/blob/main/masactrl/masactrl.py
  3. # Copyright (c) 2023 TencentARC. All Rights Reserved.
  4. # ------------------------------------------------------------------------
  5. import torch
  6. from einops import rearrange
  7. from .masactrl_utils import AttentionBase
  8. class MutualSelfAttentionControl(AttentionBase):
  9. def __init__(self,
  10. start_step=4,
  11. start_layer=10,
  12. layer_idx=None,
  13. step_idx=None,
  14. total_steps=50):
  15. """
  16. Mutual self-attention control for Stable-Diffusion model
  17. Args:
  18. start_step: the step to start mutual self-attention control
  19. start_layer: the layer to start mutual self-attention control
  20. layer_idx: list of the layers to apply mutual self-attention control
  21. step_idx: list the steps to apply mutual self-attention control
  22. total_steps: the total number of steps
  23. """
  24. super().__init__()
  25. self.total_steps = total_steps
  26. self.start_step = start_step
  27. self.start_layer = start_layer
  28. self.layer_idx = layer_idx if layer_idx is not None else list(
  29. range(start_layer, 16))
  30. self.step_idx = step_idx if step_idx is not None else list(
  31. range(start_step, total_steps)) # denoise index
  32. print('step_idx: ', self.step_idx)
  33. print('layer_idx: ', self.layer_idx)
  34. def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet,
  35. num_heads, **kwargs):
  36. b = q.shape[0] // num_heads
  37. q = rearrange(q, '(b h) n d -> h (b n) d', h=num_heads)
  38. k = rearrange(k, '(b h) n d -> h (b n) d', h=num_heads)
  39. v = rearrange(v, '(b h) n d -> h (b n) d', h=num_heads)
  40. sim = torch.einsum('h i d, h j d -> h i j', q, k) * kwargs.get('scale')
  41. attn = sim.softmax(-1)
  42. out = torch.einsum('h i j, h j d -> h i d', attn, v)
  43. out = rearrange(out, 'h (b n) d -> b n (h d)', b=b)
  44. return out
  45. def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,
  46. **kwargs):
  47. """
  48. Attention forward function
  49. """
  50. if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
  51. return super().forward(q, k, v, sim, attn, is_cross, place_in_unet,
  52. num_heads, **kwargs)
  53. qu, qc = q.chunk(2) # uncond, cond
  54. ku, kc = k.chunk(2)
  55. vu, vc = v.chunk(2)
  56. attnu, attnc = attn.chunk(2)
  57. # uncond
  58. # ku[:num_heads], vu[:num_heads] -> source
  59. # qu -> [source, target]
  60. out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads],
  61. sim[:num_heads], attnu, is_cross,
  62. place_in_unet, num_heads, **kwargs)
  63. out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads],
  64. sim[:num_heads], attnc, is_cross,
  65. place_in_unet, num_heads, **kwargs)
  66. out = torch.cat([out_u, out_c], dim=0)
  67. return out