modeling_patchtsmixer.py 83 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120
  1. # coding=utf-8
  2. # Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch PatchTSMixer model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import torch
  20. import torch.nn as nn
  21. from transformers.modeling_utils import PreTrainedModel
  22. from transformers.utils import ModelOutput
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  25. from ...processing_utils import Unpack
  26. from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
  27. from ...utils import auto_docstring, logging
  28. from .configuration_patchtsmixer import PatchTSMixerConfig
  29. logger = logging.get_logger(__name__)
  30. class PatchTSMixerGatedAttention(nn.Module):
  31. """
  32. Module that applies gated attention to input data.
  33. Args:
  34. in_size (`int`): The input size.
  35. out_size (`int`): The output size.
  36. """
  37. def __init__(self, in_size: int, out_size: int):
  38. super().__init__()
  39. self.attn_layer = nn.Linear(in_size, out_size)
  40. self.attn_softmax = nn.Softmax(dim=-1)
  41. def forward(self, inputs):
  42. attn_weight = self.attn_softmax(self.attn_layer(inputs))
  43. inputs = inputs * attn_weight
  44. return inputs
  45. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTBatchNorm with PatchTST->PatchTSMixer
  46. class PatchTSMixerBatchNorm(nn.Module):
  47. """
  48. Compute batch normalization over the sequence length (time) dimension.
  49. """
  50. def __init__(self, config: PatchTSMixerConfig):
  51. super().__init__()
  52. self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
  53. def forward(self, inputs: torch.Tensor):
  54. """
  55. Parameters:
  56. inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
  57. input for Batch norm calculation
  58. Returns:
  59. `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
  60. """
  61. output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
  62. output = self.batchnorm(output)
  63. return output.transpose(1, 2)
  64. class PatchTSMixerPositionalEncoding(nn.Module):
  65. """
  66. Class for positional encoding
  67. """
  68. def __init__(self, config: PatchTSMixerConfig):
  69. super().__init__()
  70. # positional encoding: [num_patches x d_model]
  71. if config.use_positional_encoding:
  72. self.position_enc = self._init_pe(config)
  73. else:
  74. self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model))
  75. @staticmethod
  76. def _init_pe(config: PatchTSMixerConfig) -> nn.Parameter:
  77. # Positional encoding
  78. if config.positional_encoding_type == "random":
  79. position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True)
  80. elif config.positional_encoding_type == "sincos":
  81. position_enc = torch.zeros(config.num_patches, config.d_model)
  82. position = torch.arange(0, config.num_patches).unsqueeze(1)
  83. div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
  84. position_enc[:, 0::2] = torch.sin(position * div_term)
  85. position_enc[:, 1::2] = torch.cos(position * div_term)
  86. position_enc = position_enc - position_enc.mean()
  87. position_enc = position_enc / (position_enc.std() * 10)
  88. position_enc = nn.Parameter(position_enc, requires_grad=False)
  89. else:
  90. raise ValueError(
  91. f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
  92. )
  93. return position_enc
  94. def forward(self, patch_input: torch.Tensor):
  95. # hidden_state: [bs x num_channels x num_patches x d_model]
  96. hidden_state = patch_input + self.position_enc
  97. return hidden_state
  98. class PatchTSMixerNormLayer(nn.Module):
  99. """Normalization block
  100. Args:
  101. config (`PatchTSMixerConfig`):
  102. Configuration.
  103. """
  104. def __init__(self, config: PatchTSMixerConfig):
  105. super().__init__()
  106. self.norm_mlp = config.norm_mlp
  107. if "batch" in config.norm_mlp.lower():
  108. self.norm = PatchTSMixerBatchNorm(config)
  109. else:
  110. self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  111. def forward(self, inputs: torch.Tensor):
  112. """
  113. Args:
  114. inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
  115. Input to the normalization layer.
  116. Returns:
  117. `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
  118. """
  119. if "batch" in self.norm_mlp.lower():
  120. # reshape the data
  121. inputs_reshaped = torch.reshape(
  122. inputs,
  123. (
  124. inputs.shape[0] * inputs.shape[1],
  125. inputs.shape[2],
  126. inputs.shape[3],
  127. ),
  128. ) # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
  129. # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
  130. inputs_reshaped = self.norm(inputs_reshaped)
  131. # put back data to the original shape
  132. inputs = torch.reshape(inputs_reshaped, inputs.shape)
  133. else:
  134. inputs = self.norm(inputs)
  135. return inputs
  136. class PatchTSMixerMLP(nn.Module):
  137. def __init__(self, in_features, out_features, config):
  138. super().__init__()
  139. num_hidden = in_features * config.expansion_factor
  140. self.fc1 = nn.Linear(in_features, num_hidden)
  141. self.dropout1 = nn.Dropout(config.dropout)
  142. self.fc2 = nn.Linear(num_hidden, out_features)
  143. self.dropout2 = nn.Dropout(config.dropout)
  144. def forward(self, inputs: torch.Tensor):
  145. """
  146. Args:
  147. inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
  148. Input to the MLP layer.
  149. Returns:
  150. `torch.Tensor` of the same shape as `inputs`
  151. """
  152. inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
  153. inputs = self.fc2(inputs)
  154. inputs = self.dropout2(inputs)
  155. return inputs
  156. class PatchTSMixerChannelFeatureMixerBlock(nn.Module):
  157. """This module mixes the features in the channel dimension.
  158. Args:
  159. config (`PatchTSMixerConfig`):
  160. Configuration.
  161. """
  162. def __init__(self, config: PatchTSMixerConfig):
  163. super().__init__()
  164. self.norm = PatchTSMixerNormLayer(config)
  165. self.gated_attn = config.gated_attn
  166. self.mlp = PatchTSMixerMLP(
  167. in_features=config.num_input_channels,
  168. out_features=config.num_input_channels,
  169. config=config,
  170. )
  171. if config.gated_attn:
  172. self.gating_block = PatchTSMixerGatedAttention(
  173. in_size=config.num_input_channels, out_size=config.num_input_channels
  174. )
  175. def forward(self, inputs: torch.Tensor):
  176. """
  177. Args:
  178. inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
  179. input to the MLP layer
  180. Returns:
  181. `torch.Tensor` of the same shape as `inputs`
  182. """
  183. residual = inputs
  184. inputs = self.norm(inputs)
  185. inputs = inputs.permute(0, 3, 2, 1)
  186. if self.gated_attn:
  187. inputs = self.gating_block(inputs)
  188. inputs = self.mlp(inputs)
  189. inputs = inputs.permute(0, 3, 2, 1)
  190. out = inputs + residual
  191. return out
  192. # Copied from transformers.models.bart.modeling_bart.eager_attention_forward
  193. def eager_attention_forward(
  194. module: nn.Module,
  195. query: torch.Tensor,
  196. key: torch.Tensor,
  197. value: torch.Tensor,
  198. attention_mask: Optional[torch.Tensor],
  199. scaling: Optional[float] = None,
  200. dropout: float = 0.0,
  201. head_mask: Optional[torch.Tensor] = None,
  202. **kwargs,
  203. ):
  204. if scaling is None:
  205. scaling = query.size(-1) ** -0.5
  206. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  207. if attention_mask is not None:
  208. attn_weights = attn_weights + attention_mask
  209. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  210. if head_mask is not None:
  211. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  212. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  213. attn_output = torch.matmul(attn_weights, value)
  214. attn_output = attn_output.transpose(1, 2).contiguous()
  215. return attn_output, attn_weights
  216. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTSMixer
  217. class PatchTSMixerAttention(nn.Module):
  218. """Multi-headed attention from 'Attention Is All You Need' paper"""
  219. def __init__(
  220. self,
  221. embed_dim: int,
  222. num_heads: int,
  223. dropout: float = 0.0,
  224. is_decoder: bool = False,
  225. bias: bool = True,
  226. is_causal: bool = False,
  227. config: Optional[PatchTSMixerConfig] = None,
  228. ):
  229. super().__init__()
  230. self.embed_dim = embed_dim
  231. self.num_heads = num_heads
  232. self.dropout = dropout
  233. self.head_dim = embed_dim // num_heads
  234. self.config = config
  235. if (self.head_dim * num_heads) != self.embed_dim:
  236. raise ValueError(
  237. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  238. f" and `num_heads`: {num_heads})."
  239. )
  240. self.scaling = self.head_dim**-0.5
  241. self.is_decoder = is_decoder
  242. self.is_causal = is_causal
  243. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  244. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  245. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  246. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  247. def forward(
  248. self,
  249. hidden_states: torch.Tensor,
  250. key_value_states: Optional[torch.Tensor] = None,
  251. attention_mask: Optional[torch.Tensor] = None,
  252. layer_head_mask: Optional[torch.Tensor] = None,
  253. output_attentions: Optional[bool] = False,
  254. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  255. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  256. **kwargs: Unpack[FlashAttentionKwargs],
  257. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  258. """Input shape: Batch x Time x Channel"""
  259. # if key_value_states are provided this layer is used as a cross-attention layer
  260. # for the decoder
  261. is_cross_attention = key_value_states is not None
  262. # determine input shapes
  263. bsz, tgt_len = hidden_states.shape[:-1]
  264. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  265. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  266. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  267. # get query proj
  268. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  269. current_states = key_value_states if is_cross_attention else hidden_states
  270. key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  271. value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  272. attention_interface: Callable = eager_attention_forward
  273. if self.config._attn_implementation != "eager":
  274. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  275. attn_output, attn_weights = attention_interface(
  276. self,
  277. query_states,
  278. key_states,
  279. value_states,
  280. attention_mask,
  281. dropout=0.0 if not self.training else self.dropout,
  282. scaling=self.scaling,
  283. output_attentions=output_attentions,
  284. head_mask=layer_head_mask,
  285. **kwargs,
  286. )
  287. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  288. attn_output = self.out_proj(attn_output)
  289. return attn_output, attn_weights, None
  290. class PatchMixerBlock(nn.Module):
  291. """This module mixes the patch dimension.
  292. Args:
  293. config (`PatchTSMixerConfig`):
  294. Configuration.
  295. """
  296. def __init__(self, config: PatchTSMixerConfig):
  297. super().__init__()
  298. self.norm = PatchTSMixerNormLayer(config)
  299. self.self_attn = config.self_attn
  300. self.gated_attn = config.gated_attn
  301. self.mlp = PatchTSMixerMLP(
  302. in_features=config.num_patches,
  303. out_features=config.num_patches,
  304. config=config,
  305. )
  306. if config.gated_attn:
  307. self.gating_block = PatchTSMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches)
  308. if config.self_attn:
  309. self.self_attn_layer = PatchTSMixerAttention(
  310. embed_dim=config.d_model,
  311. num_heads=config.self_attn_heads,
  312. dropout=config.dropout,
  313. config=config,
  314. )
  315. self.norm_attn = PatchTSMixerNormLayer(config)
  316. def forward(self, hidden_state):
  317. """
  318. Args:
  319. hidden_state (`torch.Tensor`): Input tensor.
  320. Returns:
  321. `torch.Tensor`: Transformed tensor.
  322. """
  323. residual = hidden_state
  324. hidden_state = self.norm(hidden_state)
  325. if self.self_attn:
  326. batch_size, n_vars, num_patches, d_model = hidden_state.shape
  327. hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model)
  328. x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False)
  329. x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model)
  330. # Transpose so that num_patches is the last dimension
  331. hidden_state = hidden_state.transpose(2, 3)
  332. hidden_state = self.mlp(hidden_state)
  333. if self.gated_attn:
  334. hidden_state = self.gating_block(hidden_state)
  335. # Transpose back
  336. hidden_state = hidden_state.transpose(2, 3)
  337. if self.self_attn:
  338. hidden_state = self.norm_attn(hidden_state + x_attn)
  339. out = hidden_state + residual
  340. return out
  341. class FeatureMixerBlock(nn.Module):
  342. """This module mixes the hidden feature dimension.
  343. Args:
  344. config (`PatchTSMixerConfig`):
  345. Configuration.
  346. """
  347. def __init__(self, config: PatchTSMixerConfig):
  348. super().__init__()
  349. self.norm = PatchTSMixerNormLayer(config)
  350. self.gated_attn = config.gated_attn
  351. self.mlp = PatchTSMixerMLP(
  352. in_features=config.d_model,
  353. out_features=config.d_model,
  354. config=config,
  355. )
  356. if config.gated_attn:
  357. self.gating_block = PatchTSMixerGatedAttention(in_size=config.d_model, out_size=config.d_model)
  358. def forward(self, hidden: torch.Tensor):
  359. """
  360. Args:
  361. hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
  362. Input tensor to the layer.
  363. Returns:
  364. `torch.Tensor`: Transformed tensor.
  365. """
  366. residual = hidden
  367. hidden = self.norm(hidden)
  368. hidden = self.mlp(hidden)
  369. if self.gated_attn:
  370. hidden = self.gating_block(hidden)
  371. out = hidden + residual
  372. return out
  373. class PatchTSMixerLayer(nn.Module):
  374. """
  375. The `PatchTSMixer` layer that does all three kinds of mixing.
  376. Args:
  377. config (`PatchTSMixerConfig`):
  378. Configuration.
  379. """
  380. def __init__(self, config: PatchTSMixerConfig):
  381. super().__init__()
  382. self.patch_mixer = PatchMixerBlock(config=config)
  383. self.feature_mixer = FeatureMixerBlock(config=config)
  384. self.mode = config.mode
  385. if config.mode == "mix_channel":
  386. self.channel_feature_mixer = PatchTSMixerChannelFeatureMixerBlock(config=config)
  387. def forward(self, hidden: torch.Tensor):
  388. """
  389. Args:
  390. hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
  391. Input tensor to the layer.
  392. Returns:
  393. `torch.Tensor`: Transformed tensor.
  394. """
  395. if self.mode == "mix_channel":
  396. hidden = self.channel_feature_mixer(hidden)
  397. hidden = self.patch_mixer(hidden)
  398. hidden = self.feature_mixer(hidden) # hidden: (batch_size x num_patches x d_model)
  399. return hidden
  400. class PatchTSMixerBlock(nn.Module):
  401. """The main computing framework of the `PatchTSMixer` model.
  402. Args:
  403. config (`PatchTSMixerConfig`):
  404. Configuration.
  405. """
  406. def __init__(self, config: PatchTSMixerConfig):
  407. super().__init__()
  408. num_layers = config.num_layers
  409. self.mixers = nn.ModuleList([PatchTSMixerLayer(config=config) for _ in range(num_layers)])
  410. def forward(self, hidden_state, output_hidden_states: bool = False):
  411. """
  412. Args:
  413. hidden_state (`torch.Tensor`): The input tensor.
  414. output_hidden_states (`bool`, *optional*, defaults to False.):
  415. Whether to output the hidden states as well.
  416. Returns:
  417. `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
  418. `True`.
  419. """
  420. all_hidden_states = []
  421. embedding = hidden_state
  422. for mod in self.mixers:
  423. embedding = mod(embedding)
  424. if output_hidden_states:
  425. all_hidden_states.append(embedding)
  426. if output_hidden_states:
  427. return embedding, all_hidden_states
  428. else:
  429. return embedding, None
  430. class PatchTSMixerForPredictionHead(nn.Module):
  431. """Prediction Head for Forecasting
  432. Args:
  433. config (`PatchTSMixerConfig`):
  434. Configuration.
  435. """
  436. def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
  437. super().__init__()
  438. self.prediction_channel_indices = config.prediction_channel_indices
  439. if self.prediction_channel_indices is not None:
  440. self.prediction_channel_indices.sort()
  441. self.dropout_layer = nn.Dropout(config.head_dropout)
  442. if distribution_output is None:
  443. self.base_forecast_block = nn.Linear((config.num_patches * config.d_model), config.prediction_length)
  444. else:
  445. self.base_forecast_block = distribution_output.get_parameter_projection(
  446. config.num_patches * config.d_model
  447. )
  448. self.flatten = nn.Flatten(start_dim=-2)
  449. def forward(self, hidden_features):
  450. """
  451. Args:
  452. hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
  453. or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
  454. features.
  455. Returns:
  456. `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.
  457. """
  458. hidden_features = self.flatten(hidden_features) # [batch_size x n_vars x num_patch * d_model]
  459. hidden_features = self.dropout_layer(hidden_features) # [batch_size x n_vars x num_patch * d_model]
  460. forecast = self.base_forecast_block(hidden_features) # [batch_size x n_vars x prediction_length]
  461. if isinstance(forecast, tuple):
  462. forecast = tuple(z.transpose(-1, -2) for z in forecast)
  463. else:
  464. forecast = forecast.transpose(-1, -2) # [batch_size x prediction_length x n_vars]
  465. if self.prediction_channel_indices is not None:
  466. if isinstance(forecast, tuple):
  467. forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast)
  468. else:
  469. forecast = forecast[..., self.prediction_channel_indices] # [batch_size x prediction_length x n_vars]
  470. return forecast
  471. class PatchTSMixerLinearHead(nn.Module):
  472. """Linear head for Classification and Regression.
  473. Args:
  474. config (`PatchTSMixerConfig`):
  475. Configuration.
  476. """
  477. def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
  478. super().__init__()
  479. self.head_aggregation = config.head_aggregation
  480. self.output_range = config.output_range
  481. if config.head_aggregation is None:
  482. mul_factor = config.num_patches
  483. else:
  484. mul_factor = 1
  485. self.distribution_output = distribution_output
  486. if distribution_output is None:
  487. self.projection = nn.Linear(
  488. config.d_model * config.num_input_channels * mul_factor,
  489. config.num_targets,
  490. )
  491. else:
  492. self.projection = distribution_output.get_parameter_projection(
  493. config.d_model * config.num_input_channels * mul_factor
  494. )
  495. if config.head_aggregation is None:
  496. self.flatten = nn.Flatten(start_dim=-3)
  497. else:
  498. self.flatten = nn.Flatten(start_dim=-2)
  499. self.dropout = nn.Dropout(config.head_dropout)
  500. def forward(self, hidden_features):
  501. """
  502. Args:
  503. hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
  504. or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
  505. features.
  506. Returns:
  507. `torch.Tensor` of shape `(batch_size x num_targets)`.
  508. """
  509. # batch_size x d_model x num_patch or batch_size x n_vars x d_model x num_patch
  510. hidden_features = hidden_features.transpose(-1, -2)
  511. if self.head_aggregation == "use_last":
  512. # batch_size x d_model (flatten) or # batch_size x n_vars x d_model (common_channel)
  513. hidden_features = hidden_features[..., -1]
  514. elif self.head_aggregation == "max_pool":
  515. # batch_size x n_vars x d_model or batch_size x d_model
  516. hidden_features = hidden_features.max(dim=-1).values
  517. elif self.head_aggregation == "avg_pool":
  518. # batch_size x n_vars x d_model or batch_size x d_model
  519. hidden_features = hidden_features.mean(dim=-1)
  520. if self.flatten:
  521. hidden_features = self.flatten(hidden_features)
  522. hidden_features = self.dropout(hidden_features)
  523. hidden_features = self.projection(hidden_features) # batch_size x num_targets
  524. if (self.distribution_output is None) and (self.output_range is not None):
  525. hidden_features = (
  526. torch.sigmoid(hidden_features) * (self.output_range[1] - self.output_range[0]) + self.output_range[0]
  527. )
  528. return hidden_features
  529. @auto_docstring
  530. class PatchTSMixerPreTrainedModel(PreTrainedModel):
  531. # Weight initialization
  532. config: PatchTSMixerConfig
  533. base_model_prefix = "model"
  534. main_input_name = "past_values"
  535. supports_gradient_checkpointing = False
  536. def _init_weights(self, module):
  537. """Initialize weights"""
  538. if isinstance(module, PatchTSMixerPositionalEncoding):
  539. # initialize positional encoding
  540. if self.config.positional_encoding_type == "random":
  541. nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
  542. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
  543. module.bias.data.zero_()
  544. module.weight.data.fill_(1.0)
  545. elif isinstance(module, PatchTSMixerBatchNorm):
  546. module.batchnorm.bias.data.zero_()
  547. module.batchnorm.weight.data.fill_(1.0)
  548. elif isinstance(module, nn.Linear):
  549. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  550. if module.bias is not None:
  551. module.bias.data.zero_()
  552. class PatchTSMixerPretrainHead(nn.Module):
  553. """Pretraining head.
  554. Args:
  555. config (`PatchTSMixerConfig`):
  556. Configuration.
  557. """
  558. def __init__(self, config: PatchTSMixerConfig):
  559. super().__init__()
  560. self.dropout_layer = nn.Dropout(config.head_dropout)
  561. self.base_pt_block = nn.Linear(config.d_model, config.patch_length)
  562. def forward(self, hidden_features):
  563. """
  564. Args:
  565. hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
  566. or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
  567. features.
  568. Returns:
  569. `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
  570. """
  571. hidden_features = self.dropout_layer(hidden_features)
  572. forecast = self.base_pt_block(hidden_features) # [batch_size x n_vars x num_patch x patch_length]
  573. return forecast
  574. # Copied from transformers.models.patchtst.modeling_patchtst.random_masking
  575. def random_masking(
  576. inputs: torch.Tensor,
  577. mask_ratio: float,
  578. unmasked_channel_indices: Optional[list] = None,
  579. channel_consistent_masking: bool = False,
  580. mask_value: int = 0,
  581. ):
  582. """random_masking: Mask the input considering the control variables.
  583. Args:
  584. inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
  585. The input tensor to mask.
  586. mask_ratio (`float`):
  587. Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
  588. unmasked_channel_indices (list, *optional*):
  589. Indices of channels that will not be masked.
  590. channel_consistent_masking (bool, *optional*, defaults to `False`):
  591. When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
  592. across channels.
  593. mask_value (int, *optional*, defaults to 0):
  594. Define the value of masked patches for pretraining.
  595. Returns:
  596. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
  597. n]
  598. """
  599. if mask_ratio < 0 or mask_ratio >= 1:
  600. raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
  601. batch_size, num_channels, sequence_length, num_features = inputs.shape
  602. device = inputs.device
  603. len_keep = int(sequence_length * (1 - mask_ratio))
  604. if channel_consistent_masking:
  605. noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L
  606. noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time
  607. else:
  608. # noise in [0, 1], bs x num_channels x L
  609. noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
  610. # mask: [bs x num_channels x num_patch]
  611. mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
  612. mask[:, :, :len_keep] = 0
  613. # sort noise for each sample
  614. ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove
  615. ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L]
  616. mask = torch.gather(mask, dim=-1, index=ids_restore)
  617. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length]
  618. if unmasked_channel_indices is not None:
  619. mask[:, unmasked_channel_indices, :, :] = 0
  620. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  621. return inputs_mask, mask[..., 0]
  622. # Copied from transformers.models.patchtst.modeling_patchtst.forecast_masking
  623. def forecast_masking(
  624. inputs: torch.Tensor,
  625. num_forecast_mask_patches: Union[list, int],
  626. unmasked_channel_indices: Optional[list] = None,
  627. mask_value: int = 0,
  628. ):
  629. """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
  630. If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
  631. Parameters:
  632. inputs (`torch.Tensor`):
  633. Input of shape `(bs, num_channels, num_patch, patch_length)`
  634. num_forecast_mask_patches (`list`):
  635. Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
  636. unmasked_channel_indices (`list`, *optional*):
  637. Indices of channels that are not masked.
  638. mask_value (`int`, *optional*, defaults to 0):
  639. Values in the masked patches will be filled by `mask_value`.
  640. Returns:
  641. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
  642. num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
  643. """
  644. if isinstance(num_forecast_mask_patches, int):
  645. num_forecast_mask_patches = [num_forecast_mask_patches]
  646. forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
  647. batch_size, num_channels, sequence_length, num_features = inputs.shape
  648. mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
  649. t_list = []
  650. total_length = 0
  651. total_ratio = sum(forecast_mask_ratios)
  652. for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
  653. if patch_length <= 0 or patch_length >= sequence_length:
  654. raise ValueError(
  655. f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
  656. )
  657. temp_len = int(batch_size * ratio / total_ratio)
  658. t_list.append([patch_length, ratio, temp_len])
  659. total_length += temp_len
  660. t_list = sorted(t_list, key=lambda x: x[2])
  661. if total_length < batch_size:
  662. t_list[0][2] = t_list[0][2] + (batch_size - total_length)
  663. elif total_length > batch_size:
  664. t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
  665. batch1 = 0
  666. for patch_len, _, temp_len in t_list:
  667. batch2 = batch1 + temp_len
  668. mask[batch1:batch2, :, -patch_len:] = 1
  669. batch1 = batch2
  670. perm = torch.randperm(mask.shape[0])
  671. mask = mask[perm]
  672. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len]
  673. if unmasked_channel_indices is not None:
  674. mask[:, unmasked_channel_indices, :, :] = 0
  675. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  676. return inputs_mask, mask[..., 0]
  677. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTPatchify with PatchTST->PatchTSMixer
  678. class PatchTSMixerPatchify(nn.Module):
  679. """
  680. A class to patchify the time series sequence into different patches
  681. Returns:
  682. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  683. """
  684. def __init__(self, config: PatchTSMixerConfig):
  685. super().__init__()
  686. self.sequence_length = config.context_length
  687. self.patch_length = config.patch_length
  688. self.patch_stride = config.patch_stride
  689. if self.sequence_length <= self.patch_length:
  690. raise ValueError(
  691. f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
  692. )
  693. # get the number of patches
  694. self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
  695. new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
  696. self.sequence_start = self.sequence_length - new_sequence_length
  697. def forward(self, past_values: torch.Tensor):
  698. """
  699. Parameters:
  700. past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
  701. Input for patchification
  702. Returns:
  703. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  704. """
  705. sequence_length = past_values.shape[-2]
  706. if sequence_length != self.sequence_length:
  707. raise ValueError(
  708. f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
  709. )
  710. # output: [bs x new_sequence_length x num_channels]
  711. output = past_values[:, self.sequence_start :, :]
  712. # output: [bs x num_patches x num_input_channels x patch_length]
  713. output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
  714. # output: [bs x num_input_channels x num_patches x patch_length]
  715. output = output.transpose(-2, -3).contiguous()
  716. return output
  717. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMasking with PatchTST->PatchTSMixer
  718. class PatchTSMixerMasking(nn.Module):
  719. """
  720. Class to perform random or forecast masking.
  721. Parameters:
  722. config (`PatchTSMixerConfig`): model config
  723. Returns:
  724. x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  725. Masked patched input
  726. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  727. Bool tensor indicating True on masked points
  728. """
  729. def __init__(self, config: PatchTSMixerConfig):
  730. super().__init__()
  731. self.random_mask_ratio = config.random_mask_ratio
  732. self.channel_consistent_masking = config.channel_consistent_masking
  733. self.mask_type = config.mask_type
  734. self.num_forecast_mask_patches = config.num_forecast_mask_patches
  735. self.unmasked_channel_indices = config.unmasked_channel_indices
  736. self.mask_value = config.mask_value
  737. if self.unmasked_channel_indices is not None:
  738. self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
  739. def forward(self, patch_input: torch.Tensor):
  740. """
  741. Parameters:
  742. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  743. Patch input
  744. Return:
  745. masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  746. Masked patched input
  747. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  748. Bool tensor indicating True on masked points
  749. """
  750. if self.mask_type == "random":
  751. masked_input, mask = random_masking(
  752. inputs=patch_input,
  753. mask_ratio=self.random_mask_ratio,
  754. unmasked_channel_indices=self.unmasked_channel_indices,
  755. channel_consistent_masking=self.channel_consistent_masking,
  756. mask_value=self.mask_value,
  757. )
  758. elif self.mask_type == "forecast":
  759. masked_input, mask = forecast_masking(
  760. inputs=patch_input,
  761. num_forecast_mask_patches=self.num_forecast_mask_patches,
  762. unmasked_channel_indices=self.unmasked_channel_indices,
  763. mask_value=self.mask_value,
  764. )
  765. else:
  766. raise ValueError(f"Invalid mask type {self.mask_type}.")
  767. # mask: [bs x num_input_channels x num_patch]
  768. mask = mask.bool()
  769. return masked_input, mask
  770. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTStdScaler with PatchTST->PatchTSMixer
  771. class PatchTSMixerStdScaler(nn.Module):
  772. """
  773. Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
  774. subtracting from the mean and dividing by the standard deviation.
  775. """
  776. def __init__(self, config: PatchTSMixerConfig):
  777. super().__init__()
  778. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  779. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  780. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
  781. def forward(
  782. self, data: torch.Tensor, observed_indicator: torch.Tensor
  783. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  784. """
  785. Parameters:
  786. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  787. input for Batch norm calculation
  788. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  789. Calculating the scale on the observed indicator.
  790. Returns:
  791. tuple of `torch.Tensor` of shapes
  792. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  793. `(batch_size, 1, num_input_channels)`)
  794. """
  795. denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
  796. denominator = denominator.clamp_min(1.0)
  797. loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
  798. variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
  799. scale = torch.sqrt(variance + self.minimum_scale)
  800. return (data - loc) / scale, loc, scale
  801. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMeanScaler with PatchTST->PatchTSMixer
  802. class PatchTSMixerMeanScaler(nn.Module):
  803. """
  804. Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
  805. accordingly.
  806. """
  807. def __init__(self, config: PatchTSMixerConfig):
  808. super().__init__()
  809. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  810. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  811. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
  812. self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
  813. def forward(
  814. self, data: torch.Tensor, observed_indicator: torch.Tensor
  815. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  816. """
  817. Parameters:
  818. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  819. input for Batch norm calculation
  820. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  821. Calculating the scale on the observed indicator.
  822. Returns:
  823. tuple of `torch.Tensor` of shapes
  824. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  825. `(batch_size, 1, num_input_channels)`)
  826. """
  827. ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
  828. num_observed = observed_indicator.sum(self.dim, keepdim=True)
  829. scale = ts_sum / torch.clamp(num_observed, min=1)
  830. # If `default_scale` is provided, we use it, otherwise we use the scale
  831. # of the batch.
  832. if self.default_scale is None:
  833. batch_sum = ts_sum.sum(dim=0)
  834. batch_observations = torch.clamp(num_observed.sum(0), min=1)
  835. default_scale = torch.squeeze(batch_sum / batch_observations)
  836. else:
  837. default_scale = self.default_scale * torch.ones_like(scale)
  838. # apply default scale where there are no observations
  839. scale = torch.where(num_observed > 0, scale, default_scale)
  840. # ensure the scale is at least `self.minimum_scale`
  841. scale = torch.clamp(scale, min=self.minimum_scale)
  842. scaled_data = data / scale
  843. if not self.keepdim:
  844. scale = scale.squeeze(dim=self.dim)
  845. return scaled_data, torch.zeros_like(scale), scale
  846. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTNOPScaler with PatchTST->PatchTSMixer
  847. class PatchTSMixerNOPScaler(nn.Module):
  848. """
  849. Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
  850. """
  851. def __init__(self, config: PatchTSMixerConfig):
  852. super().__init__()
  853. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  854. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  855. def forward(
  856. self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None
  857. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  858. """
  859. Parameters:
  860. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  861. input for Batch norm calculation
  862. Returns:
  863. tuple of `torch.Tensor` of shapes
  864. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  865. `(batch_size, 1, num_input_channels)`)
  866. """
  867. scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  868. loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  869. return data, loc, scale
  870. @dataclass
  871. @auto_docstring(
  872. custom_intro="""
  873. Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.
  874. """
  875. )
  876. class PatchTSMixerEncoderOutput(ModelOutput):
  877. r"""
  878. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
  879. Hidden-state at the output of the last layer of the model.
  880. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  881. Hidden-states of the model at the output of each layer.
  882. """
  883. last_hidden_state: Optional[torch.FloatTensor] = None
  884. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  885. class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel):
  886. """
  887. Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.
  888. Args:
  889. config (`PatchTSMixerConfig`):
  890. Configuration.
  891. """
  892. def __init__(self, config: PatchTSMixerConfig):
  893. super().__init__(config)
  894. self.use_return_dict = config.use_return_dict
  895. self.patcher = nn.Linear(config.patch_length, config.d_model)
  896. if config.use_positional_encoding:
  897. self.positional_encoder = PatchTSMixerPositionalEncoding(config=config)
  898. else:
  899. self.positional_encoder = None
  900. self.mlp_mixer_encoder = PatchTSMixerBlock(config=config)
  901. # Initialize weights and apply final processing
  902. if config.post_init:
  903. self.post_init()
  904. @auto_docstring
  905. def forward(
  906. self,
  907. past_values: torch.Tensor,
  908. output_hidden_states: Optional[bool] = False,
  909. return_dict: Optional[bool] = None,
  910. ) -> Union[tuple, PatchTSMixerEncoderOutput]:
  911. r"""
  912. past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
  913. Context values of the time series. For a pretraining task, this denotes the input time series to
  914. predict the masked portion. For a forecasting task, this denotes the history/past time series values.
  915. Similarly, for classification or regression tasks, it denotes the appropriate context values of the
  916. time series.
  917. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
  918. it is greater than 1.
  919. Returns:
  920. `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
  921. """
  922. return_dict = return_dict if return_dict is not None else self.use_return_dict
  923. # flatten [bs x num_patch x d_model]. common_channel/mix_channel: [bs x n_vars x num_patch x d_model]
  924. patches = self.patcher(past_values)
  925. # add positional encoder
  926. if self.positional_encoder is not None:
  927. patches = self.positional_encoder(patches)
  928. last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states)
  929. if not return_dict:
  930. return tuple(
  931. v
  932. for v in [
  933. last_hidden_state,
  934. hidden_states,
  935. ]
  936. )
  937. return PatchTSMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states)
  938. @dataclass
  939. @auto_docstring(
  940. custom_intro="""
  941. Base class for model's outputs, with potential hidden states.
  942. """
  943. )
  944. class PatchTSMixerModelOutput(ModelOutput):
  945. r"""
  946. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
  947. Hidden-state at the output of the last layer of the model.
  948. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  949. Hidden-states of the model at the output of each layer.
  950. patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
  951. Patched input data to the model.
  952. mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
  953. Bool Tensor indicating True in masked patches and False otherwise.
  954. loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
  955. Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
  956. enabled.
  957. scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
  958. Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
  959. enabled.
  960. """
  961. last_hidden_state: Optional[torch.FloatTensor] = None
  962. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  963. patch_input: Optional[torch.FloatTensor] = None
  964. mask: Optional[torch.FloatTensor] = None
  965. loc: Optional[torch.FloatTensor] = None
  966. scale: Optional[torch.FloatTensor] = None
  967. @auto_docstring(
  968. custom_intro="""
  969. The PatchTSMixer Model for time-series forecasting.
  970. """
  971. )
  972. class PatchTSMixerModel(PatchTSMixerPreTrainedModel):
  973. def __init__(self, config: PatchTSMixerConfig, mask_input: bool = False):
  974. r"""
  975. mask_input (bool, *optional*, defaults to `False`):
  976. Whether to mask the input using the [`PatchTSMixerMasking`] module.
  977. """
  978. super().__init__(config)
  979. self.use_return_dict = config.use_return_dict
  980. self.encoder = PatchTSMixerEncoder(config)
  981. self.patching = PatchTSMixerPatchify(config)
  982. if mask_input is True:
  983. self.masking = PatchTSMixerMasking(config)
  984. else:
  985. self.masking = None
  986. if config.scaling == "mean":
  987. self.scaler = PatchTSMixerMeanScaler(config)
  988. elif config.scaling == "std" or config.scaling is True:
  989. self.scaler = PatchTSMixerStdScaler(config)
  990. else:
  991. self.scaler = PatchTSMixerNOPScaler(config)
  992. # Initialize weights and apply final processing
  993. if config.post_init:
  994. self.post_init()
  995. @auto_docstring
  996. def forward(
  997. self,
  998. past_values: torch.Tensor,
  999. observed_mask: Optional[torch.Tensor] = None,
  1000. output_hidden_states: Optional[bool] = False,
  1001. return_dict: Optional[bool] = None,
  1002. ) -> PatchTSMixerModelOutput:
  1003. r"""
  1004. past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
  1005. Context values of the time series. For a pretraining task, this denotes the input time series to predict
  1006. the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
  1007. for classification or regression tasks, it denotes the appropriate context values of the time series.
  1008. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
  1009. greater than 1.
  1010. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1011. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1012. in `[0, 1]`:
  1013. - 1 for values that are **observed**,
  1014. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1015. """
  1016. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1017. mask = None
  1018. if observed_mask is None:
  1019. observed_mask = torch.ones_like(past_values)
  1020. scaled_past_values, loc, scale = self.scaler(past_values, observed_mask)
  1021. patched_x = self.patching(scaled_past_values) # [batch_size x num_input_channels x num_patch x patch_length
  1022. enc_input = patched_x
  1023. if self.masking is not None:
  1024. enc_input, mask = self.masking(patched_x)
  1025. # enc_input: [batch_size x num_input_channels x num_patch x patch_length]
  1026. # mask: [batch_size x num_input_channels x num_patch]
  1027. encoder_output = self.encoder(
  1028. enc_input,
  1029. output_hidden_states=output_hidden_states,
  1030. return_dict=return_dict,
  1031. )
  1032. if isinstance(encoder_output, tuple):
  1033. encoder_output = PatchTSMixerEncoderOutput(*encoder_output)
  1034. if not return_dict:
  1035. return tuple(
  1036. v
  1037. for v in [
  1038. encoder_output.last_hidden_state,
  1039. encoder_output.hidden_states,
  1040. patched_x,
  1041. mask,
  1042. loc,
  1043. scale,
  1044. ]
  1045. )
  1046. return PatchTSMixerModelOutput(
  1047. last_hidden_state=encoder_output.last_hidden_state,
  1048. hidden_states=encoder_output.hidden_states,
  1049. patch_input=patched_x,
  1050. mask=mask,
  1051. loc=loc,
  1052. scale=scale,
  1053. )
  1054. @dataclass
  1055. @auto_docstring(
  1056. custom_intro="""
  1057. Output type of [`PatchTSMixerForPreTrainingOutput`].
  1058. """
  1059. )
  1060. class PatchTSMixerForPreTrainingOutput(ModelOutput):
  1061. r"""
  1062. loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
  1063. Total loss
  1064. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
  1065. Prediction output from the pretrain head.
  1066. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
  1067. Backbone embeddings before passing through the head.
  1068. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1069. Hidden-states of the model at the output of each layer.
  1070. """
  1071. loss: Optional[torch.FloatTensor] = None
  1072. prediction_outputs: Optional[torch.FloatTensor] = None
  1073. last_hidden_state: Optional[torch.FloatTensor] = None
  1074. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  1075. @auto_docstring(
  1076. custom_intro="""
  1077. `PatchTSMixer` for mask pretraining.
  1078. """
  1079. )
  1080. class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel):
  1081. def __init__(self, config: PatchTSMixerConfig):
  1082. super().__init__(config)
  1083. self.model = PatchTSMixerModel(config, mask_input=True)
  1084. self.head = PatchTSMixerPretrainHead(config=config)
  1085. self.masked_loss = config.masked_loss
  1086. self.use_return_dict = config.use_return_dict
  1087. # Initialize weights and apply final processing
  1088. if config.post_init:
  1089. self.post_init()
  1090. @auto_docstring
  1091. def forward(
  1092. self,
  1093. past_values: torch.Tensor,
  1094. observed_mask: Optional[torch.Tensor] = None,
  1095. output_hidden_states: Optional[bool] = False,
  1096. return_loss: bool = True,
  1097. return_dict: Optional[bool] = None,
  1098. ) -> PatchTSMixerForPreTrainingOutput:
  1099. r"""
  1100. past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
  1101. Context values of the time series. For a pretraining task, this denotes the input time series to predict
  1102. the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
  1103. for classification or regression tasks, it denotes the appropriate context values of the time series.
  1104. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
  1105. greater than 1.
  1106. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1107. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1108. in `[0, 1]`:
  1109. - 1 for values that are **observed**,
  1110. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1111. return_loss (`bool`, *optional*):
  1112. Whether to return the loss in the `forward` call.
  1113. """
  1114. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1115. if self.masked_loss is True:
  1116. loss = torch.nn.MSELoss(reduction="none")
  1117. else:
  1118. loss = torch.nn.MSELoss(reduction="mean")
  1119. # past_values: tensor [batch_size x context_length x num_input_channels]
  1120. model_output = self.model(
  1121. past_values,
  1122. observed_mask=observed_mask,
  1123. output_hidden_states=output_hidden_states,
  1124. return_dict=return_dict,
  1125. ) # x.last_hidden_state: [batch_size x nvars x num_patch x d_model]
  1126. if isinstance(model_output, tuple):
  1127. model_output = PatchTSMixerModelOutput(*model_output)
  1128. x_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x nvars x num_patch x patch_length]
  1129. if return_loss is True:
  1130. loss_val = loss(x_hat, model_output.patch_input)
  1131. else:
  1132. loss_val = None
  1133. # calculate masked_loss
  1134. if self.masked_loss is True and loss_val is not None:
  1135. loss_val = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)
  1136. if not return_dict:
  1137. return tuple(
  1138. v
  1139. for v in [
  1140. loss_val,
  1141. x_hat,
  1142. model_output.last_hidden_state,
  1143. model_output.hidden_states,
  1144. ]
  1145. )
  1146. return PatchTSMixerForPreTrainingOutput(
  1147. loss=loss_val,
  1148. prediction_outputs=x_hat, # tensor [batch_size x nvars x num_patch x patch_length]
  1149. last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
  1150. hidden_states=model_output.hidden_states,
  1151. )
  1152. @dataclass
  1153. @auto_docstring(
  1154. custom_intro="""
  1155. Output type of [`PatchTSMixerForPredictionOutput`].
  1156. """
  1157. )
  1158. class PatchTSMixerForPredictionOutput(ModelOutput):
  1159. r"""
  1160. loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
  1161. Total loss.
  1162. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
  1163. Prediction output from the forecast head.
  1164. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
  1165. Backbone embeddings before passing through the head.
  1166. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1167. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  1168. loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
  1169. Input mean
  1170. scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
  1171. Input std dev
  1172. """
  1173. loss: Optional[torch.FloatTensor] = None
  1174. prediction_outputs: Optional[torch.FloatTensor] = None
  1175. last_hidden_state: Optional[torch.FloatTensor] = None
  1176. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  1177. loc: Optional[torch.FloatTensor] = None
  1178. scale: Optional[torch.FloatTensor] = None
  1179. @dataclass
  1180. @auto_docstring(
  1181. custom_intro="""
  1182. Base class for time series model's predictions outputs that contains the sampled values from the chosen
  1183. distribution.
  1184. """
  1185. )
  1186. class SamplePatchTSMixerPredictionOutput(ModelOutput):
  1187. r"""
  1188. sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
  1189. Sampled values from the chosen distribution.
  1190. """
  1191. sequences: Optional[torch.FloatTensor] = None
  1192. @dataclass
  1193. @auto_docstring(
  1194. custom_intro="""
  1195. Base class for time series model's predictions outputs that contains the sampled values from the chosen
  1196. distribution.
  1197. """
  1198. )
  1199. class SamplePatchTSMixerRegressionOutput(ModelOutput):
  1200. r"""
  1201. sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
  1202. Sampled values from the chosen distribution.
  1203. """
  1204. sequences: Optional[torch.FloatTensor] = None
  1205. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
  1206. def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
  1207. """
  1208. Computes the negative log likelihood loss from input distribution with respect to target.
  1209. """
  1210. return -input.log_prob(target)
  1211. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
  1212. def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
  1213. """
  1214. Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
  1215. meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
  1216. Args:
  1217. input_tensor (`torch.FloatTensor`):
  1218. Input tensor, of which the average must be computed.
  1219. weights (`torch.FloatTensor`, *optional*):
  1220. Weights tensor, of the same shape as `input_tensor`.
  1221. dim (`int`, *optional*):
  1222. The dim along which to average `input_tensor`.
  1223. Returns:
  1224. `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
  1225. """
  1226. if weights is not None:
  1227. weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
  1228. sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
  1229. return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
  1230. else:
  1231. return input_tensor.mean(dim=dim)
  1232. class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel):
  1233. r"""
  1234. `PatchTSMixer` for forecasting application.
  1235. Args:
  1236. config (`PatchTSMixerConfig`):
  1237. Configuration.
  1238. Returns:
  1239. `None`.
  1240. """
  1241. def __init__(self, config: PatchTSMixerConfig):
  1242. super().__init__(config)
  1243. self.loss = config.loss
  1244. self.use_return_dict = config.use_return_dict
  1245. self.prediction_channel_indices = config.prediction_channel_indices
  1246. self.num_parallel_samples = config.num_parallel_samples
  1247. if config.loss == "mse":
  1248. self.distribution_output = None
  1249. else:
  1250. dim = config.prediction_length
  1251. distribution_output_map = {
  1252. "student_t": StudentTOutput,
  1253. "normal": NormalOutput,
  1254. "negative_binomial": NegativeBinomialOutput,
  1255. }
  1256. output_class = distribution_output_map.get(config.distribution_output, None)
  1257. if output_class is not None:
  1258. self.distribution_output = output_class(dim=dim)
  1259. else:
  1260. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1261. self.model = PatchTSMixerModel(config)
  1262. self.head = PatchTSMixerForPredictionHead(
  1263. config=config,
  1264. distribution_output=self.distribution_output,
  1265. )
  1266. # Initialize weights and apply final processing
  1267. if config.post_init:
  1268. self.post_init()
  1269. @auto_docstring
  1270. def forward(
  1271. self,
  1272. past_values: torch.Tensor,
  1273. observed_mask: Optional[torch.Tensor] = None,
  1274. future_values: Optional[torch.Tensor] = None,
  1275. output_hidden_states: Optional[bool] = False,
  1276. return_loss: bool = True,
  1277. return_dict: Optional[bool] = None,
  1278. ) -> PatchTSMixerForPredictionOutput:
  1279. r"""
  1280. past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
  1281. Context values of the time series. For a pretraining task, this denotes the input time series to predict
  1282. the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
  1283. for classification or regression tasks, it denotes the appropriate context values of the time series.
  1284. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
  1285. greater than 1.
  1286. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1287. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1288. in `[0, 1]`:
  1289. - 1 for values that are **observed**,
  1290. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1291. future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
  1292. `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
  1293. Target values of the time series, that serve as labels for the model. The `future_values` is what the
  1294. Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
  1295. required for a pretraining task.
  1296. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
  1297. to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
  1298. pass the target data with all channels, as channel Filtering for both prediction and target will be
  1299. manually applied before the loss computation.
  1300. return_loss (`bool`, *optional*):
  1301. Whether to return the loss in the `forward` call.
  1302. """
  1303. if self.loss == "mse":
  1304. loss = nn.MSELoss(reduction="mean")
  1305. elif self.loss == "nll":
  1306. loss = nll
  1307. else:
  1308. raise ValueError("Invalid loss function: Allowed values: mse and nll")
  1309. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1310. # past_values: tensor [batch_size x context_length x num_input_channels]
  1311. model_output = self.model(
  1312. past_values,
  1313. observed_mask=observed_mask,
  1314. output_hidden_states=output_hidden_states,
  1315. return_dict=return_dict,
  1316. ) # model_output: [batch_size x nvars x num_patch x d_model]
  1317. if isinstance(model_output, tuple):
  1318. model_output = PatchTSMixerModelOutput(*model_output)
  1319. # tensor [batch_size x prediction_length x num_input_channels]
  1320. y_hat = self.head(model_output.last_hidden_state)
  1321. loss_val = None
  1322. if self.prediction_channel_indices is not None:
  1323. if self.distribution_output:
  1324. distribution = self.distribution_output.distribution(
  1325. y_hat,
  1326. loc=model_output.loc[..., self.prediction_channel_indices],
  1327. scale=model_output.scale[..., self.prediction_channel_indices],
  1328. )
  1329. if future_values is not None and return_loss is True:
  1330. loss_val = loss(
  1331. distribution,
  1332. future_values[..., self.prediction_channel_indices],
  1333. )
  1334. # take average of the loss
  1335. loss_val = weighted_average(loss_val)
  1336. else:
  1337. y_hat = (
  1338. y_hat * model_output.scale[..., self.prediction_channel_indices]
  1339. + model_output.loc[..., self.prediction_channel_indices]
  1340. )
  1341. if future_values is not None and return_loss is True:
  1342. loss_val = loss(y_hat, future_values[..., self.prediction_channel_indices])
  1343. else:
  1344. if self.distribution_output:
  1345. distribution = self.distribution_output.distribution(
  1346. y_hat, loc=model_output.loc, scale=model_output.scale
  1347. )
  1348. if future_values is not None and return_loss is True:
  1349. loss_val = loss(distribution, future_values)
  1350. loss_val = weighted_average(loss_val)
  1351. else:
  1352. y_hat = y_hat * model_output.scale + model_output.loc
  1353. if future_values is not None and return_loss is True:
  1354. loss_val = loss(y_hat, future_values)
  1355. if self.prediction_channel_indices is not None:
  1356. loc = model_output.loc[..., self.prediction_channel_indices]
  1357. scale = model_output.scale[..., self.prediction_channel_indices]
  1358. else:
  1359. loc = model_output.loc
  1360. scale = model_output.scale
  1361. if not return_dict:
  1362. return tuple(
  1363. v
  1364. for v in [
  1365. loss_val,
  1366. y_hat,
  1367. model_output.last_hidden_state,
  1368. model_output.hidden_states,
  1369. loc,
  1370. scale,
  1371. ]
  1372. )
  1373. return PatchTSMixerForPredictionOutput(
  1374. loss=loss_val,
  1375. prediction_outputs=y_hat, # tensor [batch_size x prediction_length x num_input_channels]
  1376. last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
  1377. hidden_states=model_output.hidden_states,
  1378. loc=loc,
  1379. scale=scale,
  1380. )
  1381. @torch.no_grad()
  1382. def generate(
  1383. self,
  1384. past_values: torch.Tensor,
  1385. observed_mask: Optional[torch.Tensor] = None,
  1386. ) -> SamplePatchTSMixerPredictionOutput:
  1387. """
  1388. Generate sequences of sample predictions from a model with a probability distribution head.
  1389. Args:
  1390. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1391. Past values of the time series that serves as context in order to predict the future.
  1392. observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1393. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1394. in `[0, 1]`:
  1395. - 1 for values that are **observed**,
  1396. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1397. Return:
  1398. [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
  1399. number of samples, prediction_length, num_input_channels)`.
  1400. """
  1401. # get number of samples
  1402. num_parallel_samples = self.num_parallel_samples
  1403. # get model output
  1404. outputs = self(
  1405. past_values=past_values,
  1406. future_values=None,
  1407. observed_mask=observed_mask,
  1408. output_hidden_states=False,
  1409. )
  1410. # get distribution
  1411. distribution = self.distribution_output.distribution(
  1412. outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
  1413. )
  1414. # get samples: list of [batch_size x prediction_length x num_channels]
  1415. samples = [distribution.sample() for _ in range(num_parallel_samples)]
  1416. # stack tensors
  1417. samples = torch.stack(samples, dim=1) # [batch_size x num_samples x prediction_length x num_channels]
  1418. return SamplePatchTSMixerPredictionOutput(sequences=samples)
  1419. @dataclass
  1420. @auto_docstring(
  1421. custom_intro="""
  1422. Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].
  1423. """
  1424. )
  1425. class PatchTSMixerForTimeSeriesClassificationOutput(ModelOutput):
  1426. r"""
  1427. loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
  1428. Total loss.
  1429. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
  1430. Prediction output from the classification head.
  1431. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
  1432. Backbone embeddings before passing through the head.
  1433. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1434. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  1435. """
  1436. loss: Optional[torch.FloatTensor] = None
  1437. prediction_outputs: Optional[torch.FloatTensor] = None
  1438. last_hidden_state: Optional[torch.FloatTensor] = None
  1439. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  1440. class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
  1441. r"""
  1442. `PatchTSMixer` for classification application.
  1443. Args:
  1444. config (`PatchTSMixerConfig`):
  1445. Configuration.
  1446. Returns:
  1447. `None`.
  1448. """
  1449. def __init__(self, config: PatchTSMixerConfig):
  1450. super().__init__(config)
  1451. self.model = PatchTSMixerModel(config)
  1452. self.head = PatchTSMixerLinearHead(
  1453. config=config,
  1454. )
  1455. self.use_return_dict = config.use_return_dict
  1456. if config.scaling in ["std", "mean", True]:
  1457. self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
  1458. else:
  1459. self.inject_scale = None
  1460. # Initialize weights and apply final processing
  1461. if config.post_init:
  1462. self.post_init()
  1463. @auto_docstring
  1464. def forward(
  1465. self,
  1466. past_values: torch.Tensor,
  1467. target_values: Optional[torch.Tensor] = None,
  1468. output_hidden_states: Optional[bool] = False,
  1469. return_loss: bool = True,
  1470. return_dict: Optional[bool] = None,
  1471. ) -> PatchTSMixerForTimeSeriesClassificationOutput:
  1472. r"""
  1473. past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
  1474. Context values of the time series. For a pretraining task, this denotes the input time series to predict
  1475. the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
  1476. for classification or regression tasks, it denotes the appropriate context values of the time series.
  1477. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
  1478. greater than 1.
  1479. target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
  1480. `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
  1481. Target
  1482. values of the time series, that serve as labels for the model. The `target_values` is what the
  1483. Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
  1484. required for a pretraining task.
  1485. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
  1486. to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
  1487. pass the target data with all channels, as channel Filtering for both prediction and target will be
  1488. manually applied before the loss computation.
  1489. For a classification task, it has a shape of `(batch_size,)`.
  1490. For a regression task, it has a shape of `(batch_size, num_targets)`.
  1491. return_loss (`bool`, *optional*):
  1492. Whether to return the loss in the `forward` call.
  1493. """
  1494. loss = torch.nn.CrossEntropyLoss()
  1495. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1496. model_output = self.model(
  1497. past_values,
  1498. output_hidden_states=output_hidden_states,
  1499. return_dict=return_dict,
  1500. ) # x: [batch_size x nvars x num_patch x d_model]
  1501. if isinstance(model_output, tuple):
  1502. model_output = PatchTSMixerModelOutput(*model_output)
  1503. if self.inject_scale is not None:
  1504. model_output.last_hidden_state = self.inject_scale(
  1505. model_output.last_hidden_state,
  1506. loc=model_output.loc,
  1507. scale=model_output.scale,
  1508. ) # x: [batch_size x nvars x num_patch x d_model]
  1509. y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels]
  1510. if target_values is not None and return_loss is True:
  1511. loss_val = loss(y_hat, target_values)
  1512. else:
  1513. loss_val = None
  1514. if not return_dict:
  1515. return tuple(
  1516. v
  1517. for v in [
  1518. loss_val,
  1519. y_hat,
  1520. model_output.last_hidden_state,
  1521. model_output.hidden_states,
  1522. ]
  1523. )
  1524. return PatchTSMixerForTimeSeriesClassificationOutput(
  1525. loss=loss_val,
  1526. prediction_outputs=y_hat, # tensor [batch_size x n_labels]
  1527. last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
  1528. hidden_states=model_output.hidden_states,
  1529. )
  1530. @dataclass
  1531. @auto_docstring(
  1532. custom_intro="""
  1533. Output type of [`PatchTSMixerForRegressionOutput`].
  1534. """
  1535. )
  1536. class PatchTSMixerForRegressionOutput(ModelOutput):
  1537. r"""
  1538. loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
  1539. Total loss.
  1540. regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
  1541. Prediction output from the regression head.
  1542. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
  1543. Backbone embeddings before passing through the head.
  1544. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1545. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  1546. """
  1547. loss: Optional[torch.FloatTensor] = None
  1548. regression_outputs: Optional[torch.FloatTensor] = None
  1549. last_hidden_state: Optional[torch.FloatTensor] = None
  1550. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  1551. class InjectScalerStatistics4D(nn.Module):
  1552. def __init__(self, d_model: int, num_patches: int, expansion: int = 2):
  1553. super().__init__()
  1554. self.inverse_trans_expansion = nn.Linear(d_model + 2, expansion * d_model)
  1555. self.inverse_trans_compression = nn.Linear(expansion * d_model, d_model)
  1556. self.map_scale_expansion = nn.Linear(2, 2 * expansion)
  1557. self.map_scale_compression = nn.Linear(2 * expansion, 2)
  1558. self.num_patches = num_patches
  1559. def forward(self, inputs: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
  1560. """
  1561. Args:
  1562. inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
  1563. loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
  1564. scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
  1565. Returns:
  1566. `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
  1567. """
  1568. mean = loc.transpose(-1, -2) # [batch_size x n_channels x 1 ]
  1569. mean = mean.unsqueeze(-2) # [batch_size x n_channels x 1 x 1]
  1570. mean = mean.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1]
  1571. stdev = scale.transpose(-1, -2) # [batch_size x n_channels x 1 ]
  1572. stdev = stdev.unsqueeze(-2) # [batch_size x n_channels x 1 x 1]
  1573. stdev = stdev.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1]
  1574. concat_stats = torch.cat([mean, stdev], dim=-1) # [batch_size x n_channels x num_patch x 2]
  1575. concat_stats = self.map_scale_expansion(concat_stats) # [batch_size x n_channels x num_patch x (2*expansion)]
  1576. concat_stats = self.map_scale_compression(concat_stats) # [batch_size x n_channels x num_patch x 2]
  1577. inputs = torch.cat([inputs, concat_stats], dim=-1) # [batch_size x channels x num_patch x d_model+2]
  1578. inputs = self.inverse_trans_expansion(inputs) # [batch_size x channels x num_patch x (expansion*d_model)]
  1579. inputs = self.inverse_trans_compression(inputs) # [batch_size x channels x num_patch x d_model]
  1580. return inputs
  1581. @auto_docstring(
  1582. custom_intro="""
  1583. `PatchTSMixer` for regression application.
  1584. """
  1585. )
  1586. class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
  1587. def __init__(self, config: PatchTSMixerConfig):
  1588. super().__init__(config)
  1589. self.model = PatchTSMixerModel(config)
  1590. self.loss = config.loss
  1591. self.distribution_output = config.distribution_output
  1592. self.use_return_dict = config.use_return_dict
  1593. self.num_parallel_samples = config.num_parallel_samples
  1594. if config.loss == "mse":
  1595. self.distribution_output = None
  1596. else:
  1597. distribution_output_map = {
  1598. "student_t": StudentTOutput,
  1599. "normal": NormalOutput,
  1600. "negative_binomial": NegativeBinomialOutput,
  1601. }
  1602. output_class = distribution_output_map.get(config.distribution_output)
  1603. if output_class is not None:
  1604. self.distribution_output = output_class(dim=config.num_targets)
  1605. else:
  1606. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1607. if config.scaling in ["std", "mean", True]:
  1608. self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
  1609. else:
  1610. self.inject_scale = None
  1611. self.head = PatchTSMixerLinearHead(
  1612. config=config,
  1613. distribution_output=self.distribution_output,
  1614. )
  1615. # Initialize weights and apply final processing
  1616. if config.post_init:
  1617. self.post_init()
  1618. @auto_docstring
  1619. def forward(
  1620. self,
  1621. past_values: torch.Tensor,
  1622. target_values: Optional[torch.Tensor] = None,
  1623. output_hidden_states: Optional[bool] = False,
  1624. return_loss: bool = True,
  1625. return_dict: Optional[bool] = None,
  1626. ) -> PatchTSMixerForRegressionOutput:
  1627. r"""
  1628. past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
  1629. Context values of the time series. For a pretraining task, this denotes the input time series to predict
  1630. the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
  1631. for classification or regression tasks, it denotes the appropriate context values of the time series.
  1632. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
  1633. greater than 1.
  1634. target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
  1635. `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
  1636. Target values of the time series, that serve as labels for the model. The `target_values` is what the
  1637. Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
  1638. required for a pretraining task.
  1639. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
  1640. to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
  1641. pass the target data with all channels, as channel Filtering for both prediction and target will be
  1642. manually applied before the loss computation.
  1643. For a classification task, it has a shape of `(batch_size,)`.
  1644. For a regression task, it has a shape of `(batch_size, num_targets)`.
  1645. return_loss (`bool`, *optional*):
  1646. Whether to return the loss in the `forward` call.
  1647. """
  1648. if self.loss == "mse":
  1649. loss = nn.MSELoss(reduction="mean")
  1650. elif self.loss == "nll":
  1651. loss = nll
  1652. else:
  1653. raise ValueError("Invalid loss function: Allowed values: mse and nll")
  1654. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1655. model_output = self.model(
  1656. past_values,
  1657. output_hidden_states=output_hidden_states,
  1658. return_dict=return_dict,
  1659. ) # model_output: [batch_size x nvars x num_patch x d_model]
  1660. if isinstance(model_output, tuple):
  1661. model_output = PatchTSMixerModelOutput(*model_output)
  1662. if self.inject_scale is not None:
  1663. model_output.last_hidden_state = self.inject_scale(
  1664. model_output.last_hidden_state,
  1665. loc=model_output.loc,
  1666. scale=model_output.scale,
  1667. ) # x: [batch_size x nvars x num_patch x d_model]
  1668. y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets]
  1669. if target_values is not None and return_loss is True:
  1670. if self.distribution_output:
  1671. if self.distribution_output == "negative_binomial" and torch.any(target_values < 0):
  1672. raise Exception("target_values cannot be negative for negative_binomial distribution.")
  1673. distribution = self.distribution_output.distribution(y_hat)
  1674. # y_hat should be a 2-tuple, each with dimension [bs, num_targets]
  1675. y_hat = tuple(item.view(-1, self.config.num_targets) for item in y_hat)
  1676. loss_val = loss(distribution, target_values)
  1677. # take average of the loss
  1678. loss_val = weighted_average(loss_val)
  1679. else:
  1680. loss_val = loss(y_hat, target_values)
  1681. else:
  1682. loss_val = None
  1683. if not return_dict:
  1684. return tuple(
  1685. v
  1686. for v in [
  1687. loss_val,
  1688. y_hat,
  1689. model_output.last_hidden_state,
  1690. model_output.hidden_states,
  1691. ]
  1692. )
  1693. return PatchTSMixerForRegressionOutput(
  1694. loss=loss_val,
  1695. regression_outputs=y_hat, # tensor [batch_size x num_targets]
  1696. last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model]
  1697. hidden_states=model_output.hidden_states,
  1698. )
  1699. @torch.no_grad()
  1700. def generate(
  1701. self,
  1702. past_values: torch.Tensor,
  1703. ) -> SamplePatchTSMixerRegressionOutput:
  1704. """
  1705. Generate sequences of sample predictions from a model with a probability distribution head.
  1706. Args:
  1707. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1708. Past values of the time series that serves as context in order to predict the target values.
  1709. Return:
  1710. [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
  1711. number of samples, num_targets)`.
  1712. """
  1713. # get number of samples
  1714. num_parallel_samples = self.num_parallel_samples
  1715. # get model output
  1716. outputs = self(
  1717. past_values=past_values,
  1718. target_values=None,
  1719. output_hidden_states=False,
  1720. )
  1721. # get distribution
  1722. distribution = self.distribution_output.distribution(outputs.regression_outputs)
  1723. # get samples
  1724. samples = [
  1725. distribution.sample() for _ in range(num_parallel_samples)
  1726. ] # samples: list of [batch_size x num_targets]
  1727. # stack tensors
  1728. # [batch_size x num_samples x num_targets]
  1729. samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
  1730. return SamplePatchTSMixerRegressionOutput(sequences=samples)
  1731. __all__ = [
  1732. "PatchTSMixerPreTrainedModel",
  1733. "PatchTSMixerModel",
  1734. "PatchTSMixerForPretraining",
  1735. "PatchTSMixerForPrediction",
  1736. "PatchTSMixerForTimeSeriesClassification",
  1737. "PatchTSMixerForRegression",
  1738. ]