modeling_patchtst.py 83 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957
  1. # coding=utf-8
  2. # Copyright 2023 IBM & Hugging Face. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch PatchTST model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2CLS
  22. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  23. from ...modeling_outputs import BaseModelOutput
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
  27. from ...utils import ModelOutput, auto_docstring, logging
  28. from .configuration_patchtst import PatchTSTConfig
  29. logger = logging.get_logger(__name__)
  30. # Copied from transformers.models.bart.modeling_bart.eager_attention_forward
  31. def eager_attention_forward(
  32. module: nn.Module,
  33. query: torch.Tensor,
  34. key: torch.Tensor,
  35. value: torch.Tensor,
  36. attention_mask: Optional[torch.Tensor],
  37. scaling: Optional[float] = None,
  38. dropout: float = 0.0,
  39. head_mask: Optional[torch.Tensor] = None,
  40. **kwargs,
  41. ):
  42. if scaling is None:
  43. scaling = query.size(-1) ** -0.5
  44. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  45. if attention_mask is not None:
  46. attn_weights = attn_weights + attention_mask
  47. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  48. if head_mask is not None:
  49. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  50. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  51. attn_output = torch.matmul(attn_weights, value)
  52. attn_output = attn_output.transpose(1, 2).contiguous()
  53. return attn_output, attn_weights
  54. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTST
  55. class PatchTSTAttention(nn.Module):
  56. """Multi-headed attention from 'Attention Is All You Need' paper"""
  57. def __init__(
  58. self,
  59. embed_dim: int,
  60. num_heads: int,
  61. dropout: float = 0.0,
  62. is_decoder: bool = False,
  63. bias: bool = True,
  64. is_causal: bool = False,
  65. config: Optional[PatchTSTConfig] = None,
  66. ):
  67. super().__init__()
  68. self.embed_dim = embed_dim
  69. self.num_heads = num_heads
  70. self.dropout = dropout
  71. self.head_dim = embed_dim // num_heads
  72. self.config = config
  73. if (self.head_dim * num_heads) != self.embed_dim:
  74. raise ValueError(
  75. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  76. f" and `num_heads`: {num_heads})."
  77. )
  78. self.scaling = self.head_dim**-0.5
  79. self.is_decoder = is_decoder
  80. self.is_causal = is_causal
  81. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  82. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  83. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  84. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  85. def forward(
  86. self,
  87. hidden_states: torch.Tensor,
  88. key_value_states: Optional[torch.Tensor] = None,
  89. attention_mask: Optional[torch.Tensor] = None,
  90. layer_head_mask: Optional[torch.Tensor] = None,
  91. output_attentions: Optional[bool] = False,
  92. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  93. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  94. **kwargs: Unpack[FlashAttentionKwargs],
  95. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  96. """Input shape: Batch x Time x Channel"""
  97. # if key_value_states are provided this layer is used as a cross-attention layer
  98. # for the decoder
  99. is_cross_attention = key_value_states is not None
  100. # determine input shapes
  101. bsz, tgt_len = hidden_states.shape[:-1]
  102. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  103. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  104. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  105. # get query proj
  106. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  107. current_states = key_value_states if is_cross_attention else hidden_states
  108. key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  109. value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  110. attention_interface: Callable = eager_attention_forward
  111. if self.config._attn_implementation != "eager":
  112. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  113. attn_output, attn_weights = attention_interface(
  114. self,
  115. query_states,
  116. key_states,
  117. value_states,
  118. attention_mask,
  119. dropout=0.0 if not self.training else self.dropout,
  120. scaling=self.scaling,
  121. output_attentions=output_attentions,
  122. head_mask=layer_head_mask,
  123. **kwargs,
  124. )
  125. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  126. attn_output = self.out_proj(attn_output)
  127. return attn_output, attn_weights, None
  128. class PatchTSTBatchNorm(nn.Module):
  129. """
  130. Compute batch normalization over the sequence length (time) dimension.
  131. """
  132. def __init__(self, config: PatchTSTConfig):
  133. super().__init__()
  134. self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
  135. def forward(self, inputs: torch.Tensor):
  136. """
  137. Parameters:
  138. inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
  139. input for Batch norm calculation
  140. Returns:
  141. `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
  142. """
  143. output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
  144. output = self.batchnorm(output)
  145. return output.transpose(1, 2)
  146. def random_masking(
  147. inputs: torch.Tensor,
  148. mask_ratio: float,
  149. unmasked_channel_indices: Optional[list] = None,
  150. channel_consistent_masking: bool = False,
  151. mask_value: int = 0,
  152. ):
  153. """random_masking: Mask the input considering the control variables.
  154. Args:
  155. inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
  156. The input tensor to mask.
  157. mask_ratio (`float`):
  158. Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
  159. unmasked_channel_indices (list, *optional*):
  160. Indices of channels that will not be masked.
  161. channel_consistent_masking (bool, *optional*, defaults to `False`):
  162. When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
  163. across channels.
  164. mask_value (int, *optional*, defaults to 0):
  165. Define the value of masked patches for pretraining.
  166. Returns:
  167. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
  168. n]
  169. """
  170. if mask_ratio < 0 or mask_ratio >= 1:
  171. raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
  172. batch_size, num_channels, sequence_length, num_features = inputs.shape
  173. device = inputs.device
  174. len_keep = int(sequence_length * (1 - mask_ratio))
  175. if channel_consistent_masking:
  176. noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L
  177. noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time
  178. else:
  179. # noise in [0, 1], bs x num_channels x L
  180. noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
  181. # mask: [bs x num_channels x num_patch]
  182. mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
  183. mask[:, :, :len_keep] = 0
  184. # sort noise for each sample
  185. ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove
  186. ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L]
  187. mask = torch.gather(mask, dim=-1, index=ids_restore)
  188. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length]
  189. if unmasked_channel_indices is not None:
  190. mask[:, unmasked_channel_indices, :, :] = 0
  191. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  192. return inputs_mask, mask[..., 0]
  193. def forecast_masking(
  194. inputs: torch.Tensor,
  195. num_forecast_mask_patches: Union[list, int],
  196. unmasked_channel_indices: Optional[list] = None,
  197. mask_value: int = 0,
  198. ):
  199. """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
  200. If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
  201. Parameters:
  202. inputs (`torch.Tensor`):
  203. Input of shape `(bs, num_channels, num_patch, patch_length)`
  204. num_forecast_mask_patches (`list`):
  205. Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
  206. unmasked_channel_indices (`list`, *optional*):
  207. Indices of channels that are not masked.
  208. mask_value (`int`, *optional*, defaults to 0):
  209. Values in the masked patches will be filled by `mask_value`.
  210. Returns:
  211. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
  212. num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
  213. """
  214. if isinstance(num_forecast_mask_patches, int):
  215. num_forecast_mask_patches = [num_forecast_mask_patches]
  216. forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
  217. batch_size, num_channels, sequence_length, num_features = inputs.shape
  218. mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
  219. t_list = []
  220. total_length = 0
  221. total_ratio = sum(forecast_mask_ratios)
  222. for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
  223. if patch_length <= 0 or patch_length >= sequence_length:
  224. raise ValueError(
  225. f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
  226. )
  227. temp_len = int(batch_size * ratio / total_ratio)
  228. t_list.append([patch_length, ratio, temp_len])
  229. total_length += temp_len
  230. t_list = sorted(t_list, key=lambda x: x[2])
  231. if total_length < batch_size:
  232. t_list[0][2] = t_list[0][2] + (batch_size - total_length)
  233. elif total_length > batch_size:
  234. t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
  235. batch1 = 0
  236. for patch_len, _, temp_len in t_list:
  237. batch2 = batch1 + temp_len
  238. mask[batch1:batch2, :, -patch_len:] = 1
  239. batch1 = batch2
  240. perm = torch.randperm(mask.shape[0])
  241. mask = mask[perm]
  242. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len]
  243. if unmasked_channel_indices is not None:
  244. mask[:, unmasked_channel_indices, :, :] = 0
  245. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  246. return inputs_mask, mask[..., 0]
  247. class PatchTSTPatchify(nn.Module):
  248. """
  249. A class to patchify the time series sequence into different patches
  250. Returns:
  251. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  252. """
  253. def __init__(self, config: PatchTSTConfig):
  254. super().__init__()
  255. self.sequence_length = config.context_length
  256. self.patch_length = config.patch_length
  257. self.patch_stride = config.patch_stride
  258. if self.sequence_length <= self.patch_length:
  259. raise ValueError(
  260. f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
  261. )
  262. # get the number of patches
  263. self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
  264. new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
  265. self.sequence_start = self.sequence_length - new_sequence_length
  266. def forward(self, past_values: torch.Tensor):
  267. """
  268. Parameters:
  269. past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
  270. Input for patchification
  271. Returns:
  272. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  273. """
  274. sequence_length = past_values.shape[-2]
  275. if sequence_length != self.sequence_length:
  276. raise ValueError(
  277. f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
  278. )
  279. # output: [bs x new_sequence_length x num_channels]
  280. output = past_values[:, self.sequence_start :, :]
  281. # output: [bs x num_patches x num_input_channels x patch_length]
  282. output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
  283. # output: [bs x num_input_channels x num_patches x patch_length]
  284. output = output.transpose(-2, -3).contiguous()
  285. return output
  286. class PatchTSTMasking(nn.Module):
  287. """
  288. Class to perform random or forecast masking.
  289. Parameters:
  290. config (`PatchTSTConfig`): model config
  291. Returns:
  292. x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  293. Masked patched input
  294. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  295. Bool tensor indicating True on masked points
  296. """
  297. def __init__(self, config: PatchTSTConfig):
  298. super().__init__()
  299. self.random_mask_ratio = config.random_mask_ratio
  300. self.channel_consistent_masking = config.channel_consistent_masking
  301. self.mask_type = config.mask_type
  302. self.num_forecast_mask_patches = config.num_forecast_mask_patches
  303. self.unmasked_channel_indices = config.unmasked_channel_indices
  304. self.mask_value = config.mask_value
  305. if self.unmasked_channel_indices is not None:
  306. self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
  307. def forward(self, patch_input: torch.Tensor):
  308. """
  309. Parameters:
  310. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  311. Patch input
  312. Return:
  313. masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  314. Masked patched input
  315. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  316. Bool tensor indicating True on masked points
  317. """
  318. if self.mask_type == "random":
  319. masked_input, mask = random_masking(
  320. inputs=patch_input,
  321. mask_ratio=self.random_mask_ratio,
  322. unmasked_channel_indices=self.unmasked_channel_indices,
  323. channel_consistent_masking=self.channel_consistent_masking,
  324. mask_value=self.mask_value,
  325. )
  326. elif self.mask_type == "forecast":
  327. masked_input, mask = forecast_masking(
  328. inputs=patch_input,
  329. num_forecast_mask_patches=self.num_forecast_mask_patches,
  330. unmasked_channel_indices=self.unmasked_channel_indices,
  331. mask_value=self.mask_value,
  332. )
  333. else:
  334. raise ValueError(f"Invalid mask type {self.mask_type}.")
  335. # mask: [bs x num_input_channels x num_patch]
  336. mask = mask.bool()
  337. return masked_input, mask
  338. class PatchTSTEncoderLayer(nn.Module):
  339. """
  340. PatchTST encoder layer
  341. """
  342. def __init__(self, config: PatchTSTConfig):
  343. super().__init__()
  344. self.channel_attention = config.channel_attention
  345. # Multi-Head attention
  346. self.self_attn = PatchTSTAttention(
  347. embed_dim=config.d_model,
  348. num_heads=config.num_attention_heads,
  349. dropout=config.attention_dropout,
  350. config=config,
  351. )
  352. # Add & Norm of the sublayer 1
  353. self.dropout_path1 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  354. if config.norm_type == "batchnorm":
  355. self.norm_sublayer1 = PatchTSTBatchNorm(config)
  356. elif config.norm_type == "layernorm":
  357. self.norm_sublayer1 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  358. else:
  359. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  360. # Add & Norm of the sublayer 2
  361. if self.channel_attention:
  362. self.dropout_path2 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  363. if config.norm_type == "batchnorm":
  364. self.norm_sublayer2 = PatchTSTBatchNorm(config)
  365. elif config.norm_type == "layernorm":
  366. self.norm_sublayer2 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  367. else:
  368. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  369. # Position-wise Feed-Forward
  370. self.ff = nn.Sequential(
  371. nn.Linear(config.d_model, config.ffn_dim, bias=config.bias),
  372. ACT2CLS[config.activation_function](),
  373. nn.Dropout(config.ff_dropout) if config.ff_dropout > 0 else nn.Identity(),
  374. nn.Linear(config.ffn_dim, config.d_model, bias=config.bias),
  375. )
  376. # Add & Norm of sublayer 3
  377. self.dropout_path3 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  378. if config.norm_type == "batchnorm":
  379. self.norm_sublayer3 = PatchTSTBatchNorm(config)
  380. elif config.norm_type == "layernorm":
  381. self.norm_sublayer3 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  382. else:
  383. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  384. self.pre_norm = config.pre_norm
  385. def forward(self, hidden_state: torch.Tensor, output_attentions: Optional[bool] = None):
  386. """
  387. Parameters:
  388. hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*):
  389. Past values of the time series
  390. output_attentions (`bool`, *optional*):
  391. Whether or not to return the output attention of all layers
  392. Return:
  393. `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`
  394. """
  395. batch_size, num_input_channels, sequence_length, d_model = hidden_state.shape
  396. # First sublayer: attention across time
  397. # hidden_states: [(bs*num_channels) x sequence_length x d_model]
  398. hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model)
  399. if self.pre_norm:
  400. ## Norm and Multi-Head attention and Add residual connection
  401. attn_output, attn_weights, _ = self.self_attn(
  402. hidden_states=self.norm_sublayer1(hidden_state), output_attentions=output_attentions
  403. )
  404. # Add: residual connection with residual dropout
  405. hidden_state = hidden_state + self.dropout_path1(attn_output)
  406. else:
  407. ## Multi-Head attention and Add residual connection and Norm - Standard Transformer from BERT
  408. attn_output, attn_weights, _ = self.self_attn(
  409. hidden_states=hidden_state, output_attentions=output_attentions
  410. )
  411. # hidden_states: [(bs*num_channels) x sequence_length x d_model]
  412. hidden_state = self.norm_sublayer1(hidden_state + self.dropout_path1(attn_output))
  413. # hidden_state: [bs x num_channels x sequence_length x d_model]
  414. hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model)
  415. # second sublayer: attention across variable at any given time
  416. if self.channel_attention:
  417. # hidden_state: [bs x sequence_length x num_channels x d_model]
  418. hidden_state = hidden_state.transpose(2, 1).contiguous()
  419. # hidden_state: [(bs*sequence_length) x num_channels x d_model]
  420. hidden_state = hidden_state.view(batch_size * sequence_length, num_input_channels, d_model)
  421. if self.pre_norm:
  422. ## Norm and Multi-Head attention and Add residual connection
  423. attn_output, channel_attn_weights, _ = self.self_attn(
  424. hidden_states=self.norm_sublayer2(hidden_state), output_attentions=output_attentions
  425. )
  426. # Add: residual connection with residual dropout
  427. hidden_state = hidden_state + self.dropout_path2(attn_output)
  428. else:
  429. ## Multi-Head attention and Add residual connection and Norm
  430. attn_output, channel_attn_weights, _ = self.self_attn(
  431. hidden_states=hidden_state, output_attentions=output_attentions
  432. )
  433. # hidden_states: [(bs*sequence_length) x num_channels x d_model]
  434. hidden_state = self.norm_sublayer2(hidden_state + self.dropout_path2(attn_output))
  435. # Reshape hidden state
  436. # hidden_state: [bs x sequence_length x num_channels x d_model]
  437. hidden_state = hidden_state.reshape(batch_size, sequence_length, num_input_channels, d_model)
  438. # hidden_state: [bs x num_channels x sequence_length x d_model]
  439. hidden_state = hidden_state.transpose(1, 2).contiguous()
  440. # Third sublayer: mixing across hidden
  441. # hidden_state: [(batch_size*num_channels) x sequence_length x d_model]
  442. hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model)
  443. if self.pre_norm:
  444. ## Norm and Position-wise Feed-Forward and Add residual connection
  445. # Add: residual connection with residual dropout
  446. hidden_state = hidden_state + self.dropout_path3(self.ff(self.norm_sublayer3(hidden_state)))
  447. else:
  448. ## Position-wise Feed-Forward and Add residual connection and Norm
  449. # Add: residual connection with residual dropout
  450. hidden_state = self.norm_sublayer3(hidden_state + self.dropout_path3(self.ff(hidden_state)))
  451. # [bs x num_channels x sequence_length x d_model]
  452. hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model)
  453. outputs = (hidden_state,)
  454. if output_attentions:
  455. outputs += (attn_weights, channel_attn_weights) if self.channel_attention else (attn_weights,)
  456. return outputs
  457. @auto_docstring
  458. class PatchTSTPreTrainedModel(PreTrainedModel):
  459. config: PatchTSTConfig
  460. base_model_prefix = "model"
  461. main_input_name = "past_values"
  462. supports_gradient_checkpointing = False
  463. def _init_weights(self, module: nn.Module):
  464. """
  465. Initialize weights
  466. """
  467. if isinstance(module, PatchTSTPositionalEncoding):
  468. # get the number of patches
  469. num_patches = (
  470. max(self.config.context_length, self.config.patch_length) - self.config.patch_length
  471. ) // self.config.patch_stride + 1
  472. # initialize cls_token
  473. if self.config.use_cls_token:
  474. nn.init.normal_(module.cls_token, std=0.02)
  475. num_patches += 1
  476. # initialize positional encoding
  477. module.position_enc = module._init_pe(self.config, num_patches)
  478. elif isinstance(module, nn.LayerNorm):
  479. module.bias.data.zero_()
  480. module.weight.data.fill_(1.0)
  481. elif isinstance(module, PatchTSTBatchNorm):
  482. module.batchnorm.bias.data.zero_()
  483. module.batchnorm.weight.data.fill_(1.0)
  484. elif isinstance(module, nn.Linear):
  485. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  486. if module.bias is not None:
  487. module.bias.data.zero_()
  488. def _set_gradient_checkpointing(self, module, value=False):
  489. if isinstance(module, (PatchTSTEncoder)):
  490. module.gradient_checkpointing = value
  491. class PatchTSTEmbedding(nn.Module):
  492. def __init__(self, config: PatchTSTConfig):
  493. super().__init__()
  494. self.num_input_channels = config.num_input_channels
  495. self.share_embedding = config.share_embedding
  496. # Input encoding: projection of feature vectors onto a d-dim vector space
  497. if self.share_embedding:
  498. self.input_embedding = nn.Linear(config.patch_length, config.d_model)
  499. else:
  500. self.input_embedding = nn.ModuleList()
  501. for _ in range(config.num_input_channels):
  502. self.input_embedding.append(nn.Linear(config.patch_length, config.d_model))
  503. def forward(self, patch_input: torch.Tensor):
  504. """
  505. Parameters:
  506. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  507. Patch input for embedding
  508. return:
  509. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
  510. """
  511. # Input encoding
  512. num_input_channels = patch_input.shape[1]
  513. if num_input_channels != self.num_input_channels:
  514. raise ValueError(
  515. f"The defined number of input channels ({self.num_input_channels}) in the config "
  516. f"has to be the same as the number of channels in the batch input ({num_input_channels})"
  517. )
  518. if self.share_embedding:
  519. embeddings = self.input_embedding(patch_input) # x: [bs x num_channels x num_patches x d_model]
  520. else:
  521. embeddings = [self.input_embedding[i](patch_input[:, i, :, :]) for i in range(num_input_channels)]
  522. embeddings = torch.stack(embeddings, dim=1)
  523. return embeddings
  524. class PatchTSTPositionalEncoding(nn.Module):
  525. """
  526. Class for positional encoding
  527. """
  528. def __init__(self, config: PatchTSTConfig, num_patches: int):
  529. super().__init__()
  530. self.use_cls_token = config.use_cls_token
  531. self.num_input_channels = config.num_input_channels
  532. if config.use_cls_token:
  533. # cls_token: [1 x num_input_channels x 1 x d_model]
  534. self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, config.d_model))
  535. num_patches += 1
  536. # positional encoding: [num_patches x d_model]
  537. self.position_enc = self._init_pe(config, num_patches)
  538. # Positional dropout
  539. self.positional_dropout = (
  540. nn.Dropout(config.positional_dropout) if config.positional_dropout > 0 else nn.Identity()
  541. )
  542. @staticmethod
  543. def _init_pe(config: PatchTSTConfig, num_patches: int) -> nn.Parameter:
  544. # Positional encoding
  545. if config.positional_encoding_type == "random":
  546. position_enc = nn.Parameter(torch.randn(num_patches, config.d_model), requires_grad=True)
  547. elif config.positional_encoding_type == "sincos":
  548. position_enc = torch.zeros(num_patches, config.d_model)
  549. position = torch.arange(0, num_patches).unsqueeze(1)
  550. div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
  551. position_enc[:, 0::2] = torch.sin(position * div_term)
  552. position_enc[:, 1::2] = torch.cos(position * div_term)
  553. position_enc = position_enc - position_enc.mean()
  554. position_enc = position_enc / (position_enc.std() * 10)
  555. position_enc = nn.Parameter(position_enc, requires_grad=False)
  556. else:
  557. raise ValueError(
  558. f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
  559. )
  560. return position_enc
  561. def forward(self, patch_input: torch.Tensor):
  562. if self.use_cls_token:
  563. # patch_input: [bs x num_channels x num_patches x d_model]
  564. patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :])
  565. # append cls token where cls_token: [1 x num_channels x 1 x d_model]
  566. cls_token = self.cls_token + self.position_enc[:1, :]
  567. # get the same copy of cls_token for all the samples in batch: [bs x num_channels x 1 x d_model]
  568. cls_tokens = cls_token.expand(patch_input.shape[0], self.num_input_channels, -1, -1)
  569. # hidden_state: [bs x num_channels x (num_patches+1) x d_model]
  570. hidden_state = torch.cat((cls_tokens, patch_input), dim=2)
  571. else:
  572. # hidden_state: [bs x num_channels x num_patches x d_model]
  573. hidden_state = self.positional_dropout(patch_input + self.position_enc)
  574. return hidden_state
  575. class PatchTSTEncoder(PatchTSTPreTrainedModel):
  576. """
  577. PatchTST Encoder
  578. """
  579. def __init__(self, config: PatchTSTConfig, num_patches: int):
  580. super().__init__(config)
  581. self.gradient_checkpointing = False
  582. # Input embedding: projection of feature vectors onto a d-dim vector space
  583. self.embedder = PatchTSTEmbedding(config)
  584. # Positional encoding
  585. self.positional_encoder = PatchTSTPositionalEncoding(config, num_patches)
  586. # Encoder
  587. self.layers = nn.ModuleList([PatchTSTEncoderLayer(config) for i in range(config.num_hidden_layers)])
  588. # Initialize weights and apply final processing
  589. self.post_init()
  590. def forward(
  591. self,
  592. patch_input: torch.Tensor,
  593. output_hidden_states: Optional[bool] = None,
  594. output_attentions: Optional[bool] = None,
  595. ) -> BaseModelOutput:
  596. """
  597. Parameters:
  598. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  599. Past values of the time series
  600. output_hidden_states (bool, optional): Indicates if hidden states should be outputted.
  601. output_attentions (bool, optional): Indicates if attentions should be outputted.
  602. return:
  603. `BaseModelOutput`
  604. """
  605. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  606. output_hidden_states = (
  607. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  608. )
  609. # Input embedding
  610. patch_input = self.embedder(patch_input)
  611. # Positional encoding
  612. hidden_state = self.positional_encoder(patch_input)
  613. encoder_states = () if output_hidden_states else None
  614. all_attentions = () if output_attentions else None
  615. for encoder_layer in self.layers:
  616. if output_hidden_states:
  617. encoder_states = encoder_states + (hidden_state,)
  618. layer_outputs = encoder_layer(hidden_state=hidden_state, output_attentions=output_attentions)
  619. # get hidden state. hidden_state shape is [bs x num_channels x num_patches x d_model]
  620. # or [bs x num_channels x (num_patches+1) x d_model] if use cls_token
  621. hidden_state = layer_outputs[0]
  622. # append attention matrix at each layer
  623. if output_attentions:
  624. all_attentions = all_attentions + (layer_outputs[1],)
  625. # return past_values, hidden_states
  626. return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=encoder_states, attentions=all_attentions)
  627. @dataclass
  628. @auto_docstring(
  629. custom_intro="""
  630. Base class for model's outputs, with potential hidden states.
  631. """
  632. )
  633. class PatchTSTModelOutput(ModelOutput):
  634. r"""
  635. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
  636. Sequence of hidden-states at the output of the last layer of the model.
  637. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  638. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  639. one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of
  640. the model at the output of each layer plus the optional initial embedding outputs.
  641. mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
  642. Bool masked tensor indicating which patches are masked
  643. loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
  644. Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  645. scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
  646. Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  647. patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
  648. Patched input to the Transformer
  649. """
  650. last_hidden_state: Optional[torch.FloatTensor] = None
  651. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  652. attentions: Optional[tuple[torch.FloatTensor]] = None
  653. mask: Optional[torch.FloatTensor] = None
  654. loc: Optional[torch.FloatTensor] = None
  655. scale: Optional[torch.FloatTensor] = None
  656. patch_input: Optional[torch.FloatTensor] = None
  657. @dataclass
  658. @auto_docstring(
  659. custom_intro="""
  660. Output type of [`PatchTSTForPretraining`].
  661. """
  662. )
  663. class PatchTSTForPretrainingOutput(ModelOutput):
  664. r"""
  665. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  666. MSE loss.
  667. prediction_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  668. Prediction outputs of the time series modeling heads.
  669. """
  670. loss: Optional[torch.FloatTensor] = None
  671. prediction_output: Optional[torch.FloatTensor] = None
  672. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  673. attentions: Optional[tuple[torch.FloatTensor]] = None
  674. @dataclass
  675. @auto_docstring(
  676. custom_intro="""
  677. Output type of [`PatchTSTForRegression`].
  678. """
  679. )
  680. class PatchTSTForRegressionOutput(ModelOutput):
  681. r"""
  682. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  683. MSE loss.
  684. regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
  685. Regression outputs of the time series modeling heads.
  686. """
  687. loss: Optional[torch.FloatTensor] = None
  688. regression_outputs: Optional[torch.FloatTensor] = None
  689. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  690. attentions: Optional[tuple[torch.FloatTensor]] = None
  691. @dataclass
  692. @auto_docstring(
  693. custom_intro="""
  694. Output type of [`PatchTSTForPrediction`].
  695. """
  696. )
  697. class PatchTSTForPredictionOutput(ModelOutput):
  698. r"""
  699. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  700. MSE loss.
  701. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`):
  702. Prediction outputs of the time series modeling heads.
  703. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  704. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  705. sequence_length)`.
  706. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  707. heads.
  708. loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
  709. Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  710. scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
  711. Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  712. """
  713. loss: Optional[torch.FloatTensor] = None
  714. prediction_outputs: Optional[torch.FloatTensor] = None
  715. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  716. attentions: Optional[tuple[torch.FloatTensor]] = None
  717. loc: Optional[torch.FloatTensor] = None
  718. scale: Optional[torch.FloatTensor] = None
  719. @dataclass
  720. @auto_docstring(
  721. custom_intro="""
  722. Output type of [`PatchTSTForClassification`].
  723. """
  724. )
  725. class PatchTSTForClassificationOutput(ModelOutput):
  726. r"""
  727. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  728. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  729. (classification) loss.
  730. prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
  731. Prediction scores of the PatchTST modeling head (scores before SoftMax).
  732. """
  733. loss: Optional[torch.FloatTensor] = None
  734. prediction_logits: Optional[torch.FloatTensor] = None
  735. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  736. attentions: Optional[tuple[torch.FloatTensor]] = None
  737. @dataclass
  738. @auto_docstring(
  739. custom_intro="""
  740. Base class for time series model's predictions outputs that contains the sampled values from the chosen
  741. distribution.
  742. """
  743. )
  744. class SamplePatchTSTOutput(ModelOutput):
  745. r"""
  746. sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, num_targets)`):
  747. Sampled values from the chosen distribution.
  748. """
  749. sequences: Optional[torch.FloatTensor] = None
  750. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
  751. def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
  752. """
  753. Computes the negative log likelihood loss from input distribution with respect to target.
  754. """
  755. return -input.log_prob(target)
  756. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
  757. def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
  758. """
  759. Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
  760. meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
  761. Args:
  762. input_tensor (`torch.FloatTensor`):
  763. Input tensor, of which the average must be computed.
  764. weights (`torch.FloatTensor`, *optional*):
  765. Weights tensor, of the same shape as `input_tensor`.
  766. dim (`int`, *optional*):
  767. The dim along which to average `input_tensor`.
  768. Returns:
  769. `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
  770. """
  771. if weights is not None:
  772. weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
  773. sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
  774. return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
  775. else:
  776. return input_tensor.mean(dim=dim)
  777. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  778. class PatchTSTStdScaler(nn.Module):
  779. """
  780. Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
  781. subtracting from the mean and dividing by the standard deviation.
  782. """
  783. def __init__(self, config: PatchTSTConfig):
  784. super().__init__()
  785. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  786. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  787. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
  788. def forward(
  789. self, data: torch.Tensor, observed_indicator: torch.Tensor
  790. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  791. """
  792. Parameters:
  793. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  794. input for Batch norm calculation
  795. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  796. Calculating the scale on the observed indicator.
  797. Returns:
  798. tuple of `torch.Tensor` of shapes
  799. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  800. `(batch_size, 1, num_input_channels)`)
  801. """
  802. denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
  803. denominator = denominator.clamp_min(1.0)
  804. loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
  805. variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
  806. scale = torch.sqrt(variance + self.minimum_scale)
  807. return (data - loc) / scale, loc, scale
  808. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  809. class PatchTSTMeanScaler(nn.Module):
  810. """
  811. Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
  812. accordingly.
  813. """
  814. def __init__(self, config: PatchTSTConfig):
  815. super().__init__()
  816. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  817. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  818. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
  819. self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
  820. def forward(
  821. self, data: torch.Tensor, observed_indicator: torch.Tensor
  822. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  823. """
  824. Parameters:
  825. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  826. input for Batch norm calculation
  827. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  828. Calculating the scale on the observed indicator.
  829. Returns:
  830. tuple of `torch.Tensor` of shapes
  831. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  832. `(batch_size, 1, num_input_channels)`)
  833. """
  834. ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
  835. num_observed = observed_indicator.sum(self.dim, keepdim=True)
  836. scale = ts_sum / torch.clamp(num_observed, min=1)
  837. # If `default_scale` is provided, we use it, otherwise we use the scale
  838. # of the batch.
  839. if self.default_scale is None:
  840. batch_sum = ts_sum.sum(dim=0)
  841. batch_observations = torch.clamp(num_observed.sum(0), min=1)
  842. default_scale = torch.squeeze(batch_sum / batch_observations)
  843. else:
  844. default_scale = self.default_scale * torch.ones_like(scale)
  845. # apply default scale where there are no observations
  846. scale = torch.where(num_observed > 0, scale, default_scale)
  847. # ensure the scale is at least `self.minimum_scale`
  848. scale = torch.clamp(scale, min=self.minimum_scale)
  849. scaled_data = data / scale
  850. if not self.keepdim:
  851. scale = scale.squeeze(dim=self.dim)
  852. return scaled_data, torch.zeros_like(scale), scale
  853. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  854. class PatchTSTNOPScaler(nn.Module):
  855. """
  856. Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
  857. """
  858. def __init__(self, config: PatchTSTConfig):
  859. super().__init__()
  860. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  861. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  862. def forward(
  863. self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None
  864. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  865. """
  866. Parameters:
  867. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  868. input for Batch norm calculation
  869. Returns:
  870. tuple of `torch.Tensor` of shapes
  871. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  872. `(batch_size, 1, num_input_channels)`)
  873. """
  874. scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  875. loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  876. return data, loc, scale
  877. class PatchTSTScaler(nn.Module):
  878. def __init__(self, config: PatchTSTConfig):
  879. super().__init__()
  880. if config.scaling == "mean" or config.scaling is True:
  881. self.scaler = PatchTSTMeanScaler(config)
  882. elif config.scaling == "std":
  883. self.scaler = PatchTSTStdScaler(config)
  884. else:
  885. self.scaler = PatchTSTNOPScaler(config)
  886. def forward(
  887. self, data: torch.Tensor, observed_indicator: torch.Tensor
  888. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  889. """
  890. Parameters:
  891. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  892. Input for scaler calculation
  893. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  894. Calculating the scale on the observed indicator.
  895. Returns:
  896. tuple of `torch.Tensor` of shapes
  897. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  898. `(batch_size, 1, um_input_channels)`)
  899. """
  900. data, loc, scale = self.scaler(data, observed_indicator)
  901. return data, loc, scale
  902. @auto_docstring
  903. class PatchTSTModel(PatchTSTPreTrainedModel):
  904. def __init__(self, config: PatchTSTConfig):
  905. super().__init__(config)
  906. self.scaler = PatchTSTScaler(config)
  907. self.patchifier = PatchTSTPatchify(config)
  908. self.do_mask_input = config.do_mask_input
  909. # get num_patches information from PatchTSTPatchify
  910. num_patches = self.patchifier.num_patches
  911. if self.do_mask_input:
  912. self.masking = PatchTSTMasking(config)
  913. else:
  914. self.masking = nn.Identity()
  915. self.encoder = PatchTSTEncoder(config, num_patches=num_patches)
  916. # Initialize weights and apply final processing
  917. self.post_init()
  918. def forward(
  919. self,
  920. past_values: torch.Tensor,
  921. past_observed_mask: Optional[torch.Tensor] = None,
  922. future_values: Optional[torch.Tensor] = None,
  923. output_hidden_states: Optional[bool] = None,
  924. output_attentions: Optional[bool] = None,
  925. return_dict: Optional[bool] = None,
  926. ) -> Union[tuple, PatchTSTModelOutput]:
  927. r"""
  928. Parameters:
  929. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  930. Input sequence to the model
  931. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  932. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  933. in `[0, 1]`:
  934. - 1 for values that are **observed**,
  935. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  936. future_values (`torch.BoolTensor` of shape `(batch_size, prediction_length, num_input_channels)`, *optional*):
  937. Future target values associated with the `past_values`
  938. output_hidden_states (`bool`, *optional*):
  939. Whether or not to return the hidden states of all layers
  940. output_attentions (`bool`, *optional*):
  941. Whether or not to return the output attention of all layers
  942. return_dict (`bool`, *optional*):
  943. Whether or not to return a `ModelOutput` instead of a plain tuple.
  944. Returns:
  945. `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False)
  946. Examples:
  947. ```python
  948. >>> from huggingface_hub import hf_hub_download
  949. >>> import torch
  950. >>> from transformers import PatchTSTModel
  951. >>> file = hf_hub_download(
  952. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  953. ... )
  954. >>> batch = torch.load(file)
  955. >>> model = PatchTSTModel.from_pretrained("namctin/patchtst_etth1_pretrain")
  956. >>> # during training, one provides both past and future values
  957. >>> outputs = model(
  958. ... past_values=batch["past_values"],
  959. ... future_values=batch["future_values"],
  960. ... )
  961. >>> last_hidden_state = outputs.last_hidden_state
  962. ```"""
  963. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  964. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  965. output_hidden_states = (
  966. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  967. )
  968. if past_observed_mask is None:
  969. past_observed_mask = torch.ones_like(past_values)
  970. # x: tensor [bs x sequence_length x num_input_channels]
  971. scaled_past_values, loc, scale = self.scaler(past_values, past_observed_mask)
  972. # patched_values: [bs x num_input_channels x num_patches x patch_length] for pretrain
  973. patched_values = self.patchifier(scaled_past_values)
  974. if self.do_mask_input:
  975. masked_values, mask = self.masking(patched_values)
  976. else:
  977. masked_values, mask = self.masking(patched_values), None
  978. encoder_output = self.encoder(
  979. patch_input=masked_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
  980. )
  981. if not return_dict:
  982. outputs = (encoder_output.last_hidden_state, encoder_output.hidden_states, encoder_output.attentions)
  983. outputs = outputs + (mask, loc, scale, patched_values)
  984. return tuple(v for v in outputs if v is not None)
  985. return PatchTSTModelOutput(
  986. last_hidden_state=encoder_output.last_hidden_state,
  987. hidden_states=encoder_output.hidden_states,
  988. attentions=encoder_output.attentions,
  989. mask=mask,
  990. loc=loc,
  991. scale=scale,
  992. patch_input=patched_values,
  993. )
  994. class PatchTSTMaskPretrainHead(nn.Module):
  995. """
  996. Pretraining head for mask modelling
  997. """
  998. def __init__(self, config: PatchTSTConfig):
  999. super().__init__()
  1000. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1001. self.linear = nn.Linear(config.d_model, config.patch_length)
  1002. self.use_cls_token = config.use_cls_token
  1003. def forward(self, embedding: torch.Tensor) -> torch.Tensor:
  1004. """
  1005. Parameters:
  1006. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1007. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1008. Embedding from the model
  1009. Returns:
  1010. `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1011. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True
  1012. """
  1013. embedding = self.linear(self.dropout(embedding)) # [bs x num_channels x num_patches x patch_length]
  1014. if self.use_cls_token:
  1015. embedding = embedding[:, :, 1:, :] # remove the first cls token
  1016. return embedding
  1017. @auto_docstring(
  1018. custom_intro="""
  1019. The PatchTST for pretrain model.
  1020. """
  1021. )
  1022. class PatchTSTForPretraining(PatchTSTPreTrainedModel):
  1023. def __init__(self, config: PatchTSTConfig):
  1024. super().__init__(config)
  1025. config.do_mask_input = True
  1026. self.model = PatchTSTModel(config=config)
  1027. self.head = PatchTSTMaskPretrainHead(config)
  1028. # Initialize weights and apply final processing
  1029. self.post_init()
  1030. def forward(
  1031. self,
  1032. past_values: torch.Tensor,
  1033. past_observed_mask: Optional[torch.Tensor] = None,
  1034. output_hidden_states: Optional[bool] = None,
  1035. output_attentions: Optional[bool] = None,
  1036. return_dict: Optional[bool] = None,
  1037. ) -> Union[tuple, PatchTSTForPretrainingOutput]:
  1038. r"""
  1039. Parameters:
  1040. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1041. Input sequence to the model
  1042. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1043. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1044. in `[0, 1]`:
  1045. - 1 for values that are **observed**,
  1046. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1047. output_hidden_states (`bool`, *optional*):
  1048. Whether or not to return the hidden states of all layers
  1049. output_attentions (`bool`, *optional*):
  1050. Whether or not to return the output attention of all layers
  1051. return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple.
  1052. Returns:
  1053. `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
  1054. `config.return_dict`=False)
  1055. Examples:
  1056. ```python
  1057. >>> from huggingface_hub import hf_hub_download
  1058. >>> import torch
  1059. >>> from transformers import PatchTSTConfig, PatchTSTForPretraining
  1060. >>> file = hf_hub_download(
  1061. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  1062. ... )
  1063. >>> batch = torch.load(file)
  1064. >>> # Config for random mask pretraining
  1065. >>> config = PatchTSTConfig(
  1066. ... num_input_channels=7,
  1067. ... context_length=512,
  1068. ... patch_length=12,
  1069. ... stride=12,
  1070. ... mask_type='random',
  1071. ... random_mask_ratio=0.4,
  1072. ... use_cls_token=True,
  1073. ... )
  1074. >>> # Config for forecast mask pretraining
  1075. >>> config = PatchTSTConfig(
  1076. ... num_input_channels=7,
  1077. ... context_length=512,
  1078. ... patch_length=12,
  1079. ... stride=12,
  1080. ... mask_type='forecast',
  1081. ... num_forecast_mask_patches=5,
  1082. ... use_cls_token=True,
  1083. ... )
  1084. >>> model = PatchTSTForPretraining(config)
  1085. >>> # during training, one provides both past and future values
  1086. >>> outputs = model(past_values=batch["past_values"])
  1087. >>> loss = outputs.loss
  1088. >>> loss.backward()
  1089. ```"""
  1090. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1091. # past_values: [bs x num_channels x num_patches x d_model] or
  1092. # [bs x num_channels x (num_patches+1) x d_model] if use cls_token
  1093. model_output = self.model(
  1094. past_values=past_values,
  1095. past_observed_mask=past_observed_mask,
  1096. output_hidden_states=output_hidden_states,
  1097. output_attentions=output_attentions,
  1098. return_dict=True,
  1099. )
  1100. # last_hidden_state: [bs x num_channels x num_patches x patch_length] or
  1101. # [bs x num_channels x (num_patches+1) x patch_length] if use cls_token
  1102. x_hat = self.head(model_output.last_hidden_state)
  1103. # calculate masked_loss
  1104. loss = nn.MSELoss(reduction="none")
  1105. loss_val = loss(x_hat, model_output.patch_input)
  1106. masked_loss = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)
  1107. encoder_states = model_output.hidden_states
  1108. if not return_dict:
  1109. outputs = (x_hat,) + model_output[1:-4]
  1110. outputs = (masked_loss,) + outputs if masked_loss is not None else outputs
  1111. return outputs
  1112. return PatchTSTForPretrainingOutput(
  1113. loss=masked_loss, prediction_output=x_hat, hidden_states=encoder_states, attentions=model_output.attentions
  1114. )
  1115. class PatchTSTClassificationHead(nn.Module):
  1116. def __init__(self, config: PatchTSTConfig):
  1117. super().__init__()
  1118. self.use_cls_token = config.use_cls_token
  1119. self.pooling_type = config.pooling_type
  1120. self.flatten = nn.Flatten(start_dim=1)
  1121. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1122. self.linear = nn.Linear(config.num_input_channels * config.d_model, config.num_targets)
  1123. def forward(self, embedding: torch.Tensor):
  1124. """
  1125. Parameters:
  1126. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1127. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1128. Embedding from the model
  1129. Returns:
  1130. `torch.Tensor` of shape `(bs, num_targets)`
  1131. """
  1132. if self.use_cls_token:
  1133. # use the first output token, pooled_embedding: bs x num_channels x d_model
  1134. pooled_embedding = embedding[:, :, 0, :]
  1135. elif self.pooling_type == "mean":
  1136. # pooled_embedding: [bs x num_channels x d_model]
  1137. pooled_embedding = embedding.mean(dim=2)
  1138. elif self.pooling_type == "max":
  1139. # pooled_embedding: [bs x num_channels x d_model]
  1140. pooled_embedding = embedding.max(dim=2).values
  1141. else:
  1142. raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
  1143. # pooled_embedding: bs x num_channels * d_model
  1144. pooled_embedding = self.flatten(pooled_embedding)
  1145. # output: bs x n_classes
  1146. output = self.linear(self.dropout(pooled_embedding))
  1147. return output
  1148. @auto_docstring(
  1149. custom_intro="""
  1150. The PatchTST for classification model.
  1151. """
  1152. )
  1153. class PatchTSTForClassification(PatchTSTPreTrainedModel):
  1154. def __init__(self, config: PatchTSTConfig):
  1155. super().__init__(config)
  1156. # Turn off masking
  1157. if config.do_mask_input:
  1158. logger.warning("Setting `do_mask_input` parameter to False.")
  1159. config.do_mask_input = False
  1160. self.model = PatchTSTModel(config)
  1161. self.head = PatchTSTClassificationHead(config)
  1162. # Initialize weights and apply final processing
  1163. self.post_init()
  1164. @auto_docstring
  1165. def forward(
  1166. self,
  1167. past_values: torch.Tensor,
  1168. target_values: Optional[torch.Tensor] = None,
  1169. past_observed_mask: Optional[bool] = None,
  1170. output_hidden_states: Optional[bool] = None,
  1171. output_attentions: Optional[bool] = None,
  1172. return_dict: Optional[bool] = None,
  1173. ) -> Union[tuple, PatchTSTForClassificationOutput]:
  1174. r"""
  1175. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1176. Input sequence to the model
  1177. target_values (`torch.Tensor`, *optional*):
  1178. Labels associates with the `past_values`
  1179. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1180. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1181. in `[0, 1]`:
  1182. - 1 for values that are **observed**,
  1183. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1184. Examples:
  1185. ```python
  1186. >>> from transformers import PatchTSTConfig, PatchTSTForClassification
  1187. >>> # classification task with two input channel2 and 3 classes
  1188. >>> config = PatchTSTConfig(
  1189. ... num_input_channels=2,
  1190. ... num_targets=3,
  1191. ... context_length=512,
  1192. ... patch_length=12,
  1193. ... stride=12,
  1194. ... use_cls_token=True,
  1195. ... )
  1196. >>> model = PatchTSTForClassification(config=config)
  1197. >>> # during inference, one only provides past values
  1198. >>> past_values = torch.randn(20, 512, 2)
  1199. >>> outputs = model(past_values=past_values)
  1200. >>> labels = outputs.prediction_logits
  1201. ```"""
  1202. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1203. model_output = self.model(
  1204. past_values=past_values,
  1205. past_observed_mask=past_observed_mask,
  1206. output_hidden_states=output_hidden_states,
  1207. output_attentions=output_attentions,
  1208. return_dict=True,
  1209. )
  1210. y_hat = self.head(model_output.last_hidden_state)
  1211. loss_val = None
  1212. if target_values is not None:
  1213. loss = nn.CrossEntropyLoss()
  1214. loss_val = loss(y_hat, target_values)
  1215. if not return_dict:
  1216. outputs = (y_hat,) + model_output[1:-3]
  1217. outputs = (loss_val,) + outputs if loss_val is not None else outputs
  1218. return outputs
  1219. return PatchTSTForClassificationOutput(
  1220. loss=loss_val,
  1221. prediction_logits=y_hat,
  1222. hidden_states=model_output.hidden_states,
  1223. attentions=model_output.attentions,
  1224. )
  1225. @auto_docstring(
  1226. custom_intro="""
  1227. The PatchTST for regression Model.
  1228. """
  1229. )
  1230. class PatchTSTPredictionHead(nn.Module):
  1231. def __init__(self, config: PatchTSTConfig, num_patches: int, distribution_output=None):
  1232. r"""
  1233. num_patches (`int`):
  1234. The number of patches in the input sequence.
  1235. distribution_output (`DistributionOutput`, *optional*):
  1236. The distribution output layer for probabilistic forecasting. If None, a linear output layer is used.
  1237. """
  1238. super().__init__()
  1239. self.share_projection = config.share_projection
  1240. self.num_input_channels = config.num_input_channels
  1241. self.use_cls_token = config.use_cls_token
  1242. self.pooling_type = config.pooling_type
  1243. if self.pooling_type or self.use_cls_token:
  1244. head_dim = config.d_model
  1245. else:
  1246. head_dim = config.d_model * num_patches
  1247. if not self.share_projection:
  1248. # if each channel has its own head
  1249. self.projections = nn.ModuleList()
  1250. self.dropouts = nn.ModuleList()
  1251. self.flattens = nn.ModuleList()
  1252. for i in range(self.num_input_channels):
  1253. self.flattens.append(nn.Flatten(start_dim=2))
  1254. if distribution_output is None:
  1255. # use linear head
  1256. self.projections.append(nn.Linear(head_dim, config.prediction_length))
  1257. else:
  1258. # use distribution head
  1259. self.projections.append(distribution_output.get_parameter_projection(head_dim))
  1260. self.dropouts.append(nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity())
  1261. else:
  1262. # all the channels share the same head
  1263. self.flatten = nn.Flatten(start_dim=2)
  1264. if distribution_output is None:
  1265. # use linear head
  1266. self.projection = nn.Linear(head_dim, config.prediction_length)
  1267. else:
  1268. # use distribution head
  1269. self.projection = distribution_output.get_parameter_projection(head_dim)
  1270. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1271. def forward(self, embedding: torch.Tensor):
  1272. """
  1273. Parameters:
  1274. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1275. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1276. Embedding from the model
  1277. Returns:
  1278. `torch.Tensor` of shape `(bs, forecast_len, num_channels)`
  1279. """
  1280. if self.use_cls_token:
  1281. # pooled_embedding: [bs x num_channels x d_model]
  1282. pooled_embedding = embedding[:, :, 0, :]
  1283. else:
  1284. if self.pooling_type == "mean":
  1285. # pooled_embedding: [bs x num_channels x d_model]
  1286. pooled_embedding = embedding.mean(dim=2)
  1287. elif self.pooling_type == "max":
  1288. # pooled_embedding: [bs x num_channels x d_model]
  1289. pooled_embedding = embedding.max(dim=2).values
  1290. else:
  1291. # pooled_embedding: [bs x num_channels x num_patches x d_model]
  1292. pooled_embedding = embedding
  1293. if not self.share_projection:
  1294. output = []
  1295. for i in range(self.num_input_channels):
  1296. # pooled_embedding: [bs x (d_model * num_patches)] or [bs x d_model)]
  1297. pooled_embedding = self.flattens[i](pooled_embedding[:, i, :])
  1298. pooled_embedding = self.dropouts[i](pooled_embedding)
  1299. # pooled_embedding: [bs x forecast_len]
  1300. # or tuple ([bs x forecast_len], [bs x forecast_len]) if using distribution head
  1301. pooled_embedding = self.projections[i](pooled_embedding)
  1302. output.append(pooled_embedding)
  1303. # output: [bs x num_channels x forecast_len]
  1304. output = torch.stack(output, dim=1)
  1305. else:
  1306. # pooled_embedding: [bs x num_channels x (d_model * num_patches)] or [bs x num_channels x d_model)]
  1307. pooled_embedding = self.flatten(pooled_embedding)
  1308. pooled_embedding = self.dropout(pooled_embedding)
  1309. # output: [bs x num_channels x forecast_len] or
  1310. # tuple ([bs x num_channels x forecast_len], [bs x num_channels x forecast_len]) if using distribution head
  1311. output = self.projection(pooled_embedding)
  1312. if isinstance(output, tuple):
  1313. # output: ([bs x forecast_len x num_channels], [bs x forecast_len x num_channels])
  1314. output = tuple(z.transpose(2, 1) for z in output)
  1315. else:
  1316. output = output.transpose(2, 1) # [bs x forecast_len x num_channels]
  1317. return output
  1318. @auto_docstring(
  1319. custom_intro="""
  1320. The PatchTST for prediction model.
  1321. """
  1322. )
  1323. class PatchTSTForPrediction(PatchTSTPreTrainedModel):
  1324. def __init__(self, config: PatchTSTConfig):
  1325. super().__init__(config)
  1326. # Turn off masking
  1327. if config.do_mask_input:
  1328. logger.warning("Setting `do_mask_input` parameter to False.")
  1329. config.do_mask_input = False
  1330. self.model = PatchTSTModel(config)
  1331. if config.loss == "mse":
  1332. self.distribution_output = None
  1333. else:
  1334. if config.distribution_output == "student_t":
  1335. self.distribution_output = StudentTOutput(dim=config.prediction_length)
  1336. elif config.distribution_output == "normal":
  1337. self.distribution_output = NormalOutput(dim=config.prediction_length)
  1338. elif config.distribution_output == "negative_binomial":
  1339. self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length)
  1340. else:
  1341. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1342. self.head = PatchTSTPredictionHead(
  1343. config, self.model.patchifier.num_patches, distribution_output=self.distribution_output
  1344. )
  1345. # Initialize weights and apply final processing
  1346. self.post_init()
  1347. def forward(
  1348. self,
  1349. past_values: torch.Tensor,
  1350. past_observed_mask: Optional[torch.Tensor] = None,
  1351. future_values: Optional[torch.Tensor] = None,
  1352. output_hidden_states: Optional[bool] = None,
  1353. output_attentions: Optional[bool] = None,
  1354. return_dict: Optional[bool] = None,
  1355. ) -> Union[tuple, PatchTSTForPredictionOutput]:
  1356. r"""
  1357. Parameters:
  1358. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1359. Input sequence to the model
  1360. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1361. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1362. in `[0, 1]`:
  1363. - 1 for values that are **observed**,
  1364. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1365. future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*):
  1366. Future target values associated with the `past_values`
  1367. output_hidden_states (`bool`, *optional*):
  1368. Whether or not to return the hidden states of all layers
  1369. output_attentions (`bool`, *optional*):
  1370. Whether or not to return the output attention of all layers
  1371. return_dict (`bool`, *optional*):
  1372. Whether or not to return a `ModelOutput` instead of a plain tuple.
  1373. Returns:
  1374. `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
  1375. `config.return_dict`=False)
  1376. Examples:
  1377. ```python
  1378. >>> from huggingface_hub import hf_hub_download
  1379. >>> import torch
  1380. >>> from transformers import PatchTSTConfig, PatchTSTForPrediction
  1381. >>> file = hf_hub_download(
  1382. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  1383. ... )
  1384. >>> batch = torch.load(file)
  1385. >>> # Prediction task with 7 input channels and prediction length is 96
  1386. >>> model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast")
  1387. >>> # during training, one provides both past and future values
  1388. >>> outputs = model(
  1389. ... past_values=batch["past_values"],
  1390. ... future_values=batch["future_values"],
  1391. ... )
  1392. >>> loss = outputs.loss
  1393. >>> loss.backward()
  1394. >>> # during inference, one only provides past values, the model outputs future values
  1395. >>> outputs = model(past_values=batch["past_values"])
  1396. >>> prediction_outputs = outputs.prediction_outputs
  1397. ```"""
  1398. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1399. # get model output
  1400. model_output = self.model(
  1401. past_values=past_values,
  1402. past_observed_mask=past_observed_mask,
  1403. output_hidden_states=output_hidden_states,
  1404. output_attentions=output_attentions,
  1405. return_dict=True,
  1406. )
  1407. # get output head
  1408. y_hat = self.head(model_output.last_hidden_state)
  1409. loss_val = None
  1410. if self.distribution_output:
  1411. y_hat_out = y_hat
  1412. else:
  1413. y_hat_out = y_hat * model_output.scale + model_output.loc
  1414. if future_values is not None:
  1415. if self.distribution_output:
  1416. distribution = self.distribution_output.distribution(
  1417. y_hat, loc=model_output.loc, scale=model_output.scale
  1418. )
  1419. loss_val = nll(distribution, future_values)
  1420. # take average of the loss
  1421. loss_val = weighted_average(loss_val)
  1422. else:
  1423. loss = nn.MSELoss(reduction="mean")
  1424. loss_val = loss(y_hat_out, future_values)
  1425. loc = model_output.loc
  1426. scale = model_output.scale
  1427. if not return_dict:
  1428. outputs = (y_hat_out,) + model_output[1:-1]
  1429. outputs = (loss_val,) + outputs if loss_val is not None else outputs
  1430. return outputs
  1431. return PatchTSTForPredictionOutput(
  1432. loss=loss_val,
  1433. prediction_outputs=y_hat_out,
  1434. hidden_states=model_output.hidden_states,
  1435. attentions=model_output.attentions,
  1436. loc=loc,
  1437. scale=scale,
  1438. )
  1439. @torch.no_grad()
  1440. def generate(
  1441. self,
  1442. past_values: torch.Tensor,
  1443. past_observed_mask: Optional[torch.Tensor] = None,
  1444. ) -> SamplePatchTSTOutput:
  1445. """
  1446. Generate sequences of sample predictions from a model with a probability distribution head.
  1447. Parameters:
  1448. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1449. Past values of the time series that serves as context in order to predict the future.
  1450. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1451. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1452. in `[0, 1]`:
  1453. - 1 for values that are **observed**,
  1454. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1455. Return:
  1456. [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
  1457. samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)`
  1458. for multivariate predictions.
  1459. """
  1460. # get number of samples
  1461. num_parallel_samples = self.config.num_parallel_samples
  1462. # get model output
  1463. outputs = self(
  1464. past_values=past_values,
  1465. future_values=None,
  1466. past_observed_mask=past_observed_mask,
  1467. output_hidden_states=False,
  1468. )
  1469. if self.distribution_output:
  1470. # get distribution
  1471. distribution = self.distribution_output.distribution(
  1472. outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
  1473. )
  1474. # get samples: list of [bs x forecast_len x num_channels]
  1475. samples = [distribution.sample() for _ in range(num_parallel_samples)]
  1476. # samples: [bs x num_samples x forecast_len x num_channels]
  1477. samples = torch.stack(samples, dim=1)
  1478. else:
  1479. samples = outputs.prediction_outputs.unsqueeze(1)
  1480. return SamplePatchTSTOutput(sequences=samples)
  1481. class PatchTSTRegressionHead(nn.Module):
  1482. """
  1483. Regression head
  1484. """
  1485. def __init__(self, config: PatchTSTConfig, distribution_output=None):
  1486. super().__init__()
  1487. self.y_range = config.output_range
  1488. self.use_cls_token = config.use_cls_token
  1489. self.pooling_type = config.pooling_type
  1490. self.distribution_output = distribution_output
  1491. head_dim = config.num_input_channels * config.d_model
  1492. self.flatten = nn.Flatten(start_dim=1)
  1493. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1494. if distribution_output is None:
  1495. self.projection = nn.Linear(head_dim, config.num_targets)
  1496. else:
  1497. self.projection = distribution_output.get_parameter_projection(head_dim)
  1498. def forward(self, embedding: torch.Tensor):
  1499. """
  1500. Parameters:
  1501. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1502. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1503. Embedding from the model
  1504. Returns:
  1505. `torch.Tensor` of shape `(bs, output_dim)`
  1506. """
  1507. if self.use_cls_token:
  1508. # use the first output token, pooled_embedding: [bs x num_channels x d_model]
  1509. pooled_embedding = embedding[:, :, 0, :]
  1510. elif self.pooling_type == "mean":
  1511. # pooled_embedding: [bs x num_channels x d_model]
  1512. pooled_embedding = embedding.mean(dim=2)
  1513. elif self.pooling_type == "max":
  1514. # pooled_embedding: [bs x num_channels x d_model]
  1515. pooled_embedding = embedding.max(dim=2).values
  1516. else:
  1517. raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
  1518. # flatten the input
  1519. # pooled_embedding: bs x (num_channels * d_model)
  1520. pooled_embedding = self.dropout(self.flatten(pooled_embedding))
  1521. # projection
  1522. # output: bs x output_dim or a tuple of this shape for distribution head
  1523. output = self.projection(pooled_embedding)
  1524. # apply sigmoid to bound the output if required
  1525. if (self.distribution_output is None) & (self.y_range is not None): # linear head
  1526. output = torch.sigmoid(output) * (self.y_range[1] - self.y_range[0]) + self.y_range[0]
  1527. return output
  1528. @auto_docstring(
  1529. custom_intro="""
  1530. The PatchTST for regression model.
  1531. """
  1532. )
  1533. class PatchTSTForRegression(PatchTSTPreTrainedModel):
  1534. def __init__(self, config: PatchTSTConfig):
  1535. super().__init__(config)
  1536. # Turn off masking
  1537. if config.do_mask_input:
  1538. logger.warning("Setting `do_mask_input` parameter to False.")
  1539. config.do_mask_input = False
  1540. self.model = PatchTSTModel(config)
  1541. if config.loss == "mse":
  1542. self.distribution_output = None
  1543. else:
  1544. if config.distribution_output == "student_t":
  1545. self.distribution_output = StudentTOutput(dim=config.num_targets)
  1546. elif config.distribution_output == "normal":
  1547. self.distribution_output = NormalOutput(dim=config.num_targets)
  1548. elif config.distribution_output == "negative_binomial":
  1549. self.distribution_output = NegativeBinomialOutput(dim=config.num_targets)
  1550. else:
  1551. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1552. self.head = PatchTSTRegressionHead(config, self.distribution_output)
  1553. # Initialize weights and apply final processing
  1554. self.post_init()
  1555. @auto_docstring
  1556. def forward(
  1557. self,
  1558. past_values: torch.Tensor,
  1559. target_values: Optional[torch.Tensor] = None,
  1560. past_observed_mask: Optional[torch.Tensor] = None,
  1561. output_hidden_states: Optional[bool] = None,
  1562. output_attentions: Optional[bool] = None,
  1563. return_dict: Optional[bool] = None,
  1564. ) -> Union[tuple, PatchTSTForRegressionOutput]:
  1565. r"""
  1566. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1567. Input sequence to the model
  1568. target_values (`torch.Tensor` of shape `(bs, num_input_channels)`):
  1569. Target values associates with the `past_values`
  1570. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1571. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1572. in `[0, 1]`:
  1573. - 1 for values that are **observed**,
  1574. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1575. Whether or not to return a `ModelOutput` instead of a plain tuple.
  1576. Examples:
  1577. ```python
  1578. >>> from transformers import PatchTSTConfig, PatchTSTForRegression
  1579. >>> # Regression task with 6 input channels and regress 2 targets
  1580. >>> model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression")
  1581. >>> # during inference, one only provides past values, the model outputs future values
  1582. >>> past_values = torch.randn(20, 512, 6)
  1583. >>> outputs = model(past_values=past_values)
  1584. >>> regression_outputs = outputs.regression_outputs
  1585. ```"""
  1586. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1587. model_output = self.model(
  1588. past_values=past_values,
  1589. past_observed_mask=past_observed_mask,
  1590. output_hidden_states=output_hidden_states,
  1591. output_attentions=output_attentions,
  1592. return_dict=True,
  1593. )
  1594. # get output head. y_hat is of shape [bs x num_targets] or tuple of this shape
  1595. y_hat = self.head(model_output.last_hidden_state)
  1596. loss = None
  1597. if target_values is not None:
  1598. if self.distribution_output:
  1599. distribution = self.distribution_output.distribution(y_hat)
  1600. # y_hat should be a 2-tuple, each with dimension [bs, num_targets]
  1601. y_hat = tuple(item.view(-1, self.config.num_targets) for item in y_hat)
  1602. loss = nll(distribution, target_values)
  1603. # take average of the loss
  1604. loss = weighted_average(loss)
  1605. else:
  1606. loss = nn.MSELoss(reduction="mean")
  1607. loss = loss(y_hat, target_values)
  1608. if not return_dict:
  1609. # hidden_states, attentions, mask
  1610. outputs = (y_hat,) + model_output[1:-3]
  1611. outputs = (loss,) + outputs if loss is not None else outputs
  1612. return outputs
  1613. return PatchTSTForRegressionOutput(
  1614. loss=loss,
  1615. regression_outputs=y_hat,
  1616. hidden_states=model_output.hidden_states,
  1617. attentions=model_output.attentions,
  1618. )
  1619. @torch.no_grad()
  1620. def generate(
  1621. self,
  1622. past_values: torch.Tensor,
  1623. past_observed_mask: Optional[torch.Tensor] = None,
  1624. ) -> SamplePatchTSTOutput:
  1625. """
  1626. Generate sequences of sample predictions from a model with a probability distribution head.
  1627. Parameters:
  1628. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1629. Past values of the time series that serves as context in order to predict the future.
  1630. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1631. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1632. in `[0, 1]`:
  1633. - 1 for values that are **observed**,
  1634. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1635. Return:
  1636. [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
  1637. samples, num_targets)`.
  1638. """
  1639. # get number of samples
  1640. num_parallel_samples = self.config.num_parallel_samples
  1641. # get model output
  1642. outputs = self(
  1643. past_values=past_values,
  1644. target_values=None,
  1645. past_observed_mask=past_observed_mask,
  1646. output_hidden_states=False,
  1647. )
  1648. # get distribution
  1649. distribution = self.distribution_output.distribution(outputs.regression_outputs)
  1650. # get samples: list of [bs x num_targets]
  1651. samples = [distribution.sample() for _ in range(num_parallel_samples)]
  1652. # samples: [bs x num_samples x num_targets]
  1653. samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
  1654. return SamplePatchTSTOutput(sequences=samples)
  1655. __all__ = [
  1656. "PatchTSTModel",
  1657. "PatchTSTPreTrainedModel",
  1658. "PatchTSTForPrediction",
  1659. "PatchTSTForPretraining",
  1660. "PatchTSTForRegression",
  1661. "PatchTSTForClassification",
  1662. ]