io.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .. import core
  15. from ..executor import global_scope
  16. from ..framework import default_main_program, default_startup_program
  17. from ..unique_name import generate as unique_name
  18. __all__ = []
  19. def monkey_patch_reader_methods(reader):
  20. def __get_reader__():
  21. scope = global_scope()
  22. var = scope.find_var(reader.name)
  23. return var.get_reader()
  24. def reset():
  25. return __get_reader__().reset()
  26. reader.reset = reset
  27. reader.stop_gradient = True
  28. reader.persistable = True
  29. return reader
  30. def _copy_reader_var_(block, var):
  31. new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
  32. new_var.desc.set_shapes(var.desc.shapes())
  33. new_var.desc.set_dtypes(var.desc.dtypes())
  34. new_var.desc.set_lod_levels(var.desc.lod_levels())
  35. new_var.persistable = True
  36. return new_var
  37. def _copy_reader_create_op_(block, op):
  38. input_param_names = op.input_names
  39. new_input_map = {}
  40. for param_name in input_param_names:
  41. new_input_map[param_name] = []
  42. arg_names = op.input(param_name)
  43. for arg_name in arg_names:
  44. new_input_map[param_name].append(block.var(arg_name))
  45. output_param_names = op.output_names
  46. new_output_map = {}
  47. for param_name in output_param_names:
  48. new_output_map[param_name] = []
  49. arg_names = op.output(param_name)
  50. for arg_name in arg_names:
  51. new_output_map[param_name].append(block.var(arg_name))
  52. new_op = block.append_op(
  53. type=op.type,
  54. inputs=new_input_map,
  55. outputs=new_output_map,
  56. attrs=op.all_attrs(),
  57. )
  58. return new_op
  59. def __create_shared_decorated_reader__(op_type, reader, attrs):
  60. var_name = unique_name(op_type)
  61. startup_blk = default_startup_program().current_block()
  62. startup_var = startup_blk.create_var(name=var_name)
  63. startup_op = startup_blk.append_op(
  64. type=op_type,
  65. inputs={'UnderlyingReader': reader},
  66. outputs={'Out': [startup_var]},
  67. attrs=attrs,
  68. )
  69. startup_var.persistable = True
  70. main_prog_block = default_main_program().current_block()
  71. main_prog_var = _copy_reader_var_(main_prog_block, startup_var)
  72. _copy_reader_create_op_(main_prog_block, startup_op)
  73. return monkey_patch_reader_methods(main_prog_var)
  74. def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
  75. new_reader_name = name if name is not None else unique_name(op_type)
  76. main_blk = default_main_program().current_block()
  77. new_reader = main_blk.create_var(name=new_reader_name)
  78. main_blk.append_op(
  79. type=op_type,
  80. inputs={'UnderlyingReader': reader},
  81. outputs={'Out': [new_reader]},
  82. attrs=attrs,
  83. )
  84. return monkey_patch_reader_methods(new_reader)