testing_utils.py 833 B

123456789101112131415161718192021222324252627282930313233343536
  1. from pathlib import Path
  2. TEST_DATA_DIR = Path(__file__).parent / "test_files"
  3. def check_simple_inference_result(result, *, expected_length=1):
  4. assert result is not None
  5. assert isinstance(result, list)
  6. assert len(result) == expected_length
  7. for res in result:
  8. assert isinstance(res, dict)
  9. def check_wrapper_simple_inference_param_forwarding(
  10. monkeypatch,
  11. wrapper,
  12. wrapped_obj_attr_name,
  13. input,
  14. params,
  15. ):
  16. def _dummy_predict(input, **params):
  17. yield params
  18. monkeypatch.setattr(
  19. getattr(wrapper, wrapped_obj_attr_name), "predict", _dummy_predict
  20. )
  21. result = getattr(wrapper, "predict")(
  22. input,
  23. **params,
  24. )
  25. assert isinstance(result, list)
  26. assert len(result) == 1
  27. for k, v in params.items():
  28. assert result[0][k] == v