base.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. from paddlex import create_predictor
  16. from paddlex.utils.deps import DependencyError
  17. from .._abstract import CLISubcommandExecutor
  18. from .._common_args import (
  19. add_common_cli_opts,
  20. parse_common_args,
  21. prepare_common_init_args,
  22. )
  23. _DEFAULT_ENABLE_HPI = False
  24. class PaddleXPredictorWrapper(metaclass=abc.ABCMeta):
  25. def __init__(
  26. self,
  27. *,
  28. model_name=None,
  29. model_dir=None,
  30. **common_args,
  31. ):
  32. super().__init__()
  33. self._model_name = (
  34. model_name if model_name is not None else self.default_model_name
  35. )
  36. self._model_dir = model_dir
  37. self._common_args = parse_common_args(
  38. common_args, default_enable_hpi=_DEFAULT_ENABLE_HPI
  39. )
  40. self.paddlex_predictor = self._create_paddlex_predictor()
  41. @property
  42. @abc.abstractmethod
  43. def default_model_name(self):
  44. raise NotImplementedError
  45. def predict_iter(self, *args, **kwargs):
  46. return self.paddlex_predictor.predict(*args, **kwargs)
  47. def predict(self, *args, **kwargs):
  48. result = list(self.predict_iter(*args, **kwargs))
  49. return result
  50. def close(self):
  51. self.paddlex_predictor.close()
  52. @classmethod
  53. @abc.abstractmethod
  54. def get_cli_subcommand_executor(cls):
  55. raise NotImplementedError
  56. def _get_extra_paddlex_predictor_init_args(self):
  57. return {}
  58. def _create_paddlex_predictor(self):
  59. kwargs = prepare_common_init_args(self._model_name, self._common_args)
  60. kwargs = {**self._get_extra_paddlex_predictor_init_args(), **kwargs}
  61. # Should we check model names?
  62. try:
  63. return create_predictor(
  64. model_name=self._model_name, model_dir=self._model_dir, **kwargs
  65. )
  66. except DependencyError as e:
  67. raise RuntimeError(
  68. "A dependency error occurred during predictor creation. Please refer to the installation documentation to ensure all required dependencies are installed."
  69. ) from e
  70. class PredictorCLISubcommandExecutor(CLISubcommandExecutor):
  71. @property
  72. @abc.abstractmethod
  73. def subparser_name(self):
  74. raise NotImplementedError
  75. def add_subparser(self, subparsers):
  76. subparser = subparsers.add_parser(name=self.subparser_name)
  77. self._update_subparser(subparser)
  78. subparser.add_argument("--model_name", type=str, help="Name of the model.")
  79. subparser.add_argument(
  80. "--model_dir", type=str, help="Directory where the model is stored."
  81. )
  82. add_common_cli_opts(
  83. subparser,
  84. default_enable_hpi=_DEFAULT_ENABLE_HPI,
  85. allow_multiple_devices=False,
  86. )
  87. return subparser
  88. @abc.abstractmethod
  89. def _update_subparser(self, subparser):
  90. raise NotImplementedError