modeling_timesfm.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/timesfm/modular_timesfm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_timesfm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Google LLC and HuggingFace Inc. team.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from collections.abc import Sequence
  23. from dataclasses import dataclass
  24. from typing import Callable, Optional, Union
  25. import torch
  26. import torch.nn as nn
  27. import torch.nn.functional as F
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_outputs import BaseModelOutput
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  34. from .configuration_timesfm import TimesFmConfig
  35. logger = logging.get_logger(__name__)
  36. @dataclass
  37. @auto_docstring
  38. class TimesFmOutput(BaseModelOutput):
  39. r"""
  40. loc (`torch.Tensor` of shape `(batch_size, )`):
  41. The mean of the time series inputs.
  42. scale (`torch.Tensor` of shape `(batch_size,)`):
  43. The scale of the time series inputs.
  44. """
  45. loc: Optional[torch.Tensor] = None
  46. scale: Optional[torch.Tensor] = None
  47. @dataclass
  48. @auto_docstring
  49. class TimesFmOutputForPrediction(BaseModelOutput):
  50. r"""
  51. mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  52. The mean predictions of the time series.
  53. full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  54. The full predictions of the time series including the mean and the quantiles.
  55. loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
  56. The loss of the TimesFM model.
  57. """
  58. mean_predictions: Optional[torch.Tensor] = None
  59. full_predictions: Optional[torch.Tensor] = None
  60. loss: Optional[Union[torch.Tensor, float]] = None
  61. class TimesFmMLP(nn.Module):
  62. """Pax MLP in pytorch."""
  63. def __init__(self, config: TimesFmConfig):
  64. super().__init__()
  65. hidden_size = config.hidden_size
  66. intermediate_size = config.intermediate_size
  67. self.gate_proj = nn.Linear(hidden_size, intermediate_size)
  68. self.down_proj = nn.Linear(intermediate_size, hidden_size)
  69. self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
  70. def forward(self, x, paddings=None):
  71. gate_inp = self.layer_norm(x)
  72. gate = self.gate_proj(gate_inp)
  73. gate = F.relu(gate)
  74. outputs = self.down_proj(gate)
  75. if paddings is not None:
  76. outputs = outputs * (1.0 - paddings[:, :, None])
  77. return outputs + x
  78. class TimesFmResidualBlock(nn.Module):
  79. """TimesFM residual block."""
  80. def __init__(self, input_dims, hidden_dims, output_dims):
  81. super().__init__()
  82. self.input_dims = input_dims
  83. self.hidden_dims = hidden_dims
  84. self.output_dims = output_dims
  85. self.input_layer = nn.Linear(input_dims, hidden_dims)
  86. self.activation = nn.SiLU()
  87. self.output_layer = nn.Linear(hidden_dims, output_dims)
  88. self.residual_layer = nn.Linear(input_dims, output_dims)
  89. def forward(self, x):
  90. hidden = self.input_layer(x)
  91. hidden = self.activation(hidden)
  92. output = self.output_layer(hidden)
  93. residual = self.residual_layer(x)
  94. return output + residual
  95. @use_kernel_forward_from_hub("RMSNorm")
  96. class TimesFmRMSNorm(nn.Module):
  97. def __init__(self, hidden_size, eps=1e-6):
  98. """
  99. TimesFmRMSNorm is equivalent to T5LayerNorm
  100. """
  101. super().__init__()
  102. self.weight = nn.Parameter(torch.ones(hidden_size))
  103. self.variance_epsilon = eps
  104. def forward(self, hidden_states):
  105. input_dtype = hidden_states.dtype
  106. hidden_states = hidden_states.to(torch.float32)
  107. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  108. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  109. return self.weight * hidden_states.to(input_dtype)
  110. def extra_repr(self):
  111. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  112. class TimesFmPositionalEmbedding(nn.Module):
  113. """Generates position embedding for a given 1-d sequence."""
  114. def __init__(self, config: TimesFmConfig):
  115. super().__init__()
  116. min_timescale = config.min_timescale
  117. max_timescale = config.max_timescale
  118. self.embedding_dims = config.hidden_size
  119. num_timescales = self.embedding_dims // 2
  120. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
  121. self.register_buffer(
  122. "inv_timescales",
  123. min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
  124. )
  125. def forward(self, seq_length=None, position=None):
  126. """Generates a Tensor of sinusoids with different frequencies.
  127. Args:
  128. seq_length: an optional Python int defining the output sequence length.
  129. if the `position` argument is specified.
  130. position: [B, seq_length], optional position for each token in the
  131. sequence, only required when the sequence is packed.
  132. Returns:
  133. [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
  134. """
  135. if position is None and seq_length is None:
  136. raise ValueError("Either position or seq_length must be provided")
  137. if position is None:
  138. # [1, seqlen]
  139. position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
  140. elif position.ndim != 2:
  141. raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
  142. scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
  143. signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
  144. # Padding to ensure correct embedding dimension
  145. signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
  146. return signal
  147. def simple_eager_attention_forward(
  148. module: nn.Module,
  149. query_states: torch.Tensor,
  150. key_states: torch.Tensor,
  151. value_states: torch.Tensor,
  152. attention_mask: Optional[torch.Tensor],
  153. scaling: float,
  154. dropout: float = 0.0,
  155. **kwargs: Unpack[TransformersKwargs],
  156. ):
  157. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
  158. if attention_mask is not None:
  159. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  160. attn_weights = attn_weights + causal_mask
  161. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  162. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  163. attn_output = torch.matmul(attn_weights, value_states)
  164. attn_output = attn_output.transpose(1, 2).contiguous()
  165. return attn_output, attn_weights
  166. class TimesFmAttention(nn.Module):
  167. """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
  168. def __init__(self, config: TimesFmConfig, layer_idx: int):
  169. super().__init__()
  170. self.config = config
  171. self.is_causal = True
  172. self.attention_dropout = config.attention_dropout
  173. self.layer_idx = layer_idx
  174. self.num_heads = config.num_attention_heads
  175. self.hidden_size = config.hidden_size
  176. self.head_dim = config.head_dim
  177. self.q_size = self.num_heads * self.head_dim
  178. self.kv_size = self.num_heads * self.head_dim
  179. self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
  180. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  181. self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  182. self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  183. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
  184. def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
  185. scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
  186. return query * scale[None, None, None, :]
  187. def forward(
  188. self,
  189. hidden_states: torch.Tensor,
  190. attention_mask: Optional[torch.Tensor] = None,
  191. **kwargs: Unpack[FlashAttentionKwargs],
  192. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  193. input_shape = hidden_states.shape[:-1]
  194. hidden_shape = (*input_shape, -1, self.head_dim)
  195. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  196. query_states = self._scale_query(query_states)
  197. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  198. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  199. attention_interface: Callable = simple_eager_attention_forward
  200. if self.config._attn_implementation != "eager":
  201. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  202. attn_output, attn_weights = attention_interface(
  203. self,
  204. query_states,
  205. key_states,
  206. value_states,
  207. attention_mask,
  208. dropout=0.0 if not self.training else self.attention_dropout,
  209. scaling=1.0,
  210. **kwargs,
  211. )
  212. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  213. attn_output = self.o_proj(attn_output)
  214. return attn_output, attn_weights
  215. class TimesFmDecoderLayer(nn.Module):
  216. """Transformer layer."""
  217. def __init__(self, config: TimesFmConfig, layer_idx: int):
  218. super().__init__()
  219. self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
  220. self.mlp = TimesFmMLP(config)
  221. self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  222. def forward(
  223. self,
  224. hidden_states: torch.Tensor,
  225. attention_mask: torch.Tensor,
  226. paddings: torch.Tensor,
  227. output_attentions: bool = False,
  228. ) -> tuple[Optional[torch.Tensor], torch.Tensor]:
  229. # Self Attention
  230. residual = hidden_states
  231. hidden_states = self.input_layernorm(hidden_states)
  232. hidden_states, scores = self.self_attn(
  233. hidden_states=hidden_states,
  234. attention_mask=attention_mask,
  235. output_attentions=output_attentions,
  236. )
  237. hidden_states = residual + hidden_states
  238. # MLP
  239. hidden_states = self.mlp(hidden_states, paddings=paddings)
  240. return scores, hidden_states
  241. @auto_docstring
  242. class TimesFmPreTrainedModel(PreTrainedModel):
  243. config: TimesFmConfig
  244. base_model_prefix = "timesfm"
  245. _no_split_modules = ["TimesFmDecoderLayer"]
  246. main_input_name = "past_values"
  247. _supports_sdpa = True
  248. def _init_weights(self, module):
  249. super()._init_weights(module)
  250. if isinstance(module, TimesFmAttention):
  251. # Initialize scaling parameter
  252. nn.init.ones_(module.scaling)
  253. @auto_docstring
  254. class TimesFmModel(TimesFmPreTrainedModel):
  255. def __init__(self, config: TimesFmConfig):
  256. super().__init__(config)
  257. self.config = config
  258. self.input_ff_layer = TimesFmResidualBlock(
  259. input_dims=2 * config.patch_length,
  260. output_dims=config.hidden_size,
  261. hidden_dims=config.intermediate_size,
  262. )
  263. self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
  264. self.layers = nn.ModuleList(
  265. [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  266. )
  267. if self.config.use_positional_embedding:
  268. self.position_emb = TimesFmPositionalEmbedding(config=config)
  269. # Initialize weights and apply final processing
  270. self.post_init()
  271. def _forward_transform(
  272. self, inputs: torch.Tensor, patched_pads: torch.Tensor
  273. ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  274. """Input is of shape [B, N, P]."""
  275. mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
  276. sigma = torch.where(
  277. sigma < self.config.tolerance,
  278. torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
  279. sigma,
  280. )
  281. # Normalize each patch
  282. outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
  283. outputs = torch.where(
  284. torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
  285. torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
  286. outputs,
  287. )
  288. return outputs, (mu, sigma)
  289. @can_return_tuple
  290. @auto_docstring
  291. def forward(
  292. self,
  293. past_values: torch.Tensor,
  294. past_values_padding: torch.LongTensor,
  295. freq: torch.Tensor,
  296. output_attentions: bool = False,
  297. output_hidden_states: bool = False,
  298. ) -> TimesFmOutput:
  299. r"""
  300. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  301. Past values of the time series that serves as input to the model.
  302. past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  303. The padding indicator of the time series.
  304. freq (`torch.LongTensor` of shape `(batch_size,)`):
  305. Frequency indices for the time series data.
  306. """
  307. # Reshape into patches (using view for efficiency)
  308. bsize = past_values.shape[0]
  309. patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
  310. patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
  311. patched_inputs = torch.where(
  312. torch.abs(patched_pads - 1.0) < self.config.tolerance,
  313. torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
  314. patched_inputs,
  315. )
  316. patched_pads = torch.where(
  317. torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
  318. torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
  319. patched_pads,
  320. )
  321. patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
  322. # B x N x D
  323. patched_inputs = patched_inputs * (1.0 - patched_pads)
  324. concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
  325. model_input = self.input_ff_layer(concat_inputs)
  326. # A patch should not be padded even if there is at least one zero.
  327. patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
  328. if self.config.use_positional_embedding:
  329. pos_emb = self.position_emb(model_input.shape[1])
  330. pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
  331. pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
  332. model_input += pos_emb
  333. f_emb = self.freq_emb(freq) # B x 1 x D
  334. model_input += f_emb
  335. # Convert paddings to attention mask and combine with causal mask
  336. hidden_states = model_input
  337. attention_mask = self._prepare_4d_attention_mask(
  338. attention_mask=patched_padding,
  339. sequence_length=hidden_states.shape[1],
  340. dtype=hidden_states.dtype,
  341. device=hidden_states.device,
  342. is_causal=True,
  343. )
  344. all_attentions = []
  345. all_hidden_states = []
  346. for layer in self.layers[: self.config.num_hidden_layers]:
  347. scores, hidden_states = layer(
  348. hidden_states=hidden_states,
  349. attention_mask=attention_mask,
  350. paddings=patched_padding,
  351. output_attentions=output_attentions,
  352. )
  353. if output_attentions:
  354. all_attentions.append(scores)
  355. if output_hidden_states:
  356. all_hidden_states.append(hidden_states)
  357. if output_hidden_states:
  358. all_hidden_states = [model_input] + all_hidden_states
  359. else:
  360. all_hidden_states = None
  361. return TimesFmOutput(
  362. last_hidden_state=hidden_states,
  363. hidden_states=all_hidden_states,
  364. attentions=all_attentions if output_attentions else None,
  365. loc=stats[0],
  366. scale=stats[1],
  367. )
  368. @staticmethod
  369. def _prepare_4d_attention_mask(
  370. attention_mask: Optional[torch.Tensor],
  371. sequence_length: int,
  372. dtype: torch.dtype,
  373. device: torch.device,
  374. is_causal: bool = True,
  375. ) -> Optional[torch.Tensor]:
  376. """
  377. Creates 4D attention mask and combines causal and padding masks if needed.
  378. Args:
  379. attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
  380. sequence_length: Length of the sequence
  381. dtype: Data type of the mask
  382. device: Device of the mask
  383. is_causal: Whether to apply causal masking
  384. Returns:
  385. 4D attention mask of shape (batch_size, 1, seq_length, seq_length)
  386. """
  387. # Get minimum value for the dtype
  388. min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
  389. # Handle padding mask
  390. if attention_mask is not None:
  391. # Convert 2D padding mask to 4D attention mask
  392. attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
  393. attention_mask = attention_mask * min_value
  394. # Create causal mask if needed
  395. if is_causal:
  396. causal_mask = torch.triu(
  397. torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
  398. diagonal=1,
  399. )
  400. causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
  401. # Combine with padding mask if it exists
  402. if attention_mask is not None:
  403. attention_mask = torch.minimum(attention_mask, causal_mask)
  404. else:
  405. attention_mask = causal_mask
  406. return attention_mask
  407. @staticmethod
  408. def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  409. """Calculates mean and standard deviation of `inputs` across axis 1.
  410. It excludes values where `padding` is 1.
  411. Args:
  412. inputs: A PyTorch tensor of shape [b, n, p].
  413. padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
  414. Returns:
  415. A tuple containing the mean and standard deviation.
  416. We return the statistics of the first patch with more than three non-padded values.
  417. """
  418. # Selecting the first patch with more than 3 unpadded values.
  419. def _get_patch_index(arr: torch.Tensor):
  420. indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
  421. row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
  422. return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
  423. pad_sum = torch.sum(1 - padding, dim=2)
  424. patch_indices = _get_patch_index(pad_sum)
  425. bidxs = torch.arange(inputs.shape[0])
  426. arr = inputs[bidxs, patch_indices, :]
  427. pad = padding[bidxs, patch_indices, :]
  428. # Create a mask where padding is 0
  429. mask = 1 - pad
  430. # Calculate the number of valid elements
  431. num_valid_elements = torch.sum(mask, dim=1)
  432. num_valid_elements = torch.where(
  433. num_valid_elements == 0,
  434. torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
  435. num_valid_elements,
  436. )
  437. # Calculate the masked sum and squared sum
  438. masked_sum = torch.sum(arr * mask, dim=1)
  439. masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
  440. # Calculate the masked mean and standard deviation
  441. masked_mean = masked_sum / num_valid_elements
  442. masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
  443. masked_var = torch.where(
  444. masked_var < 0.0,
  445. torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
  446. masked_var,
  447. )
  448. masked_std = torch.sqrt(masked_var)
  449. return masked_mean, masked_std
  450. @staticmethod
  451. def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
  452. """Shifts rows of seq based on the first 0 in each row of the mask.
  453. Args:
  454. mask: mask tensor of shape [B, N]
  455. seq: seq tensor of shape [B, N, P]
  456. Returns:
  457. The shifted sequence.
  458. """
  459. batch_size, num_seq, feature_dim = seq.shape
  460. new_mask: torch.BoolTensor = mask == 0
  461. # Use argmax to find the first True value in each row
  462. indices = new_mask.to(torch.int32).argmax(dim=1)
  463. # Handle rows with all zeros
  464. indices[~new_mask.any(dim=1)] = -1
  465. # Create index ranges for each sequence in the batch
  466. idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
  467. # Calculate shifted indices for each element in each sequence
  468. shifted_idx = (idx_range - indices[:, None, None]) % num_seq
  469. # Gather values from seq using shifted indices
  470. shifted_seq = seq.gather(1, shifted_idx)
  471. return shifted_seq
  472. class TimesFmModelForPrediction(TimesFmPreTrainedModel):
  473. """TimesFM model for quantile and mean prediction."""
  474. def __init__(self, config: TimesFmConfig):
  475. super().__init__(config)
  476. self.config = config
  477. self.context_len = config.context_length
  478. self.horizon_len = config.horizon_length
  479. self.decoder = TimesFmModel(config)
  480. # quantile and mean output
  481. self.horizon_ff_layer = TimesFmResidualBlock(
  482. input_dims=config.hidden_size,
  483. output_dims=config.horizon_length * (1 + len(config.quantiles)),
  484. hidden_dims=config.intermediate_size,
  485. )
  486. # Initialize weights and apply final processing
  487. self.post_init()
  488. def _preprocess(
  489. self, inputs: Sequence[torch.Tensor], freq: Sequence[int]
  490. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  491. """Formats and pads raw inputs to feed into the model.
  492. This function both pads each time series to match the context length, and
  493. pads the inputs to meet the SPMD shape requirement.
  494. Args:
  495. inputs: A list of 1d Tensors. Each Tensor is the context time series of
  496. a single forecast task.
  497. freq: list of frequencies
  498. Returns:
  499. A tuple of:
  500. - the padded input time series to meet the model required context.
  501. - the padding indicator.
  502. - the number of padded examples for SPMD so that each core has the same
  503. number (a multiple of `batch_size`) of examples.
  504. """
  505. input_ts, input_padding, inp_freq = [], [], []
  506. for i, ts in enumerate(inputs):
  507. input_len = ts.shape[0]
  508. padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
  509. if input_len < self.context_len:
  510. num_front_pad = self.context_len - input_len
  511. ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
  512. padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
  513. elif input_len > self.context_len:
  514. ts = ts[-self.context_len :]
  515. padding = padding[-(self.context_len + self.horizon_len) :]
  516. input_ts.append(ts)
  517. input_padding.append(padding)
  518. inp_freq.append(freq[i])
  519. return (
  520. torch.stack(input_ts, dim=0),
  521. torch.stack(input_padding, dim=0),
  522. torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
  523. )
  524. def _postprocess_output(
  525. self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
  526. ) -> torch.Tensor:
  527. """Postprocess output of stacked transformer."""
  528. # B x N x (H.Q)
  529. output_ts = self.horizon_ff_layer(model_output)
  530. # Reshape using view
  531. b, n, _ = output_ts.shape
  532. output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
  533. mu, sigma = stats
  534. return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
  535. def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
  536. losses = []
  537. for i, q in enumerate(self.config.quantiles):
  538. errors = targets - predictions[..., i]
  539. loss = torch.max((q - 1) * errors, q * errors)
  540. losses.append(loss.mean())
  541. return torch.stack(losses).mean()
  542. @can_return_tuple
  543. @auto_docstring
  544. def forward(
  545. self,
  546. past_values: Sequence[torch.Tensor],
  547. freq: Optional[Sequence[Union[torch.Tensor, int]]] = None,
  548. window_size: Optional[int] = None,
  549. future_values: Optional[torch.Tensor] = None,
  550. forecast_context_len: Optional[int] = None,
  551. return_forecast_on_context: bool = False,
  552. truncate_negative: bool = False,
  553. output_attentions: Optional[bool] = None,
  554. output_hidden_states: Optional[bool] = None,
  555. ) -> TimesFmOutputForPrediction:
  556. r"""
  557. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  558. Past values of the time series that serves as input to the model.
  559. freq (`torch.LongTensor` of shape `(batch_size,)`):
  560. Frequency indices for the time series data.
  561. window_size (`int`, *optional*):
  562. Window size of trend + residual decomposition. If None then we do not do decomposition.
  563. future_values (`torch.Tensor`, *optional*):
  564. Optional future time series values to be used for loss computation.
  565. forecast_context_len (`int`, *optional*):
  566. Optional max context length.
  567. return_forecast_on_context (`bool`, *optional*):
  568. True to return the forecast on the context when available, i.e. after the first input patch.
  569. truncate_negative (`bool`, *optional*):
  570. Truncate to only non-negative values if any of the contexts have non-negative values,
  571. otherwise do nothing.
  572. output_attentions (`bool`, *optional*):
  573. Whether to output the attentions.
  574. output_hidden_states (`bool`, *optional*):
  575. Whether to output the hidden states.
  576. Example:
  577. ```python
  578. >>> from transformers import TimesFmModelForPrediction
  579. >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
  580. >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
  581. >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
  582. >>> # Generate
  583. >>> with torch.no_grad():
  584. >>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
  585. >>> point_forecast_conv = outputs.mean_predictions
  586. >>> quantile_forecast_conv = outputs.full_predictions
  587. ```
  588. """
  589. if forecast_context_len is None:
  590. fcontext_len = self.context_len
  591. else:
  592. fcontext_len = forecast_context_len
  593. # Get device from first input tensor
  594. device = past_values[0].device
  595. # Truncate inputs to forecast_context_len
  596. inputs = [ts[-fcontext_len:] for ts in past_values]
  597. inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
  598. if window_size is not None:
  599. new_inputs = []
  600. new_freqs = []
  601. for i, ts in enumerate(inputs):
  602. new_inputs.extend(self._timesfm_moving_average(ts, window_size))
  603. if freq is not None:
  604. new_freqs.extend([freq[i]] * 2)
  605. inputs = new_inputs
  606. if freq is not None:
  607. freq = new_freqs
  608. if freq is None:
  609. logger.info("No frequency provided via `freq`. Default to high (0).")
  610. freq = [0] * len(inputs)
  611. if output_attentions is None:
  612. output_attentions = self.config.output_attentions
  613. if output_hidden_states is None:
  614. output_hidden_states = self.config.output_hidden_states
  615. input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
  616. # Move tensors to the same device as input
  617. input_ts = input_ts.to(device)
  618. input_padding = input_padding.to(device)
  619. inp_freq = inp_freq.to(device)
  620. final_out = input_ts
  621. context_len = final_out.shape[1]
  622. full_outputs = []
  623. if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
  624. raise ValueError(
  625. "Length of paddings must match length of input + horizon_len:"
  626. f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
  627. )
  628. output_patch_len = self.config.horizon_length
  629. num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
  630. for step_index in range(num_decode_patches):
  631. current_padding = input_padding[:, 0 : final_out.shape[1]]
  632. input_ts = final_out[:, -fcontext_len:]
  633. input_padding = current_padding[:, -fcontext_len:]
  634. decoder_output = self.decoder(
  635. past_values=input_ts,
  636. past_values_padding=input_padding,
  637. freq=inp_freq,
  638. output_attentions=output_attentions,
  639. output_hidden_states=output_hidden_states,
  640. )
  641. fprop_outputs = self._postprocess_output(
  642. decoder_output.last_hidden_state,
  643. (decoder_output.loc, decoder_output.scale),
  644. )
  645. if return_forecast_on_context and step_index == 0:
  646. # For the first decodings step, collect the model forecast on the
  647. # context except the unavailable first input batch forecast.
  648. new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
  649. # We have to use reshape and not view for non-contiguous memory
  650. new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
  651. full_outputs.append(new_full_ts)
  652. # (full batch, last patch, output_patch_len, index of mean forecast = 0)
  653. new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
  654. new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
  655. # (full batch, last patch, output_patch_len, all output indices)
  656. full_outputs.append(new_full_ts)
  657. final_out = torch.concatenate([final_out, new_ts], axis=-1)
  658. if return_forecast_on_context:
  659. # `full_outputs` indexing starts at after the first input patch.
  660. full_outputs = torch.concatenate(full_outputs, axis=1)[
  661. :, : (context_len - self.config.patch_length + self.horizon_len), :
  662. ]
  663. else:
  664. # `full_outputs` indexing starts at the forecast horizon.
  665. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
  666. mean_outputs = full_outputs[:, :, 0]
  667. if window_size is not None:
  668. mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
  669. full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
  670. if inp_min >= 0 and truncate_negative:
  671. mean_outputs = torch.maximum(mean_outputs, 0.0)
  672. full_outputs = torch.maximum(full_outputs, 0.0)
  673. loss = None
  674. if future_values is not None:
  675. mse_loss = F.mse_loss(mean_outputs, future_values)
  676. quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
  677. loss = mse_loss + quantile_loss
  678. return TimesFmOutputForPrediction(
  679. last_hidden_state=decoder_output.last_hidden_state,
  680. attentions=decoder_output.attentions if output_attentions else None,
  681. hidden_states=decoder_output.hidden_states if output_hidden_states else None,
  682. mean_predictions=mean_outputs,
  683. full_predictions=full_outputs,
  684. loss=loss,
  685. )
  686. @staticmethod
  687. def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
  688. """Calculates the moving average using PyTorch's convolution function."""
  689. # Pad with zeros to handle initial window positions
  690. arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
  691. # Create a convolution kernel
  692. kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
  693. # Apply convolution to calculate the moving average
  694. smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
  695. return [smoothed_arr, arr - smoothed_arr]
  696. __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]