__init__.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import os
  19. import sys
  20. import numpy as np
  21. import skimage
  22. import paddle
  23. import signal
  24. import random
  25. __dir__ = os.path.dirname(os.path.abspath(__file__))
  26. sys.path.append(os.path.abspath(os.path.join(__dir__, "../..")))
  27. import copy
  28. from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
  29. import paddle.distributed as dist
  30. from ppocr.data.imaug import transform, create_operators
  31. from ppocr.data.simple_dataset import SimpleDataSet, MultiScaleDataSet
  32. from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTableMaster
  33. from ppocr.data.pgnet_dataset import PGDataSet
  34. from ppocr.data.pubtab_dataset import PubTabDataSet
  35. from ppocr.data.multi_scale_sampler import MultiScaleSampler
  36. from ppocr.data.latexocr_dataset import LaTeXOCRDataSet
  37. # for PaddleX dataset_type
  38. TextDetDataset = SimpleDataSet
  39. TextRecDataset = SimpleDataSet
  40. MSTextRecDataset = MultiScaleDataSet
  41. PubTabTableRecDataset = PubTabDataSet
  42. KieDataset = SimpleDataSet
  43. LaTeXOCRDataSet = LaTeXOCRDataSet
  44. __all__ = ["build_dataloader", "transform", "create_operators", "set_signal_handlers"]
  45. def term_mp(sig_num, frame):
  46. """kill all child processes"""
  47. pid = os.getpid()
  48. pgid = os.getpgid(os.getpid())
  49. print("main proc {} exit, kill process group " "{}".format(pid, pgid))
  50. os.killpg(pgid, signal.SIGKILL)
  51. def set_signal_handlers():
  52. pid = os.getpid()
  53. try:
  54. pgid = os.getpgid(pid)
  55. except AttributeError:
  56. # In case `os.getpgid` is not available, no signal handler will be set,
  57. # because we cannot do safe cleanup.
  58. pass
  59. else:
  60. # XXX: `term_mp` kills all processes in the process group, which in
  61. # some cases includes the parent process of current process and may
  62. # cause unexpected results. To solve this problem, we set signal
  63. # handlers only when current process is the group leader. In the
  64. # future, it would be better to consider killing only descendants of
  65. # the current process.
  66. if pid == pgid:
  67. # support exit using ctrl+c
  68. signal.signal(signal.SIGINT, term_mp)
  69. signal.signal(signal.SIGTERM, term_mp)
  70. def build_dataloader(config, mode, device, logger, seed=None):
  71. config = copy.deepcopy(config)
  72. support_dict = [
  73. "SimpleDataSet",
  74. "LMDBDataSet",
  75. "PGDataSet",
  76. "PubTabDataSet",
  77. "LMDBDataSetSR",
  78. "LMDBDataSetTableMaster",
  79. "MultiScaleDataSet",
  80. "TextDetDataset",
  81. "TextRecDataset",
  82. "MSTextRecDataset",
  83. "PubTabTableRecDataset",
  84. "KieDataset",
  85. "LaTeXOCRDataSet",
  86. ]
  87. module_name = config[mode]["dataset"]["name"]
  88. assert module_name in support_dict, Exception(
  89. "DataSet only support {}".format(support_dict)
  90. )
  91. assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test."
  92. dataset = eval(module_name)(config, mode, logger, seed)
  93. loader_config = config[mode]["loader"]
  94. batch_size = loader_config["batch_size_per_card"]
  95. drop_last = loader_config["drop_last"]
  96. shuffle = loader_config["shuffle"]
  97. num_workers = loader_config["num_workers"]
  98. if "use_shared_memory" in loader_config.keys():
  99. use_shared_memory = loader_config["use_shared_memory"]
  100. else:
  101. use_shared_memory = True
  102. if mode == "Train":
  103. # Distribute data to multiple cards
  104. if "sampler" in config[mode]:
  105. config_sampler = config[mode]["sampler"]
  106. sampler_name = config_sampler.pop("name")
  107. batch_sampler = eval(sampler_name)(dataset, **config_sampler)
  108. else:
  109. batch_sampler = DistributedBatchSampler(
  110. dataset=dataset,
  111. batch_size=batch_size,
  112. shuffle=shuffle,
  113. drop_last=drop_last,
  114. )
  115. else:
  116. # Distribute data to single card
  117. batch_sampler = BatchSampler(
  118. dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
  119. )
  120. if "collate_fn" in loader_config:
  121. from . import collate_fn
  122. collate_fn = getattr(collate_fn, loader_config["collate_fn"])()
  123. else:
  124. collate_fn = None
  125. data_loader = DataLoader(
  126. dataset=dataset,
  127. batch_sampler=batch_sampler,
  128. places=device,
  129. num_workers=num_workers,
  130. return_list=True,
  131. use_shared_memory=use_shared_memory,
  132. collate_fn=collate_fn,
  133. )
  134. return data_loader