network.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import time
  17. import shutil
  18. import tarfile
  19. import requests
  20. import os.path as osp
  21. import paddle.distributed as dist
  22. from tqdm import tqdm
  23. from ppocr.utils.logging import get_logger
  24. MODELS_DIR = os.path.join(
  25. os.environ.get("PADDLE_OCR_BASE_DIR", os.path.expanduser("~/.paddleocr/")), "models"
  26. )
  27. DOWNLOAD_RETRY_LIMIT = 3
  28. def download_with_progressbar(url, save_path):
  29. logger = get_logger()
  30. if save_path and os.path.exists(save_path):
  31. logger.info(f"Path {save_path} already exists. Skipping...")
  32. return
  33. else:
  34. # Mainly used to solve the problem of downloading data from different
  35. # machines in the case of multiple machines. Different nodes will download
  36. # data, and the same node will only download data once.
  37. if dist.get_rank() == 0:
  38. _download(url, save_path)
  39. else:
  40. while not os.path.exists(save_path):
  41. time.sleep(1)
  42. def _download(url, save_path):
  43. """
  44. Download from url, save to path.
  45. url (str): download url
  46. save_path (str): download to given path
  47. """
  48. logger = get_logger()
  49. fname = osp.split(url)[-1]
  50. retry_cnt = 0
  51. while not osp.exists(save_path):
  52. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  53. retry_cnt += 1
  54. else:
  55. raise RuntimeError(
  56. "Download from {} failed. " "Retry limit reached".format(url)
  57. )
  58. try:
  59. req = requests.get(url, stream=True)
  60. except Exception as e: # requests.exceptions.ConnectionError
  61. logger.info(
  62. "Downloading {} from {} failed {} times with exception {}".format(
  63. fname, url, retry_cnt + 1, str(e)
  64. )
  65. )
  66. time.sleep(1)
  67. continue
  68. if req.status_code != 200:
  69. raise RuntimeError(
  70. "Downloading from {} failed with code "
  71. "{}!".format(url, req.status_code)
  72. )
  73. # For protecting download interrupted, download to
  74. # tmp_file firstly, move tmp_file to save_path
  75. # after download finished
  76. tmp_file = save_path + ".tmp"
  77. total_size = req.headers.get("content-length")
  78. with open(tmp_file, "wb") as f:
  79. if total_size:
  80. with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
  81. for chunk in req.iter_content(chunk_size=1024):
  82. f.write(chunk)
  83. pbar.update(1)
  84. else:
  85. for chunk in req.iter_content(chunk_size=1024):
  86. if chunk:
  87. f.write(chunk)
  88. shutil.move(tmp_file, save_path)
  89. return save_path
  90. def maybe_download(model_storage_directory, url):
  91. # using custom model
  92. tar_file_name_list = [".pdiparams", ".pdiparams.info", ".pdmodel"]
  93. if not os.path.exists(
  94. os.path.join(model_storage_directory, "inference.pdiparams")
  95. ) or not os.path.exists(os.path.join(model_storage_directory, "inference.pdmodel")):
  96. assert url.endswith(".tar"), "Only supports tar compressed package"
  97. tmp_path = os.path.join(model_storage_directory, url.split("/")[-1])
  98. print("download {} to {}".format(url, tmp_path))
  99. os.makedirs(model_storage_directory, exist_ok=True)
  100. download_with_progressbar(url, tmp_path)
  101. with tarfile.open(tmp_path, "r") as tarObj:
  102. for member in tarObj.getmembers():
  103. filename = None
  104. for tar_file_name in tar_file_name_list:
  105. if member.name.endswith(tar_file_name):
  106. filename = "inference" + tar_file_name
  107. if filename is None:
  108. continue
  109. file = tarObj.extractfile(member)
  110. with open(os.path.join(model_storage_directory, filename), "wb") as f:
  111. f.write(file.read())
  112. os.remove(tmp_path)
  113. def maybe_download_params(model_path):
  114. if os.path.exists(model_path) or not is_link(model_path):
  115. return model_path
  116. else:
  117. url = model_path
  118. tmp_path = os.path.join(MODELS_DIR, url.split("/")[-1])
  119. print("download {} to {}".format(url, tmp_path))
  120. os.makedirs(MODELS_DIR, exist_ok=True)
  121. download_with_progressbar(url, tmp_path)
  122. return tmp_path
  123. def is_link(s):
  124. return s is not None and s.startswith("http")
  125. def confirm_model_dir_url(model_dir, default_model_dir, default_url):
  126. url = default_url
  127. if model_dir is None or is_link(model_dir):
  128. if is_link(model_dir):
  129. url = model_dir
  130. file_name = url.split("/")[-1][:-4]
  131. model_dir = default_model_dir
  132. model_dir = os.path.join(model_dir, file_name)
  133. return model_dir, url