input_output.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import ast
  3. import base64
  4. import importlib
  5. import inspect
  6. import os
  7. from io import BytesIO
  8. from typing import Any
  9. from urllib.parse import urlparse
  10. import json
  11. import numpy as np
  12. from modelscope.hub.file_download import model_file_download
  13. from modelscope.outputs.outputs import (TASK_OUTPUTS, OutputKeys, OutputTypes,
  14. OutputTypeSchema)
  15. from modelscope.pipeline_inputs import (INPUT_TYPE, INPUT_TYPE_SCHEMA,
  16. TASK_INPUTS, InputType)
  17. from modelscope.pipelines import pipeline
  18. from modelscope.pipelines.base import Pipeline
  19. from modelscope.utils.config import Config
  20. from modelscope.utils.constant import ModelFile, Tasks
  21. from modelscope.utils.logger import get_logger
  22. logger = get_logger()
  23. """Support webservice integration pipeline。
  24. This module provides a support library when webservice uses pipeline,
  25. converts webservice input into pipeline input, and converts pipeline
  26. output into webservice output, which automatically encodes and
  27. decodes relevant fields.
  28. Example:
  29. # create pipeline instance and pipeline information, save it to app
  30. pipeline_instance = create_pipeline('damo/cv_gpen_image-portrait-enhancement', 'v1.0.0')
  31. # get pipeline information, input,output, request example.
  32. pipeline_info = get_pipeline_information_by_pipeline(pipeline_instance)
  33. # save the pipeline and info to the app for use in subsequent request processing
  34. app.state.pipeline = pipeline_instance
  35. app.state.pipeline_info = pipeline_info
  36. # for inference request, use call_pipeline_with_json to decode input and
  37. # call pipeline, call pipeline_output_to_service_base64_output
  38. # to encode necessary fields, and return the result.
  39. # request and response are json format.
  40. @router.post('/call')
  41. async def inference(request: Request):
  42. pipeline_service = request.app.state.pipeline
  43. pipeline_info = request.app.state.pipeline_info
  44. request_json = await request.json()
  45. result = call_pipeline_with_json(pipeline_info,
  46. pipeline_service,
  47. request_json)
  48. # convert output to json, if binary field, we need encoded.
  49. output = pipeline_output_to_service_base64_output(pipeline_info.task_name, result)
  50. return output
  51. # Inference service input and output and sample information can be obtained through the docs interface
  52. @router.get('/describe')
  53. async def index(request: Request):
  54. pipeline_info = request.app.state.pipeline_info
  55. return pipeline_info.schema
  56. Todo:
  57. * Support more service input type, such as form.
  58. """
  59. def create_pipeline(model_id: str,
  60. revision: str,
  61. external_engine_for_llm: bool = True):
  62. model_configuration_file = model_file_download(
  63. model_id=model_id,
  64. file_path=ModelFile.CONFIGURATION,
  65. revision=revision)
  66. cfg = Config.from_file(model_configuration_file)
  67. return pipeline(
  68. task=cfg.task,
  69. model=model_id,
  70. model_revision=revision,
  71. external_engine_for_llm=external_engine_for_llm)
  72. def get_class_user_attributes(cls):
  73. attributes = inspect.getmembers(cls, lambda a: not (inspect.isroutine(a)))
  74. user_attributes = [
  75. a for a in attributes
  76. if (not (a[0].startswith('__') and a[0].endswith('__')))
  77. ]
  78. return user_attributes
  79. def get_input_type(task_inputs: Any):
  80. """Get task input schema.
  81. Args:
  82. task_name (str): The task name.
  83. """
  84. if isinstance(task_inputs, str): # no input key
  85. input_type = INPUT_TYPE[task_inputs]
  86. return input_type
  87. elif isinstance(task_inputs, tuple) or isinstance(task_inputs, list):
  88. for item in task_inputs:
  89. if isinstance(item,
  90. dict): # for list, server only support dict format.
  91. return get_input_type(item)
  92. else:
  93. continue
  94. elif isinstance(task_inputs, dict):
  95. input_info = {} # key input key, value input type
  96. for k, v in task_inputs.items():
  97. input_info[k] = get_input_type(v)
  98. return input_info
  99. else:
  100. raise ValueError(f'invalid input_type definition {task_inputs}')
  101. def get_input_schema(task_name: str, input_type: type):
  102. """Get task input schema.
  103. Args:
  104. task_name (str): The task name.
  105. input_type (type): The input type
  106. """
  107. if input_type is None:
  108. task_inputs = TASK_INPUTS[task_name]
  109. if isinstance(task_inputs,
  110. str): # only one input field, key is task_inputs
  111. return {
  112. 'type': 'object',
  113. 'properties': {
  114. task_inputs: INPUT_TYPE_SCHEMA[task_inputs]
  115. }
  116. }
  117. else:
  118. task_inputs = input_type
  119. if isinstance(task_inputs, str): # no input key
  120. return INPUT_TYPE_SCHEMA[task_inputs]
  121. elif input_type is None and isinstance(task_inputs, list):
  122. for item in task_inputs:
  123. # for list, server only support dict format.
  124. if isinstance(item, dict):
  125. return get_input_schema(None, item)
  126. elif isinstance(task_inputs, tuple) or isinstance(task_inputs, list):
  127. input_schema = {'type': 'array', 'items': {}}
  128. for item in task_inputs:
  129. if isinstance(item, dict):
  130. item_schema = get_input_schema(None, item)
  131. input_schema['items']['type'] = item_schema
  132. return input_schema
  133. else:
  134. input_schema['items'] = INPUT_TYPE_SCHEMA[item]
  135. return input_schema
  136. elif isinstance(task_inputs, dict):
  137. input_schema = {
  138. 'type': 'object',
  139. 'properties': {}
  140. } # key input key, value input type
  141. for k, v in task_inputs.items():
  142. input_schema['properties'][k] = get_input_schema(None, v)
  143. return input_schema
  144. else:
  145. raise ValueError(f'invalid input_type definition {task_inputs}')
  146. def get_output_schema(task_name: str):
  147. """Get task output schema.
  148. Args:
  149. task_name (str): The task name.
  150. """
  151. task_outputs = TASK_OUTPUTS[task_name]
  152. output_schema = {'type': 'object', 'properties': {}}
  153. if not isinstance(task_outputs, list):
  154. raise ValueError('TASK_OUTPUTS for %s is not list.' % task_name)
  155. else:
  156. for output_key in task_outputs:
  157. output_schema['properties'][output_key] = OutputTypeSchema[
  158. output_key]
  159. return output_schema
  160. def get_input_info(task_name: str):
  161. task_inputs = TASK_INPUTS[task_name]
  162. if isinstance(task_inputs, str): # no input key default input key input
  163. input_type = INPUT_TYPE[task_inputs]
  164. return input_type
  165. elif isinstance(task_inputs, tuple):
  166. return task_inputs
  167. elif isinstance(task_inputs, list):
  168. for item in task_inputs:
  169. if isinstance(item,
  170. dict): # for list, server only support dict format.
  171. return {'input': get_input_type(item)}
  172. else:
  173. continue
  174. elif isinstance(task_inputs, dict):
  175. input_info = {} # key input key, value input type
  176. for k, v in task_inputs.items():
  177. input_info[k] = get_input_type(v)
  178. return {'input': input_info}
  179. else:
  180. raise ValueError(f'invalid input_type definition {task_inputs}')
  181. def get_output_info(task_name: str):
  182. output_keys = TASK_OUTPUTS[task_name]
  183. output_type = {}
  184. if not isinstance(output_keys, list):
  185. raise ValueError('TASK_OUTPUTS for %s is not list.' % task_name)
  186. else:
  187. for output_key in output_keys:
  188. output_type[output_key] = OutputTypes[output_key]
  189. return output_type
  190. def get_task_io_info(task_name: str):
  191. """Get task input output schema.
  192. Args:
  193. task_name (str): The task name.
  194. """
  195. tasks = get_class_user_attributes(Tasks)
  196. task_exist = False
  197. for key, value in tasks:
  198. if key == task_name or value == task_name:
  199. task_exist = True
  200. break
  201. if not task_exist:
  202. return None, None
  203. task_inputs = get_input_info(task_name)
  204. task_outputs = get_output_info(task_name)
  205. return task_inputs, task_outputs
  206. def process_arg_type_annotation(arg, default_value):
  207. if arg.annotation is not None:
  208. if isinstance(arg.annotation, ast.Subscript):
  209. return arg.arg, arg.annotation.value.id
  210. elif isinstance(arg.annotation, ast.Name):
  211. return arg.arg, arg.annotation.id
  212. elif isinstance(arg.annotation, ast.Attribute):
  213. return arg.arg, arg.annotation.attr
  214. else:
  215. raise Exception('Invalid annotation: %s' % arg.annotation)
  216. else:
  217. if default_value is not None:
  218. return arg.arg, type(default_value).__name__
  219. # Irregular, assuming no type hint no default value type is object
  220. logger.warning('arg: %s has no data type annotation, use default!' %
  221. (arg.arg))
  222. return arg.arg, 'object'
  223. def convert_to_value(item):
  224. if isinstance(item, ast.Str):
  225. return item.s
  226. elif hasattr(ast, 'Bytes') and isinstance(item, ast.Bytes):
  227. return item.s
  228. elif isinstance(item, ast.Tuple):
  229. return tuple(convert_to_value(i) for i in item.elts)
  230. elif isinstance(item, ast.Num):
  231. return item.n
  232. elif isinstance(item, ast.Name):
  233. result = VariableKey(item=item)
  234. constants_lookup = {
  235. 'True': True,
  236. 'False': False,
  237. 'None': None,
  238. }
  239. return constants_lookup.get(
  240. result.name,
  241. result,
  242. )
  243. elif isinstance(item, ast.NameConstant):
  244. # None, True, False are nameconstants in python3, but names in 2
  245. return item.value
  246. else:
  247. return UnhandledKeyType()
  248. def process_args(args):
  249. arguments = []
  250. # name, type, has_default, default
  251. n_args = len(args.args)
  252. n_args_default = len(args.defaults)
  253. # no default
  254. for arg in args.args[0:n_args - n_args_default]:
  255. if arg.arg == 'self':
  256. continue
  257. else:
  258. arg_name, arg_type = process_arg_type_annotation(arg, None)
  259. arguments.append((arg_name, arg_type, False, None))
  260. # process defaults arg.
  261. for arg, dft in zip(args.args[n_args - n_args_default:], args.defaults):
  262. # compatible with python3.7 ast.Num no value.
  263. value = convert_to_value(dft)
  264. arg_name, arg_type = process_arg_type_annotation(arg, value)
  265. arguments.append((arg_name, arg_type, True, value))
  266. # kwargs
  267. n_kwargs = len(args.kwonlyargs)
  268. n_kwargs_default = len(args.kw_defaults)
  269. for kwarg in args.kwonlyargs[0:n_kwargs - n_kwargs_default]:
  270. arg_name, arg_type = process_arg_type_annotation(kwarg)
  271. arguments.append((arg_name, arg_type, False, None))
  272. for kwarg, dft in zip(args.kwonlyargs[n_kwargs - n_kwargs_default:],
  273. args.kw_defaults):
  274. arg_name, arg_type = process_arg_type_annotation(kwarg)
  275. arguments.append((arg_name, arg_type, True, dft.value))
  276. return arguments
  277. class PipelineClassAnalyzer(ast.NodeVisitor):
  278. """Analysis pipeline class define get inputs and parameters.
  279. """
  280. def __init__(self) -> None:
  281. super().__init__()
  282. self.parameters = []
  283. self.has_call = False
  284. self.preprocess_parameters = []
  285. self.has_preprocess = False
  286. self.has_postprocess = False
  287. self.has_forward = False
  288. self.forward_parameters = []
  289. self.postprocess_parameters = []
  290. self.lineno = 0
  291. self.end_lineno = 0
  292. def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
  293. if node.name == '__call__':
  294. self.parameters = process_args(node.args)
  295. self.has_call = True
  296. if node.name == 'preprocess':
  297. self.preprocess_parameters = process_args(node.args)
  298. self.has_preprocess = True
  299. elif node.name == 'postprocess':
  300. self.postprocess_parameters = process_args(node.args)
  301. self.has_postprocess = True
  302. elif node.name == 'forward':
  303. self.forward_parameters = process_args(node.args)
  304. self.has_forward = True
  305. def get_input_parameters(self):
  306. if self.has_call:
  307. # custom define __call__ inputs and parameter are control by the
  308. # custom __call__, all parameter is input.
  309. return self.parameters, None
  310. parameters = []
  311. if self.has_preprocess:
  312. parameters.extend(self.preprocess_parameters[1:])
  313. if self.has_forward:
  314. parameters.extend(self.forward_parameters[1:])
  315. if self.has_postprocess:
  316. parameters.extend(self.postprocess_parameters[1:])
  317. if len(parameters) > 0:
  318. return None, parameters
  319. else:
  320. return None, []
  321. class AnalysisSourceFileRegisterModules(ast.NodeVisitor):
  322. """Get register_module call of the python source file.
  323. Args:
  324. ast (NodeVisitor): The ast node.
  325. Examples:
  326. >>> with open(source_file_path, "rb") as f:
  327. >>> src = f.read()
  328. >>> analyzer = AnalysisSourceFileRegisterModules(source_file_path)
  329. >>> analyzer.visit(ast.parse(src, filename=source_file_path))
  330. """
  331. def __init__(self, source_file_path, class_name) -> None:
  332. super().__init__()
  333. self.source_file_path = source_file_path
  334. self.class_name = class_name
  335. self.class_define = None
  336. def visit_ClassDef(self, node: ast.ClassDef):
  337. if node.name == self.class_name:
  338. self.class_define = node
  339. def get_pipeline_input_parameters(
  340. source_file_path: str,
  341. class_name: str,
  342. ):
  343. """Get pipeline input and parameter
  344. Args:
  345. source_file_path (str): The pipeline source code path
  346. class_name (str): The pipeline class name
  347. """
  348. with open(source_file_path, 'rb') as f:
  349. src = f.read()
  350. analyzer = AnalysisSourceFileRegisterModules(source_file_path,
  351. class_name)
  352. analyzer.visit(
  353. ast.parse(
  354. src,
  355. filename=source_file_path,
  356. # python3.7 no type_comments parameter .
  357. # type_comments=True
  358. ))
  359. clz = PipelineClassAnalyzer()
  360. clz.visit(analyzer.class_define)
  361. input, pipeline_parameters = clz.get_input_parameters()
  362. # remove the first input parameter, the input is defined by task.
  363. return input, pipeline_parameters
  364. meta_type_schema_map = {
  365. # For parameters, current only support types.
  366. 'str': 'string',
  367. 'int': 'integer',
  368. 'float': 'number',
  369. 'bool': 'boolean',
  370. 'Dict': 'object',
  371. 'dict': 'object',
  372. 'list': 'array',
  373. 'List': 'array',
  374. 'Union': 'object',
  375. 'Input': 'object',
  376. 'object': 'object',
  377. }
  378. def generate_pipeline_parameters_schema(parameters):
  379. parameters_schema = {'type': 'object', 'properties': {}}
  380. if parameters is None or len(parameters) == 0:
  381. return {}
  382. for param in parameters:
  383. name, param_type, has_default, default_value = param
  384. # 'max_length': ('int', True, 1024)
  385. prop = {'type': meta_type_schema_map[param_type]}
  386. if has_default:
  387. prop['default'] = default_value
  388. parameters_schema['properties'][name] = prop
  389. return parameters_schema
  390. def get_pipeline_information_by_pipeline(pipeline: Pipeline, ):
  391. """Get pipeline input output schema.
  392. Args:
  393. pipeline (Pipeline): The pipeline object.
  394. """
  395. task_name = pipeline.group_key
  396. pipeline_class = pipeline.__class__.__name__
  397. spec = importlib.util.find_spec(pipeline.__module__)
  398. pipeline_file_path = spec.origin
  399. info = PipelineInfomation(task_name, pipeline_class, pipeline_file_path)
  400. return info
  401. class PipelineInfomation():
  402. """Analyze pipeline information, task_name, schema.
  403. """
  404. def __init__(self, task_name: str, class_name, source_path):
  405. self._task_name = task_name
  406. self._class_name = class_name
  407. self._source_path = source_path
  408. self._is_custom_call_method = False
  409. self._analyze()
  410. def _analyze(self):
  411. input, parameters = get_pipeline_input_parameters(
  412. self._source_path, self._class_name)
  413. # use base pipeline __call__ if inputs and outputs are defined in modelscope lib
  414. if self._task_name in TASK_INPUTS and self._task_name in TASK_OUTPUTS:
  415. # delete the first default input which is defined by task.
  416. if parameters is None:
  417. self._parameters_schema = {}
  418. else:
  419. self._parameters_schema = generate_pipeline_parameters_schema(
  420. parameters)
  421. self._input_schema = get_input_schema(self._task_name, None)
  422. self._output_schema = get_output_schema(self._task_name)
  423. elif input is not None: # custom pipeline implemented it's own __call__ method
  424. self._is_custom_call_method = True
  425. self._input_schema = generate_pipeline_parameters_schema(input)
  426. self._input_schema[
  427. 'description'] = 'For binary input such as image audio video, only url is supported.'
  428. self._parameters_schema = {}
  429. self._output_schema = {
  430. 'type': 'object',
  431. }
  432. if self._task_name in TASK_OUTPUTS:
  433. self._output_schema = get_output_schema(self._task_name)
  434. else:
  435. logger.warning(
  436. 'Task: %s input is defined: %s, output is defined: %s which is not completed'
  437. % (self._task_name, self._task_name
  438. in TASK_INPUTS, self._task_name in TASK_OUTPUTS))
  439. self._input_schema = None
  440. self._output_schema = None
  441. if self._task_name in TASK_INPUTS:
  442. self._input_schema = get_input_schema(self._task_name, None)
  443. if self._task_name in TASK_OUTPUTS:
  444. self._output_schema = get_output_schema(self._task_name)
  445. self._parameters_schema = generate_pipeline_parameters_schema(
  446. parameters)
  447. @property
  448. def task_name(self):
  449. return self._task_name
  450. @property
  451. def is_custom_call(self):
  452. return self._is_custom_call_method
  453. @property
  454. def input_schema(self):
  455. return self._input_schema
  456. @property
  457. def output_schema(self):
  458. return self._output_schema
  459. @property
  460. def parameters_schema(self):
  461. return self._parameters_schema
  462. @property
  463. def schema(self):
  464. return {
  465. 'input': self._input_schema if self._input_schema else
  466. self._parameters_schema, # all parameter is input
  467. 'parameters':
  468. self._parameters_schema if self._input_schema else {},
  469. 'output': self._output_schema if self._output_schema else {
  470. 'type': 'object',
  471. },
  472. }
  473. def __getitem__(self, key):
  474. return self.__dict__.get('_%s' % key)
  475. def is_url(url: str):
  476. """Check the input url is valid url.
  477. Args:
  478. url (str): The url
  479. Returns:
  480. bool: If is url return True, otherwise False.
  481. """
  482. url_parsed = urlparse(url)
  483. if url_parsed.scheme in ('http', 'https', 'oss'):
  484. return True
  485. else:
  486. return False
  487. def decode_base64_to_image(content):
  488. if content.startswith('http') or content.startswith(
  489. 'oss') or os.path.exists(content):
  490. return content
  491. from PIL import Image
  492. image_file_content = base64.b64decode(content, '-_')
  493. return Image.open(BytesIO(image_file_content))
  494. def decode_base64_to_audio(content):
  495. if content.startswith('http') or content.startswith(
  496. 'oss') or os.path.exists(content):
  497. return content
  498. file_content = base64.b64decode(content)
  499. return file_content
  500. def decode_base64_to_video(content):
  501. if content.startswith('http') or content.startswith(
  502. 'oss') or os.path.exists(content):
  503. return content
  504. file_content = base64.b64decode(content)
  505. return file_content
  506. def return_origin(content):
  507. return content
  508. def decode_box(content):
  509. pass
  510. def service_multipart_input_to_pipeline_input(body):
  511. """Convert multipart data to pipeline input.
  512. Args:
  513. body (dict): The multipart data body
  514. """
  515. pass
  516. def pipeline_output_to_service_multipart_output(output):
  517. """Convert multipart data to service multipart output.
  518. Args:
  519. output (dict): Multipart body.
  520. """
  521. pass
  522. base64_decoder_map = {
  523. InputType.IMAGE: decode_base64_to_image,
  524. InputType.TEXT: return_origin,
  525. InputType.AUDIO: decode_base64_to_audio,
  526. InputType.VIDEO: decode_base64_to_video,
  527. InputType.BOX: decode_box,
  528. InputType.DICT: return_origin,
  529. InputType.LIST: return_origin,
  530. InputType.NUMBER: return_origin,
  531. }
  532. def call_pipeline_with_json(pipeline_info: PipelineInfomation,
  533. pipeline: Pipeline, body: str):
  534. """Call pipeline with json input.
  535. Args:
  536. pipeline_info (PipelineInfomation): The pipeline information object.
  537. pipeline (Pipeline): The pipeline object.
  538. body (Dict): The input object, include input and parameters
  539. """
  540. # TODO: is_custom_call misjudgment
  541. # if pipeline_info.is_custom_call:
  542. # pipeline_inputs = body['input']
  543. # result = pipeline(**pipeline_inputs)
  544. # else:
  545. pipeline_inputs, parameters = service_base64_input_to_pipeline_input(
  546. pipeline_info['task_name'], body)
  547. result = pipeline(pipeline_inputs, **parameters)
  548. return result
  549. def service_base64_input_to_pipeline_input(task_name, body):
  550. """Convert service base64 input to pipeline input and parameters
  551. Args:
  552. task_name (str): The task name.
  553. body (Dict): The input object, include input and parameters
  554. """
  555. if 'input' not in body:
  556. raise ValueError('No input data!')
  557. service_input = body['input']
  558. if 'parameters' in body:
  559. parameters = body['parameters']
  560. else:
  561. parameters = {}
  562. pipeline_input = {}
  563. if isinstance(service_input, (str, int, float)):
  564. return service_input, parameters
  565. task_input_info = TASK_INPUTS.get(task_name, None)
  566. if isinstance(task_input_info, str): # no input key default
  567. if isinstance(service_input, dict):
  568. return base64_decoder_map[task_input_info](list(
  569. service_input.values())[0]), parameters
  570. else:
  571. return base64_decoder_map[task_input_info](
  572. service_input), parameters
  573. elif isinstance(task_input_info, tuple):
  574. pipeline_input = tuple(service_input)
  575. return pipeline_input, parameters
  576. elif isinstance(task_input_info, dict):
  577. for key, value in service_input.items(
  578. ): # task input has no nesting field.
  579. # get input filed type
  580. input_type = task_input_info[key]
  581. # TODO recursion for list, dict if need.
  582. if not isinstance(input_type, str):
  583. pipeline_input[key] = value
  584. continue
  585. if input_type not in INPUT_TYPE:
  586. raise ValueError('Invalid input field: %s' % input_type)
  587. pipeline_input[key] = base64_decoder_map[input_type](value)
  588. return pipeline_input, parameters
  589. elif isinstance(task_input_info,
  590. list): # one of input format, we use dict.
  591. for item in task_input_info:
  592. if isinstance(item, dict):
  593. for key, value in service_input.items(
  594. ): # task input has no nesting field.
  595. # get input filed type
  596. input_type = item[key]
  597. if input_type not in INPUT_TYPE:
  598. raise ValueError('Invalid input field: %s'
  599. % input_type)
  600. pipeline_input[key] = base64_decoder_map[input_type](value)
  601. return pipeline_input, parameters
  602. else:
  603. return service_input, parameters
  604. def encode_numpy_image_to_base64(image):
  605. import cv2
  606. _, img_encode = cv2.imencode('.png', image)
  607. bytes_data = img_encode.tobytes()
  608. base64_str = str(base64.b64encode(bytes_data), 'utf-8')
  609. return base64_str
  610. def encode_video_to_base64(video):
  611. return str(base64.b64encode(video), 'utf-8')
  612. def encode_pcm_to_base64(pcm):
  613. return str(base64.b64encode(pcm), 'utf-8')
  614. def encode_wav_to_base64(wav):
  615. return str(base64.b64encode(wav), 'utf-8')
  616. def encode_bytes_to_base64(bts):
  617. return str(base64.b64encode(bts), 'utf-8')
  618. base64_encoder_map = {
  619. 'image': encode_numpy_image_to_base64,
  620. 'video': encode_video_to_base64,
  621. 'pcm': encode_pcm_to_base64,
  622. 'wav': encode_wav_to_base64,
  623. 'bytes': encode_bytes_to_base64,
  624. }
  625. # convert numpy etc type to python type.
  626. type_to_python_type = {
  627. np.int64: int,
  628. }
  629. def _convert_to_python_type(inputs):
  630. if isinstance(inputs, (list, tuple)):
  631. res = []
  632. for item in inputs:
  633. res.append(_convert_to_python_type(item))
  634. return res
  635. elif isinstance(inputs, dict):
  636. res = {}
  637. for k, v in inputs.items():
  638. if type(v) in type_to_python_type:
  639. res[k] = type_to_python_type[type(v)](v)
  640. else:
  641. res[k] = _convert_to_python_type(v)
  642. return res
  643. elif isinstance(inputs, np.ndarray):
  644. return inputs.tolist()
  645. elif isinstance(inputs, np.floating):
  646. return float(inputs)
  647. elif isinstance(inputs, np.integer):
  648. return int(inputs)
  649. else:
  650. return inputs
  651. def pipeline_output_to_service_base64_output(task_name, pipeline_output):
  652. """Convert pipeline output to service output,
  653. convert binary fields to base64 encoding。
  654. Args:
  655. task_name (str): The output task name.
  656. pipeline_output (object): The pipeline output.
  657. """
  658. json_serializable_output = {}
  659. task_outputs = TASK_OUTPUTS.get(task_name, [])
  660. # TODO: for batch
  661. if isinstance(pipeline_output, list):
  662. pipeline_output = pipeline_output[0]
  663. for key, value in pipeline_output.items():
  664. if key not in task_outputs:
  665. import torch
  666. if isinstance(value, torch.Tensor):
  667. v = np.array(value.cpu()).tolist()
  668. else:
  669. v = value
  670. json_serializable_output[key] = v
  671. continue # skip the output not defined.
  672. if key in [
  673. OutputKeys.OUTPUT_IMG, OutputKeys.OUTPUT_IMGS,
  674. OutputKeys.OUTPUT_VIDEO, OutputKeys.OUTPUT_PCM,
  675. OutputKeys.OUTPUT_PCM_LIST, OutputKeys.OUTPUT_WAV
  676. ]:
  677. if isinstance(value, list):
  678. items = []
  679. if key == OutputKeys.OUTPUT_IMGS:
  680. output_item_type = OutputKeys.OUTPUT_IMG
  681. else:
  682. output_item_type = OutputKeys.OUTPUT_PCM
  683. for item in value:
  684. items.append(base64_encoder_map[
  685. OutputTypes[output_item_type]](item))
  686. json_serializable_output[key] = items
  687. else:
  688. json_serializable_output[key] = base64_encoder_map[
  689. OutputTypes[key]](
  690. value)
  691. elif OutputTypes[key] in [np.ndarray] and isinstance(
  692. value, np.ndarray):
  693. json_serializable_output[key] = value.tolist()
  694. elif isinstance(value, np.ndarray):
  695. json_serializable_output[key] = value.tolist()
  696. else:
  697. json_serializable_output[key] = value
  698. return _convert_to_python_type(json_serializable_output)
  699. def get_task_input_examples(task):
  700. current_work_dir = os.path.dirname(__file__)
  701. with open(current_work_dir + '/pipeline_inputs.json', 'r') as f:
  702. input_examples = json.load(f)
  703. if task in input_examples:
  704. return input_examples[task]
  705. return None
  706. def get_task_schemas(task):
  707. current_work_dir = os.path.dirname(__file__)
  708. with open(current_work_dir + '/pipeline_schema.json', 'r') as f:
  709. schema = json.load(f)
  710. if task in schema:
  711. return schema[task]
  712. return None
  713. if __name__ == '__main__':
  714. from modelscope.utils.ast_utils import load_index
  715. index = load_index()
  716. task_schemas = {}
  717. for key, value in index['index'].items():
  718. reg, task_name, class_name = key
  719. if reg == 'PIPELINES' and task_name != 'default':
  720. print(
  721. f"value['filepath']: {value['filepath']}, class_name: {class_name}"
  722. )
  723. input, parameters = get_pipeline_input_parameters(
  724. value['filepath'], class_name)
  725. try:
  726. if task_name in TASK_INPUTS and task_name in TASK_OUTPUTS:
  727. # delete the first default input which is defined by task.
  728. # parameters.pop(0)
  729. parameters_schema = generate_pipeline_parameters_schema(
  730. parameters)
  731. input_schema = get_input_schema(task_name, None)
  732. output_schema = get_output_schema(task_name)
  733. schema = {
  734. 'input': input_schema,
  735. 'parameters': parameters_schema,
  736. 'output': output_schema
  737. }
  738. else:
  739. logger.warning(
  740. 'Task: %s input is defined: %s, output is defined: %s which is not completed'
  741. % (task_name, task_name in TASK_INPUTS, task_name
  742. in TASK_OUTPUTS))
  743. input_schema = None
  744. output_schema = None
  745. if task_name in TASK_INPUTS:
  746. input_schema = get_input_schema(task_name, None)
  747. if task_name in TASK_OUTPUTS:
  748. output_schema = get_output_schema(task_name)
  749. parameters_schema = generate_pipeline_parameters_schema(
  750. parameters)
  751. schema = {
  752. 'input': input_schema if input_schema else
  753. parameters_schema, # all parameter is input
  754. 'parameters':
  755. parameters_schema if input_schema else {},
  756. 'output': output_schema if output_schema else {
  757. 'type': 'object',
  758. },
  759. }
  760. except BaseException:
  761. continue
  762. task_schemas[task_name] = schema
  763. s = json.dumps(task_schemas)
  764. with open('./task_schema.json', 'w') as f:
  765. f.write(s)