attention_visualizer.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import requests
  15. from PIL import Image
  16. from ..masking_utils import create_causal_mask
  17. from ..models.auto.auto_factory import _get_model_class
  18. from ..models.auto.configuration_auto import AutoConfig
  19. from ..models.auto.modeling_auto import MODEL_FOR_PRETRAINING_MAPPING, MODEL_MAPPING
  20. from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES, AutoProcessor
  21. from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES, AutoTokenizer
  22. from .import_utils import is_torch_available
  23. if is_torch_available():
  24. import torch
  25. import torch.nn as nn
  26. # Print the matrix with words as row labels
  27. GREEN = "\033[92m"
  28. YELLOW = "\033[93m"
  29. RESET = "\033[0m"
  30. BLACK_SQUARE = "■"
  31. WHITE_SQUARE = "⬚"
  32. def generate_attention_matrix_from_mask(
  33. words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
  34. ):
  35. """
  36. Generates an attention matrix from a given attention mask.
  37. Optionally applies a sliding window mask (e.g., for Gemma2/3) and
  38. marks regions where image tokens occur based on the specified `img_token`.
  39. """
  40. mask = mask.int()
  41. if mask.ndim == 3:
  42. mask = mask[0, :, :]
  43. if mask.ndim == 4:
  44. mask = mask[0, 0, :, :]
  45. n = len(words)
  46. max_word_length = max(len(repr(word)) for word in words)
  47. first_img_idx = 0
  48. output = []
  49. for i, k in enumerate(words):
  50. if k == img_token and not first_img_idx:
  51. first_img_idx = i
  52. mask[i, i] = 2 # Mark yellow regions
  53. if first_img_idx > 0 and (k != img_token or i == n - 1):
  54. if i == n - 1:
  55. i += 1
  56. mask[first_img_idx:i, first_img_idx:i] = 2 # Mark yellow regions
  57. first_img_idx = 0
  58. # Generate sliding window mask (size = 4), excluding img_token
  59. sliding_window_mask = None
  60. if sliding_window is not None:
  61. sliding_window_mask = [[1 if (0 <= i - j < sliding_window) else 0 for j in range(n)] for i in range(n)]
  62. row_dummy = " ".join(
  63. f"{YELLOW}{BLACK_SQUARE}{RESET}"
  64. if mask[0, j]
  65. else f"{GREEN}{BLACK_SQUARE}{RESET}"
  66. if 0 == j
  67. else BLACK_SQUARE
  68. if mask[0, j]
  69. else WHITE_SQUARE
  70. for j in range(n)
  71. )
  72. if token_type_ids is not None:
  73. is_special = token_type_ids == 1
  74. token_type_buckets = torch.where(
  75. (token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
  76. )
  77. boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
  78. token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
  79. # Print headers
  80. legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
  81. output.append(" " + legend)
  82. f_string = " " * (max_word_length + 5) + "Attention Matrix".ljust(len(row_dummy) // 2)
  83. if sliding_window is not None:
  84. f_string += "Sliding Window Mask"
  85. output.append(f_string)
  86. vertical_header = []
  87. for idx, word in enumerate(words):
  88. if mask[idx, idx] == 2:
  89. vertical_header.append([f"{YELLOW}{k}{RESET}" for k in list(str(idx).rjust(len(str(n))))])
  90. else:
  91. vertical_header.append(list(str(idx).rjust(len(str(n)))))
  92. vertical_header = list(map(list, zip(*vertical_header))) # Transpose
  93. for row in vertical_header:
  94. output.append(
  95. (max_word_length + 5) * " " + " ".join(row) + " | " + " ".join(row)
  96. if sliding_window is not None
  97. else ""
  98. )
  99. for i, word in enumerate(words):
  100. word_repr = repr(word).ljust(max_word_length)
  101. colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
  102. row_display = " ".join(
  103. f"{YELLOW}{BLACK_SQUARE}{RESET}"
  104. if img_token in words[j] and mask[i, j] and img_token in word
  105. else f"{GREEN}{BLACK_SQUARE}{RESET}"
  106. if i == j
  107. else BLACK_SQUARE
  108. if mask[i, j]
  109. else WHITE_SQUARE
  110. for j in range(n)
  111. )
  112. sliding_window_row = ""
  113. if sliding_window is not None:
  114. sliding_window_row = " ".join(
  115. f"{YELLOW}{BLACK_SQUARE}{RESET}"
  116. if img_token in words[j] and img_token in word and token_type_buckets[0, i] == token_type_buckets[0, j]
  117. else f"{GREEN}{BLACK_SQUARE}{RESET}"
  118. if i == j
  119. else BLACK_SQUARE
  120. if sliding_window_mask[i][j]
  121. else WHITE_SQUARE
  122. for j in range(n)
  123. )
  124. output.append(f"{colored_word}: {str(i).rjust(2)} {row_display} | {sliding_window_row}")
  125. return "\n".join(output)
  126. class AttentionMaskVisualizer:
  127. def __init__(self, model_name: str):
  128. config = AutoConfig.from_pretrained(model_name)
  129. self.image_token = "<img>"
  130. if hasattr(config.get_text_config(), "sliding_window"):
  131. self.sliding_window = getattr(config.get_text_config(), "sliding_window", None)
  132. try:
  133. mapped_cls = _get_model_class(config, MODEL_MAPPING)
  134. except Exception:
  135. mapped_cls = _get_model_class(config, MODEL_FOR_PRETRAINING_MAPPING)
  136. if mapped_cls is None:
  137. raise ValueError(f"Model name {model_name} is not supported for attention visualization")
  138. self.mapped_cls = mapped_cls
  139. class _ModelWrapper(mapped_cls, nn.Module):
  140. def __init__(self, config, model_name):
  141. nn.Module.__init__(self)
  142. self.dummy_module = nn.Linear(1, 1)
  143. self.config = config
  144. self.model = _ModelWrapper(config, model_name)
  145. self.model.to(config.dtype)
  146. self.repo_id = model_name
  147. self.config = config
  148. def __call__(self, input_sentence: str, suffix=""):
  149. self.visualize_attention_mask(input_sentence, suffix=suffix)
  150. def visualize_attention_mask(self, input_sentence: str, suffix=""):
  151. model = self.model
  152. kwargs = {}
  153. image_seq_length = None
  154. if self.config.model_type in PROCESSOR_MAPPING_NAMES:
  155. img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
  156. img = Image.open(requests.get(img, stream=True).raw)
  157. image_seq_length = 5
  158. processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
  159. if hasattr(processor, "image_token"):
  160. image_token = processor.image_token
  161. else:
  162. image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
  163. if image_token:
  164. input_sentence = input_sentence.replace("<img>", image_token)
  165. inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
  166. self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
  167. attention_mask = inputs["attention_mask"]
  168. if "token_type_ids" in inputs: # TODO inspect signature of update causal mask
  169. kwargs["token_type_ids"] = inputs["token_type_ids"]
  170. tokens = processor.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
  171. elif self.config.model_type in TOKENIZER_MAPPING_NAMES:
  172. tokenizer = AutoTokenizer.from_pretrained(self.repo_id)
  173. tokens = tokenizer.tokenize(input_sentence)
  174. attention_mask = tokenizer(input_sentence, return_tensors="pt")["attention_mask"]
  175. else:
  176. raise ValueError(f"Model type {model.config.model_type} does not support attention visualization")
  177. model.config._attn_implementation = "eager"
  178. model.train()
  179. batch_size, seq_length = attention_mask.shape
  180. input_embeds = torch.zeros((batch_size, seq_length, model.config.hidden_size), dtype=self.model.dtype)
  181. cache_position = torch.arange(seq_length)
  182. causal_mask = create_causal_mask(
  183. config=model.config,
  184. input_embeds=input_embeds,
  185. attention_mask=attention_mask,
  186. cache_position=cache_position,
  187. past_key_values=None,
  188. )
  189. if causal_mask is not None:
  190. attention_mask = ~causal_mask.bool()
  191. else:
  192. attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, seq_length, seq_length)
  193. top_bottom_border = "##" * (
  194. len(f"Attention visualization for {self.config.model_type} | {self.mapped_cls}") + 4
  195. ) # Box width adjusted to text length
  196. side_border = "##"
  197. print(f"\n{top_bottom_border}")
  198. print(
  199. "##"
  200. + f" Attention visualization for \033[1m{self.config.model_type}:{self.repo_id}\033[0m {self.mapped_cls.__name__}".center(
  201. len(top_bottom_border)
  202. )
  203. + " "
  204. + side_border,
  205. )
  206. print(f"{top_bottom_border}")
  207. f_string = generate_attention_matrix_from_mask(
  208. tokens,
  209. attention_mask,
  210. img_token=self.image_token,
  211. sliding_window=getattr(self.config, "sliding_window", None),
  212. token_type_ids=kwargs.get("token_type_ids"),
  213. image_seq_length=image_seq_length,
  214. )
  215. print(f_string)
  216. print(f"{top_bottom_border}")