utils_dst.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import List
  3. from modelscope.outputs import OutputKeys
  4. from modelscope.pipelines.nlp import DialogStateTrackingPipeline
  5. def tracking_and_print_dialog_states(
  6. test_case, pipelines: List[DialogStateTrackingPipeline]):
  7. import json
  8. pipelines_len = len(pipelines)
  9. history_states = [{}]
  10. utter = {}
  11. for step, item in enumerate(test_case):
  12. utter.update(item)
  13. result = pipelines[step % pipelines_len]({
  14. 'utter':
  15. utter,
  16. 'history_states':
  17. history_states
  18. })
  19. print(json.dumps(result))
  20. history_states.extend([result[OutputKeys.OUTPUT], {}])
  21. def batch_to_device(batch, device):
  22. batch_on_device = []
  23. for element in batch:
  24. if isinstance(element, dict):
  25. batch_on_device.append(
  26. {k: v.to(device)
  27. for k, v in element.items()})
  28. else:
  29. batch_on_device.append(element.to(device))
  30. return tuple(batch_on_device)