sudoku.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import numpy as np
  4. import torch
  5. from modelscope.utils.constant import ModeKeys
  6. from .base import OfaBasePreprocessor
  7. class OfaSudokuPreprocessor(OfaBasePreprocessor):
  8. r"""
  9. OFA preprocessor for sudoku tasks
  10. """
  11. def __init__(self,
  12. cfg,
  13. model_dir,
  14. mode=ModeKeys.INFERENCE,
  15. *args,
  16. **kwargs):
  17. """preprocess the data
  18. Args:
  19. cfg(modelscope.utils.config.ConfigDict) : model config
  20. model_dir (str): model path,
  21. mode: preprocessor mode (model mode)
  22. """
  23. super(OfaSudokuPreprocessor, self).__init__(cfg, model_dir, mode,
  24. *args, **kwargs)
  25. self.instruction_text = self.cfg.model.get('prompt',
  26. ' solve the sudoku .')
  27. self.seg_embedding = self.cfg.get('seg_embedding', False)
  28. self.max_struct_length = self.cfg.get('max_struct_length', 256)
  29. if self.seg_embedding:
  30. self.input_puzzle_row = []
  31. self.input_puzzle_col = []
  32. for idx in range(9):
  33. for jdx in range(9):
  34. self.input_puzzle_row.append(jdx + 1)
  35. self.input_puzzle_col.append(idx + 1)
  36. if not (idx == 8 and jdx == 8):
  37. self.input_puzzle_row.append(0)
  38. self.input_puzzle_col.append(0)
  39. self.input_puzzle_col = torch.tensor(self.input_puzzle_col)
  40. self.input_puzzle_row = torch.tensor(self.input_puzzle_row)
  41. instruct_seg = torch.zeros_like(
  42. self.tokenize_text(self.instruction_text))
  43. input_puzzle_col = torch.cat([self.input_puzzle_col, instruct_seg])
  44. input_puzzle_row = torch.cat([self.input_puzzle_row, instruct_seg])
  45. self.input_puzzle_col = torch.cat(
  46. [self.bos_item, input_puzzle_col, self.eos_item])
  47. self.input_puzzle_row = torch.cat(
  48. [self.bos_item, input_puzzle_row, self.eos_item])
  49. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  50. if self.mode == ModeKeys.TRAIN:
  51. return self._build_train_sample(data)
  52. else:
  53. return self._build_infer_sample(data)
  54. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  55. r"""
  56. build sample for training tasks.
  57. step 1. execute the `_build_infer_sample` function to get a batch sample
  58. for inference.
  59. step 2. process the label data for training.
  60. """
  61. sample = self._build_infer_sample(data)
  62. target = sample['label']
  63. target_token_list = target.lower().strip().split()
  64. target = ' '.join(target_token_list[:self.max_tgt_length])
  65. sample['target'] = self.tokenize_text(target, add_bos=False)
  66. sample['prev_output_tokens'] = torch.cat(
  67. [self.bos_item, sample['target'][:-1]])
  68. return sample
  69. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  70. r"""
  71. build sample for inference tasks.
  72. step 1. Get the input random masked sudoku text input, which shold be
  73. generated like below pseudo code.
  74. >>> sudo = np.random.randint(1, 9, size=(9, 9)) # a pseudo sudoku
  75. >>> sudo_text = " | ".join(" : ".join(str(c) for c in row) \
  76. >>> for row in sudo)
  77. step 2. Limit the length, tokenize the input text and add the bos token
  78. to the front of the input as source input.
  79. step 3. Add a pseodo ids for every input.
  80. """
  81. assert 'text' in self.column_map and 'text' in data, \
  82. 'there must be `text` column in task key map and source data'
  83. text = data[self.column_map['text']] # equal data['text']
  84. text = ' '.join(text.lower().strip().split()[:self.max_struct_length])
  85. src_item = self.tokenize_text(text + self.instruction_text)
  86. src_item = src_item[:(self.max_src_length + self.max_struct_length)]
  87. sample = {'id': 0.0, 'source': src_item}
  88. if self.seg_embedding:
  89. sample['seg_row_tokens'] = self.input_puzzle_row
  90. sample['seg_col_tokens'] = self.input_puzzle_col
  91. if 'solution' in self.column_map and self.column_map[
  92. 'solution'] in data:
  93. sample['label'] = ' {}'.format(data[self.column_map['solution']])
  94. return sample