cv_outputs.py 908 B

1234567891011121314151617181920212223242526272829
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from dataclasses import dataclass
  3. from typing import Optional, Tuple, Union
  4. import numpy as np
  5. from modelscope.outputs.outputs import ModelOutputBase
  6. Tensor = Union['torch.Tensor', 'tf.Tensor']
  7. @dataclass
  8. class DetectionOutput(ModelOutputBase):
  9. """The output class for object detection models.
  10. Args:
  11. class_ids (`Tensor`, *optional*): class id for each object.
  12. boxes (`Tensor`, *optional*): Bounding box for each detected object in [left, top, right, bottom] format.
  13. scores (`Tensor`, *optional*): Detection score for each object.
  14. keypoints (`Tensor`, *optional*): Keypoints for each object using four corner points in a 8-dim tensor
  15. in the order of (x, y) for each corner point.
  16. """
  17. class_ids: Tensor = None
  18. scores: Tensor = None
  19. boxes: Tensor = None
  20. keypoints: Tensor = None