streaming_output.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import types
  3. import warnings
  4. from contextlib import contextmanager
  5. from typing import Any, Dict, Generator, List, Optional, Union
  6. import torch
  7. import torch.distributed as dist
  8. import transformers
  9. from packaging import version
  10. from torch import nn
  11. from transformers import PreTrainedModel
  12. from transformers.generation import GreedySearchDecoderOnlyOutput # noqa
  13. from transformers.generation import (GreedySearchEncoderDecoderOutput,
  14. LogitsProcessorList,
  15. SampleDecoderOnlyOutput,
  16. SampleEncoderDecoderOutput,
  17. StoppingCriteriaList,
  18. validate_stopping_criteria)
  19. from modelscope.pipelines.base import Input
  20. from modelscope.utils.constant import Frameworks
  21. from modelscope.utils.device import device_placement
  22. class StreamingOutputMixin:
  23. def stream_generate(self, *args, **kwargs) -> Generator:
  24. """
  25. Support the input of Model and Pipeline.
  26. The output is a `Generator` type,
  27. which conforms to the output standard of modelscope.
  28. """
  29. raise NotImplementedError
  30. class PipelineStreamingOutputMixin(StreamingOutputMixin):
  31. def stream_generate(self, input: Union[Input, List[Input]], *args,
  32. **kwargs) -> Generator:
  33. """
  34. Similar to the `Pipeline.__call__` method.
  35. it supports the input that the pipeline can accept,
  36. and also supports batch input.
  37. self.model must be a subclass of StreamingOutputMixin
  38. and implement the stream method.
  39. """
  40. assert isinstance(self.model, StreamingOutputMixin
  41. ), 'pipeline.model must be StreamingOutputMixin!'
  42. if (self.model or (self.has_multiple_models and self.models[0])):
  43. if not self._model_prepare:
  44. self.prepare_model()
  45. batch_size = kwargs.pop('batch_size', None)
  46. preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
  47. **kwargs)
  48. if isinstance(input, list):
  49. model_input_list = [
  50. self._preprocess_with_check(i, preprocess_params)
  51. for i in input
  52. ]
  53. if batch_size is None:
  54. output = []
  55. for ele in model_input_list:
  56. output.append(
  57. self._stream_single(ele, forward_params,
  58. postprocess_params))
  59. else:
  60. output = self._stream_batch(model_input_list, batch_size,
  61. forward_params, postprocess_params)
  62. else:
  63. model_input = self._preprocess_with_check(input, preprocess_params)
  64. output = self._stream_single(model_input, forward_params,
  65. postprocess_params)
  66. return output
  67. def _preprocess_with_check(
  68. self, input: Input,
  69. preprocess_params: Dict[str, Any]) -> Dict[str, Any]:
  70. self._check_input(input)
  71. return self.preprocess(input, **preprocess_params)
  72. def _stream_single(self, model_input: Dict[str, Any],
  73. forward_params: Dict[str, Any],
  74. postprocess_params: Dict[str, Any]) -> Generator:
  75. with device_placement(self.framework, self.device_name):
  76. if self.framework == Frameworks.torch:
  77. with torch.no_grad():
  78. if self._auto_collate:
  79. model_input = self._collate_fn(model_input)
  80. stream = self.model.stream_generate(
  81. model_input, **forward_params)
  82. else:
  83. stream = self.model.stream_generate(model_input,
  84. **forward_params)
  85. for out in stream:
  86. out = self.postprocess(out, **postprocess_params)
  87. self._check_output(out)
  88. yield out
  89. def _stream_batch(self, model_input_list: List[Dict[str, Any]],
  90. batch_size: int, forward_params: Dict[str, Any],
  91. postprocess_params: Dict[str, Any]) -> Generator:
  92. stream_list = []
  93. real_batch_sizes = []
  94. with device_placement(self.framework, self.device_name):
  95. for i in range(0, len(model_input_list), batch_size):
  96. end = min(i + batch_size, len(model_input_list))
  97. real_batch_size = end - i
  98. real_batch_sizes.append(real_batch_size)
  99. batched_out = self._batch(model_input_list[i:end])
  100. if self.framework == Frameworks.torch:
  101. with torch.no_grad():
  102. if self._auto_collate:
  103. batched_out = self._collate_fn(batched_out)
  104. stream_list.append(
  105. self.model.stream_generate(batched_out,
  106. **forward_params))
  107. else:
  108. stream_list.append(
  109. self.model.stream_generate(batched_out,
  110. **forward_params))
  111. output_list = [None] * len(model_input_list)
  112. stop_streams = 0
  113. while stop_streams < len(stream_list):
  114. stop_streams = 0
  115. for i, (stream, real_batch_size) in enumerate(
  116. zip(stream_list, real_batch_sizes)):
  117. try:
  118. batched_out = next(stream)
  119. for batch_idx in range(real_batch_size):
  120. out = {}
  121. for k, element in batched_out.items():
  122. if element is not None:
  123. if isinstance(element, (tuple, list)):
  124. if isinstance(element[0],
  125. torch.Tensor):
  126. out[k] = type(element)(
  127. e[batch_idx:batch_idx + 1]
  128. for e in element)
  129. else:
  130. # Compatible with traditional pipelines
  131. out[k] = element[batch_idx]
  132. else:
  133. out[k] = element[batch_idx:batch_idx
  134. + 1]
  135. out = self.postprocess(out, **postprocess_params)
  136. self._check_output(out)
  137. output_index = i * batch_size + batch_idx
  138. output_list[output_index] = out
  139. except StopIteration:
  140. stop_streams += 1
  141. yield output_list
  142. return output_list
  143. class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
  144. def stream_generate(self, *args, **kwargs) -> Generator:
  145. model = self if isinstance(self, PreTrainedModel) else self.model
  146. assert isinstance(model, PreTrainedModel), \
  147. 'self or self.model must be `PretrainedModel`!'
  148. with self._replace_generate(model):
  149. return model.generate(*args, **kwargs)
  150. @contextmanager
  151. def _replace_generate(self, model: PreTrainedModel) -> Generator:
  152. if version.parse(transformers.__version__) >= version.parse('4.43.0'):
  153. greedy_search_name = 'stream_greedy_search'
  154. sample_name = '_sample'
  155. elif version.parse(
  156. transformers.__version__) >= version.parse('4.39.0'):
  157. greedy_search_name = '_greedy_search'
  158. sample_name = '_sample'
  159. else:
  160. greedy_search_name = 'greedy_search'
  161. sample_name = 'sample'
  162. origin_greedy_search = getattr(model, greedy_search_name)
  163. origin_sample = getattr(model, sample_name)
  164. setattr(model, greedy_search_name,
  165. types.MethodType(self.stream_greedy_search, model))
  166. setattr(model, sample_name, types.MethodType(self.stream_sample,
  167. model))
  168. yield
  169. setattr(model, greedy_search_name, origin_greedy_search)
  170. setattr(model, sample_name, origin_sample)
  171. @staticmethod
  172. def stream_greedy_search(
  173. self,
  174. input_ids: torch.LongTensor,
  175. logits_processor: Optional[LogitsProcessorList] = None,
  176. stopping_criteria: Optional[StoppingCriteriaList] = None,
  177. max_length: Optional[int] = None,
  178. pad_token_id: Optional[int] = None,
  179. eos_token_id: Optional[Union[int, List[int]]] = None,
  180. output_attentions: Optional[bool] = None,
  181. output_hidden_states: Optional[bool] = None,
  182. output_scores: Optional[bool] = None,
  183. return_dict_in_generate: Optional[bool] = None,
  184. synced_gpus: bool = False,
  185. **model_kwargs,
  186. ) -> Generator:
  187. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
  188. )
  189. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList(
  190. )
  191. if max_length is not None:
  192. warnings.warn(
  193. '`max_length` is deprecated in this function, use'
  194. ' `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.',
  195. UserWarning,
  196. )
  197. stopping_criteria = validate_stopping_criteria(
  198. stopping_criteria, max_length)
  199. pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
  200. eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
  201. if isinstance(eos_token_id, int):
  202. eos_token_id = [eos_token_id]
  203. eos_token_id_tensor = torch.tensor(eos_token_id).to(
  204. input_ids.device) if eos_token_id is not None else None
  205. output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
  206. output_attentions = (
  207. output_attentions if output_attentions is not None else
  208. self.generation_config.output_attentions)
  209. output_hidden_states = (
  210. output_hidden_states if output_hidden_states is not None else
  211. self.generation_config.output_hidden_states)
  212. return_dict_in_generate = (
  213. return_dict_in_generate if return_dict_in_generate is not None else
  214. self.generation_config.return_dict_in_generate)
  215. # init attention / hidden states / scores tuples
  216. scores = () if (return_dict_in_generate and output_scores) else None
  217. decoder_attentions = () if (return_dict_in_generate
  218. and output_attentions) else None
  219. cross_attentions = () if (return_dict_in_generate
  220. and output_attentions) else None
  221. decoder_hidden_states = () if (return_dict_in_generate
  222. and output_hidden_states) else None
  223. # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
  224. if return_dict_in_generate and self.config.is_encoder_decoder:
  225. encoder_attentions = model_kwargs['encoder_outputs'].get(
  226. 'attentions') if output_attentions else None
  227. encoder_hidden_states = (
  228. model_kwargs['encoder_outputs'].get('hidden_states')
  229. if output_hidden_states else None)
  230. # keep track of which sequences are already finished
  231. unfinished_sequences = torch.ones(
  232. input_ids.shape[0], dtype=torch.long, device=input_ids.device)
  233. this_peer_finished = False # used by synced_gpus only
  234. while True:
  235. if synced_gpus:
  236. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
  237. # The following logic allows an early break if all peers finished generating their sequence
  238. this_peer_finished_flag = torch.tensor(
  239. 0.0 if this_peer_finished else 1.0).to(input_ids.device)
  240. # send 0.0 if we finished, 1.0 otherwise
  241. dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
  242. # did all peers finish? the reduced sum will be 0.0 then
  243. if this_peer_finished_flag.item() == 0.0:
  244. break
  245. # prepare model inputs
  246. model_inputs = self.prepare_inputs_for_generation(
  247. input_ids, **model_kwargs)
  248. # forward pass to get next token
  249. outputs = self(
  250. **model_inputs,
  251. return_dict=True,
  252. output_attentions=output_attentions,
  253. output_hidden_states=output_hidden_states,
  254. )
  255. if synced_gpus and this_peer_finished:
  256. continue # don't waste resources running the code we don't need
  257. next_token_logits = outputs.logits[:, -1, :]
  258. # pre-process distribution
  259. next_tokens_scores = logits_processor(input_ids, next_token_logits)
  260. # Store scores, attentions and hidden_states when required
  261. if return_dict_in_generate:
  262. if output_scores:
  263. scores += (next_tokens_scores, )
  264. if output_attentions:
  265. decoder_attentions += ((outputs.decoder_attentions, ) if
  266. self.config.is_encoder_decoder else
  267. (outputs.attentions, ))
  268. if self.config.is_encoder_decoder:
  269. cross_attentions += (outputs.cross_attentions, )
  270. if output_hidden_states:
  271. decoder_hidden_states += ((outputs.decoder_hidden_states, )
  272. if self.config.is_encoder_decoder
  273. else (outputs.hidden_states, ))
  274. # argmax
  275. next_tokens = torch.argmax(next_tokens_scores, dim=-1)
  276. # finished sentences should have their next token be a padding token
  277. if eos_token_id is not None:
  278. if pad_token_id is None:
  279. raise ValueError(
  280. 'If `eos_token_id` is defined, make sure that `pad_token_id` is defined.'
  281. )
  282. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  283. 1 - unfinished_sequences)
  284. # update generated ids, model inputs, and length for next step
  285. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  286. # return Generator for stream
  287. if return_dict_in_generate:
  288. if self.config.is_encoder_decoder:
  289. yield GreedySearchEncoderDecoderOutput(
  290. sequences=input_ids,
  291. scores=scores,
  292. encoder_attentions=encoder_attentions,
  293. encoder_hidden_states=encoder_hidden_states,
  294. decoder_attentions=decoder_attentions,
  295. cross_attentions=cross_attentions,
  296. decoder_hidden_states=decoder_hidden_states,
  297. )
  298. else:
  299. yield GreedySearchDecoderOnlyOutput(
  300. sequences=input_ids,
  301. scores=scores,
  302. attentions=decoder_attentions,
  303. hidden_states=decoder_hidden_states,
  304. )
  305. else:
  306. yield input_ids
  307. model_kwargs = self._update_model_kwargs_for_generation(
  308. outputs,
  309. model_kwargs,
  310. is_encoder_decoder=self.config.is_encoder_decoder)
  311. # if eos_token was found in one sentence, set sentence to finished
  312. if eos_token_id_tensor is not None:
  313. unfinished_sequences = unfinished_sequences.mul(
  314. next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
  315. eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
  316. # stop when each sentence is finished
  317. if unfinished_sequences.max() == 0:
  318. this_peer_finished = True
  319. # stop if we exceed the maximum length
  320. if stopping_criteria(input_ids, scores):
  321. this_peer_finished = True
  322. if this_peer_finished and not synced_gpus:
  323. break
  324. @staticmethod
  325. def stream_sample(
  326. self,
  327. input_ids: torch.LongTensor,
  328. logits_processor: Optional[LogitsProcessorList] = None,
  329. stopping_criteria: Optional[StoppingCriteriaList] = None,
  330. logits_warper: Optional[LogitsProcessorList] = None,
  331. max_length: Optional[int] = None,
  332. pad_token_id: Optional[int] = None,
  333. eos_token_id: Optional[Union[int, List[int]]] = None,
  334. output_attentions: Optional[bool] = None,
  335. output_hidden_states: Optional[bool] = None,
  336. output_scores: Optional[bool] = None,
  337. return_dict_in_generate: Optional[bool] = None,
  338. synced_gpus: bool = False,
  339. **model_kwargs,
  340. ) -> Generator:
  341. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
  342. )
  343. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList(
  344. )
  345. if max_length is not None:
  346. warnings.warn(
  347. '`max_length` is deprecated in this function, use'
  348. ' `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.',
  349. UserWarning,
  350. )
  351. stopping_criteria = validate_stopping_criteria(
  352. stopping_criteria, max_length)
  353. logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList(
  354. )
  355. pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
  356. eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
  357. if isinstance(eos_token_id, int):
  358. eos_token_id = [eos_token_id]
  359. eos_token_id_tensor = torch.tensor(eos_token_id).to(
  360. input_ids.device) if eos_token_id is not None else None
  361. output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
  362. output_attentions = (
  363. output_attentions if output_attentions is not None else
  364. self.generation_config.output_attentions)
  365. output_hidden_states = (
  366. output_hidden_states if output_hidden_states is not None else
  367. self.generation_config.output_hidden_states)
  368. return_dict_in_generate = (
  369. return_dict_in_generate if return_dict_in_generate is not None else
  370. self.generation_config.return_dict_in_generate)
  371. # init attention / hidden states / scores tuples
  372. scores = () if (return_dict_in_generate and output_scores) else None
  373. decoder_attentions = () if (return_dict_in_generate
  374. and output_attentions) else None
  375. cross_attentions = () if (return_dict_in_generate
  376. and output_attentions) else None
  377. decoder_hidden_states = () if (return_dict_in_generate
  378. and output_hidden_states) else None
  379. # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
  380. if return_dict_in_generate and self.config.is_encoder_decoder:
  381. encoder_attentions = model_kwargs['encoder_outputs'].get(
  382. 'attentions') if output_attentions else None
  383. encoder_hidden_states = (
  384. model_kwargs['encoder_outputs'].get('hidden_states')
  385. if output_hidden_states else None)
  386. # keep track of which sequences are already finished
  387. unfinished_sequences = torch.ones(
  388. input_ids.shape[0], dtype=torch.long, device=input_ids.device)
  389. this_peer_finished = False # used by synced_gpus only
  390. # auto-regressive generation
  391. while True:
  392. if synced_gpus:
  393. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
  394. # The following logic allows an early break if all peers finished generating their sequence
  395. this_peer_finished_flag = torch.tensor(
  396. 0.0 if this_peer_finished else 1.0).to(input_ids.device)
  397. # send 0.0 if we finished, 1.0 otherwise
  398. dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
  399. # did all peers finish? the reduced sum will be 0.0 then
  400. if this_peer_finished_flag.item() == 0.0:
  401. break
  402. # prepare model inputs
  403. model_kwargs = self._get_initial_cache_position(
  404. input_ids, model_kwargs)
  405. model_inputs = self.prepare_inputs_for_generation(
  406. input_ids, **model_kwargs)
  407. # forward pass to get next token
  408. outputs = self(
  409. **model_inputs,
  410. return_dict=True,
  411. output_attentions=output_attentions,
  412. output_hidden_states=output_hidden_states,
  413. )
  414. if synced_gpus and this_peer_finished:
  415. continue # don't waste resources running the code we don't need
  416. next_token_logits = outputs.logits[:, -1, :]
  417. # pre-process distribution
  418. next_token_scores = logits_processor(input_ids, next_token_logits)
  419. next_token_scores = logits_warper(input_ids, next_token_scores)
  420. # Store scores, attentions and hidden_states when required
  421. if return_dict_in_generate:
  422. if output_scores:
  423. scores += (next_token_scores, )
  424. if output_attentions:
  425. decoder_attentions += ((outputs.decoder_attentions, ) if
  426. self.config.is_encoder_decoder else
  427. (outputs.attentions, ))
  428. if self.config.is_encoder_decoder:
  429. cross_attentions += (outputs.cross_attentions, )
  430. if output_hidden_states:
  431. decoder_hidden_states += ((outputs.decoder_hidden_states, )
  432. if self.config.is_encoder_decoder
  433. else (outputs.hidden_states, ))
  434. # sample
  435. probs = nn.functional.softmax(next_token_scores, dim=-1)
  436. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  437. # finished sentences should have their next token be a padding token
  438. if eos_token_id is not None:
  439. if pad_token_id is None:
  440. raise ValueError(
  441. 'If `eos_token_id` is defined, make sure that `pad_token_id` is defined.'
  442. )
  443. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  444. 1 - unfinished_sequences)
  445. # update generated ids, model inputs, and length for next step
  446. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  447. # return Generator for stream
  448. if return_dict_in_generate:
  449. if self.config.is_encoder_decoder:
  450. yield SampleEncoderDecoderOutput(
  451. sequences=input_ids,
  452. scores=scores,
  453. encoder_attentions=encoder_attentions,
  454. encoder_hidden_states=encoder_hidden_states,
  455. decoder_attentions=decoder_attentions,
  456. cross_attentions=cross_attentions,
  457. decoder_hidden_states=decoder_hidden_states,
  458. )
  459. else:
  460. yield SampleDecoderOnlyOutput(
  461. sequences=input_ids,
  462. scores=scores,
  463. attentions=decoder_attentions,
  464. hidden_states=decoder_hidden_states,
  465. )
  466. else:
  467. yield input_ids
  468. model_kwargs = self._update_model_kwargs_for_generation(
  469. outputs,
  470. model_kwargs,
  471. is_encoder_decoder=self.config.is_encoder_decoder)
  472. # if eos_token was found in one sentence, set sentence to finished
  473. if eos_token_id_tensor is not None:
  474. unfinished_sequences = unfinished_sequences.mul(
  475. next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
  476. eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
  477. # stop when each sentence is finished
  478. if unfinished_sequences.max() == 0:
  479. this_peer_finished = True
  480. # stop if we exceed the maximum length
  481. if stopping_criteria(input_ids, scores):
  482. this_peer_finished = True
  483. if this_peer_finished and not synced_gpus:
  484. break
  485. def add_stream_generate(model: PreTrainedModel):
  486. pretrained_class = type(model)
  487. parent_classes = (pretrained_class, PretrainedModelStreamingOutputMixin)
  488. new_model = type(pretrained_class.__name__, parent_classes, {})(
  489. model.config)
  490. new_model.__dict__.update(model.__dict__)
  491. return new_model