base_head.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from abc import ABC, abstractmethod
  3. from typing import Any, Dict, Union
  4. from modelscope.models.base.base_model import Model
  5. from modelscope.utils.config import ConfigDict
  6. from modelscope.utils.logger import get_logger
  7. logger = get_logger()
  8. Tensor = Union['torch.Tensor', 'tf.Tensor']
  9. Input = Union[Dict[str, Tensor], Model]
  10. class Head(ABC):
  11. """The head base class is for the tasks head method definition
  12. """
  13. def __init__(self, **kwargs):
  14. self.config = ConfigDict(kwargs)
  15. @abstractmethod
  16. def forward(self, *args, **kwargs) -> Dict[str, Any]:
  17. """
  18. This method will use the output from backbone model to do any
  19. downstream tasks. Receive The output from backbone model.
  20. Returns (Dict[str, Any]): The output from downstream task.
  21. """
  22. pass
  23. @abstractmethod
  24. def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
  25. """
  26. compute loss for head during the finetuning.
  27. Returns (Dict[str, Any]): The loss dict
  28. """
  29. pass