| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import gc
- import sys
- import platform
- import yaml
- import time
- import datetime
- import paddle
- import paddle.distributed as dist
- from tqdm import tqdm
- import cv2
- import numpy as np
- import copy
- from argparse import ArgumentParser, RawDescriptionHelpFormatter
- from ppocr.utils.stats import TrainingStats
- from ppocr.utils.save_load import save_model
- from ppocr.utils.utility import print_dict, AverageMeter
- from ppocr.utils.logging import get_logger
- from ppocr.utils.loggers import WandbLogger, Loggers
- from ppocr.utils import profiler
- from ppocr.data import build_dataloader
- from ppocr.utils.export_model import export
- class ArgsParser(ArgumentParser):
- def __init__(self):
- super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
- self.add_argument("-c", "--config", help="configuration file to use")
- self.add_argument("-o", "--opt", nargs="+", help="set configuration options")
- self.add_argument(
- "-p",
- "--profiler_options",
- type=str,
- default=None,
- help="The option of profiler, which should be in format "
- '"key1=value1;key2=value2;key3=value3".',
- )
- def parse_args(self, argv=None):
- args = super(ArgsParser, self).parse_args(argv)
- assert args.config is not None, "Please specify --config=configure_file_path."
- args.opt = self._parse_opt(args.opt)
- return args
- def _parse_opt(self, opts):
- config = {}
- if not opts:
- return config
- for s in opts:
- s = s.strip()
- k, v = s.split("=")
- config[k] = yaml.load(v, Loader=yaml.Loader)
- return config
- def load_config(file_path):
- """
- Load config from yml/yaml file.
- Args:
- file_path (str): Path of the config file to be loaded.
- Returns: global config
- """
- _, ext = os.path.splitext(file_path)
- assert ext in [".yml", ".yaml"], "only support yaml files for now"
- config = yaml.load(open(file_path, "rb"), Loader=yaml.Loader)
- return config
- def merge_config(config, opts):
- """
- Merge config into global config.
- Args:
- config (dict): Config to be merged.
- Returns: global config
- """
- for key, value in opts.items():
- if "." not in key:
- if isinstance(value, dict) and key in config:
- config[key].update(value)
- else:
- config[key] = value
- else:
- sub_keys = key.split(".")
- assert sub_keys[0] in config, (
- "the sub_keys can only be one of global_config: {}, but get: "
- "{}, please check your running command".format(
- config.keys(), sub_keys[0]
- )
- )
- cur = config[sub_keys[0]]
- for idx, sub_key in enumerate(sub_keys[1:]):
- if idx == len(sub_keys) - 2:
- cur[sub_key] = value
- else:
- cur = cur[sub_key]
- return config
- def check_device(
- use_gpu,
- use_xpu=False,
- use_npu=False,
- use_mlu=False,
- use_gcu=False,
- use_iluvatar_gpu=False,
- use_metax_gpu=False,
- ):
- """
- Log error and exit when set use_gpu=true in paddlepaddle
- cpu version.
- """
- err = (
- "Config {} cannot be set as true while your paddle "
- "is not compiled with {} ! \nPlease try: \n"
- "\t1. Install paddlepaddle to run model on {} \n"
- "\t2. Set {} as false in config file to run "
- "model on CPU"
- )
- try:
- if use_gpu and use_xpu:
- print("use_xpu and use_gpu can not both be true.")
- if use_gpu and not paddle.is_compiled_with_cuda():
- print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
- sys.exit(1)
- if use_xpu and not paddle.device.is_compiled_with_xpu():
- print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
- sys.exit(1)
- if use_npu:
- if (
- int(paddle.version.major) != 0
- and int(paddle.version.major) <= 2
- and int(paddle.version.minor) <= 4
- ):
- if not paddle.device.is_compiled_with_npu():
- print(err.format("use_npu", "npu", "npu", "use_npu"))
- sys.exit(1)
- # is_compiled_with_npu() has been updated after paddle-2.4
- else:
- if not paddle.device.is_compiled_with_custom_device("npu"):
- print(err.format("use_npu", "npu", "npu", "use_npu"))
- sys.exit(1)
- if use_mlu and not paddle.device.is_compiled_with_mlu():
- print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
- sys.exit(1)
- if use_gcu and not paddle.device.is_compiled_with_custom_device("gcu"):
- print(err.format("use_gcu", "gcu", "gcu", "use_gcu"))
- sys.exit(1)
- if use_metax_gpu and not paddle.device.is_compiled_with_custom_device(
- "metax_gpu"
- ):
- print(
- err.format("use_metax_gpu", "metax_gpu", "metax_gpu", "use_metax_gpu")
- )
- sys.exit(1)
- except Exception as e:
- pass
- def to_float32(preds):
- if isinstance(preds, dict):
- for k in preds:
- if isinstance(preds[k], dict) or isinstance(preds[k], list):
- preds[k] = to_float32(preds[k])
- elif isinstance(preds[k], paddle.Tensor):
- preds[k] = preds[k].astype(paddle.float32)
- elif isinstance(preds, list):
- for k in range(len(preds)):
- if isinstance(preds[k], dict):
- preds[k] = to_float32(preds[k])
- elif isinstance(preds[k], list):
- preds[k] = to_float32(preds[k])
- elif isinstance(preds[k], paddle.Tensor):
- preds[k] = preds[k].astype(paddle.float32)
- elif isinstance(preds, paddle.Tensor):
- preds = preds.astype(paddle.float32)
- return preds
- def train(
- config,
- train_dataloader,
- valid_dataloader,
- device,
- model,
- loss_class,
- optimizer,
- lr_scheduler,
- post_process_class,
- eval_class,
- pre_best_model_dict,
- logger,
- step_pre_epoch,
- log_writer=None,
- scaler=None,
- amp_level="O2",
- amp_custom_black_list=[],
- amp_custom_white_list=[],
- amp_dtype="float16",
- ):
- cal_metric_during_train = config["Global"].get("cal_metric_during_train", False)
- calc_epoch_interval = config["Global"].get("calc_epoch_interval", 1)
- log_smooth_window = config["Global"]["log_smooth_window"]
- epoch_num = config["Global"]["epoch_num"]
- print_batch_step = config["Global"]["print_batch_step"]
- eval_batch_step = config["Global"]["eval_batch_step"]
- eval_batch_epoch = config["Global"].get("eval_batch_epoch", None)
- profiler_options = config["profiler_options"]
- print_mem_info = config["Global"].get("print_mem_info", True)
- uniform_output_enabled = config["Global"].get("uniform_output_enabled", False)
- global_step = 0
- if "global_step" in pre_best_model_dict:
- global_step = pre_best_model_dict["global_step"]
- start_eval_step = 0
- if isinstance(eval_batch_step, list) and len(eval_batch_step) >= 2:
- start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0
- eval_batch_step = (
- eval_batch_step[1]
- if not eval_batch_epoch
- else step_pre_epoch * eval_batch_epoch
- )
- if len(valid_dataloader) == 0:
- logger.info(
- "No Images in eval dataset, evaluation during training "
- "will be disabled"
- )
- start_eval_step = 1e111
- logger.info(
- "During the training process, after the {}th iteration, "
- "an evaluation is run every {} iterations".format(
- start_eval_step, eval_batch_step
- )
- )
- save_epoch_step = config["Global"]["save_epoch_step"]
- save_model_dir = config["Global"]["save_model_dir"]
- if not os.path.exists(save_model_dir):
- os.makedirs(save_model_dir)
- main_indicator = eval_class.main_indicator
- best_model_dict = {main_indicator: 0}
- best_model_dict.update(pre_best_model_dict)
- train_stats = TrainingStats(log_smooth_window, ["lr"])
- model_average = False
- model.train()
- use_srn = config["Architecture"]["algorithm"] == "SRN"
- extra_input_models = [
- "SRN",
- "NRTR",
- "SAR",
- "SEED",
- "SVTR",
- "SVTR_LCNet",
- "SPIN",
- "VisionLAN",
- "RobustScanner",
- "RFL",
- "DRRG",
- "SATRN",
- "SVTR_HGNet",
- "ParseQ",
- "CPPD",
- ]
- extra_input = False
- if config["Architecture"]["algorithm"] == "Distillation":
- for key in config["Architecture"]["Models"]:
- extra_input = (
- extra_input
- or config["Architecture"]["Models"][key]["algorithm"]
- in extra_input_models
- )
- else:
- extra_input = config["Architecture"]["algorithm"] in extra_input_models
- try:
- model_type = config["Architecture"]["model_type"]
- except:
- model_type = None
- algorithm = config["Architecture"]["algorithm"]
- start_epoch = (
- best_model_dict["start_epoch"] if "start_epoch" in best_model_dict else 1
- )
- total_samples = 0
- train_reader_cost = 0.0
- train_batch_cost = 0.0
- reader_start = time.time()
- eta_meter = AverageMeter()
- max_iter = (
- len(train_dataloader) - 1
- if platform.system() == "Windows"
- else len(train_dataloader)
- )
- for epoch in range(start_epoch, epoch_num + 1):
- if train_dataloader.dataset.need_reset:
- train_dataloader = build_dataloader(
- config, "Train", device, logger, seed=epoch
- )
- max_iter = (
- len(train_dataloader) - 1
- if platform.system() == "Windows"
- else len(train_dataloader)
- )
- for idx, batch in enumerate(train_dataloader):
- model.train()
- profiler.add_profiler_step(profiler_options)
- train_reader_cost += time.time() - reader_start
- if idx >= max_iter:
- break
- lr = optimizer.get_lr()
- images = batch[0]
- if use_srn:
- model_average = True
- # use amp
- if scaler:
- with paddle.amp.auto_cast(
- level=amp_level,
- custom_black_list=amp_custom_black_list,
- custom_white_list=amp_custom_white_list,
- dtype=amp_dtype,
- ):
- if model_type == "table" or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie"]:
- preds = model(batch)
- elif algorithm in ["CAN"]:
- preds = model(batch[:3])
- elif algorithm in [
- "LaTeXOCR",
- "UniMERNet",
- "PP-FormulaNet-S",
- "PP-FormulaNet-L",
- "PP-FormulaNet_plus-S",
- "PP-FormulaNet_plus-M",
- "PP-FormulaNet_plus-L",
- ]:
- preds = model(batch)
- else:
- preds = model(images)
- preds = to_float32(preds)
- loss = loss_class(preds, batch)
- avg_loss = loss["loss"]
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- else:
- if model_type == "table" or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie", "sr"]:
- preds = model(batch)
- elif algorithm in ["CAN"]:
- preds = model(batch[:3])
- elif algorithm in [
- "LaTeXOCR",
- "UniMERNet",
- "PP-FormulaNet-S",
- "PP-FormulaNet-L",
- "PP-FormulaNet_plus-S",
- "PP-FormulaNet_plus-M",
- "PP-FormulaNet_plus-L",
- ]:
- preds = model(batch)
- else:
- preds = model(images)
- loss = loss_class(preds, batch)
- avg_loss = loss["loss"]
- avg_loss.backward()
- optimizer.step()
- optimizer.clear_grad()
- if (
- cal_metric_during_train and epoch % calc_epoch_interval == 0
- ): # only rec and cls need
- batch = [item.numpy() for item in batch]
- if model_type in ["kie", "sr"]:
- eval_class(preds, batch)
- elif model_type in ["table"]:
- post_result = post_process_class(preds, batch)
- eval_class(post_result, batch)
- elif algorithm in ["CAN"]:
- model_type = "can"
- eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
- elif algorithm in ["LaTeXOCR"]:
- model_type = "latexocr"
- post_result = post_process_class(preds, batch[1], mode="train")
- eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
- elif algorithm in ["UniMERNet"]:
- model_type = "unimernet"
- post_result = post_process_class(preds[0], batch[1], mode="train")
- eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
- elif algorithm in [
- "PP-FormulaNet-S",
- "PP-FormulaNet-L",
- "PP-FormulaNet_plus-S",
- "PP-FormulaNet_plus-M",
- "PP-FormulaNet_plus-L",
- ]:
- model_type = "pp_formulanet"
- post_result = post_process_class(preds[0], batch[1], mode="train")
- eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
- else:
- if config["Loss"]["name"] in [
- "MultiLoss",
- "MultiLoss_v2",
- ]: # for multi head loss
- post_result = post_process_class(
- preds["ctc"], batch[1]
- ) # for CTC head out
- elif config["Loss"]["name"] in ["VLLoss"]:
- post_result = post_process_class(preds, batch[1], batch[-1])
- else:
- post_result = post_process_class(preds, batch[1])
- eval_class(post_result, batch)
- metric = eval_class.get_metric()
- train_stats.update(metric)
- train_batch_time = time.time() - reader_start
- train_batch_cost += train_batch_time
- eta_meter.update(train_batch_time)
- global_step += 1
- total_samples += len(images)
- if not isinstance(lr_scheduler, float):
- lr_scheduler.step()
- # logger and visualdl
- stats = {
- k: float(v) if v.shape == [] else v.numpy().mean()
- for k, v in loss.items()
- }
- stats["lr"] = lr
- train_stats.update(stats)
- if log_writer is not None and dist.get_rank() == 0:
- log_writer.log_metrics(
- metrics=train_stats.get(), prefix="TRAIN", step=global_step
- )
- if (global_step > 0 and global_step % print_batch_step == 0) or (
- idx >= len(train_dataloader) - 1
- ):
- logs = train_stats.log()
- eta_sec = (
- (epoch_num + 1 - epoch) * len(train_dataloader) - idx - 1
- ) * eta_meter.avg
- eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
- max_mem_reserved_str = ""
- max_mem_allocated_str = ""
- if paddle.device.is_compiled_with_cuda() and print_mem_info:
- max_mem_reserved_str = f", max_mem_reserved: {paddle.device.cuda.max_memory_reserved() // (1024 ** 2)} MB,"
- max_mem_allocated_str = f" max_mem_allocated: {paddle.device.cuda.max_memory_allocated() // (1024 ** 2)} MB"
- strs = (
- "epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: "
- "{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, "
- "ips: {:.5f} samples/s, eta: {}{}{}".format(
- epoch,
- epoch_num,
- global_step,
- logs,
- train_reader_cost / print_batch_step,
- train_batch_cost / print_batch_step,
- total_samples / print_batch_step,
- total_samples / train_batch_cost,
- eta_sec_format,
- max_mem_reserved_str,
- max_mem_allocated_str,
- )
- )
- logger.info(strs)
- total_samples = 0
- train_reader_cost = 0.0
- train_batch_cost = 0.0
- # eval
- if (
- global_step > start_eval_step
- and (global_step - start_eval_step) % eval_batch_step == 0
- and dist.get_rank() == 0
- ):
- if model_average:
- Model_Average = paddle.incubate.ModelAverage(
- 0.15,
- parameters=model.parameters(),
- min_average_window=10000,
- max_average_window=15625,
- )
- Model_Average.apply()
- cur_metric = eval(
- model,
- valid_dataloader,
- post_process_class,
- eval_class,
- model_type,
- extra_input=extra_input,
- scaler=scaler,
- amp_level=amp_level,
- amp_custom_black_list=amp_custom_black_list,
- amp_custom_white_list=amp_custom_white_list,
- amp_dtype=amp_dtype,
- )
- cur_metric_str = "cur metric, {}".format(
- ", ".join(["{}: {}".format(k, v) for k, v in cur_metric.items()])
- )
- logger.info(cur_metric_str)
- # logger metric
- if log_writer is not None:
- log_writer.log_metrics(
- metrics=cur_metric, prefix="EVAL", step=global_step
- )
- if cur_metric[main_indicator] >= best_model_dict[main_indicator]:
- best_model_dict.update(cur_metric)
- best_model_dict["best_epoch"] = epoch
- prefix = "best_accuracy"
- if uniform_output_enabled:
- export(
- config,
- model,
- os.path.join(save_model_dir, prefix, "inference"),
- )
- gc.collect()
- model_info = {"epoch": epoch, "metric": best_model_dict}
- else:
- model_info = None
- save_model(
- model,
- optimizer,
- (
- os.path.join(save_model_dir, prefix)
- if uniform_output_enabled
- else save_model_dir
- ),
- logger,
- config,
- is_best=True,
- prefix=prefix,
- save_model_info=model_info,
- best_model_dict=best_model_dict,
- epoch=epoch,
- global_step=global_step,
- )
- best_str = "best metric, {}".format(
- ", ".join(
- ["{}: {}".format(k, v) for k, v in best_model_dict.items()]
- )
- )
- logger.info(best_str)
- # logger best metric
- if log_writer is not None:
- log_writer.log_metrics(
- metrics={
- "best_{}".format(main_indicator): best_model_dict[
- main_indicator
- ]
- },
- prefix="EVAL",
- step=global_step,
- )
- log_writer.log_model(
- is_best=True, prefix="best_accuracy", metadata=best_model_dict
- )
- reader_start = time.time()
- if dist.get_rank() == 0:
- prefix = "latest"
- if uniform_output_enabled:
- export(config, model, os.path.join(save_model_dir, prefix, "inference"))
- gc.collect()
- model_info = {"epoch": epoch, "metric": best_model_dict}
- else:
- model_info = None
- save_model(
- model,
- optimizer,
- (
- os.path.join(save_model_dir, prefix)
- if uniform_output_enabled
- else save_model_dir
- ),
- logger,
- config,
- is_best=False,
- prefix=prefix,
- save_model_info=model_info,
- best_model_dict=best_model_dict,
- epoch=epoch,
- global_step=global_step,
- )
- if log_writer is not None:
- log_writer.log_model(is_best=False, prefix="latest")
- if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
- prefix = "iter_epoch_{}".format(epoch)
- if uniform_output_enabled:
- export(config, model, os.path.join(save_model_dir, prefix, "inference"))
- gc.collect()
- model_info = {"epoch": epoch, "metric": best_model_dict}
- else:
- model_info = None
- save_model(
- model,
- optimizer,
- (
- os.path.join(save_model_dir, prefix)
- if uniform_output_enabled
- else save_model_dir
- ),
- logger,
- config,
- is_best=False,
- prefix=prefix,
- save_model_info=model_info,
- best_model_dict=best_model_dict,
- epoch=epoch,
- global_step=global_step,
- done_flag=epoch == config["Global"]["epoch_num"],
- )
- if log_writer is not None:
- log_writer.log_model(
- is_best=False, prefix="iter_epoch_{}".format(epoch)
- )
- best_str = "best metric, {}".format(
- ", ".join(["{}: {}".format(k, v) for k, v in best_model_dict.items()])
- )
- logger.info(best_str)
- if dist.get_rank() == 0 and log_writer is not None:
- log_writer.close()
- return
- def eval(
- model,
- valid_dataloader,
- post_process_class,
- eval_class,
- model_type=None,
- extra_input=False,
- scaler=None,
- amp_level="O2",
- amp_custom_black_list=[],
- amp_custom_white_list=[],
- amp_dtype="float16",
- ):
- model.eval()
- with paddle.no_grad():
- total_frame = 0.0
- total_time = 0.0
- pbar = tqdm(
- total=len(valid_dataloader), desc="eval model:", position=0, leave=True
- )
- max_iter = (
- len(valid_dataloader) - 1
- if platform.system() == "Windows"
- else len(valid_dataloader)
- )
- sum_images = 0
- for idx, batch in enumerate(valid_dataloader):
- if idx >= max_iter:
- break
- images = batch[0]
- start = time.time()
- # use amp
- if scaler:
- with paddle.amp.auto_cast(
- level=amp_level,
- custom_black_list=amp_custom_black_list,
- dtype=amp_dtype,
- ):
- if model_type == "table" or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie"]:
- preds = model(batch)
- elif model_type in ["can"]:
- preds = model(batch[:3])
- elif model_type in ["latexocr"]:
- preds = model(batch)
- elif model_type in ["sr"]:
- preds = model(batch)
- sr_img = preds["sr_img"]
- lr_img = preds["lr_img"]
- else:
- preds = model(images)
- preds = to_float32(preds)
- else:
- if model_type == "table" or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie"]:
- preds = model(batch)
- elif model_type in ["can"]:
- preds = model(batch[:3])
- elif model_type in ["latexocr", "unimernet", "pp_formulanet"]:
- preds = model(batch)
- elif model_type in ["sr"]:
- preds = model(batch)
- sr_img = preds["sr_img"]
- lr_img = preds["lr_img"]
- else:
- preds = model(images)
- batch_numpy = []
- for item in batch:
- if isinstance(item, paddle.Tensor):
- batch_numpy.append(item.numpy())
- else:
- batch_numpy.append(item)
- # Obtain usable results from post-processing methods
- total_time += time.time() - start
- # Evaluate the results of the current batch
- if model_type in ["table", "kie"]:
- if post_process_class is None:
- eval_class(preds, batch_numpy)
- else:
- post_result = post_process_class(preds, batch_numpy)
- eval_class(post_result, batch_numpy)
- elif model_type in ["sr"]:
- eval_class(preds, batch_numpy)
- elif model_type in ["can"]:
- eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
- elif model_type in ["latexocr", "unimernet", "pp_formulanet"]:
- post_result = post_process_class(preds, batch[1], "eval")
- eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
- else:
- post_result = post_process_class(preds, batch_numpy[1])
- eval_class(post_result, batch_numpy)
- pbar.update(1)
- total_frame += len(images)
- sum_images += 1
- # Get final metric,eg. acc or hmean
- metric = eval_class.get_metric()
- pbar.close()
- model.train()
- # Avoid ZeroDivisionError
- if total_time > 0:
- metric["fps"] = total_frame / total_time
- else:
- metric["fps"] = 0 # or set to a fallback value
- return metric
- def update_center(char_center, post_result, preds):
- result, label = post_result
- feats, logits = preds
- logits = paddle.argmax(logits, axis=-1)
- feats = feats.numpy()
- logits = logits.numpy()
- for idx_sample in range(len(label)):
- if result[idx_sample][0] == label[idx_sample][0]:
- feat = feats[idx_sample]
- logit = logits[idx_sample]
- for idx_time in range(len(logit)):
- index = logit[idx_time]
- if index in char_center.keys():
- char_center[index][0] = (
- char_center[index][0] * char_center[index][1] + feat[idx_time]
- ) / (char_center[index][1] + 1)
- char_center[index][1] += 1
- else:
- char_center[index] = [feat[idx_time], 1]
- return char_center
- def get_center(model, eval_dataloader, post_process_class):
- pbar = tqdm(total=len(eval_dataloader), desc="get center:")
- max_iter = (
- len(eval_dataloader) - 1
- if platform.system() == "Windows"
- else len(eval_dataloader)
- )
- char_center = dict()
- for idx, batch in enumerate(eval_dataloader):
- if idx >= max_iter:
- break
- images = batch[0]
- start = time.time()
- preds = model(images)
- batch = [item.numpy() for item in batch]
- # Obtain usable results from post-processing methods
- post_result = post_process_class(preds, batch[1])
- # update char_center
- char_center = update_center(char_center, post_result, preds)
- pbar.update(1)
- pbar.close()
- for key in char_center.keys():
- char_center[key] = char_center[key][0]
- return char_center
- def preprocess(is_train=False):
- FLAGS = ArgsParser().parse_args()
- profiler_options = FLAGS.profiler_options
- config = load_config(FLAGS.config)
- config = merge_config(config, FLAGS.opt)
- profile_dic = {"profiler_options": FLAGS.profiler_options}
- config = merge_config(config, profile_dic)
- if is_train:
- # save_config
- save_model_dir = config["Global"]["save_model_dir"]
- os.makedirs(save_model_dir, exist_ok=True)
- with open(os.path.join(save_model_dir, "config.yml"), "w") as f:
- yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
- log_file = "{}/train.log".format(save_model_dir)
- else:
- log_file = None
- log_ranks = config["Global"].get("log_ranks", "0")
- logger = get_logger(log_file=log_file, log_ranks=log_ranks)
- # check if set use_gpu=True in paddlepaddle cpu version
- use_gpu = config["Global"].get("use_gpu", False)
- use_xpu = config["Global"].get("use_xpu", False)
- use_npu = config["Global"].get("use_npu", False)
- use_mlu = config["Global"].get("use_mlu", False)
- use_gcu = config["Global"].get("use_gcu", False)
- use_metax_gpu = config["Global"].get("use_metax_gpu", False)
- use_iluvatar_gpu = config["Global"].get("use_iluvatar_gpu", False)
- alg = config["Architecture"]["algorithm"]
- assert alg in [
- "EAST",
- "DB",
- "SAST",
- "Rosetta",
- "CRNN",
- "STARNet",
- "RARE",
- "SRN",
- "CLS",
- "PGNet",
- "Distillation",
- "NRTR",
- "TableAttn",
- "SAR",
- "PSE",
- "SEED",
- "SDMGR",
- "LayoutXLM",
- "LayoutLM",
- "LayoutLMv2",
- "PREN",
- "FCE",
- "SVTR",
- "SVTR_LCNet",
- "ViTSTR",
- "ABINet",
- "DB++",
- "TableMaster",
- "SPIN",
- "VisionLAN",
- "Gestalt",
- "SLANet",
- "RobustScanner",
- "CT",
- "RFL",
- "DRRG",
- "CAN",
- "Telescope",
- "SATRN",
- "SVTR_HGNet",
- "ParseQ",
- "CPPD",
- "LaTeXOCR",
- "UniMERNet",
- "SLANeXt",
- "PP-FormulaNet-S",
- "PP-FormulaNet-L",
- "PP-FormulaNet_plus-S",
- "PP-FormulaNet_plus-M",
- "PP-FormulaNet_plus-L",
- ]
- if use_xpu:
- device = "xpu:{0}".format(os.getenv("FLAGS_selected_xpus", 0))
- elif use_npu:
- device = "npu:{0}".format(os.getenv("FLAGS_selected_npus", 0))
- elif use_mlu:
- device = "mlu:{0}".format(os.getenv("FLAGS_selected_mlus", 0))
- elif use_gcu: # Use Enflame GCU(General Compute Unit)
- device = "gcu:{0}".format(os.getenv("FLAGS_selected_gcus", 0))
- elif use_metax_gpu: # Use Enflame GCU(General Compute Unit)
- device = "metax:{0}".format(os.getenv("FLAGS_selected_metaxs", 0))
- elif use_iluvatar_gpu:
- device = "iluvatar_gpu:{0}".format(dist.ParallelEnv().dev_id)
- else:
- device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
- check_device(
- use_gpu, use_xpu, use_npu, use_mlu, use_gcu, use_iluvatar_gpu, use_metax_gpu
- )
- device = paddle.set_device(device)
- config["Global"]["distributed"] = dist.get_world_size() != 1
- loggers = []
- if "use_visualdl" in config["Global"] and config["Global"]["use_visualdl"]:
- logger.warning(
- "You are using VisualDL, the VisualDL is deprecated and "
- "removed in ppocr!"
- )
- log_writer = None
- if (
- "use_wandb" in config["Global"] and config["Global"]["use_wandb"]
- ) or "wandb" in config:
- save_dir = config["Global"]["save_model_dir"]
- wandb_writer_path = "{}/wandb".format(save_dir)
- if "wandb" in config:
- wandb_params = config["wandb"]
- else:
- wandb_params = dict()
- wandb_params.update({"save_dir": save_dir})
- log_writer = WandbLogger(**wandb_params, config=config)
- loggers.append(log_writer)
- else:
- log_writer = None
- print_dict(config, logger)
- if loggers:
- log_writer = Loggers(loggers)
- else:
- log_writer = None
- logger.info("train with paddle {} and device {}".format(paddle.__version__, device))
- return config, device, logger, log_writer
|