loss_utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. import torch.nn as nn
  17. from torch.nn import BCEWithLogitsLoss, MSELoss
  18. from .loss_d_fine import DFineForObjectDetectionLoss
  19. from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
  20. from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
  21. from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
  22. from .loss_rt_detr import RTDetrForObjectDetectionLoss
  23. def fixed_cross_entropy(
  24. source: torch.Tensor,
  25. target: torch.Tensor,
  26. num_items_in_batch: Optional[torch.Tensor] = None,
  27. ignore_index: int = -100,
  28. **kwargs,
  29. ) -> torch.Tensor:
  30. reduction = "sum" if num_items_in_batch is not None else "mean"
  31. loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
  32. if reduction == "sum":
  33. # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
  34. if torch.is_tensor(num_items_in_batch):
  35. num_items_in_batch = num_items_in_batch.to(loss.device)
  36. loss = loss / num_items_in_batch
  37. return loss
  38. def ForCausalLMLoss(
  39. logits,
  40. labels,
  41. vocab_size: int,
  42. num_items_in_batch: Optional[torch.Tensor] = None,
  43. ignore_index: int = -100,
  44. shift_labels: Optional[torch.Tensor] = None,
  45. **kwargs,
  46. ) -> torch.Tensor:
  47. # Upcast to float if we need to compute the loss to avoid potential precision issues
  48. logits = logits.float()
  49. if shift_labels is None:
  50. # Shift so that tokens < n predict n
  51. labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
  52. shift_labels = labels[..., 1:].contiguous()
  53. # Flatten the tokens
  54. logits = logits.view(-1, vocab_size)
  55. shift_labels = shift_labels.view(-1)
  56. # Enable model parallelism
  57. shift_labels = shift_labels.to(logits.device)
  58. loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
  59. return loss
  60. def ForMaskedLMLoss(
  61. logits: torch.Tensor,
  62. labels: torch.Tensor,
  63. vocab_size: int,
  64. num_items_in_batch: Optional[torch.Tensor] = None,
  65. ignore_index: int = -100,
  66. **kwargs,
  67. ):
  68. # Upcast to float if we need to compute the loss to avoid potential precision issues
  69. logits = logits.float()
  70. # Flatten the tokens
  71. logits = logits.view(-1, vocab_size)
  72. labels = labels.view(-1)
  73. # Enable model parallelism
  74. labels = labels.to(logits.device)
  75. loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
  76. return loss
  77. def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
  78. num_labels = config.num_labels
  79. if config.problem_type is None:
  80. if num_labels == 1:
  81. config.problem_type = "regression"
  82. elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
  83. config.problem_type = "single_label_classification"
  84. else:
  85. config.problem_type = "multi_label_classification"
  86. labels = labels.to(pooled_logits.device)
  87. if config.problem_type == "regression":
  88. loss_fct = MSELoss()
  89. if num_labels == 1:
  90. return loss_fct(pooled_logits.squeeze(), labels.squeeze())
  91. else:
  92. return loss_fct(pooled_logits, labels)
  93. if config.problem_type == "single_label_classification":
  94. return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
  95. if config.problem_type == "multi_label_classification":
  96. loss_fct = BCEWithLogitsLoss()
  97. return loss_fct(pooled_logits, labels)
  98. raise RuntimeError(f"Invalid problem type: {config.problem_type}")
  99. def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
  100. total_loss = None
  101. if start_positions is not None and end_positions is not None:
  102. # If we are on multi-GPU, split add a dimension
  103. if len(start_positions.size()) > 1:
  104. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  105. if len(end_positions.size()) > 1:
  106. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  107. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  108. ignored_index = start_logits.size(1)
  109. start_positions = start_positions.clamp(0, ignored_index)
  110. end_positions = end_positions.clamp(0, ignored_index)
  111. start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs)
  112. end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs)
  113. total_loss = (start_loss + end_loss) / 2
  114. return total_loss
  115. def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
  116. # Upcast to float if we need to compute the loss to avoid potential precision issues
  117. logits = logits.view(-1, config.num_labels)
  118. labels = labels.view(-1).to(logits.device)
  119. logits = logits.float()
  120. # Flatten the tokens
  121. return fixed_cross_entropy(logits, labels, **kwargs)
  122. LOSS_MAPPING = {
  123. "ForCausalLM": ForCausalLMLoss,
  124. "ForMaskedLM": ForMaskedLMLoss,
  125. "ForQuestionAnswering": ForQuestionAnsweringLoss,
  126. "ForSequenceClassification": ForSequenceClassificationLoss,
  127. "ForImageClassification": ForSequenceClassificationLoss,
  128. "ForVideoClassification": ForSequenceClassificationLoss,
  129. "ForAudioClassification": ForSequenceClassificationLoss,
  130. "ForTokenClassification": ForTokenClassification,
  131. "ForSegmentation": ForSegmentationLoss,
  132. "ForObjectDetection": ForObjectDetectionLoss,
  133. "ForConditionalGeneration": ForCausalLMLoss,
  134. "DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
  135. "ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
  136. "DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
  137. "GroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
  138. "MMGroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
  139. "ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
  140. "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
  141. "RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
  142. "DFineForObjectDetection": DFineForObjectDetectionLoss,
  143. "CsmForConditionalGeneration": ForCausalLMLoss,
  144. }