model.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/8/23 21:57
  3. # @Author : zhoujun
  4. from addict import Dict
  5. from paddle import nn
  6. import paddle.nn.functional as F
  7. from models.backbone import build_backbone
  8. from models.neck import build_neck
  9. from models.head import build_head
  10. class Model(nn.Layer):
  11. def __init__(self, model_config: dict):
  12. """
  13. PANnet
  14. :param model_config: 模型配置
  15. """
  16. super().__init__()
  17. model_config = Dict(model_config)
  18. backbone_type = model_config.backbone.pop("type")
  19. neck_type = model_config.neck.pop("type")
  20. head_type = model_config.head.pop("type")
  21. self.backbone = build_backbone(backbone_type, **model_config.backbone)
  22. self.neck = build_neck(
  23. neck_type, in_channels=self.backbone.out_channels, **model_config.neck
  24. )
  25. self.head = build_head(
  26. head_type, in_channels=self.neck.out_channels, **model_config.head
  27. )
  28. self.name = f"{backbone_type}_{neck_type}_{head_type}"
  29. def forward(self, x):
  30. _, _, H, W = x.shape
  31. backbone_out = self.backbone(x)
  32. neck_out = self.neck(backbone_out)
  33. y = self.head(neck_out)
  34. y = F.interpolate(y, size=(H, W), mode="bilinear", align_corners=True)
  35. return y