modeling_rope_utils.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from functools import wraps
  16. from typing import Optional
  17. from .configuration_utils import PretrainedConfig
  18. from .utils import is_torch_available, logging
  19. logger = logging.get_logger(__name__)
  20. if is_torch_available():
  21. import torch
  22. def dynamic_rope_update(rope_forward):
  23. """
  24. Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
  25. (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
  26. Args:
  27. rope_forward (Callable):
  28. The forward pass of the RoPE implementation.
  29. Returns:
  30. The decorated forward pass.
  31. """
  32. def longrope_frequency_update(self, position_ids, device):
  33. """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
  34. seq_len = torch.max(position_ids) + 1
  35. if hasattr(self.config, "original_max_position_embeddings"):
  36. original_max_position_embeddings = self.config.original_max_position_embeddings
  37. else:
  38. original_max_position_embeddings = self.config.max_position_embeddings
  39. if seq_len > original_max_position_embeddings:
  40. if not hasattr(self, "long_inv_freq"):
  41. self.long_inv_freq, _ = self.rope_init_fn(
  42. self.config, device, seq_len=original_max_position_embeddings + 1
  43. )
  44. self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
  45. else:
  46. # This .to() is needed if the model has been moved to a device after being initialized (because
  47. # the buffer is automatically moved, but not the original copy)
  48. self.original_inv_freq = self.original_inv_freq.to(device)
  49. self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
  50. def dynamic_frequency_update(self, position_ids, device):
  51. """
  52. dynamic RoPE layers should recompute `inv_freq` in the following situations:
  53. 1 - growing beyond the cached sequence length (allow scaling)
  54. 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
  55. """
  56. seq_len = torch.max(position_ids) + 1
  57. if seq_len > self.max_seq_len_cached: # growth
  58. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
  59. self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
  60. self.max_seq_len_cached = seq_len
  61. if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
  62. # This .to() is needed if the model has been moved to a device after being initialized (because
  63. # the buffer is automatically moved, but not the original copy)
  64. self.original_inv_freq = self.original_inv_freq.to(device)
  65. self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
  66. self.max_seq_len_cached = self.original_max_seq_len
  67. @wraps(rope_forward)
  68. def wrapper(self, x, position_ids):
  69. if "dynamic" in self.rope_type:
  70. dynamic_frequency_update(self, position_ids, device=x.device)
  71. elif self.rope_type == "longrope":
  72. longrope_frequency_update(self, position_ids, device=x.device)
  73. return rope_forward(self, x, position_ids)
  74. return wrapper
  75. def _compute_default_rope_parameters(
  76. config: Optional[PretrainedConfig] = None,
  77. device: Optional["torch.device"] = None,
  78. seq_len: Optional[int] = None,
  79. ) -> tuple["torch.Tensor", float]:
  80. """
  81. Computes the inverse frequencies according to the original RoPE implementation
  82. Args:
  83. config ([`~transformers.PretrainedConfig`]):
  84. The model configuration. This function assumes that the config will provide at least the following
  85. properties:
  86. * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
  87. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  88. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  89. Additionally, this function will make use of the following properties if they are found in the config:
  90. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  91. derived as hidden_size // num_attention_heads.
  92. * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
  93. the first fraction of the head_dim. Defaults to 1.0.
  94. device (`torch.device`):
  95. The device to use for initialization of the inverse frequencies.
  96. seq_len (`int`, *optional*):
  97. The current sequence length. Unused for this type of RoPE.
  98. Returns:
  99. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  100. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  101. """
  102. base = config.rope_theta
  103. partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
  104. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  105. dim = int(head_dim * partial_rotary_factor)
  106. attention_factor = 1.0 # Unused in this type of RoPE
  107. # Compute the inverse frequencies
  108. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
  109. return inv_freq, attention_factor
  110. def _compute_linear_scaling_rope_parameters(
  111. config: Optional[PretrainedConfig] = None,
  112. device: Optional["torch.device"] = None,
  113. seq_len: Optional[int] = None,
  114. ) -> tuple["torch.Tensor", float]:
  115. """
  116. Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
  117. Args:
  118. config ([`~transformers.PretrainedConfig`]):
  119. The model configuration. This function assumes that the config will provide at least the following
  120. properties:
  121. * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
  122. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  123. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  124. Additionally, this function will make use of the following properties if they are found in the config:
  125. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  126. derived as hidden_size // num_attention_heads.
  127. * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
  128. the first fraction of the head_dim. Defaults to 1.0.
  129. device (`torch.device`):
  130. The device to use for initialization of the inverse frequencies.
  131. seq_len (`int`, *optional*):
  132. The current sequence length. Unused for this type of RoPE.
  133. Returns:
  134. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  135. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  136. """
  137. factor = config.rope_scaling["factor"]
  138. # Gets the default RoPE parameters
  139. inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
  140. # Then applies linear scaling to the frequencies.
  141. # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
  142. # applying scaling to the inverse frequencies is equivalent.
  143. inv_freq /= factor
  144. return inv_freq, attention_factor
  145. def _compute_dynamic_ntk_parameters(
  146. config: Optional[PretrainedConfig] = None,
  147. device: Optional["torch.device"] = None,
  148. seq_len: Optional[int] = None,
  149. ) -> tuple["torch.Tensor", float]:
  150. """
  151. Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
  152. Args:
  153. config ([`~transformers.PretrainedConfig`]):
  154. The model configuration. This function assumes that the config will provide at least the following
  155. properties:
  156. * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
  157. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  158. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  159. * max_position_embeddings (`int`): The default sequence length used to update the dynamic RoPE at
  160. inference time
  161. * rope_scaling (`dict[str, float]`): The standard RoPE scaling parameters, from which `factor`
  162. will be accessed. The value of `factor` is used to determine the new base frequency, along with the
  163. current sequence length (seq_len), the maximum positional embeddings (max_position_embeddings), and the
  164. computed dimensionality (dim) of the rotary embeddings. If seq_len <= max_position_embeddings, this
  165. factor has no effect. If seq_len <= max_position_embeddings, this factor effectively stretches the
  166. context window using an exponent derived from `dim`.
  167. Additionally, this function will make use of the following properties if they are found in the config:
  168. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  169. derived as hidden_size // num_attention_heads.
  170. * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
  171. the first fraction of the head_dim. Defaults to 1.0.
  172. device (`torch.device`):
  173. The device to use for initialization of the inverse frequencies.
  174. seq_len (`int`, *optional*):
  175. The current sequence length, used to update the dynamic RoPE at inference time. If `None` or shorter than
  176. max_position_embeddings, this value will be overridden by max_position_embeddings.
  177. Returns:
  178. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  179. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  180. """
  181. # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
  182. base = config.rope_theta
  183. partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
  184. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  185. dim = int(head_dim * partial_rotary_factor)
  186. max_position_embeddings = config.max_position_embeddings
  187. factor = config.rope_scaling["factor"]
  188. attention_factor = 1.0 # Unused in this type of RoPE
  189. # seq_len: default to max_position_embeddings, e.g. at init time
  190. if seq_len is None:
  191. seq_len = max_position_embeddings
  192. elif isinstance(seq_len, torch.Tensor):
  193. seq_len = torch.maximum(
  194. seq_len,
  195. torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
  196. )
  197. else:
  198. seq_len = max(seq_len, max_position_embeddings)
  199. # Compute the inverse frequencies
  200. base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
  201. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
  202. return inv_freq, attention_factor
  203. def _compute_yarn_parameters(
  204. config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
  205. ) -> tuple["torch.Tensor", float]:
  206. """
  207. Computes the inverse frequencies with NTK scaling. Please refer to the
  208. [original paper](https://huggingface.co/papers/2309.00071)
  209. Args:
  210. config ([`~transformers.PretrainedConfig`]):
  211. The model configuration. This function assumes that the config will provide at least the following
  212. properties:
  213. * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
  214. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  215. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  216. * max_position_embeddings (`int`): The maximum length of the positional embeddings.
  217. * rope_scaling (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following
  218. keys will be accessed:
  219. * `attention_factor` (`float`, *optional*): The scaling factor to be applied to the computed cos/sin.
  220. If None, the value is inferred from `factor`, `mscale`, and `mscale_all_dim` as avaialble.
  221. * `beta_fast` (`float`, *optional*, defaults to 32): Parameter to set the boundary for extrapolation
  222. (only) in the linear ramp function.
  223. * `beta_slow` (`float`, *optional*, defaults to 1): Parameter to set the boundary for interpolation
  224. (only) in the linear ramp function.
  225. * `factor` (`float`, *optional*): The scaling factor applied when interpolating the position IDs to
  226. extend the possible context length. Additionally, if `attention_factor` is None, the log of this
  227. value is used to compute a value for `attention_factor`, possibly in conjunciton with `mscale` and
  228. `mscale_all_dim`, if provided.
  229. * `mscale` (`float`, *optional*): If `attention_factor` is None and both `mscale` and
  230. `mscale_all_dim` are provided, `mscale` acts scalar augmenting `log(factor)` when computing the
  231. numerator for the inferred value of `attention_factor`. If not provided, `attention_factor` will be
  232. calculated based on `factor` only.
  233. * `mscale_all_dim` (`float`, *optional*): If `attention_factor` is None and both `mscale` and
  234. `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing
  235. the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor`
  236. will be calculated based on `factor` only.
  237. * `original_max_position_embeddings` (`int`, *optional*): The original max position embeddings used
  238. during pretraining. If not provided, the function falls back to `max_position_embeddings`.
  239. * `truncate` (`bool`, *optional*): Whether to truncate the correction range.
  240. Additionally, this function will make use of the following properties if they are found in the config:
  241. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  242. derived as hidden_size // num_attention_heads.
  243. * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies
  244. will be returned for the first fraction of the head_dim.
  245. device (`torch.device`):
  246. The device to use for initialization of the inverse frequencies.
  247. seq_len (`int`, *optional*):
  248. The current sequence length. Unused for this type of RoPE.
  249. Returns:
  250. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  251. post-processing scaling factor applied to the computed cos/sin.
  252. """
  253. base = config.rope_theta
  254. partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
  255. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  256. dim = int(head_dim * partial_rotary_factor)
  257. factor = config.rope_scaling["factor"]
  258. attention_factor = config.rope_scaling.get("attention_factor")
  259. mscale = config.rope_scaling.get("mscale")
  260. mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
  261. original_max_position_embeddings = (
  262. config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
  263. )
  264. def get_mscale(scale, mscale=1):
  265. if scale <= 1:
  266. return 1.0
  267. return 0.1 * mscale * math.log(scale) + 1.0
  268. # Sets the attention factor as suggested in the paper
  269. if attention_factor is None:
  270. if mscale and mscale_all_dim:
  271. attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
  272. else:
  273. attention_factor = get_mscale(factor)
  274. # Optional config options
  275. # beta_fast/beta_slow: as suggested in the paper, default to 32 and 1 respectively
  276. beta_fast = config.rope_scaling.get("beta_fast") or 32
  277. beta_slow = config.rope_scaling.get("beta_slow") or 1
  278. # Compute the inverse frequencies
  279. def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
  280. """Inverse dimension formula to find the dimension based on the number of rotations"""
  281. return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
  282. def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
  283. """Find dimension range bounds based on rotations"""
  284. low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
  285. high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
  286. if truncate:
  287. low = math.floor(low)
  288. high = math.ceil(high)
  289. return max(low, 0), min(high, dim - 1)
  290. def linear_ramp_factor(min, max, dim):
  291. if min == max:
  292. max += 0.001 # Prevent singularity
  293. linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
  294. ramp_func = torch.clamp(linear_func, 0, 1)
  295. return ramp_func
  296. # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
  297. # to expand the possible context length. In other words, interpolation = apply scaling factor.
  298. pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
  299. inv_freq_extrapolation = 1.0 / pos_freqs
  300. inv_freq_interpolation = 1.0 / (factor * pos_freqs)
  301. truncate = config.rope_scaling.get("truncate", True)
  302. low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
  303. # Get n-dimensional rotational scaling corrected for extrapolation
  304. inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
  305. inv_freq = (
  306. inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
  307. + inv_freq_extrapolation * inv_freq_extrapolation_factor
  308. )
  309. return inv_freq, attention_factor
  310. def _compute_longrope_parameters(
  311. config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
  312. ) -> tuple["torch.Tensor", float]:
  313. """
  314. Computes the inverse frequencies with LongRoPE scaling. Please refer to the
  315. [original implementation](https://github.com/microsoft/LongRoPE)
  316. Args:
  317. config ([`~transformers.PretrainedConfig`]):
  318. The model configuration. This function assumes that the config will provide at least the following
  319. properties:
  320. * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
  321. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  322. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  323. * max_position_embeddings (`int`): The maximum length of the positional embeddings.
  324. * original_max_position_embeddings (`int`, *optional*): The original max position embeddings used during
  325. pretraining. If not provided, defaults to `max_position_embeddings`.
  326. * rope_scaling (`dict[str, float]`): The standard RoPE scaling parameters, from which the following keys
  327. will be accessed:
  328. * `attention_factor` (`float`, *optional*): The scaling factor to be applied on the attention
  329. computation. If unspecified, it defaults to value recommended by the implementation, inferred from
  330. the value of `factor`.
  331. * `factor` (`float`, *optional*): The scaling factor to apply to the RoPE embeddings. If both
  332. `max_position_embeddings` and `original_max_position_embeddings` are provided, this value will be
  333. overridden s the ratio between those values.
  334. * `long_factor` (`float`, *optional*): The scale factor applied when computing the inverse
  335. frequencies if `seq_len` is provided and greater than `original_max_position_embeddings`.
  336. * `short_factor` (`float`, *optional*): The scale factor applied when computing the inverse
  337. frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`.
  338. Additionally, this function will make use of the following properties if they are found in the config:
  339. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  340. derived as hidden_size // num_attention_heads.
  341. * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies
  342. will be returned for the first fraction of the head_dim.
  343. device (`torch.device`):
  344. The device to use for initialization of the inverse frequencies.
  345. seq_len (`int`, *optional*):
  346. The current sequence length.
  347. Returns:
  348. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  349. post-processing scaling factor applied to the computed cos/sin.
  350. """
  351. # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
  352. base = config.rope_theta
  353. partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
  354. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  355. dim = int(head_dim * partial_rotary_factor)
  356. long_factor = config.rope_scaling["long_factor"]
  357. short_factor = config.rope_scaling["short_factor"]
  358. factor = config.rope_scaling.get("factor")
  359. attention_factor = config.rope_scaling.get("attention_factor")
  360. # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
  361. # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
  362. # values to compute the default attention scaling factor, instead of using `factor`.
  363. if original_max_position_embeddings := getattr(config, "original_max_position_embeddings", None):
  364. factor = config.max_position_embeddings / original_max_position_embeddings
  365. else:
  366. original_max_position_embeddings = config.max_position_embeddings
  367. # Sets the attention factor as suggested in the paper
  368. if attention_factor is None:
  369. if factor <= 1.0:
  370. attention_factor = 1.0
  371. else:
  372. attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
  373. # Compute the inverse frequencies -- scaled based on the target sequence length
  374. if seq_len and seq_len > original_max_position_embeddings:
  375. ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
  376. else:
  377. ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
  378. inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
  379. inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
  380. return inv_freq, attention_factor
  381. def _compute_llama3_parameters(
  382. config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
  383. ) -> tuple["torch.Tensor", float]:
  384. """
  385. Computes the inverse frequencies for llama 3.1.
  386. Args:
  387. config ([`~transformers.PretrainedConfig`]):
  388. The model configuration. This function assumes that the config will provide at least the following
  389. properties:
  390. * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
  391. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  392. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  393. * rope_scaling (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following
  394. keys will be accessed:
  395. * `factor` (`float`, *optional*): The scaling factor applied to the inverse frequencies when 1) the
  396. wavelength is greater than `low_freq_wavelen` prior to smoothing, and 2) to all inverse frequencies
  397. during smoothing.
  398. * `high_freq_factor` (`float`): The scale factor used to compute `high_freq_wavelen` and
  399. the value for the denominator of the smoothing factor prior to the `low_freq_factor` shift.
  400. * `low_freq_factor` (`float`): The scale factor used to compute `low_freq_wavelen` and
  401. the shift applied to the numerator and denominator of the smoothing factor.
  402. frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`.
  403. * `original_max_position_embeddings` (`int`): The original max position embeddings used
  404. during pretraining. If not provided, the function falls back to `max_position_embeddings`.
  405. Additionally, this function will make use of the following properties if they are found in the config:
  406. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  407. derived as hidden_size // num_attention_heads.
  408. * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
  409. the first fraction of the head_dim. Defaults to 1.0.
  410. device (`torch.device`):
  411. The device to use for initialization of the inverse frequencies.
  412. seq_len (`int`, *optional*):
  413. The current sequence length. Unused for this type of RoPE.
  414. Returns:
  415. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  416. post-processing scaling factor applied to the computed cos/sin.
  417. """
  418. # Gets the default RoPE parameters
  419. inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
  420. factor = config.rope_scaling["factor"] # `8` in the original implementation
  421. low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
  422. high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
  423. old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
  424. low_freq_wavelen = old_context_len / low_freq_factor
  425. high_freq_wavelen = old_context_len / high_freq_factor
  426. wavelen = 2 * math.pi / inv_freq
  427. # wavelen < high_freq_wavelen: do nothing
  428. # wavelen > low_freq_wavelen: divide by factor
  429. inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
  430. # otherwise: interpolate between the two, using a smooth factor
  431. smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
  432. smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
  433. is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
  434. inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
  435. return inv_freq_llama, attention_factor
  436. # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
  437. # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
  438. # parameterizations, as long as the callable has the same signature.
  439. ROPE_INIT_FUNCTIONS = {
  440. "default": _compute_default_rope_parameters,
  441. "linear": _compute_linear_scaling_rope_parameters,
  442. "dynamic": _compute_dynamic_ntk_parameters,
  443. "yarn": _compute_yarn_parameters,
  444. "longrope": _compute_longrope_parameters,
  445. "llama3": _compute_llama3_parameters,
  446. }
  447. def _check_received_keys(
  448. rope_type: str,
  449. received_keys: set,
  450. required_keys: set,
  451. optional_keys: Optional[set] = None,
  452. ignore_keys: Optional[set] = None,
  453. ):
  454. """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
  455. # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
  456. if "type" in received_keys:
  457. received_keys -= {"type"}
  458. required_keys.add("rope_type")
  459. # Some models need to store model-specific keys, and we don't want to throw warning at them
  460. if ignore_keys is not None:
  461. received_keys -= ignore_keys
  462. missing_keys = required_keys - received_keys
  463. if missing_keys:
  464. raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
  465. if optional_keys is not None:
  466. unused_keys = received_keys - required_keys - optional_keys
  467. else:
  468. unused_keys = received_keys - required_keys
  469. if unused_keys:
  470. logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
  471. def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  472. rope_scaling = config.rope_scaling
  473. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  474. required_keys = {"rope_type"}
  475. received_keys = set(rope_scaling.keys())
  476. _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
  477. def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  478. rope_scaling = config.rope_scaling
  479. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  480. required_keys = {"rope_type", "factor"}
  481. received_keys = set(rope_scaling.keys())
  482. _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
  483. factor = rope_scaling["factor"]
  484. if factor is None or not isinstance(factor, float) or factor < 1.0:
  485. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  486. def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  487. rope_scaling = config.rope_scaling
  488. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  489. required_keys = {"rope_type", "factor"}
  490. # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
  491. optional_keys = {"original_max_position_embeddings"}
  492. received_keys = set(rope_scaling.keys())
  493. _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
  494. factor = rope_scaling["factor"]
  495. if factor is None or not isinstance(factor, float) or factor < 1.0:
  496. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  497. def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  498. rope_scaling = config.rope_scaling
  499. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  500. required_keys = {"rope_type", "factor"}
  501. optional_keys = {
  502. "attention_factor",
  503. "beta_fast",
  504. "beta_slow",
  505. "original_max_position_embeddings",
  506. "mscale",
  507. "mscale_all_dim",
  508. "truncate",
  509. }
  510. received_keys = set(rope_scaling.keys())
  511. _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
  512. factor = rope_scaling["factor"]
  513. if factor is None or not isinstance(factor, float) or factor < 1.0:
  514. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  515. attention_factor = rope_scaling.get("attention_factor")
  516. if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
  517. logger.warning(
  518. f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
  519. )
  520. beta_fast = rope_scaling.get("beta_fast")
  521. if beta_fast is not None and not isinstance(beta_fast, float):
  522. logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
  523. beta_slow = rope_scaling.get("beta_slow")
  524. if beta_slow is not None and not isinstance(beta_slow, float):
  525. logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
  526. if (beta_fast or 32) < (beta_slow or 1):
  527. logger.warning(
  528. f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
  529. f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
  530. )
  531. # Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context
  532. # length, with `config.max_position_embeddings` corresponding to their post-yarn context length.
  533. # However, for BC purposes, we allow the former to be unset.
  534. original_max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings")
  535. if original_max_position_embeddings is not None:
  536. # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
  537. implicit_factor = config.max_position_embeddings / original_max_position_embeddings
  538. if implicit_factor != factor:
  539. logger.warning_once(
  540. f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = {factor}) does not match "
  541. "the ratio implicitly set by other parameters (implicit factor = "
  542. "post-yarn context length / pre-yarn context length = "
  543. "config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = "
  544. f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
  545. "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config."
  546. )
  547. # No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the
  548. # pre-yarn or the post-yarn context length?
  549. # BC: we assume it is the pre-yarn context length.
  550. else:
  551. logger.warning_once(
  552. "config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will "
  553. "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect "
  554. "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * "
  555. "factor) -- we recommend updating both fields for optimal downstream model usage."
  556. )
  557. def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  558. rope_scaling = config.rope_scaling
  559. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  560. required_keys = {"rope_type", "short_factor", "long_factor"}
  561. # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
  562. optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
  563. received_keys = set(rope_scaling.keys())
  564. _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
  565. partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
  566. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  567. dim = int(head_dim * partial_rotary_factor)
  568. short_factor = rope_scaling.get("short_factor")
  569. if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
  570. logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
  571. if len(short_factor) != dim // 2:
  572. logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
  573. long_factor = rope_scaling.get("long_factor")
  574. if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
  575. logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
  576. if len(long_factor) != dim // 2:
  577. logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
  578. # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
  579. # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
  580. # unique to longrope (= undesirable)
  581. if hasattr(config, "original_max_position_embeddings"):
  582. logger.warning_once(
  583. "This model has set a `original_max_position_embeddings` field, to be used together with "
  584. "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
  585. "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
  586. "as it is compatible with most model architectures."
  587. )
  588. else:
  589. factor = rope_scaling.get("factor")
  590. if factor is None:
  591. logger.warning("Missing required keys in `rope_scaling`: 'factor'")
  592. elif not isinstance(factor, float) or factor < 1.0:
  593. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  594. attention_factor = rope_scaling.get("attention_factor")
  595. if attention_factor is not None:
  596. if not isinstance(attention_factor, float) or attention_factor < 0.0:
  597. logger.warning(
  598. f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
  599. )
  600. def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  601. rope_scaling = config.rope_scaling
  602. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  603. required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
  604. received_keys = set(rope_scaling.keys())
  605. _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
  606. factor = rope_scaling["factor"]
  607. if factor is None or not isinstance(factor, float) or factor < 1.0:
  608. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  609. low_freq_factor = rope_scaling["low_freq_factor"]
  610. high_freq_factor = rope_scaling["high_freq_factor"]
  611. if low_freq_factor is None or not isinstance(low_freq_factor, float):
  612. logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
  613. if high_freq_factor is None or not isinstance(high_freq_factor, float):
  614. logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
  615. if high_freq_factor <= low_freq_factor:
  616. logger.warning(
  617. "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
  618. f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
  619. )
  620. original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
  621. if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
  622. logger.warning(
  623. "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
  624. f"{original_max_position_embeddings}"
  625. )
  626. if original_max_position_embeddings >= config.max_position_embeddings:
  627. logger.warning(
  628. "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
  629. f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
  630. )
  631. # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
  632. ROPE_VALIDATION_FUNCTIONS = {
  633. "default": _validate_default_rope_parameters,
  634. "linear": _validate_linear_scaling_rope_parameters,
  635. "dynamic": _validate_dynamic_scaling_rope_parameters,
  636. "yarn": _validate_yarn_parameters,
  637. "longrope": _validate_longrope_parameters,
  638. "llama3": _validate_llama3_parameters,
  639. }
  640. def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  641. """
  642. Validate the RoPE config arguments, given a `PretrainedConfig` object
  643. """
  644. rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
  645. if rope_scaling is None:
  646. return
  647. # BC: "rope_type" was originally "type"
  648. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
  649. validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
  650. if validation_fn is not None:
  651. validation_fn(config, ignore_keys=ignore_keys)
  652. else:
  653. logger.warning(
  654. f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
  655. )