doc.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Doc utilities: Utilities related to documentation
  16. """
  17. import functools
  18. import inspect
  19. import re
  20. import textwrap
  21. import types
  22. from collections import OrderedDict
  23. def get_docstring_indentation_level(func):
  24. """Return the indentation level of the start of the docstring of a class or function (or method)."""
  25. # We assume classes are always defined in the global scope
  26. if inspect.isclass(func):
  27. return 4
  28. source = inspect.getsource(func)
  29. first_line = source.splitlines()[0]
  30. function_def_level = len(first_line) - len(first_line.lstrip())
  31. return 4 + function_def_level
  32. def add_start_docstrings(*docstr):
  33. def docstring_decorator(fn):
  34. fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
  35. return fn
  36. return docstring_decorator
  37. def add_start_docstrings_to_model_forward(*docstr):
  38. def docstring_decorator(fn):
  39. class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
  40. intro = rf""" The {class_name} forward method, overrides the `__call__` special method.
  41. <Tip>
  42. Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
  43. instance afterwards instead of this since the former takes care of running the pre and post processing steps while
  44. the latter silently ignores them.
  45. </Tip>
  46. """
  47. correct_indentation = get_docstring_indentation_level(fn)
  48. current_doc = fn.__doc__ if fn.__doc__ is not None else ""
  49. try:
  50. first_non_empty = next(line for line in current_doc.splitlines() if line.strip() != "")
  51. doc_indentation = len(first_non_empty) - len(first_non_empty.lstrip())
  52. except StopIteration:
  53. doc_indentation = correct_indentation
  54. docs = docstr
  55. # In this case, the correct indentation level (class method, 2 Python levels) was respected, and we should
  56. # correctly reindent everything. Otherwise, the doc uses a single indentation level
  57. if doc_indentation == 4 + correct_indentation:
  58. docs = [textwrap.indent(textwrap.dedent(doc), " " * correct_indentation) for doc in docstr]
  59. intro = textwrap.indent(textwrap.dedent(intro), " " * correct_indentation)
  60. docstring = "".join(docs) + current_doc
  61. fn.__doc__ = intro + docstring
  62. return fn
  63. return docstring_decorator
  64. def add_end_docstrings(*docstr):
  65. def docstring_decorator(fn):
  66. fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
  67. return fn
  68. return docstring_decorator
  69. PT_RETURN_INTRODUCTION = r"""
  70. Returns:
  71. [`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
  72. `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
  73. elements depending on the configuration ([`{config_class}`]) and inputs.
  74. """
  75. TF_RETURN_INTRODUCTION = r"""
  76. Returns:
  77. [`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if
  78. `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the
  79. configuration ([`{config_class}`]) and inputs.
  80. """
  81. def _get_indent(t):
  82. """Returns the indentation in the first line of t"""
  83. search = re.search(r"^(\s*)\S", t)
  84. return "" if search is None else search.groups()[0]
  85. def _convert_output_args_doc(output_args_doc):
  86. """Convert output_args_doc to display properly."""
  87. # Split output_arg_doc in blocks argument/description
  88. indent = _get_indent(output_args_doc)
  89. blocks = []
  90. current_block = ""
  91. for line in output_args_doc.split("\n"):
  92. # If the indent is the same as the beginning, the line is the name of new arg.
  93. if _get_indent(line) == indent:
  94. if len(current_block) > 0:
  95. blocks.append(current_block[:-1])
  96. current_block = f"{line}\n"
  97. else:
  98. # Otherwise it's part of the description of the current arg.
  99. # We need to remove 2 spaces to the indentation.
  100. current_block += f"{line[2:]}\n"
  101. blocks.append(current_block[:-1])
  102. # Format each block for proper rendering
  103. for i in range(len(blocks)):
  104. blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
  105. blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
  106. return "\n".join(blocks)
  107. def _prepare_output_docstrings(output_type, config_class, min_indent=None, add_intro=True):
  108. """
  109. Prepares the return part of the docstring using `output_type`.
  110. """
  111. output_docstring = output_type.__doc__
  112. params_docstring = None
  113. if output_docstring is not None:
  114. # Remove the head of the docstring to keep the list of args only
  115. lines = output_docstring.split("\n")
  116. i = 0
  117. while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
  118. i += 1
  119. if i < len(lines):
  120. params_docstring = "\n".join(lines[(i + 1) :])
  121. params_docstring = _convert_output_args_doc(params_docstring)
  122. elif add_intro:
  123. raise ValueError(
  124. f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. Make sure it has "
  125. "docstring and contain either `Args` or `Parameters`."
  126. )
  127. # Add the return introduction
  128. if add_intro:
  129. full_output_type = f"{output_type.__module__}.{output_type.__name__}"
  130. intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
  131. intro = intro.format(full_output_type=full_output_type, config_class=config_class)
  132. else:
  133. full_output_type = str(output_type)
  134. intro = f"\nReturns:\n `{full_output_type}`"
  135. if params_docstring is not None:
  136. intro += ":\n"
  137. result = intro
  138. if params_docstring is not None:
  139. result += params_docstring
  140. # Apply minimum indent if necessary
  141. if min_indent is not None:
  142. lines = result.split("\n")
  143. # Find the indent of the first nonempty line
  144. i = 0
  145. while len(lines[i]) == 0:
  146. i += 1
  147. indent = len(_get_indent(lines[i]))
  148. # If too small, add indentation to all nonempty lines
  149. if indent < min_indent:
  150. to_add = " " * (min_indent - indent)
  151. lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
  152. result = "\n".join(lines)
  153. return result
  154. FAKE_MODEL_DISCLAIMER = """
  155. <Tip warning={true}>
  156. This example uses a random model as the real ones are all very big. To get proper results, you should use
  157. {real_checkpoint} instead of {fake_checkpoint}. If you get out-of-memory when loading that checkpoint, you can try
  158. adding `device_map="auto"` in the `from_pretrained` call.
  159. </Tip>
  160. """
  161. PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
  162. Example:
  163. ```python
  164. >>> from transformers import AutoTokenizer, {model_class}
  165. >>> import torch
  166. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  167. >>> model = {model_class}.from_pretrained("{checkpoint}")
  168. >>> inputs = tokenizer(
  169. ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
  170. ... )
  171. >>> with torch.no_grad():
  172. ... logits = model(**inputs).logits
  173. >>> predicted_token_class_ids = logits.argmax(-1)
  174. >>> # Note that tokens are classified rather then input words which means that
  175. >>> # there might be more predicted token classes than words.
  176. >>> # Multiple token classes might account for the same word
  177. >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
  178. >>> predicted_tokens_classes
  179. {expected_output}
  180. >>> labels = predicted_token_class_ids
  181. >>> loss = model(**inputs, labels=labels).loss
  182. >>> round(loss.item(), 2)
  183. {expected_loss}
  184. ```
  185. """
  186. PT_QUESTION_ANSWERING_SAMPLE = r"""
  187. Example:
  188. ```python
  189. >>> from transformers import AutoTokenizer, {model_class}
  190. >>> import torch
  191. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  192. >>> model = {model_class}.from_pretrained("{checkpoint}")
  193. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  194. >>> inputs = tokenizer(question, text, return_tensors="pt")
  195. >>> with torch.no_grad():
  196. ... outputs = model(**inputs)
  197. >>> answer_start_index = outputs.start_logits.argmax()
  198. >>> answer_end_index = outputs.end_logits.argmax()
  199. >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
  200. >>> tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
  201. {expected_output}
  202. >>> # target is "nice puppet"
  203. >>> target_start_index = torch.tensor([{qa_target_start_index}])
  204. >>> target_end_index = torch.tensor([{qa_target_end_index}])
  205. >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
  206. >>> loss = outputs.loss
  207. >>> round(loss.item(), 2)
  208. {expected_loss}
  209. ```
  210. """
  211. PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
  212. Example of single-label classification:
  213. ```python
  214. >>> import torch
  215. >>> from transformers import AutoTokenizer, {model_class}
  216. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  217. >>> model = {model_class}.from_pretrained("{checkpoint}")
  218. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  219. >>> with torch.no_grad():
  220. ... logits = model(**inputs).logits
  221. >>> predicted_class_id = logits.argmax().item()
  222. >>> model.config.id2label[predicted_class_id]
  223. {expected_output}
  224. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  225. >>> num_labels = len(model.config.id2label)
  226. >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
  227. >>> labels = torch.tensor([1])
  228. >>> loss = model(**inputs, labels=labels).loss
  229. >>> round(loss.item(), 2)
  230. {expected_loss}
  231. ```
  232. Example of multi-label classification:
  233. ```python
  234. >>> import torch
  235. >>> from transformers import AutoTokenizer, {model_class}
  236. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  237. >>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification")
  238. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  239. >>> with torch.no_grad():
  240. ... logits = model(**inputs).logits
  241. >>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
  242. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  243. >>> num_labels = len(model.config.id2label)
  244. >>> model = {model_class}.from_pretrained(
  245. ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
  246. ... )
  247. >>> labels = torch.sum(
  248. ... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
  249. ... ).to(torch.float)
  250. >>> loss = model(**inputs, labels=labels).loss
  251. ```
  252. """
  253. PT_MASKED_LM_SAMPLE = r"""
  254. Example:
  255. ```python
  256. >>> from transformers import AutoTokenizer, {model_class}
  257. >>> import torch
  258. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  259. >>> model = {model_class}.from_pretrained("{checkpoint}")
  260. >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
  261. >>> with torch.no_grad():
  262. ... logits = model(**inputs).logits
  263. >>> # retrieve index of {mask}
  264. >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
  265. >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
  266. >>> tokenizer.decode(predicted_token_id)
  267. {expected_output}
  268. >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
  269. >>> # mask labels of non-{mask} tokens
  270. >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
  271. >>> outputs = model(**inputs, labels=labels)
  272. >>> round(outputs.loss.item(), 2)
  273. {expected_loss}
  274. ```
  275. """
  276. PT_BASE_MODEL_SAMPLE = r"""
  277. Example:
  278. ```python
  279. >>> from transformers import AutoTokenizer, {model_class}
  280. >>> import torch
  281. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  282. >>> model = {model_class}.from_pretrained("{checkpoint}")
  283. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  284. >>> outputs = model(**inputs)
  285. >>> last_hidden_states = outputs.last_hidden_state
  286. ```
  287. """
  288. PT_MULTIPLE_CHOICE_SAMPLE = r"""
  289. Example:
  290. ```python
  291. >>> from transformers import AutoTokenizer, {model_class}
  292. >>> import torch
  293. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  294. >>> model = {model_class}.from_pretrained("{checkpoint}")
  295. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  296. >>> choice0 = "It is eaten with a fork and a knife."
  297. >>> choice1 = "It is eaten while held in the hand."
  298. >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
  299. >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
  300. >>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1
  301. >>> # the linear classifier still needs to be trained
  302. >>> loss = outputs.loss
  303. >>> logits = outputs.logits
  304. ```
  305. """
  306. PT_CAUSAL_LM_SAMPLE = r"""
  307. Example:
  308. ```python
  309. >>> import torch
  310. >>> from transformers import AutoTokenizer, {model_class}
  311. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  312. >>> model = {model_class}.from_pretrained("{checkpoint}")
  313. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  314. >>> outputs = model(**inputs, labels=inputs["input_ids"])
  315. >>> loss = outputs.loss
  316. >>> logits = outputs.logits
  317. ```
  318. """
  319. PT_SPEECH_BASE_MODEL_SAMPLE = r"""
  320. Example:
  321. ```python
  322. >>> from transformers import AutoProcessor, {model_class}
  323. >>> import torch
  324. >>> from datasets import load_dataset
  325. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  326. >>> dataset = dataset.sort("id")
  327. >>> sampling_rate = dataset.features["audio"].sampling_rate
  328. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  329. >>> model = {model_class}.from_pretrained("{checkpoint}")
  330. >>> # audio file is decoded on the fly
  331. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  332. >>> with torch.no_grad():
  333. ... outputs = model(**inputs)
  334. >>> last_hidden_states = outputs.last_hidden_state
  335. >>> list(last_hidden_states.shape)
  336. {expected_output}
  337. ```
  338. """
  339. PT_SPEECH_CTC_SAMPLE = r"""
  340. Example:
  341. ```python
  342. >>> from transformers import AutoProcessor, {model_class}
  343. >>> from datasets import load_dataset
  344. >>> import torch
  345. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  346. >>> dataset = dataset.sort("id")
  347. >>> sampling_rate = dataset.features["audio"].sampling_rate
  348. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  349. >>> model = {model_class}.from_pretrained("{checkpoint}")
  350. >>> # audio file is decoded on the fly
  351. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  352. >>> with torch.no_grad():
  353. ... logits = model(**inputs).logits
  354. >>> predicted_ids = torch.argmax(logits, dim=-1)
  355. >>> # transcribe speech
  356. >>> transcription = processor.batch_decode(predicted_ids)
  357. >>> transcription[0]
  358. {expected_output}
  359. >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
  360. >>> # compute loss
  361. >>> loss = model(**inputs).loss
  362. >>> round(loss.item(), 2)
  363. {expected_loss}
  364. ```
  365. """
  366. PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
  367. Example:
  368. ```python
  369. >>> from transformers import AutoFeatureExtractor, {model_class}
  370. >>> from datasets import load_dataset
  371. >>> import torch
  372. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  373. >>> dataset = dataset.sort("id")
  374. >>> sampling_rate = dataset.features["audio"].sampling_rate
  375. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  376. >>> model = {model_class}.from_pretrained("{checkpoint}")
  377. >>> # audio file is decoded on the fly
  378. >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  379. >>> with torch.no_grad():
  380. ... logits = model(**inputs).logits
  381. >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
  382. >>> predicted_label = model.config.id2label[predicted_class_ids]
  383. >>> predicted_label
  384. {expected_output}
  385. >>> # compute loss - target_label is e.g. "down"
  386. >>> target_label = model.config.id2label[0]
  387. >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
  388. >>> loss = model(**inputs).loss
  389. >>> round(loss.item(), 2)
  390. {expected_loss}
  391. ```
  392. """
  393. PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
  394. Example:
  395. ```python
  396. >>> from transformers import AutoFeatureExtractor, {model_class}
  397. >>> from datasets import load_dataset
  398. >>> import torch
  399. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  400. >>> dataset = dataset.sort("id")
  401. >>> sampling_rate = dataset.features["audio"].sampling_rate
  402. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  403. >>> model = {model_class}.from_pretrained("{checkpoint}")
  404. >>> # audio file is decoded on the fly
  405. >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
  406. >>> with torch.no_grad():
  407. ... logits = model(**inputs).logits
  408. >>> probabilities = torch.sigmoid(logits[0])
  409. >>> # labels is a one-hot array of shape (num_frames, num_speakers)
  410. >>> labels = (probabilities > 0.5).long()
  411. >>> labels[0].tolist()
  412. {expected_output}
  413. ```
  414. """
  415. PT_SPEECH_XVECTOR_SAMPLE = r"""
  416. Example:
  417. ```python
  418. >>> from transformers import AutoFeatureExtractor, {model_class}
  419. >>> from datasets import load_dataset
  420. >>> import torch
  421. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  422. >>> dataset = dataset.sort("id")
  423. >>> sampling_rate = dataset.features["audio"].sampling_rate
  424. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  425. >>> model = {model_class}.from_pretrained("{checkpoint}")
  426. >>> # audio file is decoded on the fly
  427. >>> inputs = feature_extractor(
  428. ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
  429. ... )
  430. >>> with torch.no_grad():
  431. ... embeddings = model(**inputs).embeddings
  432. >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
  433. >>> # the resulting embeddings can be used for cosine similarity-based retrieval
  434. >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
  435. >>> similarity = cosine_sim(embeddings[0], embeddings[1])
  436. >>> threshold = 0.7 # the optimal threshold is dataset-dependent
  437. >>> if similarity < threshold:
  438. ... print("Speakers are not the same!")
  439. >>> round(similarity.item(), 2)
  440. {expected_output}
  441. ```
  442. """
  443. PT_VISION_BASE_MODEL_SAMPLE = r"""
  444. Example:
  445. ```python
  446. >>> from transformers import AutoImageProcessor, {model_class}
  447. >>> import torch
  448. >>> from datasets import load_dataset
  449. >>> dataset = load_dataset("huggingface/cats-image")
  450. >>> image = dataset["test"]["image"][0]
  451. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  452. >>> model = {model_class}.from_pretrained("{checkpoint}")
  453. >>> inputs = image_processor(image, return_tensors="pt")
  454. >>> with torch.no_grad():
  455. ... outputs = model(**inputs)
  456. >>> last_hidden_states = outputs.last_hidden_state
  457. >>> list(last_hidden_states.shape)
  458. {expected_output}
  459. ```
  460. """
  461. PT_VISION_SEQ_CLASS_SAMPLE = r"""
  462. Example:
  463. ```python
  464. >>> from transformers import AutoImageProcessor, {model_class}
  465. >>> import torch
  466. >>> from datasets import load_dataset
  467. >>> dataset = load_dataset("huggingface/cats-image")
  468. >>> image = dataset["test"]["image"][0]
  469. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  470. >>> model = {model_class}.from_pretrained("{checkpoint}")
  471. >>> inputs = image_processor(image, return_tensors="pt")
  472. >>> with torch.no_grad():
  473. ... logits = model(**inputs).logits
  474. >>> # model predicts one of the 1000 ImageNet classes
  475. >>> predicted_label = logits.argmax(-1).item()
  476. >>> print(model.config.id2label[predicted_label])
  477. {expected_output}
  478. ```
  479. """
  480. PT_SAMPLE_DOCSTRINGS = {
  481. "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
  482. "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
  483. "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
  484. "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
  485. "MaskedLM": PT_MASKED_LM_SAMPLE,
  486. "LMHead": PT_CAUSAL_LM_SAMPLE,
  487. "BaseModel": PT_BASE_MODEL_SAMPLE,
  488. "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
  489. "CTC": PT_SPEECH_CTC_SAMPLE,
  490. "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
  491. "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
  492. "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
  493. "VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
  494. "ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
  495. }
  496. TEXT_TO_AUDIO_SPECTROGRAM_SAMPLE = r"""
  497. Example:
  498. ```python
  499. >>> from transformers import AutoProcessor, {model_class}, SpeechT5HifiGan
  500. >>> model = {model_class}.from_pretrained("{checkpoint}")
  501. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  502. >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
  503. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  504. >>> # generate speech
  505. >>> speech = model.generate(inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder)
  506. ```
  507. """
  508. TEXT_TO_AUDIO_WAVEFORM_SAMPLE = r"""
  509. Example:
  510. ```python
  511. >>> from transformers import AutoProcessor, {model_class}
  512. >>> model = {model_class}.from_pretrained("{checkpoint}")
  513. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  514. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  515. >>> # generate speech
  516. >>> speech = model(inputs["input_ids"])
  517. ```
  518. """
  519. AUDIO_FRAME_CLASSIFICATION_SAMPLE = PT_SPEECH_FRAME_CLASS_SAMPLE
  520. AUDIO_XVECTOR_SAMPLE = PT_SPEECH_XVECTOR_SAMPLE
  521. IMAGE_TO_TEXT_SAMPLE = r"""
  522. Example:
  523. ```python
  524. >>> from PIL import Image
  525. >>> import requests
  526. >>> from transformers import AutoProcessor, {model_class}
  527. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  528. >>> model = {model_class}.from_pretrained("{checkpoint}")
  529. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  530. >>> image = Image.open(requests.get(url, stream=True).raw)
  531. >>> inputs = processor(images=image, return_tensors="pt")
  532. >>> outputs = model(**inputs)
  533. ```
  534. """
  535. DEPTH_ESTIMATION_SAMPLE = r"""
  536. Example:
  537. ```python
  538. >>> from transformers import AutoImageProcessor, {model_class}
  539. >>> import torch
  540. >>> from PIL import Image
  541. >>> import requests
  542. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  543. >>> image = Image.open(requests.get(url, stream=True).raw)
  544. >>> processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  545. >>> model = {model_class}.from_pretrained("{checkpoint}")
  546. >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  547. >>> model.to(device)
  548. >>> # prepare image for the model
  549. >>> inputs = processor(images=image, return_tensors="pt").to(device)
  550. >>> with torch.no_grad():
  551. ... outputs = model(**inputs)
  552. >>> # interpolate to original size
  553. >>> post_processed_output = processor.post_process_depth_estimation(
  554. ... outputs, [(image.height, image.width)],
  555. ... )
  556. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  557. ```
  558. """
  559. VIDEO_CLASSIFICATION_SAMPLE = r"""
  560. Example:
  561. ```python
  562. ```
  563. """
  564. ZERO_SHOT_OBJECT_DETECTION_SAMPLE = r"""
  565. Example:
  566. ```python
  567. ```
  568. """
  569. IMAGE_TO_IMAGE_SAMPLE = r"""
  570. Example:
  571. ```python
  572. ```
  573. """
  574. IMAGE_FEATURE_EXTRACTION_SAMPLE = r"""
  575. Example:
  576. ```python
  577. ```
  578. """
  579. DOCUMENT_QUESTION_ANSWERING_SAMPLE = r"""
  580. Example:
  581. ```python
  582. ```
  583. """
  584. NEXT_SENTENCE_PREDICTION_SAMPLE = r"""
  585. Example:
  586. ```python
  587. ```
  588. """
  589. MULTIPLE_CHOICE_SAMPLE = PT_MULTIPLE_CHOICE_SAMPLE
  590. PRETRAINING_SAMPLE = r"""
  591. Example:
  592. ```python
  593. ```
  594. """
  595. MASK_GENERATION_SAMPLE = r"""
  596. Example:
  597. ```python
  598. ```
  599. """
  600. VISUAL_QUESTION_ANSWERING_SAMPLE = r"""
  601. Example:
  602. ```python
  603. ```
  604. """
  605. TEXT_GENERATION_SAMPLE = r"""
  606. Example:
  607. ```python
  608. ```
  609. """
  610. IMAGE_CLASSIFICATION_SAMPLE = PT_VISION_SEQ_CLASS_SAMPLE
  611. IMAGE_SEGMENTATION_SAMPLE = r"""
  612. Example:
  613. ```python
  614. ```
  615. """
  616. FILL_MASK_SAMPLE = r"""
  617. Example:
  618. ```python
  619. ```
  620. """
  621. OBJECT_DETECTION_SAMPLE = r"""
  622. Example:
  623. ```python
  624. ```
  625. """
  626. QUESTION_ANSWERING_SAMPLE = PT_QUESTION_ANSWERING_SAMPLE
  627. TEXT2TEXT_GENERATION_SAMPLE = r"""
  628. Example:
  629. ```python
  630. ```
  631. """
  632. TEXT_CLASSIFICATION_SAMPLE = PT_SEQUENCE_CLASSIFICATION_SAMPLE
  633. TABLE_QUESTION_ANSWERING_SAMPLE = r"""
  634. Example:
  635. ```python
  636. ```
  637. """
  638. TOKEN_CLASSIFICATION_SAMPLE = PT_TOKEN_CLASSIFICATION_SAMPLE
  639. AUDIO_CLASSIFICATION_SAMPLE = PT_SPEECH_SEQ_CLASS_SAMPLE
  640. AUTOMATIC_SPEECH_RECOGNITION_SAMPLE = PT_SPEECH_CTC_SAMPLE
  641. ZERO_SHOT_IMAGE_CLASSIFICATION_SAMPLE = r"""
  642. Example:
  643. ```python
  644. ```
  645. """
  646. IMAGE_TEXT_TO_TEXT_GENERATION_SAMPLE = r"""
  647. Example:
  648. ```python
  649. >>> from PIL import Image
  650. >>> import requests
  651. >>> from transformers import AutoProcessor, {model_class}
  652. >>> model = {model_class}.from_pretrained("{checkpoint}")
  653. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  654. >>> messages = [
  655. ... {{
  656. ... "role": "user", "content": [
  657. ... {{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}},
  658. ... {{"type": "text", "text": "Where is the cat standing?"}},
  659. ... ]
  660. ... }},
  661. ... ]
  662. >>> inputs = processor.apply_chat_template(
  663. ... messages,
  664. ... tokenize=True,
  665. ... return_dict=True,
  666. ... return_tensors="pt",
  667. ... add_generation_prompt=True
  668. ... )
  669. >>> # Generate
  670. >>> generate_ids = model.generate(**inputs)
  671. >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
  672. ```
  673. """
  674. PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS = OrderedDict(
  675. [
  676. ("text-to-audio-spectrogram", TEXT_TO_AUDIO_SPECTROGRAM_SAMPLE),
  677. ("text-to-audio-waveform", TEXT_TO_AUDIO_WAVEFORM_SAMPLE),
  678. ("automatic-speech-recognition", AUTOMATIC_SPEECH_RECOGNITION_SAMPLE),
  679. ("audio-frame-classification", AUDIO_FRAME_CLASSIFICATION_SAMPLE),
  680. ("audio-classification", AUDIO_CLASSIFICATION_SAMPLE),
  681. ("audio-xvector", AUDIO_XVECTOR_SAMPLE),
  682. ("image-text-to-text", IMAGE_TEXT_TO_TEXT_GENERATION_SAMPLE),
  683. ("image-to-text", IMAGE_TO_TEXT_SAMPLE),
  684. ("visual-question-answering", VISUAL_QUESTION_ANSWERING_SAMPLE),
  685. ("depth-estimation", DEPTH_ESTIMATION_SAMPLE),
  686. ("video-classification", VIDEO_CLASSIFICATION_SAMPLE),
  687. ("zero-shot-image-classification", ZERO_SHOT_IMAGE_CLASSIFICATION_SAMPLE),
  688. ("image-classification", IMAGE_CLASSIFICATION_SAMPLE),
  689. ("zero-shot-object-detection", ZERO_SHOT_OBJECT_DETECTION_SAMPLE),
  690. ("object-detection", OBJECT_DETECTION_SAMPLE),
  691. ("image-segmentation", IMAGE_SEGMENTATION_SAMPLE),
  692. ("image-to-image", IMAGE_TO_IMAGE_SAMPLE),
  693. ("image-feature-extraction", IMAGE_FEATURE_EXTRACTION_SAMPLE),
  694. ("text-generation", TEXT_GENERATION_SAMPLE),
  695. ("table-question-answering", TABLE_QUESTION_ANSWERING_SAMPLE),
  696. ("document-question-answering", DOCUMENT_QUESTION_ANSWERING_SAMPLE),
  697. ("question-answering", QUESTION_ANSWERING_SAMPLE),
  698. ("text2text-generation", TEXT2TEXT_GENERATION_SAMPLE),
  699. ("next-sentence-prediction", NEXT_SENTENCE_PREDICTION_SAMPLE),
  700. ("multiple-choice", MULTIPLE_CHOICE_SAMPLE),
  701. ("text-classification", TEXT_CLASSIFICATION_SAMPLE),
  702. ("token-classification", TOKEN_CLASSIFICATION_SAMPLE),
  703. ("fill-mask", FILL_MASK_SAMPLE),
  704. ("mask-generation", MASK_GENERATION_SAMPLE),
  705. ("pretraining", PRETRAINING_SAMPLE),
  706. ]
  707. )
  708. # Ordered dict to look for more specialized model mappings first
  709. # before falling back to the more generic ones.
  710. MODELS_TO_PIPELINE = OrderedDict(
  711. [
  712. # Audio
  713. ("MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES", "text-to-audio-spectrogram"),
  714. ("MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", "text-to-audio-waveform"),
  715. ("MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "automatic-speech-recognition"),
  716. ("MODEL_FOR_CTC_MAPPING_NAMES", "automatic-speech-recognition"),
  717. ("MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES", "audio-frame-classification"),
  718. ("MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES", "audio-classification"),
  719. ("MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "audio-xvector"),
  720. # Vision
  721. ("MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "image-text-to-text"),
  722. ("MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES", "image-to-text"),
  723. ("MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES", "visual-question-answering"),
  724. ("MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "depth-estimation"),
  725. ("MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "video-classification"),
  726. ("MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES", "zero-shot-image-classification"),
  727. ("MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "image-classification"),
  728. ("MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES", "zero-shot-object-detection"),
  729. ("MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "object-detection"),
  730. ("MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "image-segmentation"),
  731. ("MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES", "image-to-image"),
  732. ("MODEL_FOR_IMAGE_MAPPING_NAMES", "image-feature-extraction"),
  733. # Text/tokens
  734. ("MODEL_FOR_CAUSAL_LM_MAPPING_NAMES", "text-generation"),
  735. ("MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES", "table-question-answering"),
  736. ("MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES", "document-question-answering"),
  737. ("MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "question-answering"),
  738. ("MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "text2text-generation"),
  739. ("MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES", "next-sentence-prediction"),
  740. ("MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES", "multiple-choice"),
  741. ("MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "text-classification"),
  742. ("MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES", "token-classification"),
  743. ("MODEL_FOR_MASKED_LM_MAPPING_NAMES", "fill-mask"),
  744. ("MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "mask-generation"),
  745. ("MODEL_FOR_PRETRAINING_MAPPING_NAMES", "pretraining"),
  746. ]
  747. )
  748. TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
  749. Example:
  750. ```python
  751. >>> from transformers import AutoTokenizer, {model_class}
  752. >>> import tensorflow as tf
  753. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  754. >>> model = {model_class}.from_pretrained("{checkpoint}")
  755. >>> inputs = tokenizer(
  756. ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf"
  757. ... )
  758. >>> logits = model(**inputs).logits
  759. >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1)
  760. >>> # Note that tokens are classified rather then input words which means that
  761. >>> # there might be more predicted token classes than words.
  762. >>> # Multiple token classes might account for the same word
  763. >>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()]
  764. >>> predicted_tokens_classes
  765. {expected_output}
  766. ```
  767. ```python
  768. >>> labels = predicted_token_class_ids
  769. >>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss)
  770. >>> round(float(loss), 2)
  771. {expected_loss}
  772. ```
  773. """
  774. TF_QUESTION_ANSWERING_SAMPLE = r"""
  775. Example:
  776. ```python
  777. >>> from transformers import AutoTokenizer, {model_class}
  778. >>> import tensorflow as tf
  779. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  780. >>> model = {model_class}.from_pretrained("{checkpoint}")
  781. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  782. >>> inputs = tokenizer(question, text, return_tensors="tf")
  783. >>> outputs = model(**inputs)
  784. >>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
  785. >>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])
  786. >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
  787. >>> tokenizer.decode(predict_answer_tokens)
  788. {expected_output}
  789. ```
  790. ```python
  791. >>> # target is "nice puppet"
  792. >>> target_start_index = tf.constant([{qa_target_start_index}])
  793. >>> target_end_index = tf.constant([{qa_target_end_index}])
  794. >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
  795. >>> loss = tf.math.reduce_mean(outputs.loss)
  796. >>> round(float(loss), 2)
  797. {expected_loss}
  798. ```
  799. """
  800. TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
  801. Example:
  802. ```python
  803. >>> from transformers import AutoTokenizer, {model_class}
  804. >>> import tensorflow as tf
  805. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  806. >>> model = {model_class}.from_pretrained("{checkpoint}")
  807. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
  808. >>> logits = model(**inputs).logits
  809. >>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
  810. >>> model.config.id2label[predicted_class_id]
  811. {expected_output}
  812. ```
  813. ```python
  814. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  815. >>> num_labels = len(model.config.id2label)
  816. >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
  817. >>> labels = tf.constant(1)
  818. >>> loss = model(**inputs, labels=labels).loss
  819. >>> round(float(loss), 2)
  820. {expected_loss}
  821. ```
  822. """
  823. TF_MASKED_LM_SAMPLE = r"""
  824. Example:
  825. ```python
  826. >>> from transformers import AutoTokenizer, {model_class}
  827. >>> import tensorflow as tf
  828. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  829. >>> model = {model_class}.from_pretrained("{checkpoint}")
  830. >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
  831. >>> logits = model(**inputs).logits
  832. >>> # retrieve index of {mask}
  833. >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
  834. >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
  835. >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
  836. >>> tokenizer.decode(predicted_token_id)
  837. {expected_output}
  838. ```
  839. ```python
  840. >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
  841. >>> # mask labels of non-{mask} tokens
  842. >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
  843. >>> outputs = model(**inputs, labels=labels)
  844. >>> round(float(outputs.loss), 2)
  845. {expected_loss}
  846. ```
  847. """
  848. TF_BASE_MODEL_SAMPLE = r"""
  849. Example:
  850. ```python
  851. >>> from transformers import AutoTokenizer, {model_class}
  852. >>> import tensorflow as tf
  853. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  854. >>> model = {model_class}.from_pretrained("{checkpoint}")
  855. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
  856. >>> outputs = model(inputs)
  857. >>> last_hidden_states = outputs.last_hidden_state
  858. ```
  859. """
  860. TF_MULTIPLE_CHOICE_SAMPLE = r"""
  861. Example:
  862. ```python
  863. >>> from transformers import AutoTokenizer, {model_class}
  864. >>> import tensorflow as tf
  865. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  866. >>> model = {model_class}.from_pretrained("{checkpoint}")
  867. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  868. >>> choice0 = "It is eaten with a fork and a knife."
  869. >>> choice1 = "It is eaten while held in the hand."
  870. >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True)
  871. >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
  872. >>> outputs = model(inputs) # batch size is 1
  873. >>> # the linear classifier still needs to be trained
  874. >>> logits = outputs.logits
  875. ```
  876. """
  877. TF_CAUSAL_LM_SAMPLE = r"""
  878. Example:
  879. ```python
  880. >>> from transformers import AutoTokenizer, {model_class}
  881. >>> import tensorflow as tf
  882. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  883. >>> model = {model_class}.from_pretrained("{checkpoint}")
  884. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
  885. >>> outputs = model(inputs)
  886. >>> logits = outputs.logits
  887. ```
  888. """
  889. TF_SPEECH_BASE_MODEL_SAMPLE = r"""
  890. Example:
  891. ```python
  892. >>> from transformers import AutoProcessor, {model_class}
  893. >>> from datasets import load_dataset
  894. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  895. >>> dataset = dataset.sort("id")
  896. >>> sampling_rate = dataset.features["audio"].sampling_rate
  897. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  898. >>> model = {model_class}.from_pretrained("{checkpoint}")
  899. >>> # audio file is decoded on the fly
  900. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
  901. >>> outputs = model(**inputs)
  902. >>> last_hidden_states = outputs.last_hidden_state
  903. >>> list(last_hidden_states.shape)
  904. {expected_output}
  905. ```
  906. """
  907. TF_SPEECH_CTC_SAMPLE = r"""
  908. Example:
  909. ```python
  910. >>> from transformers import AutoProcessor, {model_class}
  911. >>> from datasets import load_dataset
  912. >>> import tensorflow as tf
  913. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  914. >>> dataset = dataset.sort("id")
  915. >>> sampling_rate = dataset.features["audio"].sampling_rate
  916. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  917. >>> model = {model_class}.from_pretrained("{checkpoint}")
  918. >>> # audio file is decoded on the fly
  919. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
  920. >>> logits = model(**inputs).logits
  921. >>> predicted_ids = tf.math.argmax(logits, axis=-1)
  922. >>> # transcribe speech
  923. >>> transcription = processor.batch_decode(predicted_ids)
  924. >>> transcription[0]
  925. {expected_output}
  926. ```
  927. ```python
  928. >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
  929. >>> # compute loss
  930. >>> loss = model(**inputs).loss
  931. >>> round(float(loss), 2)
  932. {expected_loss}
  933. ```
  934. """
  935. TF_VISION_BASE_MODEL_SAMPLE = r"""
  936. Example:
  937. ```python
  938. >>> from transformers import AutoImageProcessor, {model_class}
  939. >>> from datasets import load_dataset
  940. >>> dataset = load_dataset("huggingface/cats-image")
  941. >>> image = dataset["test"]["image"][0]
  942. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  943. >>> model = {model_class}.from_pretrained("{checkpoint}")
  944. >>> inputs = image_processor(image, return_tensors="tf")
  945. >>> outputs = model(**inputs)
  946. >>> last_hidden_states = outputs.last_hidden_state
  947. >>> list(last_hidden_states.shape)
  948. {expected_output}
  949. ```
  950. """
  951. TF_VISION_SEQ_CLASS_SAMPLE = r"""
  952. Example:
  953. ```python
  954. >>> from transformers import AutoImageProcessor, {model_class}
  955. >>> import tensorflow as tf
  956. >>> from datasets import load_dataset
  957. >>> dataset = load_dataset("huggingface/cats-image"))
  958. >>> image = dataset["test"]["image"][0]
  959. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  960. >>> model = {model_class}.from_pretrained("{checkpoint}")
  961. >>> inputs = image_processor(image, return_tensors="tf")
  962. >>> logits = model(**inputs).logits
  963. >>> # model predicts one of the 1000 ImageNet classes
  964. >>> predicted_label = int(tf.math.argmax(logits, axis=-1))
  965. >>> print(model.config.id2label[predicted_label])
  966. {expected_output}
  967. ```
  968. """
  969. TF_SAMPLE_DOCSTRINGS = {
  970. "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
  971. "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
  972. "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
  973. "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
  974. "MaskedLM": TF_MASKED_LM_SAMPLE,
  975. "LMHead": TF_CAUSAL_LM_SAMPLE,
  976. "BaseModel": TF_BASE_MODEL_SAMPLE,
  977. "SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE,
  978. "CTC": TF_SPEECH_CTC_SAMPLE,
  979. "VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE,
  980. "ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE,
  981. }
  982. FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
  983. Example:
  984. ```python
  985. >>> from transformers import AutoTokenizer, {model_class}
  986. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  987. >>> model = {model_class}.from_pretrained("{checkpoint}")
  988. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
  989. >>> outputs = model(**inputs)
  990. >>> logits = outputs.logits
  991. ```
  992. """
  993. FLAX_QUESTION_ANSWERING_SAMPLE = r"""
  994. Example:
  995. ```python
  996. >>> from transformers import AutoTokenizer, {model_class}
  997. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  998. >>> model = {model_class}.from_pretrained("{checkpoint}")
  999. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  1000. >>> inputs = tokenizer(question, text, return_tensors="jax")
  1001. >>> outputs = model(**inputs)
  1002. >>> start_scores = outputs.start_logits
  1003. >>> end_scores = outputs.end_logits
  1004. ```
  1005. """
  1006. FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
  1007. Example:
  1008. ```python
  1009. >>> from transformers import AutoTokenizer, {model_class}
  1010. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  1011. >>> model = {model_class}.from_pretrained("{checkpoint}")
  1012. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
  1013. >>> outputs = model(**inputs)
  1014. >>> logits = outputs.logits
  1015. ```
  1016. """
  1017. FLAX_MASKED_LM_SAMPLE = r"""
  1018. Example:
  1019. ```python
  1020. >>> from transformers import AutoTokenizer, {model_class}
  1021. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  1022. >>> model = {model_class}.from_pretrained("{checkpoint}")
  1023. >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax")
  1024. >>> outputs = model(**inputs)
  1025. >>> logits = outputs.logits
  1026. ```
  1027. """
  1028. FLAX_BASE_MODEL_SAMPLE = r"""
  1029. Example:
  1030. ```python
  1031. >>> from transformers import AutoTokenizer, {model_class}
  1032. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  1033. >>> model = {model_class}.from_pretrained("{checkpoint}")
  1034. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
  1035. >>> outputs = model(**inputs)
  1036. >>> last_hidden_states = outputs.last_hidden_state
  1037. ```
  1038. """
  1039. FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
  1040. Example:
  1041. ```python
  1042. >>> from transformers import AutoTokenizer, {model_class}
  1043. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  1044. >>> model = {model_class}.from_pretrained("{checkpoint}")
  1045. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  1046. >>> choice0 = "It is eaten with a fork and a knife."
  1047. >>> choice1 = "It is eaten while held in the hand."
  1048. >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True)
  1049. >>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}})
  1050. >>> logits = outputs.logits
  1051. ```
  1052. """
  1053. FLAX_CAUSAL_LM_SAMPLE = r"""
  1054. Example:
  1055. ```python
  1056. >>> from transformers import AutoTokenizer, {model_class}
  1057. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  1058. >>> model = {model_class}.from_pretrained("{checkpoint}")
  1059. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
  1060. >>> outputs = model(**inputs)
  1061. >>> # retrieve logts for next token
  1062. >>> next_token_logits = outputs.logits[:, -1]
  1063. ```
  1064. """
  1065. FLAX_SAMPLE_DOCSTRINGS = {
  1066. "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
  1067. "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
  1068. "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
  1069. "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
  1070. "MaskedLM": FLAX_MASKED_LM_SAMPLE,
  1071. "BaseModel": FLAX_BASE_MODEL_SAMPLE,
  1072. "LMHead": FLAX_CAUSAL_LM_SAMPLE,
  1073. }
  1074. def filter_outputs_from_example(docstring, **kwargs):
  1075. """
  1076. Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`.
  1077. """
  1078. for key, value in kwargs.items():
  1079. if value is not None:
  1080. continue
  1081. doc_key = "{" + key + "}"
  1082. docstring = re.sub(rf"\n([^\n]+)\n\s+{doc_key}\n", "\n", docstring)
  1083. return docstring
  1084. def add_code_sample_docstrings(
  1085. *docstr,
  1086. processor_class=None,
  1087. checkpoint=None,
  1088. output_type=None,
  1089. config_class=None,
  1090. mask="[MASK]",
  1091. qa_target_start_index=14,
  1092. qa_target_end_index=15,
  1093. model_cls=None,
  1094. modality=None,
  1095. expected_output=None,
  1096. expected_loss=None,
  1097. real_checkpoint=None,
  1098. revision=None,
  1099. ):
  1100. def docstring_decorator(fn):
  1101. # model_class defaults to function's class if not specified otherwise
  1102. model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
  1103. if model_class[:2] == "TF":
  1104. sample_docstrings = TF_SAMPLE_DOCSTRINGS
  1105. elif model_class[:4] == "Flax":
  1106. sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
  1107. else:
  1108. sample_docstrings = PT_SAMPLE_DOCSTRINGS
  1109. # putting all kwargs for docstrings in a dict to be used
  1110. # with the `.format(**doc_kwargs)`. Note that string might
  1111. # be formatted with non-existing keys, which is fine.
  1112. doc_kwargs = {
  1113. "model_class": model_class,
  1114. "processor_class": processor_class,
  1115. "checkpoint": checkpoint,
  1116. "mask": mask,
  1117. "qa_target_start_index": qa_target_start_index,
  1118. "qa_target_end_index": qa_target_end_index,
  1119. "expected_output": expected_output,
  1120. "expected_loss": expected_loss,
  1121. "real_checkpoint": real_checkpoint,
  1122. "fake_checkpoint": checkpoint,
  1123. "true": "{true}", # For <Tip warning={true}> syntax that conflicts with formatting.
  1124. }
  1125. if ("SequenceClassification" in model_class or "AudioClassification" in model_class) and modality == "audio":
  1126. code_sample = sample_docstrings["AudioClassification"]
  1127. elif "SequenceClassification" in model_class:
  1128. code_sample = sample_docstrings["SequenceClassification"]
  1129. elif "QuestionAnswering" in model_class:
  1130. code_sample = sample_docstrings["QuestionAnswering"]
  1131. elif "TokenClassification" in model_class:
  1132. code_sample = sample_docstrings["TokenClassification"]
  1133. elif "MultipleChoice" in model_class:
  1134. code_sample = sample_docstrings["MultipleChoice"]
  1135. elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
  1136. code_sample = sample_docstrings["MaskedLM"]
  1137. elif "LMHead" in model_class or "CausalLM" in model_class:
  1138. code_sample = sample_docstrings["LMHead"]
  1139. elif "CTC" in model_class:
  1140. code_sample = sample_docstrings["CTC"]
  1141. elif "AudioFrameClassification" in model_class:
  1142. code_sample = sample_docstrings["AudioFrameClassification"]
  1143. elif "XVector" in model_class and modality == "audio":
  1144. code_sample = sample_docstrings["AudioXVector"]
  1145. elif "Model" in model_class and modality == "audio":
  1146. code_sample = sample_docstrings["SpeechBaseModel"]
  1147. elif "Model" in model_class and modality == "vision":
  1148. code_sample = sample_docstrings["VisionBaseModel"]
  1149. elif "Model" in model_class or "Encoder" in model_class:
  1150. code_sample = sample_docstrings["BaseModel"]
  1151. elif "ImageClassification" in model_class:
  1152. code_sample = sample_docstrings["ImageClassification"]
  1153. else:
  1154. raise ValueError(f"Docstring can't be built for model {model_class}")
  1155. code_sample = filter_outputs_from_example(
  1156. code_sample, expected_output=expected_output, expected_loss=expected_loss
  1157. )
  1158. if real_checkpoint is not None:
  1159. code_sample = FAKE_MODEL_DISCLAIMER + code_sample
  1160. func_doc = (fn.__doc__ or "") + "".join(docstr)
  1161. output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
  1162. built_doc = code_sample.format(**doc_kwargs)
  1163. if revision is not None:
  1164. if re.match(r"^refs/pr/\\d+", revision):
  1165. raise ValueError(
  1166. f"The provided revision '{revision}' is incorrect. It should point to"
  1167. " a pull request reference on the hub like 'refs/pr/6'"
  1168. )
  1169. built_doc = built_doc.replace(
  1170. f'from_pretrained("{checkpoint}")', f'from_pretrained("{checkpoint}", revision="{revision}")'
  1171. )
  1172. fn.__doc__ = func_doc + output_doc + built_doc
  1173. return fn
  1174. return docstring_decorator
  1175. def replace_return_docstrings(output_type=None, config_class=None):
  1176. def docstring_decorator(fn):
  1177. func_doc = fn.__doc__
  1178. lines = func_doc.split("\n")
  1179. i = 0
  1180. while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
  1181. i += 1
  1182. if i < len(lines):
  1183. indent = len(_get_indent(lines[i]))
  1184. lines[i] = _prepare_output_docstrings(output_type, config_class, min_indent=indent)
  1185. func_doc = "\n".join(lines)
  1186. else:
  1187. raise ValueError(
  1188. f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, "
  1189. f"current docstring is:\n{func_doc}"
  1190. )
  1191. fn.__doc__ = func_doc
  1192. return fn
  1193. return docstring_decorator
  1194. def copy_func(f):
  1195. """Returns a copy of a function f."""
  1196. # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
  1197. g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
  1198. g = functools.update_wrapper(g, f)
  1199. g.__kwdefaults__ = f.__kwdefaults__
  1200. return g