_object_detection.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. from .._utils.cli import (
  16. add_simple_inference_args,
  17. get_subcommand_args,
  18. perform_simple_inference,
  19. str2bool,
  20. )
  21. from .base import PaddleXPredictorWrapper, PredictorCLISubcommandExecutor
  22. class ObjectDetection(PaddleXPredictorWrapper):
  23. def __init__(
  24. self,
  25. *,
  26. img_size=None,
  27. threshold=None,
  28. layout_nms=None,
  29. layout_unclip_ratio=None,
  30. layout_merge_bboxes_mode=None,
  31. **kwargs,
  32. ):
  33. self._extra_init_args = {
  34. "img_size": img_size,
  35. "threshold": threshold,
  36. "layout_nms": layout_nms,
  37. "layout_unclip_ratio": layout_unclip_ratio,
  38. "layout_merge_bboxes_mode": layout_merge_bboxes_mode,
  39. }
  40. super().__init__(**kwargs)
  41. def _get_extra_paddlex_predictor_init_args(self):
  42. return self._extra_init_args
  43. class ObjectDetectionSubcommandExecutor(PredictorCLISubcommandExecutor):
  44. def _update_subparser(self, subparser):
  45. add_simple_inference_args(subparser)
  46. subparser.add_argument(
  47. "--img_size",
  48. type=int,
  49. help="Input image size (w, h).",
  50. )
  51. subparser.add_argument(
  52. "--threshold",
  53. type=float,
  54. help="Threshold for filtering out low-confidence predictions.",
  55. )
  56. subparser.add_argument(
  57. "--layout_nms",
  58. type=str2bool,
  59. help="Whether to use layout-aware NMS.",
  60. )
  61. subparser.add_argument(
  62. "--layout_unclip_ratio",
  63. type=float,
  64. help="Ratio of unclipping the bounding box.",
  65. )
  66. subparser.add_argument(
  67. "--layout_merge_bboxes_mode",
  68. type=str,
  69. help="Mode for merging bounding boxes.",
  70. )
  71. @property
  72. @abc.abstractmethod
  73. def wrapper_cls(self):
  74. raise NotImplementedError
  75. def execute_with_args(self, args):
  76. params = get_subcommand_args(args)
  77. perform_simple_inference(self.wrapper_cls, params)