| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from .. import core
- from ..executor import global_scope
- from ..framework import default_main_program, default_startup_program
- from ..unique_name import generate as unique_name
- __all__ = []
- def monkey_patch_reader_methods(reader):
- def __get_reader__():
- scope = global_scope()
- var = scope.find_var(reader.name)
- return var.get_reader()
- def reset():
- return __get_reader__().reset()
- reader.reset = reset
- reader.stop_gradient = True
- reader.persistable = True
- return reader
- def _copy_reader_var_(block, var):
- new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
- new_var.desc.set_shapes(var.desc.shapes())
- new_var.desc.set_dtypes(var.desc.dtypes())
- new_var.desc.set_lod_levels(var.desc.lod_levels())
- new_var.persistable = True
- return new_var
- def _copy_reader_create_op_(block, op):
- input_param_names = op.input_names
- new_input_map = {}
- for param_name in input_param_names:
- new_input_map[param_name] = []
- arg_names = op.input(param_name)
- for arg_name in arg_names:
- new_input_map[param_name].append(block.var(arg_name))
- output_param_names = op.output_names
- new_output_map = {}
- for param_name in output_param_names:
- new_output_map[param_name] = []
- arg_names = op.output(param_name)
- for arg_name in arg_names:
- new_output_map[param_name].append(block.var(arg_name))
- new_op = block.append_op(
- type=op.type,
- inputs=new_input_map,
- outputs=new_output_map,
- attrs=op.all_attrs(),
- )
- return new_op
- def __create_shared_decorated_reader__(op_type, reader, attrs):
- var_name = unique_name(op_type)
- startup_blk = default_startup_program().current_block()
- startup_var = startup_blk.create_var(name=var_name)
- startup_op = startup_blk.append_op(
- type=op_type,
- inputs={'UnderlyingReader': reader},
- outputs={'Out': [startup_var]},
- attrs=attrs,
- )
- startup_var.persistable = True
- main_prog_block = default_main_program().current_block()
- main_prog_var = _copy_reader_var_(main_prog_block, startup_var)
- _copy_reader_create_op_(main_prog_block, startup_op)
- return monkey_patch_reader_methods(main_prog_var)
- def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
- new_reader_name = name if name is not None else unique_name(op_type)
- main_blk = default_main_program().current_block()
- new_reader = main_blk.create_var(name=new_reader_name)
- main_blk.append_op(
- type=op_type,
- inputs={'UnderlyingReader': reader},
- outputs={'Out': [new_reader]},
- attrs=attrs,
- )
- return monkey_patch_reader_methods(new_reader)
|