api_server.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import argparse
  2. def add_server_args(parser: argparse.ArgumentParser):
  3. parser.add_argument(
  4. '--model_id', required=True, type=str, help='The target model id')
  5. parser.add_argument(
  6. '--revision', required=True, type=str, help='Model revision')
  7. parser.add_argument('--host', default='0.0.0.0', help='Host to listen')
  8. parser.add_argument('--port', type=int, default=8000, help='Server port')
  9. parser.add_argument('--debug', default='debug', help='Set debug level.')
  10. parser.add_argument(
  11. '--external_engine_for_llm',
  12. type=bool,
  13. default=True,
  14. help='Use LLMPipeline first for llm models.')
  15. def run_server(args):
  16. try:
  17. import uvicorn
  18. app = get_app(args)
  19. uvicorn.run(app, host=args.host, port=args.port)
  20. except ModuleNotFoundError as e:
  21. print(e)
  22. print(
  23. 'To execute the server command, first '
  24. 'install the domain dependencies with: '
  25. 'pip install modelscope[DOMAIN] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html '
  26. 'the "DOMAIN" include [cv|nlp|audio|multi-modal|science] '
  27. 'and then install server dependencies with: pip install modelscope[server]'
  28. )
  29. def get_app(args):
  30. from fastapi import FastAPI
  31. from modelscope.server.api.routers.router import api_router
  32. from modelscope.server.core.event_handlers import (start_app_handler,
  33. stop_app_handler)
  34. app = FastAPI(
  35. title='modelscope_server',
  36. version='0.1',
  37. debug=True,
  38. swagger_ui_parameters={'tryItOutEnabled': True})
  39. app.state.args = args
  40. app.include_router(api_router)
  41. app.add_event_handler('startup', start_app_handler(app))
  42. app.add_event_handler('shutdown', stop_app_handler(app))
  43. return app
  44. if __name__ == '__main__':
  45. import uvicorn
  46. parser = argparse.ArgumentParser('modelscope_server')
  47. add_server_args(parser)
  48. args = parser.parse_args()
  49. app = get_app(args)
  50. uvicorn.run(app, host=args.host, port=args.port)