token_classification.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. import types
  2. import warnings
  3. from typing import Any, Optional, Union, overload
  4. import numpy as np
  5. from ..models.bert.tokenization_bert import BasicTokenizer
  6. from ..utils import (
  7. ExplicitEnum,
  8. add_end_docstrings,
  9. is_tf_available,
  10. is_torch_available,
  11. )
  12. from .base import ArgumentHandler, ChunkPipeline, Dataset, build_pipeline_init_args
  13. if is_tf_available():
  14. import tensorflow as tf
  15. from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  16. if is_torch_available():
  17. import torch
  18. from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  19. class TokenClassificationArgumentHandler(ArgumentHandler):
  20. """
  21. Handles arguments for token classification.
  22. """
  23. def __call__(self, inputs: Union[str, list[str]], **kwargs):
  24. is_split_into_words = kwargs.get("is_split_into_words", False)
  25. delimiter = kwargs.get("delimiter")
  26. if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
  27. inputs = list(inputs)
  28. batch_size = len(inputs)
  29. elif isinstance(inputs, str):
  30. inputs = [inputs]
  31. batch_size = 1
  32. elif Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType):
  33. return inputs, is_split_into_words, None, delimiter
  34. else:
  35. raise ValueError("At least one input is required.")
  36. offset_mapping = kwargs.get("offset_mapping")
  37. if offset_mapping:
  38. if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
  39. offset_mapping = [offset_mapping]
  40. if len(offset_mapping) != batch_size:
  41. raise ValueError("offset_mapping should have the same batch size as the input")
  42. return inputs, is_split_into_words, offset_mapping, delimiter
  43. class AggregationStrategy(ExplicitEnum):
  44. """All the valid aggregation strategies for TokenClassificationPipeline"""
  45. NONE = "none"
  46. SIMPLE = "simple"
  47. FIRST = "first"
  48. AVERAGE = "average"
  49. MAX = "max"
  50. @add_end_docstrings(
  51. build_pipeline_init_args(has_tokenizer=True),
  52. r"""
  53. ignore_labels (`list[str]`, defaults to `["O"]`):
  54. A list of labels to ignore.
  55. grouped_entities (`bool`, *optional*, defaults to `False`):
  56. DEPRECATED, use `aggregation_strategy` instead. Whether or not to group the tokens corresponding to the
  57. same entity together in the predictions or not.
  58. stride (`int`, *optional*):
  59. If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
  60. model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
  61. value of this argument defines the number of overlapping tokens between chunks. In other words, the model
  62. will shift forward by `tokenizer.model_max_length - stride` tokens each step.
  63. aggregation_strategy (`str`, *optional*, defaults to `"none"`):
  64. The strategy to fuse (or not) tokens based on the model prediction.
  65. - "none" : Will simply not do any aggregation and simply return raw results from the model
  66. - "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
  67. I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
  68. "entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
  69. different entities. On word based languages, we might end up splitting words undesirably : Imagine
  70. Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
  71. "NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
  72. that support that meaning, which is basically tokens separated by a space). These mitigations will
  73. only work on real words, "New york" might still be tagged with two different entities.
  74. - "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
  75. end up with different tags. Words will simply use the tag of the first token of the word when there
  76. is ambiguity.
  77. - "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
  78. cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
  79. label is applied.
  80. - "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
  81. end up with different tags. Word entity will simply be the token with the maximum score.""",
  82. )
  83. class TokenClassificationPipeline(ChunkPipeline):
  84. """
  85. Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
  86. examples](../task_summary#named-entity-recognition) for more information.
  87. Example:
  88. ```python
  89. >>> from transformers import pipeline
  90. >>> token_classifier = pipeline(model="Jean-Baptiste/camembert-ner", aggregation_strategy="simple")
  91. >>> sentence = "Je m'appelle jean-baptiste et je vis à montréal"
  92. >>> tokens = token_classifier(sentence)
  93. >>> tokens
  94. [{'entity_group': 'PER', 'score': 0.9931, 'word': 'jean-baptiste', 'start': 12, 'end': 26}, {'entity_group': 'LOC', 'score': 0.998, 'word': 'montréal', 'start': 38, 'end': 47}]
  95. >>> token = tokens[0]
  96. >>> # Start and end provide an easy way to highlight words in the original text.
  97. >>> sentence[token["start"] : token["end"]]
  98. ' jean-baptiste'
  99. >>> # Some models use the same idea to do part of speech.
  100. >>> syntaxer = pipeline(model="vblagoje/bert-english-uncased-finetuned-pos", aggregation_strategy="simple")
  101. >>> syntaxer("My name is Sarah and I live in London")
  102. [{'entity_group': 'PRON', 'score': 0.999, 'word': 'my', 'start': 0, 'end': 2}, {'entity_group': 'NOUN', 'score': 0.997, 'word': 'name', 'start': 3, 'end': 7}, {'entity_group': 'AUX', 'score': 0.994, 'word': 'is', 'start': 8, 'end': 10}, {'entity_group': 'PROPN', 'score': 0.999, 'word': 'sarah', 'start': 11, 'end': 16}, {'entity_group': 'CCONJ', 'score': 0.999, 'word': 'and', 'start': 17, 'end': 20}, {'entity_group': 'PRON', 'score': 0.999, 'word': 'i', 'start': 21, 'end': 22}, {'entity_group': 'VERB', 'score': 0.998, 'word': 'live', 'start': 23, 'end': 27}, {'entity_group': 'ADP', 'score': 0.999, 'word': 'in', 'start': 28, 'end': 30}, {'entity_group': 'PROPN', 'score': 0.999, 'word': 'london', 'start': 31, 'end': 37}]
  103. ```
  104. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  105. This token recognition pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  106. `"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location or miscellaneous).
  107. The models that this pipeline can use are models that have been fine-tuned on a token classification task. See the
  108. up-to-date list of available models on
  109. [huggingface.co/models](https://huggingface.co/models?filter=token-classification).
  110. """
  111. default_input_names = "sequences"
  112. _load_processor = False
  113. _load_image_processor = False
  114. _load_feature_extractor = False
  115. _load_tokenizer = True
  116. def __init__(self, args_parser=TokenClassificationArgumentHandler(), **kwargs):
  117. super().__init__(**kwargs)
  118. self.check_model_type(
  119. TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  120. if self.framework == "tf"
  121. else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  122. )
  123. self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
  124. self._args_parser = args_parser
  125. def _sanitize_parameters(
  126. self,
  127. ignore_labels=None,
  128. grouped_entities: Optional[bool] = None,
  129. ignore_subwords: Optional[bool] = None,
  130. aggregation_strategy: Optional[AggregationStrategy] = None,
  131. offset_mapping: Optional[list[tuple[int, int]]] = None,
  132. is_split_into_words: bool = False,
  133. stride: Optional[int] = None,
  134. delimiter: Optional[str] = None,
  135. ):
  136. preprocess_params = {}
  137. preprocess_params["is_split_into_words"] = is_split_into_words
  138. if is_split_into_words:
  139. preprocess_params["delimiter"] = " " if delimiter is None else delimiter
  140. if offset_mapping is not None:
  141. preprocess_params["offset_mapping"] = offset_mapping
  142. postprocess_params = {}
  143. if grouped_entities is not None or ignore_subwords is not None:
  144. if grouped_entities and ignore_subwords:
  145. aggregation_strategy = AggregationStrategy.FIRST
  146. elif grouped_entities and not ignore_subwords:
  147. aggregation_strategy = AggregationStrategy.SIMPLE
  148. else:
  149. aggregation_strategy = AggregationStrategy.NONE
  150. if grouped_entities is not None:
  151. warnings.warn(
  152. "`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to"
  153. f' `aggregation_strategy="{aggregation_strategy}"` instead.'
  154. )
  155. if ignore_subwords is not None:
  156. warnings.warn(
  157. "`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to"
  158. f' `aggregation_strategy="{aggregation_strategy}"` instead.'
  159. )
  160. if aggregation_strategy is not None:
  161. if isinstance(aggregation_strategy, str):
  162. aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
  163. if (
  164. aggregation_strategy
  165. in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
  166. and not self.tokenizer.is_fast
  167. ):
  168. raise ValueError(
  169. "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
  170. ' to `"simple"` or use a fast tokenizer.'
  171. )
  172. postprocess_params["aggregation_strategy"] = aggregation_strategy
  173. if ignore_labels is not None:
  174. postprocess_params["ignore_labels"] = ignore_labels
  175. if stride is not None:
  176. if stride >= self.tokenizer.model_max_length:
  177. raise ValueError(
  178. "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
  179. )
  180. if aggregation_strategy == AggregationStrategy.NONE:
  181. raise ValueError(
  182. "`stride` was provided to process all the text but `aggregation_strategy="
  183. f'"{aggregation_strategy}"`, please select another one instead.'
  184. )
  185. else:
  186. if self.tokenizer.is_fast:
  187. tokenizer_params = {
  188. "return_overflowing_tokens": True,
  189. "padding": True,
  190. "stride": stride,
  191. }
  192. preprocess_params["tokenizer_params"] = tokenizer_params
  193. else:
  194. raise ValueError(
  195. "`stride` was provided to process all the text but you're using a slow tokenizer."
  196. " Please use a fast tokenizer."
  197. )
  198. return preprocess_params, {}, postprocess_params
  199. @overload
  200. def __call__(self, inputs: str, **kwargs: Any) -> list[dict[str, str]]: ...
  201. @overload
  202. def __call__(self, inputs: list[str], **kwargs: Any) -> list[list[dict[str, str]]]: ...
  203. def __call__(
  204. self, inputs: Union[str, list[str]], **kwargs: Any
  205. ) -> Union[list[dict[str, str]], list[list[dict[str, str]]]]:
  206. """
  207. Classify each token of the text(s) given as inputs.
  208. Args:
  209. inputs (`str` or `List[str]`):
  210. One or several texts (or one list of texts) for token classification. Can be pre-tokenized when
  211. `is_split_into_words=True`.
  212. Return:
  213. A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
  214. corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy) with
  215. the following keys:
  216. - **word** (`str`) -- The token/word classified. This is obtained by decoding the selected tokens. If you
  217. want to have the exact string in the original sentence, use `start` and `end`.
  218. - **score** (`float`) -- The corresponding probability for `entity`.
  219. - **entity** (`str`) -- The entity predicted for that token/word (it is named *entity_group* when
  220. *aggregation_strategy* is not `"none"`.
  221. - **index** (`int`, only present when `aggregation_strategy="none"`) -- The index of the corresponding
  222. token in the sentence.
  223. - **start** (`int`, *optional*) -- The index of the start of the corresponding entity in the sentence. Only
  224. exists if the offsets are available within the tokenizer
  225. - **end** (`int`, *optional*) -- The index of the end of the corresponding entity in the sentence. Only
  226. exists if the offsets are available within the tokenizer
  227. """
  228. _inputs, is_split_into_words, offset_mapping, delimiter = self._args_parser(inputs, **kwargs)
  229. kwargs["is_split_into_words"] = is_split_into_words
  230. kwargs["delimiter"] = delimiter
  231. if is_split_into_words and not all(isinstance(input, list) for input in inputs):
  232. return super().__call__([inputs], **kwargs)
  233. if offset_mapping:
  234. kwargs["offset_mapping"] = offset_mapping
  235. return super().__call__(inputs, **kwargs)
  236. def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
  237. tokenizer_params = preprocess_params.pop("tokenizer_params", {})
  238. truncation = self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0
  239. word_to_chars_map = None
  240. is_split_into_words = preprocess_params["is_split_into_words"]
  241. if is_split_into_words:
  242. delimiter = preprocess_params["delimiter"]
  243. if not isinstance(sentence, list):
  244. raise ValueError("When `is_split_into_words=True`, `sentence` must be a list of tokens.")
  245. words = sentence
  246. sentence = delimiter.join(words) # Recreate the sentence string for later display and slicing
  247. # This map will allows to convert back word => char indices
  248. word_to_chars_map = []
  249. delimiter_len = len(delimiter)
  250. char_offset = 0
  251. for word in words:
  252. word_to_chars_map.append((char_offset, char_offset + len(word)))
  253. char_offset += len(word) + delimiter_len
  254. # We use `words` as the actual input for the tokenizer
  255. text_to_tokenize = words
  256. tokenizer_params["is_split_into_words"] = True
  257. else:
  258. if not isinstance(sentence, str):
  259. raise ValueError("When `is_split_into_words=False`, `sentence` must be an untokenized string.")
  260. text_to_tokenize = sentence
  261. inputs = self.tokenizer(
  262. text_to_tokenize,
  263. return_tensors=self.framework,
  264. truncation=truncation,
  265. return_special_tokens_mask=True,
  266. return_offsets_mapping=self.tokenizer.is_fast,
  267. **tokenizer_params,
  268. )
  269. if is_split_into_words and not self.tokenizer.is_fast:
  270. raise ValueError("is_split_into_words=True is only supported with fast tokenizers.")
  271. inputs.pop("overflow_to_sample_mapping", None)
  272. num_chunks = len(inputs["input_ids"])
  273. for i in range(num_chunks):
  274. if self.framework == "tf":
  275. model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()}
  276. else:
  277. model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
  278. if offset_mapping is not None:
  279. model_inputs["offset_mapping"] = offset_mapping
  280. model_inputs["sentence"] = sentence if i == 0 else None
  281. model_inputs["is_last"] = i == num_chunks - 1
  282. if word_to_chars_map is not None:
  283. model_inputs["word_ids"] = inputs.word_ids(i)
  284. model_inputs["word_to_chars_map"] = word_to_chars_map
  285. yield model_inputs
  286. def _forward(self, model_inputs):
  287. # Forward
  288. special_tokens_mask = model_inputs.pop("special_tokens_mask")
  289. offset_mapping = model_inputs.pop("offset_mapping", None)
  290. sentence = model_inputs.pop("sentence")
  291. is_last = model_inputs.pop("is_last")
  292. word_ids = model_inputs.pop("word_ids", None)
  293. word_to_chars_map = model_inputs.pop("word_to_chars_map", None)
  294. if self.framework == "tf":
  295. logits = self.model(**model_inputs)[0]
  296. else:
  297. output = self.model(**model_inputs)
  298. logits = output["logits"] if isinstance(output, dict) else output[0]
  299. return {
  300. "logits": logits,
  301. "special_tokens_mask": special_tokens_mask,
  302. "offset_mapping": offset_mapping,
  303. "sentence": sentence,
  304. "is_last": is_last,
  305. "word_ids": word_ids,
  306. "word_to_chars_map": word_to_chars_map,
  307. **model_inputs,
  308. }
  309. def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
  310. if ignore_labels is None:
  311. ignore_labels = ["O"]
  312. all_entities = []
  313. # Get map from the first output, it's the same for all chunks
  314. word_to_chars_map = all_outputs[0].get("word_to_chars_map")
  315. for model_outputs in all_outputs:
  316. if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
  317. logits = model_outputs["logits"][0].to(torch.float32).numpy()
  318. else:
  319. logits = model_outputs["logits"][0].numpy()
  320. sentence = all_outputs[0]["sentence"]
  321. input_ids = model_outputs["input_ids"][0]
  322. offset_mapping = (
  323. model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
  324. )
  325. special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
  326. word_ids = model_outputs.get("word_ids")
  327. maxes = np.max(logits, axis=-1, keepdims=True)
  328. shifted_exp = np.exp(logits - maxes)
  329. scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
  330. if self.framework == "tf":
  331. input_ids = input_ids.numpy()
  332. offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
  333. pre_entities = self.gather_pre_entities(
  334. sentence,
  335. input_ids,
  336. scores,
  337. offset_mapping,
  338. special_tokens_mask,
  339. aggregation_strategy,
  340. word_ids=word_ids,
  341. word_to_chars_map=word_to_chars_map,
  342. )
  343. grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
  344. # Filter anything that is in self.ignore_labels
  345. entities = [
  346. entity
  347. for entity in grouped_entities
  348. if entity.get("entity", None) not in ignore_labels
  349. and entity.get("entity_group", None) not in ignore_labels
  350. ]
  351. all_entities.extend(entities)
  352. num_chunks = len(all_outputs)
  353. if num_chunks > 1:
  354. all_entities = self.aggregate_overlapping_entities(all_entities)
  355. return all_entities
  356. def aggregate_overlapping_entities(self, entities):
  357. if len(entities) == 0:
  358. return entities
  359. entities = sorted(entities, key=lambda x: x["start"])
  360. aggregated_entities = []
  361. previous_entity = entities[0]
  362. for entity in entities:
  363. if previous_entity["start"] <= entity["start"] < previous_entity["end"]:
  364. current_length = entity["end"] - entity["start"]
  365. previous_length = previous_entity["end"] - previous_entity["start"]
  366. if (
  367. current_length > previous_length
  368. or current_length == previous_length
  369. and entity["score"] > previous_entity["score"]
  370. ):
  371. previous_entity = entity
  372. else:
  373. aggregated_entities.append(previous_entity)
  374. previous_entity = entity
  375. aggregated_entities.append(previous_entity)
  376. return aggregated_entities
  377. def gather_pre_entities(
  378. self,
  379. sentence: str,
  380. input_ids: np.ndarray,
  381. scores: np.ndarray,
  382. offset_mapping: Optional[list[tuple[int, int]]],
  383. special_tokens_mask: np.ndarray,
  384. aggregation_strategy: AggregationStrategy,
  385. word_ids: Optional[list[Optional[int]]] = None,
  386. word_to_chars_map: Optional[list[tuple[int, int]]] = None,
  387. ) -> list[dict]:
  388. """Fuse various numpy arrays into dicts with all the information needed for aggregation"""
  389. pre_entities = []
  390. for idx, token_scores in enumerate(scores):
  391. # Filter special_tokens
  392. if special_tokens_mask[idx]:
  393. continue
  394. word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
  395. if offset_mapping is not None:
  396. start_ind, end_ind = offset_mapping[idx]
  397. # If the input is pre-tokenized, we need to rescale the offsets to the absolute sentence.
  398. if word_ids is not None and word_to_chars_map is not None:
  399. word_index = word_ids[idx]
  400. if word_index is not None:
  401. start_char, _ = word_to_chars_map[word_index]
  402. start_ind += start_char
  403. end_ind += start_char
  404. if not isinstance(start_ind, int):
  405. if self.framework == "pt":
  406. start_ind = start_ind.item()
  407. end_ind = end_ind.item()
  408. word_ref = sentence[start_ind:end_ind]
  409. if getattr(self.tokenizer, "_tokenizer", None) and getattr(
  410. self.tokenizer._tokenizer.model, "continuing_subword_prefix", None
  411. ):
  412. # This is a BPE, word aware tokenizer, there is a correct way
  413. # to fuse tokens
  414. is_subword = len(word) != len(word_ref)
  415. else:
  416. # This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
  417. if aggregation_strategy in {
  418. AggregationStrategy.FIRST,
  419. AggregationStrategy.AVERAGE,
  420. AggregationStrategy.MAX,
  421. }:
  422. warnings.warn(
  423. "Tokenizer does not support real words, using fallback heuristic",
  424. UserWarning,
  425. )
  426. is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1]
  427. if int(input_ids[idx]) == self.tokenizer.unk_token_id:
  428. word = word_ref
  429. is_subword = False
  430. else:
  431. start_ind = None
  432. end_ind = None
  433. is_subword = False
  434. pre_entity = {
  435. "word": word,
  436. "scores": token_scores,
  437. "start": start_ind,
  438. "end": end_ind,
  439. "index": idx,
  440. "is_subword": is_subword,
  441. }
  442. pre_entities.append(pre_entity)
  443. return pre_entities
  444. def aggregate(self, pre_entities: list[dict], aggregation_strategy: AggregationStrategy) -> list[dict]:
  445. if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}:
  446. entities = []
  447. for pre_entity in pre_entities:
  448. entity_idx = pre_entity["scores"].argmax()
  449. score = pre_entity["scores"][entity_idx]
  450. entity = {
  451. "entity": self.model.config.id2label[entity_idx],
  452. "score": score,
  453. "index": pre_entity["index"],
  454. "word": pre_entity["word"],
  455. "start": pre_entity["start"],
  456. "end": pre_entity["end"],
  457. }
  458. entities.append(entity)
  459. else:
  460. entities = self.aggregate_words(pre_entities, aggregation_strategy)
  461. if aggregation_strategy == AggregationStrategy.NONE:
  462. return entities
  463. return self.group_entities(entities)
  464. def aggregate_word(self, entities: list[dict], aggregation_strategy: AggregationStrategy) -> dict:
  465. word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
  466. if aggregation_strategy == AggregationStrategy.FIRST:
  467. scores = entities[0]["scores"]
  468. idx = scores.argmax()
  469. score = scores[idx]
  470. entity = self.model.config.id2label[idx]
  471. elif aggregation_strategy == AggregationStrategy.MAX:
  472. max_entity = max(entities, key=lambda entity: entity["scores"].max())
  473. scores = max_entity["scores"]
  474. idx = scores.argmax()
  475. score = scores[idx]
  476. entity = self.model.config.id2label[idx]
  477. elif aggregation_strategy == AggregationStrategy.AVERAGE:
  478. scores = np.stack([entity["scores"] for entity in entities])
  479. average_scores = np.nanmean(scores, axis=0)
  480. entity_idx = average_scores.argmax()
  481. entity = self.model.config.id2label[entity_idx]
  482. score = average_scores[entity_idx]
  483. else:
  484. raise ValueError("Invalid aggregation_strategy")
  485. new_entity = {
  486. "entity": entity,
  487. "score": score,
  488. "word": word,
  489. "start": entities[0]["start"],
  490. "end": entities[-1]["end"],
  491. }
  492. return new_entity
  493. def aggregate_words(self, entities: list[dict], aggregation_strategy: AggregationStrategy) -> list[dict]:
  494. """
  495. Override tokens from a given word that disagree to force agreement on word boundaries.
  496. Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
  497. company| B-ENT I-ENT
  498. """
  499. if aggregation_strategy in {
  500. AggregationStrategy.NONE,
  501. AggregationStrategy.SIMPLE,
  502. }:
  503. raise ValueError("NONE and SIMPLE strategies are invalid for word aggregation")
  504. word_entities = []
  505. word_group = None
  506. for entity in entities:
  507. if word_group is None:
  508. word_group = [entity]
  509. elif entity["is_subword"]:
  510. word_group.append(entity)
  511. else:
  512. word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
  513. word_group = [entity]
  514. # Last item
  515. if word_group is not None:
  516. word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
  517. return word_entities
  518. def group_sub_entities(self, entities: list[dict]) -> dict:
  519. """
  520. Group together the adjacent tokens with the same entity predicted.
  521. Args:
  522. entities (`dict`): The entities predicted by the pipeline.
  523. """
  524. # Get the first entity in the entity group
  525. entity = entities[0]["entity"].split("-", 1)[-1]
  526. scores = np.nanmean([entity["score"] for entity in entities])
  527. tokens = [entity["word"] for entity in entities]
  528. entity_group = {
  529. "entity_group": entity,
  530. "score": np.mean(scores),
  531. "word": self.tokenizer.convert_tokens_to_string(tokens),
  532. "start": entities[0]["start"],
  533. "end": entities[-1]["end"],
  534. }
  535. return entity_group
  536. def get_tag(self, entity_name: str) -> tuple[str, str]:
  537. if entity_name.startswith("B-"):
  538. bi = "B"
  539. tag = entity_name[2:]
  540. elif entity_name.startswith("I-"):
  541. bi = "I"
  542. tag = entity_name[2:]
  543. else:
  544. # It's not in B-, I- format
  545. # Default to I- for continuation.
  546. bi = "I"
  547. tag = entity_name
  548. return bi, tag
  549. def group_entities(self, entities: list[dict]) -> list[dict]:
  550. """
  551. Find and group together the adjacent tokens with the same entity predicted.
  552. Args:
  553. entities (`dict`): The entities predicted by the pipeline.
  554. """
  555. entity_groups = []
  556. entity_group_disagg = []
  557. for entity in entities:
  558. if not entity_group_disagg:
  559. entity_group_disagg.append(entity)
  560. continue
  561. # If the current entity is similar and adjacent to the previous entity,
  562. # append it to the disaggregated entity group
  563. # The split is meant to account for the "B" and "I" prefixes
  564. # Shouldn't merge if both entities are B-type
  565. bi, tag = self.get_tag(entity["entity"])
  566. last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])
  567. if tag == last_tag and bi != "B":
  568. # Modify subword type to be previous_type
  569. entity_group_disagg.append(entity)
  570. else:
  571. # If the current entity is different from the previous entity
  572. # aggregate the disaggregated entity group
  573. entity_groups.append(self.group_sub_entities(entity_group_disagg))
  574. entity_group_disagg = [entity]
  575. if entity_group_disagg:
  576. # it's the last entity, add it to the entity groups
  577. entity_groups.append(self.group_sub_entities(entity_group_disagg))
  578. return entity_groups
  579. NerPipeline = TokenClassificationPipeline