io_utils.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. import warnings
  17. import paddle
  18. from paddle import pir
  19. from paddle.base import (
  20. CompiledProgram,
  21. Variable,
  22. )
  23. def _check_args(caller, args, supported_args=None, deprecated_args=None):
  24. supported_args = [] if supported_args is None else supported_args
  25. deprecated_args = [] if deprecated_args is None else deprecated_args
  26. for arg in args:
  27. if arg in deprecated_args:
  28. raise ValueError(
  29. f"argument '{arg}' in function '{caller}' is deprecated, only {supported_args} are supported."
  30. )
  31. elif arg not in supported_args:
  32. raise ValueError(
  33. f"function '{caller}' doesn't support argument '{arg}',\n only {supported_args} are supported."
  34. )
  35. def _check_vars(name, var_list):
  36. if not isinstance(var_list, list):
  37. var_list = [var_list]
  38. if not all(isinstance(var, (Variable, pir.Value)) for var in var_list):
  39. raise ValueError(
  40. f"'{name}' should be a Variable or a list of Variable."
  41. )
  42. def _normalize_path_prefix(path_prefix):
  43. """
  44. convert path_prefix to absolute path.
  45. """
  46. if not isinstance(path_prefix, str):
  47. raise ValueError("'path_prefix' should be a string.")
  48. if path_prefix.endswith("/"):
  49. raise ValueError("'path_prefix' should not be a directory")
  50. path_prefix = os.path.normpath(path_prefix)
  51. path_prefix = os.path.abspath(path_prefix)
  52. return path_prefix
  53. def _get_valid_program(program=None):
  54. """
  55. return default main program if program is None.
  56. """
  57. if program is None:
  58. program = paddle.static.default_main_program()
  59. elif isinstance(program, CompiledProgram):
  60. program = program._program
  61. if program is None:
  62. raise TypeError(
  63. "The type of input program is invalid, expected type is Program, but received None"
  64. )
  65. warnings.warn(
  66. "The input is a CompiledProgram, this is not recommended."
  67. )
  68. if not isinstance(program, paddle.static.Program):
  69. raise TypeError(
  70. "The type of input program is invalid, expected type is base.Program, but received %s"
  71. % type(program)
  72. )
  73. return program
  74. def _safe_load_pickle(file, encoding="ASCII"):
  75. load_dict = pickle.Unpickler(file, encoding=encoding).load()
  76. return load_dict