rec_latexocr_loss.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/lucidrains/x-transformers/blob/main/x_transformers/autoregressive_wrapper.py
  17. """
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. import numpy as np
  22. class LaTeXOCRLoss(nn.Layer):
  23. """
  24. LaTeXOCR adopt CrossEntropyLoss for network training.
  25. """
  26. def __init__(self):
  27. super(LaTeXOCRLoss, self).__init__()
  28. self.ignore_index = -100
  29. self.cross = nn.CrossEntropyLoss(
  30. reduction="mean", ignore_index=self.ignore_index
  31. )
  32. def forward(self, preds, batch):
  33. word_probs = preds
  34. labels = batch[1][:, 1:]
  35. word_loss = self.cross(
  36. paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
  37. paddle.reshape(labels, [-1]),
  38. )
  39. loss = word_loss
  40. return {"loss": loss}