| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- # -*- coding: utf-8 -*-
- # @Time : 2019/8/23 21:59
- # @Author : zhoujun
- import json
- import pathlib
- import time
- import os
- import glob
- import cv2
- import yaml
- from typing import Mapping
- import matplotlib.pyplot as plt
- import numpy as np
- from argparse import ArgumentParser, RawDescriptionHelpFormatter
- def _check_image_file(path):
- img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
- return any([path.lower().endswith(e) for e in img_end])
- def get_image_file_list(img_file):
- imgs_lists = []
- if img_file is None or not os.path.exists(img_file):
- raise Exception("not found any img file in {}".format(img_file))
- img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
- if os.path.isfile(img_file) and _check_image_file(img_file):
- imgs_lists.append(img_file)
- elif os.path.isdir(img_file):
- for single_file in os.listdir(img_file):
- file_path = os.path.join(img_file, single_file)
- if os.path.isfile(file_path) and _check_image_file(file_path):
- imgs_lists.append(file_path)
- if len(imgs_lists) == 0:
- raise Exception("not found any img file in {}".format(img_file))
- imgs_lists = sorted(imgs_lists)
- return imgs_lists
- def setup_logger(log_file_path: str = None):
- import logging
- logging._warn_preinit_stderr = 0
- logger = logging.getLogger("DBNet.paddle")
- formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
- ch = logging.StreamHandler()
- ch.setFormatter(formatter)
- logger.addHandler(ch)
- if log_file_path is not None:
- file_handle = logging.FileHandler(log_file_path)
- file_handle.setFormatter(formatter)
- logger.addHandler(file_handle)
- logger.setLevel(logging.DEBUG)
- return logger
- # --exeTime
- def exe_time(func):
- def newFunc(*args, **args2):
- t0 = time.time()
- back = func(*args, **args2)
- print("{} cost {:.3f}s".format(func.__name__, time.time() - t0))
- return back
- return newFunc
- def load(file_path: str):
- file_path = pathlib.Path(file_path)
- func_dict = {".txt": _load_txt, ".json": _load_json, ".list": _load_txt}
- assert file_path.suffix in func_dict
- return func_dict[file_path.suffix](file_path)
- def _load_txt(file_path: str):
- with open(file_path, "r", encoding="utf8") as f:
- content = [
- x.strip().strip("\ufeff").strip("\xef\xbb\xbf") for x in f.readlines()
- ]
- return content
- def _load_json(file_path: str):
- with open(file_path, "r", encoding="utf8") as f:
- content = json.load(f)
- return content
- def save(data, file_path):
- file_path = pathlib.Path(file_path)
- func_dict = {".txt": _save_txt, ".json": _save_json}
- assert file_path.suffix in func_dict
- return func_dict[file_path.suffix](data, file_path)
- def _save_txt(data, file_path):
- """
- 将一个list的数组写入txt文件里
- :param data:
- :param file_path:
- :return:
- """
- if not isinstance(data, list):
- data = [data]
- with open(file_path, mode="w", encoding="utf8") as f:
- f.write("\n".join(data))
- def _save_json(data, file_path):
- with open(file_path, "w", encoding="utf-8") as json_file:
- json.dump(data, json_file, ensure_ascii=False, indent=4)
- def show_img(imgs: np.ndarray, title="img"):
- color = len(imgs.shape) == 3 and imgs.shape[-1] == 3
- imgs = np.expand_dims(imgs, axis=0)
- for i, img in enumerate(imgs):
- plt.figure()
- plt.title("{}_{}".format(title, i))
- plt.imshow(img, cmap=None if color else "gray")
- plt.show()
- def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2):
- if isinstance(img_path, str):
- img_path = cv2.imread(img_path)
- # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
- img_path = img_path.copy()
- for point in result:
- point = point.astype(int)
- cv2.polylines(img_path, [point], True, color, thickness)
- return img_path
- def cal_text_score(texts, gt_texts, training_masks, running_metric_text, thred=0.5):
- training_masks = training_masks.numpy()
- pred_text = texts.numpy() * training_masks
- pred_text[pred_text <= thred] = 0
- pred_text[pred_text > thred] = 1
- pred_text = pred_text.astype(np.int32)
- gt_text = gt_texts.numpy() * training_masks
- gt_text = gt_text.astype(np.int32)
- running_metric_text.update(gt_text, pred_text)
- score_text, _ = running_metric_text.get_scores()
- return score_text
- def order_points_clockwise(pts):
- rect = np.zeros((4, 2), dtype="float32")
- s = pts.sum(axis=1)
- rect[0] = pts[np.argmin(s)]
- rect[2] = pts[np.argmax(s)]
- diff = np.diff(pts, axis=1)
- rect[1] = pts[np.argmin(diff)]
- rect[3] = pts[np.argmax(diff)]
- return rect
- def order_points_clockwise_list(pts):
- pts = pts.tolist()
- pts.sort(key=lambda x: (x[1], x[0]))
- pts[:2] = sorted(pts[:2], key=lambda x: x[0])
- pts[2:] = sorted(pts[2:], key=lambda x: -x[0])
- pts = np.array(pts)
- return pts
- def get_datalist(train_data_path):
- """
- 获取训练和验证的数据list
- :param train_data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
- :return:
- """
- train_data = []
- for p in train_data_path:
- with open(p, "r", encoding="utf-8") as f:
- for line in f.readlines():
- line = line.strip("\n").replace(".jpg ", ".jpg\t").split("\t")
- if len(line) > 1:
- img_path = pathlib.Path(line[0].strip(" "))
- label_path = pathlib.Path(line[1].strip(" "))
- if (
- img_path.exists()
- and img_path.stat().st_size > 0
- and label_path.exists()
- and label_path.stat().st_size > 0
- ):
- train_data.append((str(img_path), str(label_path)))
- return train_data
- def save_result(result_path, box_list, score_list, is_output_polygon):
- if is_output_polygon:
- with open(result_path, "wt") as res:
- for i, box in enumerate(box_list):
- box = box.reshape(-1).tolist()
- result = ",".join([str(int(x)) for x in box])
- score = score_list[i]
- res.write(result + "," + str(score) + "\n")
- else:
- with open(result_path, "wt") as res:
- for i, box in enumerate(box_list):
- score = score_list[i]
- box = box.reshape(-1).tolist()
- result = ",".join([str(int(x)) for x in box])
- res.write(result + "," + str(score) + "\n")
- def expand_polygon(polygon):
- """
- 对只有一个字符的框进行扩充
- """
- (x, y), (w, h), angle = cv2.minAreaRect(np.float32(polygon))
- if angle < -45:
- w, h = h, w
- angle += 90
- new_w = w + h
- box = ((x, y), (new_w, h), angle)
- points = cv2.boxPoints(box)
- return order_points_clockwise(points)
- def _merge_dict(config, merge_dct):
- """Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
- updating only top-level keys, dict_merge recurses down into dicts nested
- to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
- ``dct``.
- Args:
- config: dict onto which the merge is executed
- merge_dct: dct merged into config
- Returns: dct
- """
- for key, value in merge_dct.items():
- sub_keys = key.split(".")
- key = sub_keys[0]
- if key in config and len(sub_keys) > 1:
- _merge_dict(config[key], {".".join(sub_keys[1:]): value})
- elif (
- key in config
- and isinstance(config[key], dict)
- and isinstance(value, Mapping)
- ):
- _merge_dict(config[key], value)
- else:
- config[key] = value
- return config
- def print_dict(cfg, print_func=print, delimiter=0):
- """
- Recursively visualize a dict and
- indenting acrrording by the relationship of keys.
- """
- for k, v in sorted(cfg.items()):
- if isinstance(v, dict):
- print_func("{}{} : ".format(delimiter * " ", str(k)))
- print_dict(v, print_func, delimiter + 4)
- elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
- print_func("{}{} : ".format(delimiter * " ", str(k)))
- for value in v:
- print_dict(value, print_func, delimiter + 4)
- else:
- print_func("{}{} : {}".format(delimiter * " ", k, v))
- class Config(object):
- def __init__(self, config_path, BASE_KEY="base"):
- self.BASE_KEY = BASE_KEY
- self.cfg = self._load_config_with_base(config_path)
- def _load_config_with_base(self, file_path):
- """
- Load config from file.
- Args:
- file_path (str): Path of the config file to be loaded.
- Returns: global config
- """
- _, ext = os.path.splitext(file_path)
- assert ext in [".yml", ".yaml"], "only support yaml files for now"
- with open(file_path) as f:
- file_cfg = yaml.load(f, Loader=yaml.Loader)
- # NOTE: cfgs outside have higher priority than cfgs in _BASE_
- if self.BASE_KEY in file_cfg:
- all_base_cfg = dict()
- base_ymls = list(file_cfg[self.BASE_KEY])
- for base_yml in base_ymls:
- with open(base_yml) as f:
- base_cfg = self._load_config_with_base(base_yml)
- all_base_cfg = _merge_dict(all_base_cfg, base_cfg)
- del file_cfg[self.BASE_KEY]
- file_cfg = _merge_dict(all_base_cfg, file_cfg)
- file_cfg["filename"] = os.path.splitext(os.path.split(file_path)[-1])[0]
- return file_cfg
- def merge_dict(self, args):
- self.cfg = _merge_dict(self.cfg, args)
- def print_cfg(self, print_func=print):
- """
- Recursively visualize a dict and
- indenting according by the relationship of keys.
- """
- print_func("----------- Config -----------")
- print_dict(self.cfg, print_func)
- print_func("---------------------------------------------")
- def save(self, p):
- with open(p, "w") as f:
- yaml.dump(dict(self.cfg), f, default_flow_style=False, sort_keys=False)
- class ArgsParser(ArgumentParser):
- def __init__(self):
- super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
- self.add_argument("-c", "--config_file", help="configuration file to use")
- self.add_argument("-o", "--opt", nargs="*", help="set configuration options")
- self.add_argument(
- "-p",
- "--profiler_options",
- type=str,
- default=None,
- help="The option of profiler, which should be in format "
- '"key1=value1;key2=value2;key3=value3".',
- )
- def parse_args(self, argv=None):
- args = super(ArgsParser, self).parse_args(argv)
- assert (
- args.config_file is not None
- ), "Please specify --config_file=configure_file_path."
- args.opt = self._parse_opt(args.opt)
- return args
- def _parse_opt(self, opts):
- config = {}
- if not opts:
- return config
- for s in opts:
- s = s.strip()
- k, v = s.split("=", 1)
- if "." not in k:
- config[k] = yaml.load(v, Loader=yaml.Loader)
- else:
- keys = k.split(".")
- if keys[0] not in config:
- config[keys[0]] = {}
- cur = config[keys[0]]
- for idx, key in enumerate(keys[1:]):
- if idx == len(keys) - 2:
- cur[key] = yaml.load(v, Loader=yaml.Loader)
- else:
- cur[key] = {}
- cur = cur[key]
- return config
- if __name__ == "__main__":
- img = np.zeros((1, 3, 640, 640))
- show_img(img[0][0])
- plt.show()
|