create_attn.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """ Attention Factory
  2. Hacked together by / Copyright 2021 Ross Wightman
  3. """
  4. import torch
  5. from functools import partial
  6. from .bottleneck_attn import BottleneckAttn
  7. from .cbam import CbamModule, LightCbamModule
  8. from .eca import EcaModule, CecaModule
  9. from .gather_excite import GatherExcite
  10. from .global_context import GlobalContext
  11. from .halo_attn import HaloAttn
  12. from .lambda_layer import LambdaLayer
  13. from .non_local_attn import NonLocalAttn, BatNonLocalAttn
  14. from .selective_kernel import SelectiveKernel
  15. from .split_attn import SplitAttn
  16. from .squeeze_excite import SEModule, EffectiveSEModule
  17. def get_attn(attn_type):
  18. if isinstance(attn_type, torch.nn.Module):
  19. return attn_type
  20. module_cls = None
  21. if attn_type:
  22. if isinstance(attn_type, str):
  23. attn_type = attn_type.lower()
  24. # Lightweight attention modules (channel and/or coarse spatial).
  25. # Typically added to existing network architecture blocks in addition to existing convolutions.
  26. if attn_type == 'se':
  27. module_cls = SEModule
  28. elif attn_type == 'ese':
  29. module_cls = EffectiveSEModule
  30. elif attn_type == 'eca':
  31. module_cls = EcaModule
  32. elif attn_type == 'ecam':
  33. module_cls = partial(EcaModule, use_mlp=True)
  34. elif attn_type == 'ceca':
  35. module_cls = CecaModule
  36. elif attn_type == 'ge':
  37. module_cls = GatherExcite
  38. elif attn_type == 'gc':
  39. module_cls = GlobalContext
  40. elif attn_type == 'gca':
  41. module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
  42. elif attn_type == 'cbam':
  43. module_cls = CbamModule
  44. elif attn_type == 'lcbam':
  45. module_cls = LightCbamModule
  46. # Attention / attention-like modules w/ significant params
  47. # Typically replace some of the existing workhorse convs in a network architecture.
  48. # All of these accept a stride argument and can spatially downsample the input.
  49. elif attn_type == 'sk':
  50. module_cls = SelectiveKernel
  51. elif attn_type == 'splat':
  52. module_cls = SplitAttn
  53. # Self-attention / attention-like modules w/ significant compute and/or params
  54. # Typically replace some of the existing workhorse convs in a network architecture.
  55. # All of these accept a stride argument and can spatially downsample the input.
  56. elif attn_type == 'lambda':
  57. return LambdaLayer
  58. elif attn_type == 'bottleneck':
  59. return BottleneckAttn
  60. elif attn_type == 'halo':
  61. return HaloAttn
  62. elif attn_type == 'nl':
  63. module_cls = NonLocalAttn
  64. elif attn_type == 'bat':
  65. module_cls = BatNonLocalAttn
  66. # Woops!
  67. else:
  68. assert False, "Invalid attn module (%s)" % attn_type
  69. elif isinstance(attn_type, bool):
  70. if attn_type:
  71. module_cls = SEModule
  72. else:
  73. module_cls = attn_type
  74. return module_cls
  75. def create_attn(attn_type, channels, **kwargs):
  76. module_cls = get_attn(attn_type)
  77. if module_cls is not None:
  78. # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
  79. return module_cls(channels, **kwargs)
  80. return None