vision_tf.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. # coding=utf-8
  2. # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """TF IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import tensorflow as tf
  20. from ...activations_tf import get_tf_activation
  21. from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling
  22. from ...modeling_tf_utils import TFPreTrainedModel, shape_list
  23. from ...tf_utils import flatten
  24. from ...utils import ModelOutput, logging
  25. from .configuration_idefics import IdeficsVisionConfig
  26. logger = logging.get_logger(__name__)
  27. @dataclass
  28. class TFIdeficsVisionModelOutput(ModelOutput):
  29. """
  30. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  31. Args:
  32. image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  33. The image embeddings obtained by applying the projection layer to the pooler_output.
  34. last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
  35. Sequence of hidden-states at the output of the last layer of the model.
  36. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  37. Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
  38. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  39. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  40. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  41. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  42. sequence_length)`.
  43. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  44. heads.
  45. """
  46. image_embeds: Optional[tf.Tensor] = None
  47. last_hidden_state: Optional[tf.Tensor] = None
  48. hidden_states: Optional[tuple[tf.Tensor]] = None
  49. attentions: Optional[tuple[tf.Tensor]] = None
  50. class TFIdeficsVisionEmbeddings(tf.keras.layers.Layer):
  51. def __init__(self, config: IdeficsVisionConfig, **kwargs):
  52. super().__init__(**kwargs)
  53. self.config = config
  54. self.embed_dim = config.hidden_size
  55. self.image_size = config.image_size
  56. self.patch_size = config.patch_size
  57. self.patch_embedding = tf.keras.layers.Conv2D(
  58. filters=self.embed_dim,
  59. kernel_size=self.patch_size,
  60. strides=self.patch_size,
  61. use_bias=False,
  62. padding="valid",
  63. data_format="channels_last",
  64. name="patch_embedding",
  65. )
  66. self.num_patches = (self.image_size // self.patch_size) ** 2
  67. self.num_positions = self.num_patches + 1
  68. self.position_embedding = tf.keras.layers.Embedding(
  69. self.num_positions, self.embed_dim, name="position_embedding"
  70. )
  71. # self.position_ids = tf.range(self.num_positions)[tf.newaxis, :]
  72. def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
  73. num_patches = shape_list(embeddings)[1] - 1
  74. pos_embed = self.position_embedding(self.position_ids)
  75. num_positions = shape_list(pos_embed)[1] - 1
  76. if num_patches == num_positions and height == width:
  77. return pos_embed
  78. class_pos_embed = pos_embed[:, 0]
  79. patch_pos_embed = pos_embed[:, 1:]
  80. embed_dim = shape_list(embeddings)[-1]
  81. num_h_patches = height // self.config.patch_size
  82. num_w_patches = width // self.config.patch_size
  83. num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
  84. sqrt_num_positions = math.sqrt(float(num_positions))
  85. patch_pos_embed = tf.reshape(patch_pos_embed, (1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim))
  86. scale_height = num_h_patches / sqrt_num_positions
  87. scale_width = num_w_patches / sqrt_num_positions
  88. original_height = tf.cast(tf.shape(patch_pos_embed)[1], tf.float32)
  89. original_width = tf.cast(tf.shape(patch_pos_embed)[2], tf.float32)
  90. # Apply scaling
  91. new_height = tf.cast(original_height * scale_height, tf.int32)
  92. new_width = tf.cast(original_width * scale_width, tf.int32)
  93. patch_pos_embed = tf.image.resize(
  94. patch_pos_embed, size=[new_height, new_width], method=tf.image.ResizeMethod.BICUBIC
  95. )
  96. if (
  97. int(num_h_patches) != shape_list(patch_pos_embed)[-3]
  98. or int(num_w_patches) != shape_list(patch_pos_embed)[-2]
  99. ):
  100. raise ValueError(
  101. f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
  102. f"shape of position embedding ({shape_list(patch_pos_embed)[-2], shape_list(patch_pos_embed)[-1]})"
  103. )
  104. patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, embed_dim))
  105. return tf.concat((class_pos_embed[tf.newaxis, :], patch_pos_embed), axis=1)
  106. def call(self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False) -> tf.Tensor:
  107. # Input `pixel_values` is NCHW format which doesn't run on CPU so first thing we do is
  108. # transpose it to change it to NHWC. We don't care to transpose it back because
  109. # the Conv2D layer is only hit once for each query
  110. if isinstance(pixel_values, dict):
  111. pixel_values = pixel_values["pixel_values"]
  112. pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
  113. batch_size, height, width, num_channels = shape_list(pixel_values)
  114. if not interpolate_pos_encoding:
  115. if height != self.image_size or width != self.image_size:
  116. raise ValueError(
  117. f"Input image size ({height}*{width}) doesn't match model"
  118. f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`"
  119. )
  120. patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
  121. # Change the 2D spatial dimensions to a single temporal dimension.
  122. # shape = (batch_size, num_patches, out_channels=embed_dim)
  123. patch_embeds = flatten(patch_embeds, 1, 2)
  124. class_embeds = tf.broadcast_to(
  125. self.class_embedding[tf.newaxis, tf.newaxis, :], [batch_size, 1, self.embed_dim]
  126. )
  127. embeddings = tf.concat([class_embeds, patch_embeds], axis=1)
  128. # add positional encoding to each token
  129. if interpolate_pos_encoding:
  130. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  131. else:
  132. embeddings = embeddings + self.position_embedding(self.position_ids)
  133. return embeddings
  134. def build(self, input_shape=None):
  135. if self.built:
  136. return
  137. self.built = True
  138. self.position_ids = tf.range(self.num_positions, name="self.position_ids")[tf.newaxis, :]
  139. self.class_embedding = self.add_weight(shape=(self.embed_dim,), name="class_embedding")
  140. if getattr(self, "patch_embedding", None) is not None:
  141. with tf.name_scope(self.patch_embedding.name):
  142. self.patch_embedding.build([None, None, None, self.config.num_channels])
  143. if getattr(self, "position_embedding", None) is not None:
  144. with tf.name_scope(self.position_embedding.name):
  145. self.position_embedding.build(None)
  146. class TFIdeficsVisionAttention(tf.keras.layers.Layer):
  147. """Multi-headed attention from 'Attention Is All You Need' paper"""
  148. def __init__(self, config, **kwargs):
  149. super().__init__(**kwargs)
  150. self.config = config
  151. self.embed_dim = config.hidden_size
  152. self.num_heads = config.num_attention_heads
  153. self.head_dim = self.embed_dim // self.num_heads
  154. if self.head_dim * self.num_heads != self.embed_dim:
  155. raise ValueError(
  156. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  157. f" {self.num_heads})."
  158. )
  159. self.scale = self.head_dim**-0.5
  160. self.dropout = config.attention_dropout
  161. self.k_proj = tf.keras.layers.Dense(self.embed_dim, name="k_proj")
  162. self.v_proj = tf.keras.layers.Dense(self.embed_dim, name="v_proj")
  163. self.q_proj = tf.keras.layers.Dense(self.embed_dim, name="q_proj")
  164. self.out_proj = tf.keras.layers.Dense(self.embed_dim, name="out_proj")
  165. def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
  166. return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3])
  167. def call(
  168. self,
  169. hidden_states: tf.Tensor,
  170. attention_mask: Optional[tf.Tensor] = None,
  171. causal_attention_mask: Optional[tf.Tensor] = None,
  172. output_attentions: Optional[bool] = False,
  173. ) -> tuple[tf.Tensor, Optional[tf.Tensor], Optional[tuple[tf.Tensor]]]:
  174. """Input shape: Batch x Time x Channel"""
  175. bsz, tgt_len, embed_dim = shape_list(hidden_states)
  176. # get query proj
  177. query_states = self.q_proj(hidden_states) * self.scale
  178. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  179. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  180. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  181. query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
  182. key_states = tf.reshape(key_states, proj_shape)
  183. value_states = tf.reshape(value_states, proj_shape)
  184. src_len = shape_list(key_states)[1]
  185. attn_weights = tf.linalg.matmul(query_states, key_states, transpose_b=True)
  186. tf.debugging.assert_equal(
  187. tf.shape(attn_weights),
  188. [bsz * self.num_heads, tgt_len, src_len],
  189. message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, src_len]}, but is {tf.shape(attn_weights)}",
  190. )
  191. # apply the causal_attention_mask first
  192. if causal_attention_mask is not None:
  193. if shape_list(causal_attention_mask) != [bsz, 1, tgt_len, src_len]:
  194. raise ValueError(
  195. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
  196. f" {shape_list(causal_attention_mask)}"
  197. )
  198. attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + causal_attention_mask
  199. attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
  200. if attention_mask is not None:
  201. if shape_list(attention_mask) != [bsz, 1, tgt_len, src_len]:
  202. raise ValueError(
  203. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}"
  204. )
  205. attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
  206. attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
  207. attn_weights = tf.nn.softmax(attn_weights, axis=-1)
  208. if output_attentions:
  209. # this operation is a bit awkward, but it's required to
  210. # make sure that attn_weights keeps its gradient.
  211. # In order to do so, attn_weights have to reshaped
  212. # twice and have to be reused in the following
  213. attn_weights_reshaped = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
  214. attn_weights = tf.reshape(attn_weights_reshaped, (bsz * self.num_heads, tgt_len, src_len))
  215. else:
  216. attn_weights_reshaped = None
  217. attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
  218. attn_output = tf.linalg.matmul(attn_probs, value_states)
  219. tf.debugging.assert_equal(
  220. tf.shape(attn_output),
  221. [bsz * self.num_heads, tgt_len, self.head_dim],
  222. message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, self.head_dim]}, but is {tf.shape(attn_output)}",
  223. )
  224. attn_output = tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim))
  225. attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
  226. attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
  227. attn_output = self.out_proj(attn_output)
  228. return attn_output, attn_weights_reshaped
  229. def build(self, input_shape=None):
  230. if self.built:
  231. return
  232. self.built = True
  233. if getattr(self, "k_proj", None) is not None:
  234. with tf.name_scope(self.k_proj.name):
  235. self.k_proj.build((self.embed_dim, self.embed_dim))
  236. if getattr(self, "v_proj", None) is not None:
  237. with tf.name_scope(self.v_proj.name):
  238. self.v_proj.build((self.embed_dim, self.embed_dim))
  239. if getattr(self, "q_proj", None) is not None:
  240. with tf.name_scope(self.q_proj.name):
  241. self.q_proj.build((self.embed_dim, self.embed_dim))
  242. if getattr(self, "out_proj", None) is not None:
  243. with tf.name_scope(self.out_proj.name):
  244. self.out_proj.build((self.embed_dim, self.embed_dim))
  245. class TFIdeficsVisionMLP(tf.keras.layers.Layer):
  246. def __init__(self, config, **kwargs):
  247. super().__init__(**kwargs)
  248. self.config = config
  249. self.activation_fn = get_tf_activation(config.hidden_act)
  250. self.fc1 = tf.keras.layers.Dense(config.intermediate_size, name="fc1")
  251. self.fc2 = tf.keras.layers.Dense(config.hidden_size, name="fc2")
  252. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  253. hidden_states = self.fc1(hidden_states)
  254. hidden_states = self.activation_fn(hidden_states)
  255. hidden_states = self.fc2(hidden_states)
  256. return hidden_states
  257. def build(self, input_shape=None):
  258. if self.built:
  259. return
  260. self.built = True
  261. if getattr(self, "fc1", None) is not None:
  262. with tf.name_scope(self.fc1.name):
  263. self.fc1.build(self.config.hidden_size)
  264. if getattr(self, "fc2", None) is not None:
  265. with tf.name_scope(self.fc2.name):
  266. self.fc2.build(self.config.intermediate_size)
  267. class TFIdeficsVisionEncoderLayer(tf.keras.layers.Layer):
  268. def __init__(self, config: IdeficsVisionConfig, **kwargs):
  269. super().__init__(**kwargs)
  270. self.embed_dim = config.hidden_size
  271. self.self_attn = TFIdeficsVisionAttention(config, name="self_attn")
  272. self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
  273. self.mlp = TFIdeficsVisionMLP(config, name="mlp")
  274. self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
  275. def call(
  276. self,
  277. hidden_states: tf.Tensor,
  278. attention_mask: tf.Tensor,
  279. causal_attention_mask: tf.Tensor,
  280. output_attentions: Optional[bool] = False,
  281. ) -> tuple[tf.Tensor]:
  282. """
  283. Args:
  284. hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  285. attention_mask (`tf.Tensor`): attention mask of size
  286. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  287. `(config.encoder_attention_heads,)`.
  288. output_attentions (`bool`, *optional*):
  289. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  290. returned tensors for more detail.
  291. """
  292. residual = hidden_states
  293. hidden_states = self.layer_norm1(hidden_states)
  294. hidden_states, attn_weights = self.self_attn(
  295. hidden_states=hidden_states,
  296. attention_mask=attention_mask,
  297. causal_attention_mask=causal_attention_mask,
  298. output_attentions=output_attentions,
  299. )
  300. hidden_states = residual + hidden_states
  301. residual = hidden_states
  302. hidden_states = self.layer_norm2(hidden_states)
  303. hidden_states = self.mlp(hidden_states)
  304. hidden_states = residual + hidden_states
  305. outputs = (hidden_states,)
  306. if output_attentions:
  307. outputs += (attn_weights,)
  308. return outputs
  309. def build(self, input_shape=None):
  310. if self.built:
  311. return
  312. self.built = True
  313. if getattr(self, "layer_norm1", None) is not None:
  314. with tf.name_scope(self.layer_norm1.name):
  315. self.layer_norm1.build([None, None, self.embed_dim])
  316. if getattr(self, "layer_norm2", None) is not None:
  317. with tf.name_scope(self.layer_norm2.name):
  318. self.layer_norm2.build([None, None, self.embed_dim])
  319. class TFIdeficsVisionEncoder(tf.keras.layers.Layer):
  320. """
  321. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  322. [`TFIdeficsVisionEncoderLayer`].
  323. Args:
  324. config: IdeficsVisionConfig
  325. """
  326. def __init__(self, config: IdeficsVisionConfig, **kwargs):
  327. super().__init__(**kwargs)
  328. self.config = config
  329. self.layers = [
  330. TFIdeficsVisionEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
  331. ]
  332. self.gradient_checkpointing = False
  333. def call(
  334. self,
  335. inputs_embeds,
  336. attention_mask: Optional[tf.Tensor] = None,
  337. causal_attention_mask: Optional[tf.Tensor] = None,
  338. output_attentions: Optional[bool] = None,
  339. output_hidden_states: Optional[bool] = None,
  340. return_dict: Optional[bool] = None,
  341. training: Optional[bool] = None,
  342. ) -> Union[tuple, TFBaseModelOutput]:
  343. r"""
  344. Args:
  345. inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
  346. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  347. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  348. than the model's internal embedding lookup matrix.
  349. attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  350. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  351. - 1 for tokens that are **not masked**,
  352. - 0 for tokens that are **masked**.
  353. [What are attention masks?](../glossary#attention-mask)
  354. causal_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  355. Causal mask for the text model. Mask values selected in `[0, 1]`:
  356. - 1 for tokens that are **not masked**,
  357. - 0 for tokens that are **masked**.
  358. [What are attention masks?](../glossary#attention-mask)
  359. output_attentions (`bool`, *optional*):
  360. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  361. returned tensors for more detail.
  362. output_hidden_states (`bool`, *optional*):
  363. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  364. for more detail.
  365. return_dict (`bool`, *optional*):
  366. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  367. """
  368. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  369. output_hidden_states = (
  370. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  371. )
  372. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  373. encoder_states = () if output_hidden_states else None
  374. all_attentions = () if output_attentions else None
  375. hidden_states = inputs_embeds
  376. for idx, encoder_layer in enumerate(self.layers):
  377. if output_hidden_states:
  378. encoder_states = encoder_states + (hidden_states,)
  379. if self.gradient_checkpointing and training:
  380. def create_custom_forward(module):
  381. def custom_forward(*inputs):
  382. return module(*inputs, output_attentions)
  383. return custom_forward
  384. layer_outputs = tf.recompute_grad(
  385. create_custom_forward(encoder_layer),
  386. hidden_states,
  387. attention_mask,
  388. causal_attention_mask,
  389. )
  390. else:
  391. layer_outputs = encoder_layer(
  392. hidden_states,
  393. attention_mask,
  394. causal_attention_mask,
  395. output_attentions=output_attentions,
  396. )
  397. hidden_states = layer_outputs[0]
  398. if output_attentions:
  399. all_attentions = all_attentions + (layer_outputs[1],)
  400. if output_hidden_states:
  401. encoder_states = encoder_states + (hidden_states,)
  402. if not return_dict:
  403. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  404. return TFBaseModelOutput(
  405. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  406. )
  407. def build(self, input_shape=None):
  408. if self.built:
  409. return
  410. self.built = True
  411. if getattr(self, "layers", None) is not None:
  412. for layer in self.layers:
  413. with tf.name_scope(layer.name):
  414. layer.build(None)
  415. class TFIdeficsVisionTransformer(TFPreTrainedModel):
  416. def __init__(self, config: IdeficsVisionConfig, **kwargs):
  417. super().__init__(config, **kwargs)
  418. self.config = config
  419. self.embed_dim = config.hidden_size
  420. self.embeddings = TFIdeficsVisionEmbeddings(config, name="embeddings")
  421. self.pre_layrnorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm")
  422. self.encoder = TFIdeficsVisionEncoder(config, name="encoder")
  423. self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm")
  424. # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
  425. def call(
  426. self,
  427. pixel_values: Optional[tf.Tensor] = None,
  428. output_attentions: Optional[bool] = None,
  429. output_hidden_states: Optional[bool] = None,
  430. interpolate_pos_encoding: Optional[bool] = False,
  431. return_dict: Optional[bool] = None,
  432. training: Optional[bool] = False,
  433. ) -> Union[tuple, TFBaseModelOutputWithPooling]:
  434. r"""
  435. Returns:
  436. """
  437. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  438. output_hidden_states = (
  439. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  440. )
  441. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  442. if pixel_values is None:
  443. raise ValueError("You have to specify pixel_values")
  444. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  445. hidden_states = self.pre_layrnorm(hidden_states)
  446. encoder_outputs = self.encoder(
  447. inputs_embeds=hidden_states,
  448. output_attentions=output_attentions,
  449. output_hidden_states=output_hidden_states,
  450. return_dict=return_dict,
  451. training=training,
  452. )
  453. last_hidden_state = encoder_outputs[0]
  454. pooled_output = last_hidden_state[:, 0, :]
  455. pooled_output = self.post_layernorm(pooled_output)
  456. if not return_dict:
  457. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  458. return TFBaseModelOutputWithPooling(
  459. last_hidden_state=last_hidden_state,
  460. pooler_output=pooled_output,
  461. hidden_states=encoder_outputs.hidden_states,
  462. attentions=encoder_outputs.attentions,
  463. )
  464. def build(self, input_shape=None):
  465. if self.built:
  466. return
  467. self.built = True
  468. if getattr(self, "embeddings", None) is not None:
  469. with tf.name_scope(self.embeddings.name):
  470. self.embeddings.build(None)
  471. if getattr(self, "pre_layrnorm", None) is not None:
  472. with tf.name_scope(self.pre_layrnorm.name):
  473. self.pre_layrnorm.build([None, None, self.embed_dim])
  474. if getattr(self, "encoder", None) is not None:
  475. with tf.name_scope(self.encoder.name):
  476. self.encoder.build(None)
  477. if getattr(self, "post_layernorm", None) is not None:
  478. with tf.name_scope(self.post_layernorm.name):
  479. self.post_layernorm.build([None, self.embed_dim])