task_utils.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from modelscope.metainfo import TaskModels
  2. from modelscope.utils import registry
  3. from modelscope.utils.constant import Tasks
  4. SUB_TASKS = 'sub_tasks'
  5. PARENT_TASK = 'parent_task'
  6. TASK_MODEL = 'task_model'
  7. DEFAULT_TASKS_LEVEL = {
  8. Tasks.text_classification: {
  9. SUB_TASKS: [
  10. Tasks.text_classification,
  11. Tasks.sentence_similarity,
  12. Tasks.sentiment_classification,
  13. Tasks.sentiment_analysis,
  14. Tasks.nli,
  15. ],
  16. TASK_MODEL:
  17. TaskModels.text_classification,
  18. },
  19. Tasks.token_classification: {
  20. SUB_TASKS: [
  21. Tasks.token_classification,
  22. Tasks.named_entity_recognition,
  23. Tasks.word_segmentation,
  24. Tasks.part_of_speech,
  25. ],
  26. TASK_MODEL:
  27. TaskModels.text_classification,
  28. },
  29. Tasks.token_classification: {
  30. SUB_TASKS: [
  31. Tasks.token_classification,
  32. Tasks.named_entity_recognition,
  33. Tasks.word_segmentation,
  34. Tasks.part_of_speech,
  35. ],
  36. TASK_MODEL:
  37. TaskModels.text_classification,
  38. },
  39. Tasks.text_generation: {
  40. SUB_TASKS: [
  41. Tasks.text_generation,
  42. Tasks.text2text_generation,
  43. ],
  44. TASK_MODEL: TaskModels.text_generation,
  45. },
  46. Tasks.information_extraction: {
  47. SUB_TASKS: [
  48. Tasks.information_extraction,
  49. Tasks.relation_extraction,
  50. ],
  51. TASK_MODEL: TaskModels.information_extraction,
  52. },
  53. Tasks.fill_mask: {
  54. SUB_TASKS: [
  55. Tasks.fill_mask,
  56. ],
  57. TASK_MODEL: TaskModels.fill_mask,
  58. },
  59. Tasks.text_ranking: {
  60. SUB_TASKS: [
  61. Tasks.text_ranking,
  62. ],
  63. TASK_MODEL: TaskModels.text_ranking,
  64. }
  65. # TODO: add other tasks with their sub tasks in different domains
  66. }
  67. def _inverted_index(forward_index):
  68. inverted_index = dict()
  69. for index in forward_index:
  70. for item in forward_index[index][SUB_TASKS]:
  71. inverted_index[item] = {
  72. PARENT_TASK: index,
  73. TASK_MODEL: forward_index[index][TASK_MODEL],
  74. }
  75. return inverted_index
  76. INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL)
  77. def is_embedding_task(task: str):
  78. return task == Tasks.sentence_embedding
  79. def get_task_by_subtask_name(group_key):
  80. if group_key in INVERTED_TASKS_LEVEL:
  81. return INVERTED_TASKS_LEVEL[group_key][
  82. PARENT_TASK], INVERTED_TASKS_LEVEL[group_key][TASK_MODEL]
  83. else:
  84. return group_key, None