configuration_upernet.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """UperNet model configuration"""
  16. from ...configuration_utils import PretrainedConfig
  17. from ...utils import logging
  18. from ...utils.backbone_utils import verify_backbone_config_arguments
  19. from ..auto.configuration_auto import CONFIG_MAPPING
  20. logger = logging.get_logger(__name__)
  21. class UperNetConfig(PretrainedConfig):
  22. r"""
  23. This is the configuration class to store the configuration of an [`UperNetForSemanticSegmentation`]. It is used to
  24. instantiate an UperNet model according to the specified arguments, defining the model architecture. Instantiating a
  25. configuration with the defaults will yield a similar configuration to that of the UperNet
  26. [openmmlab/upernet-convnext-tiny](https://huggingface.co/openmmlab/upernet-convnext-tiny) architecture.
  27. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  28. documentation from [`PretrainedConfig`] for more information.
  29. Args:
  30. backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
  31. The configuration of the backbone model.
  32. backbone (`str`, *optional*):
  33. Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
  34. will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
  35. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
  36. use_pretrained_backbone (`bool`, *optional*, `False`):
  37. Whether to use pretrained weights for the backbone.
  38. use_timm_backbone (`bool`, *optional*, `False`):
  39. Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
  40. library.
  41. backbone_kwargs (`dict`, *optional*):
  42. Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
  43. e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
  44. hidden_size (`int`, *optional*, defaults to 512):
  45. The number of hidden units in the convolutional layers.
  46. initializer_range (`float`, *optional*, defaults to 0.02):
  47. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  48. pool_scales (`tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
  49. Pooling scales used in Pooling Pyramid Module applied on the last feature map.
  50. use_auxiliary_head (`bool`, *optional*, defaults to `True`):
  51. Whether to use an auxiliary head during training.
  52. auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
  53. Weight of the cross-entropy loss of the auxiliary head.
  54. auxiliary_channels (`int`, *optional*, defaults to 256):
  55. Number of channels to use in the auxiliary head.
  56. auxiliary_num_convs (`int`, *optional*, defaults to 1):
  57. Number of convolutional layers to use in the auxiliary head.
  58. auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
  59. Whether to concatenate the output of the auxiliary head with the input before the classification layer.
  60. loss_ignore_index (`int`, *optional*, defaults to 255):
  61. The index that is ignored by the loss function.
  62. Examples:
  63. ```python
  64. >>> from transformers import UperNetConfig, UperNetForSemanticSegmentation
  65. >>> # Initializing a configuration
  66. >>> configuration = UperNetConfig()
  67. >>> # Initializing a model (with random weights) from the configuration
  68. >>> model = UperNetForSemanticSegmentation(configuration)
  69. >>> # Accessing the model configuration
  70. >>> configuration = model.config
  71. ```"""
  72. model_type = "upernet"
  73. def __init__(
  74. self,
  75. backbone_config=None,
  76. backbone=None,
  77. use_pretrained_backbone=False,
  78. use_timm_backbone=False,
  79. backbone_kwargs=None,
  80. hidden_size=512,
  81. initializer_range=0.02,
  82. pool_scales=[1, 2, 3, 6],
  83. use_auxiliary_head=True,
  84. auxiliary_loss_weight=0.4,
  85. auxiliary_in_channels=None,
  86. auxiliary_channels=256,
  87. auxiliary_num_convs=1,
  88. auxiliary_concat_input=False,
  89. loss_ignore_index=255,
  90. **kwargs,
  91. ):
  92. super().__init__(**kwargs)
  93. if backbone_config is None and backbone is None:
  94. logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
  95. backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"])
  96. elif isinstance(backbone_config, dict):
  97. backbone_model_type = backbone_config.get("model_type")
  98. config_class = CONFIG_MAPPING[backbone_model_type]
  99. backbone_config = config_class.from_dict(backbone_config)
  100. verify_backbone_config_arguments(
  101. use_timm_backbone=use_timm_backbone,
  102. use_pretrained_backbone=use_pretrained_backbone,
  103. backbone=backbone,
  104. backbone_config=backbone_config,
  105. backbone_kwargs=backbone_kwargs,
  106. )
  107. self.backbone_config = backbone_config
  108. self.backbone = backbone
  109. self.use_pretrained_backbone = use_pretrained_backbone
  110. self.use_timm_backbone = use_timm_backbone
  111. self.backbone_kwargs = backbone_kwargs
  112. self.hidden_size = hidden_size
  113. self.initializer_range = initializer_range
  114. self.pool_scales = pool_scales
  115. self.use_auxiliary_head = use_auxiliary_head
  116. self.auxiliary_loss_weight = auxiliary_loss_weight
  117. self.auxiliary_in_channels = auxiliary_in_channels
  118. self.auxiliary_channels = auxiliary_channels
  119. self.auxiliary_num_convs = auxiliary_num_convs
  120. self.auxiliary_concat_input = auxiliary_concat_input
  121. self.loss_ignore_index = loss_ignore_index
  122. @property
  123. def sub_configs(self):
  124. return (
  125. {"backbone_config": type(self.backbone_config)}
  126. if getattr(self, "backbone_config", None) is not None
  127. else {}
  128. )
  129. __all__ = ["UperNetConfig"]