| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- from typing import Dict
- import edit_distance as ed
- import numpy as np
- import torch
- import torch.nn.functional as F
- from modelscope.metainfo import Metrics
- from modelscope.utils.registry import default_group
- from .base import Metric
- from .builder import METRICS, MetricKeys
- def cal_distance(label_list, pre_list):
- y = ed.SequenceMatcher(a=label_list, b=pre_list)
- yy = y.get_opcodes()
- insert = 0
- delete = 0
- replace = 0
- for item in yy:
- if item[0] == 'insert':
- insert += item[-1] - item[-2]
- if item[0] == 'delete':
- delete += item[2] - item[1]
- if item[0] == 'replace':
- replace += item[-1] - item[-2]
- distance = insert + delete + replace
- return distance, (delete, replace, insert)
- @METRICS.register_module(
- group_key=default_group, module_name=Metrics.ocr_recognition_metric)
- class OCRRecognitionMetric(Metric):
- """The metric computation class for ocr recognition.
- """
- def __init__(self, *args, **kwargs):
- self.preds = []
- self.targets = []
- self.loss_sum = 0.
- self.nsample = 0
- self.iter_sum = 0
- def add(self, outputs: Dict, inputs: Dict):
- pred = outputs['preds']
- loss = outputs['loss']
- target = inputs['labels']
- self.preds.extend(pred)
- self.targets.extend(target)
- self.loss_sum += loss.data.cpu().numpy()
- self.nsample += len(pred)
- self.iter_sum += 1
- def evaluate(self):
- total_chars = 0
- total_distance = 0
- total_fullmatch = 0
- for (pred, target) in zip(self.preds, self.targets):
- distance, _ = cal_distance(target, pred)
- total_chars += len(target)
- total_distance += distance
- total_fullmatch += (target == pred)
- accuracy = float(total_fullmatch) / self.nsample
- AR = 1 - float(total_distance) / total_chars
- average_loss = self.loss_sum / self.iter_sum if self.iter_sum > 0 else 0
- return {
- MetricKeys.ACCURACY: accuracy,
- MetricKeys.AR: AR,
- MetricKeys.AVERAGE_LOSS: average_loss
- }
- def merge(self, other: 'OCRRecognitionMetric'):
- pass
- def __getstate__(self):
- pass
- def __setstate__(self, state):
- pass
|