cli.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import time
  15. from .logging import logger
  16. def str2bool(v, /):
  17. return v.lower() in ("true", "yes", "t", "y", "1")
  18. def get_subcommand_args(args):
  19. args = vars(args).copy()
  20. args.pop("subcommand")
  21. args.pop("executor")
  22. return args
  23. def add_simple_inference_args(subparser, *, input_help=None):
  24. if input_help is None:
  25. input_help = "Input path or URL."
  26. subparser.add_argument(
  27. "-i",
  28. "--input",
  29. type=str,
  30. required=True,
  31. help=input_help,
  32. )
  33. subparser.add_argument(
  34. "--save_path",
  35. type=str,
  36. help="Path to the output directory.",
  37. )
  38. def perform_simple_inference(wrapper_cls, params, predict_param_names=None):
  39. params = params.copy()
  40. input_ = params.pop("input")
  41. save_path = params.pop("save_path")
  42. if predict_param_names is not None:
  43. predict_params = {}
  44. for name in predict_param_names:
  45. predict_params[name] = params.pop(name)
  46. else:
  47. predict_params = {}
  48. init_params = params
  49. wrapper = wrapper_cls(**init_params)
  50. try:
  51. result = wrapper.predict_iter(input_, **predict_params)
  52. t1 = time.time()
  53. for i, res in enumerate(result):
  54. logger.info(f"Processed item {i} in {(time.time()-t1) * 1000} ms")
  55. t1 = time.time()
  56. res.print()
  57. if save_path:
  58. res.save_all(save_path)
  59. finally:
  60. wrapper.close()