| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207 |
- """PyTorch CspNet
- A PyTorch implementation of Cross Stage Partial Networks including:
- * CSPResNet50
- * CSPResNeXt50
- * CSPDarkNet53
- * and DarkNet53 for good measure
- Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
- Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks
- Hacked together by / Copyright 2020 Ross Wightman
- """
- from dataclasses import dataclass, asdict, replace
- from functools import partial
- from typing import Any, Dict, List, Optional, Tuple, Type, Union
- import torch
- import torch.nn as nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import ClassifierHead, ConvNormAct, DropPath, calculate_drop_path_rates, get_attn, create_act_layer, make_divisible
- from ._builder import build_model_with_cfg
- from ._manipulate import named_apply, MATCH_PREV_GROUP
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this
- @dataclass
- class CspStemCfg:
- out_chs: Union[int, Tuple[int, ...]] = 32
- stride: Union[int, Tuple[int, ...]] = 2
- kernel_size: int = 3
- padding: Union[int, str] = ''
- pool: Optional[str] = ''
- def _pad_arg(x, n):
- # pads an argument tuple to specified n by padding with last value
- if not isinstance(x, (tuple, list)):
- x = (x,)
- curr_n = len(x)
- pad_n = n - curr_n
- if pad_n <= 0:
- return x[:n]
- return tuple(x + (x[-1],) * pad_n)
- @dataclass
- class CspStagesCfg:
- depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
- out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
- stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
- groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
- block_ratio: Union[float, Tuple[float, ...]] = 1.0
- bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
- avg_down: Union[bool, Tuple[bool, ...]] = False
- attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
- attn_kwargs: Optional[Union[Dict, Tuple[Dict]]] = None
- stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
- block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
- # cross-stage only
- expand_ratio: Union[float, Tuple[float, ...]] = 1.0
- cross_linear: Union[bool, Tuple[bool, ...]] = False
- down_growth: Union[bool, Tuple[bool, ...]] = False
- def __post_init__(self):
- n = len(self.depth)
- assert len(self.out_chs) == n
- self.stride = _pad_arg(self.stride, n)
- self.groups = _pad_arg(self.groups, n)
- self.block_ratio = _pad_arg(self.block_ratio, n)
- self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
- self.avg_down = _pad_arg(self.avg_down, n)
- self.attn_layer = _pad_arg(self.attn_layer, n)
- self.attn_kwargs = _pad_arg(self.attn_kwargs, n)
- self.stage_type = _pad_arg(self.stage_type, n)
- self.block_type = _pad_arg(self.block_type, n)
- self.expand_ratio = _pad_arg(self.expand_ratio, n)
- self.cross_linear = _pad_arg(self.cross_linear, n)
- self.down_growth = _pad_arg(self.down_growth, n)
- @dataclass
- class CspModelCfg:
- stem: CspStemCfg
- stages: CspStagesCfg
- zero_init_last: bool = True # zero init last weight (usually bn) in residual path
- act_layer: str = 'leaky_relu'
- norm_layer: str = 'batchnorm'
- aa_layer: Optional[str] = None # FIXME support string factory for this
- def _cs3_cfg(
- width_multiplier=1.0,
- depth_multiplier=1.0,
- avg_down=False,
- act_layer='silu',
- focus=False,
- attn_layer=None,
- attn_kwargs=None,
- bottle_ratio=1.0,
- block_type='dark',
- ):
- if focus:
- stem_cfg = CspStemCfg(
- out_chs=make_divisible(64 * width_multiplier),
- kernel_size=6, stride=2, padding=2, pool='')
- else:
- stem_cfg = CspStemCfg(
- out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
- kernel_size=3, stride=2, pool='')
- return CspModelCfg(
- stem=stem_cfg,
- stages=CspStagesCfg(
- out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
- depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
- stride=2,
- bottle_ratio=bottle_ratio,
- block_ratio=0.5,
- avg_down=avg_down,
- attn_layer=attn_layer,
- attn_kwargs=attn_kwargs,
- stage_type='cs3',
- block_type=block_type,
- ),
- act_layer=act_layer,
- )
- class BottleneckBlock(nn.Module):
- """ ResNe(X)t Bottleneck Block
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- dilation: int = 1,
- bottle_ratio: float = 0.25,
- groups: int = 1,
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- attn_last: bool = False,
- attn_layer: Optional[Type[nn.Module]] = None,
- drop_block: Optional[Type[nn.Module]] = None,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- mid_chs = int(round(out_chs * bottle_ratio))
- ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
- attn_last = attn_layer is not None and attn_last
- attn_first = attn_layer is not None and not attn_last
- self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs, **dd)
- self.conv2 = ConvNormAct(
- mid_chs,
- mid_chs,
- kernel_size=3,
- dilation=dilation,
- groups=groups,
- drop_layer=drop_block,
- **ckwargs,
- **dd,
- )
- self.attn2 = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_first else nn.Identity()
- self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs, **dd)
- self.attn3 = attn_layer(out_chs, act_layer=act_layer, **dd) if attn_last else nn.Identity()
- self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
- self.act3 = create_act_layer(act_layer)
- def zero_init_last(self):
- nn.init.zeros_(self.conv3.bn.weight)
- def forward(self, x):
- shortcut = x
- x = self.conv1(x)
- x = self.conv2(x)
- x = self.attn2(x)
- x = self.conv3(x)
- x = self.attn3(x)
- x = self.drop_path(x) + shortcut
- # FIXME partial shortcut needed if first block handled as per original, not used for my current impl
- #x[:, :shortcut.size(1)] += shortcut
- x = self.act3(x)
- return x
- class DarkBlock(nn.Module):
- """ DarkNet Block
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- dilation: int = 1,
- bottle_ratio: float = 0.5,
- groups: int = 1,
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- attn_layer: Optional[Type[nn.Module]] = None,
- drop_block: Optional[Type[nn.Module]] = None,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- mid_chs = int(round(out_chs * bottle_ratio))
- ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
- self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs, **dd)
- self.attn = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_layer is not None else nn.Identity()
- self.conv2 = ConvNormAct(
- mid_chs,
- out_chs,
- kernel_size=3,
- dilation=dilation,
- groups=groups,
- drop_layer=drop_block,
- **ckwargs,
- **dd,
- )
- self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
- def zero_init_last(self):
- nn.init.zeros_(self.conv2.bn.weight)
- def forward(self, x):
- shortcut = x
- x = self.conv1(x)
- x = self.attn(x)
- x = self.conv2(x)
- x = self.drop_path(x) + shortcut
- return x
- class EdgeBlock(nn.Module):
- """ EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output)
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- dilation: int = 1,
- bottle_ratio: float = 0.5,
- groups: int = 1,
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- attn_layer: Optional[Type[nn.Module]] = None,
- drop_block: Optional[Type[nn.Module]] = None,
- drop_path: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- mid_chs = int(round(out_chs * bottle_ratio))
- ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
- self.conv1 = ConvNormAct(
- in_chs,
- mid_chs,
- kernel_size=3,
- dilation=dilation,
- groups=groups,
- drop_layer=drop_block,
- **ckwargs,
- **dd,
- )
- self.attn = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_layer is not None else nn.Identity()
- self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs, **dd)
- self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
- def zero_init_last(self):
- nn.init.zeros_(self.conv2.bn.weight)
- def forward(self, x):
- shortcut = x
- x = self.conv1(x)
- x = self.attn(x)
- x = self.conv2(x)
- x = self.drop_path(x) + shortcut
- return x
- class CrossStage(nn.Module):
- """Cross Stage."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int,
- dilation: int,
- depth: int,
- block_ratio: float = 1.,
- bottle_ratio: float = 1.,
- expand_ratio: float = 1.,
- groups: int = 1,
- first_dilation: Optional[int] = None,
- avg_down: bool = False,
- down_growth: bool = False,
- cross_linear: bool = False,
- block_dpr: Optional[List[float]] = None,
- block_fn: Type[nn.Module] = BottleneckBlock,
- device=None,
- dtype=None,
- **block_kwargs,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- first_dilation = first_dilation or dilation
- down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
- self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
- block_out_chs = int(round(out_chs * block_ratio))
- conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
- aa_layer = block_kwargs.pop('aa_layer', None)
- if stride != 1 or first_dilation != dilation:
- if avg_down:
- self.conv_down = nn.Sequential(
- nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
- ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
- )
- else:
- self.conv_down = ConvNormAct(
- in_chs,
- down_chs,
- kernel_size=3,
- stride=stride,
- dilation=first_dilation,
- groups=groups,
- aa_layer=aa_layer,
- **conv_kwargs,
- **dd,
- )
- prev_chs = down_chs
- else:
- self.conv_down = nn.Identity()
- prev_chs = in_chs
- # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
- # there is also special case for the first stage for some of the model that results in uneven split
- # across the two paths. I did it this way for simplicity for now.
- self.conv_exp = ConvNormAct(
- prev_chs,
- exp_chs,
- kernel_size=1,
- apply_act=not cross_linear,
- **conv_kwargs,
- **dd,
- )
- prev_chs = exp_chs // 2 # output of conv_exp is always split in two
- self.blocks = nn.Sequential()
- for i in range(depth):
- self.blocks.add_module(str(i), block_fn(
- in_chs=prev_chs,
- out_chs=block_out_chs,
- dilation=dilation,
- bottle_ratio=bottle_ratio,
- groups=groups,
- drop_path=block_dpr[i] if block_dpr is not None else 0.,
- **block_kwargs,
- **dd,
- ))
- prev_chs = block_out_chs
- # transition convs
- self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs, **dd)
- self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs, **dd)
- def forward(self, x):
- x = self.conv_down(x)
- x = self.conv_exp(x)
- xs, xb = x.split(self.expand_chs // 2, dim=1)
- xb = self.blocks(xb)
- xb = self.conv_transition_b(xb).contiguous()
- out = self.conv_transition(torch.cat([xs, xb], dim=1))
- return out
- class CrossStage3(nn.Module):
- """Cross Stage 3.
- Similar to CrossStage, but with only one transition conv for the output.
- """
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int,
- dilation: int,
- depth: int,
- block_ratio: float = 1.,
- bottle_ratio: float = 1.,
- expand_ratio: float = 1.,
- groups: int = 1,
- first_dilation: Optional[int] = None,
- avg_down: bool = False,
- down_growth: bool = False,
- cross_linear: bool = False,
- block_dpr: Optional[List[float]] = None,
- block_fn: Type[nn.Module] = BottleneckBlock,
- device=None,
- dtype=None,
- **block_kwargs,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- first_dilation = first_dilation or dilation
- down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
- self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
- block_out_chs = int(round(out_chs * block_ratio))
- conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
- aa_layer = block_kwargs.pop('aa_layer', None)
- if stride != 1 or first_dilation != dilation:
- if avg_down:
- self.conv_down = nn.Sequential(
- nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
- ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
- )
- else:
- self.conv_down = ConvNormAct(
- in_chs,
- down_chs,
- kernel_size=3,
- stride=stride,
- dilation=first_dilation,
- groups=groups,
- aa_layer=aa_layer,
- **conv_kwargs,
- **dd,
- )
- prev_chs = down_chs
- else:
- self.conv_down = None
- prev_chs = in_chs
- # expansion conv
- self.conv_exp = ConvNormAct(
- prev_chs,
- exp_chs,
- kernel_size=1,
- apply_act=not cross_linear,
- **conv_kwargs,
- **dd,
- )
- prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage
- self.blocks = nn.Sequential()
- for i in range(depth):
- self.blocks.add_module(str(i), block_fn(
- in_chs=prev_chs,
- out_chs=block_out_chs,
- dilation=dilation,
- bottle_ratio=bottle_ratio,
- groups=groups,
- drop_path=block_dpr[i] if block_dpr is not None else 0.,
- **block_kwargs,
- **dd,
- ))
- prev_chs = block_out_chs
- # transition convs
- self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs, **dd)
- def forward(self, x):
- x = self.conv_down(x)
- x = self.conv_exp(x)
- x1, x2 = x.split(self.expand_chs // 2, dim=1)
- x1 = self.blocks(x1)
- out = self.conv_transition(torch.cat([x1, x2], dim=1))
- return out
- class DarkStage(nn.Module):
- """DarkNet stage."""
- def __init__(
- self,
- in_chs: int,
- out_chs: int,
- stride: int,
- dilation: int,
- depth: int,
- block_ratio: float = 1.,
- bottle_ratio: float = 1.,
- groups: int = 1,
- first_dilation: Optional[int] = None,
- avg_down: bool = False,
- block_fn: Type[nn.Module] = BottleneckBlock,
- block_dpr: Optional[List[float]] = None,
- device=None,
- dtype=None,
- **block_kwargs,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- first_dilation = first_dilation or dilation
- conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
- aa_layer = block_kwargs.pop('aa_layer', None)
- if avg_down:
- self.conv_down = nn.Sequential(
- nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
- ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd)
- )
- else:
- self.conv_down = ConvNormAct(
- in_chs,
- out_chs,
- kernel_size=3,
- stride=stride,
- dilation=first_dilation,
- groups=groups,
- aa_layer=aa_layer,
- **conv_kwargs,
- **dd,
- )
- prev_chs = out_chs
- block_out_chs = int(round(out_chs * block_ratio))
- self.blocks = nn.Sequential()
- for i in range(depth):
- self.blocks.add_module(str(i), block_fn(
- in_chs=prev_chs,
- out_chs=block_out_chs,
- dilation=dilation,
- bottle_ratio=bottle_ratio,
- groups=groups,
- drop_path=block_dpr[i] if block_dpr is not None else 0.,
- **block_kwargs,
- **dd,
- ))
- prev_chs = block_out_chs
- def forward(self, x):
- x = self.conv_down(x)
- x = self.blocks(x)
- return x
- def create_csp_stem(
- in_chans: int = 3,
- out_chs: int = 32,
- kernel_size: int = 3,
- stride: int = 2,
- pool: str = '',
- padding: str = '',
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Type[nn.Module] = nn.BatchNorm2d,
- aa_layer: Optional[Type[nn.Module]] = None,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- stem = nn.Sequential()
- feature_info = []
- if not isinstance(out_chs, (tuple, list)):
- out_chs = [out_chs]
- stem_depth = len(out_chs)
- assert stem_depth
- assert stride in (1, 2, 4)
- prev_feat = None
- prev_chs = in_chans
- last_idx = stem_depth - 1
- stem_stride = 1
- for i, chs in enumerate(out_chs):
- conv_name = f'conv{i + 1}'
- conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
- if conv_stride > 1 and prev_feat is not None:
- feature_info.append(prev_feat)
- stem.add_module(conv_name, ConvNormAct(
- prev_chs, chs, kernel_size,
- stride=conv_stride,
- padding=padding if i == 0 else '',
- act_layer=act_layer,
- norm_layer=norm_layer,
- **dd,
- ))
- stem_stride *= conv_stride
- prev_chs = chs
- prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
- if pool:
- assert stride > 2
- if prev_feat is not None:
- feature_info.append(prev_feat)
- if aa_layer is not None:
- stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
- stem.add_module('aa', aa_layer(channels=prev_chs, stride=2, **dd))
- pool_name = 'aa'
- else:
- stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
- pool_name = 'pool'
- stem_stride *= 2
- prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
- feature_info.append(prev_feat)
- return stem, feature_info
- def _get_stage_fn(stage_args):
- stage_type = stage_args.pop('stage_type')
- assert stage_type in ('dark', 'csp', 'cs3')
- if stage_type == 'dark':
- stage_args.pop('expand_ratio', None)
- stage_args.pop('cross_linear', None)
- stage_args.pop('down_growth', None)
- stage_fn = DarkStage
- elif stage_type == 'csp':
- stage_fn = CrossStage
- else:
- stage_fn = CrossStage3
- return stage_fn, stage_args
- def _get_block_fn(stage_args):
- block_type = stage_args.pop('block_type')
- assert block_type in ('dark', 'edge', 'bottle')
- if block_type == 'dark':
- return DarkBlock, stage_args
- elif block_type == 'edge':
- return EdgeBlock, stage_args
- else:
- return BottleneckBlock, stage_args
- def _get_attn_fn(stage_args):
- attn_layer = stage_args.pop('attn_layer')
- attn_kwargs = stage_args.pop('attn_kwargs', None) or {}
- if attn_layer is not None:
- attn_layer = get_attn(attn_layer)
- if attn_kwargs:
- attn_layer = partial(attn_layer, **attn_kwargs)
- return attn_layer, stage_args
- def create_csp_stages(
- cfg: CspModelCfg,
- drop_path_rate: float,
- output_stride: int,
- stem_feat: Dict[str, Any],
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- cfg_dict = asdict(cfg.stages)
- num_stages = len(cfg.stages.depth)
- cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
- calculate_drop_path_rates(drop_path_rate, cfg.stages.depth, stagewise=True)
- stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
- block_kwargs = dict(
- act_layer=cfg.act_layer,
- norm_layer=cfg.norm_layer,
- )
- dilation = 1
- net_stride = stem_feat['reduction']
- prev_chs = stem_feat['num_chs']
- prev_feat = stem_feat
- feature_info = []
- stages = []
- for stage_idx, stage_args in enumerate(stage_args):
- stage_fn, stage_args = _get_stage_fn(stage_args)
- block_fn, stage_args = _get_block_fn(stage_args)
- attn_fn, stage_args = _get_attn_fn(stage_args)
- stride = stage_args.pop('stride')
- if stride != 1 and prev_feat:
- feature_info.append(prev_feat)
- if net_stride >= output_stride and stride > 1:
- dilation *= stride
- stride = 1
- net_stride *= stride
- first_dilation = 1 if dilation in (1, 2) else 2
- stages += [stage_fn(
- prev_chs,
- **stage_args,
- stride=stride,
- first_dilation=first_dilation,
- dilation=dilation,
- block_fn=block_fn,
- aa_layer=cfg.aa_layer,
- attn_layer=attn_fn, # will be passed through stage as block_kwargs
- **block_kwargs,
- **dd,
- )]
- prev_chs = stage_args['out_chs']
- prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
- feature_info.append(prev_feat)
- return nn.Sequential(*stages), feature_info
- class CspNet(nn.Module):
- """Cross Stage Partial base model.
- Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
- Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
- NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the
- darknet impl. I did it this way for simplicity and less special cases.
- """
- def __init__(
- self,
- cfg: CspModelCfg,
- in_chans: int = 3,
- num_classes: int = 1000,
- output_stride: int = 32,
- global_pool: str = 'avg',
- drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- zero_init_last: bool = True,
- device=None,
- dtype=None,
- **kwargs,
- ):
- """
- Args:
- cfg (CspModelCfg): Model architecture configuration
- in_chans (int): Number of input channels (default: 3)
- num_classes (int): Number of classifier classes (default: 1000)
- output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
- global_pool (str): Global pooling type (default: 'avg')
- drop_rate (float): Dropout rate (default: 0.)
- drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
- zero_init_last (bool): Zero-init last weight of residual path
- kwargs (dict): Extra kwargs overlayed onto cfg
- """
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.num_classes = num_classes
- self.drop_rate = drop_rate
- assert output_stride in (8, 16, 32)
- cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
- layer_args = dict(
- act_layer=cfg.act_layer,
- norm_layer=cfg.norm_layer,
- aa_layer=cfg.aa_layer
- )
- self.feature_info = []
- # Construct the stem
- self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args, **dd)
- self.feature_info.extend(stem_feat_info[:-1])
- # Construct the stages
- self.stages, stage_feat_info = create_csp_stages(
- cfg,
- drop_path_rate=drop_path_rate,
- output_stride=output_stride,
- stem_feat=stem_feat_info[-1],
- **dd,
- )
- prev_chs = stage_feat_info[-1]['num_chs']
- self.feature_info.extend(stage_feat_info)
- # Construct the head
- self.num_features = self.head_hidden_size = prev_chs
- self.head = ClassifierHead(
- in_features=prev_chs,
- num_classes=num_classes,
- pool_type=global_pool,
- drop_rate=drop_rate,
- **dd,
- )
- named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(
- stem=r'^stem',
- blocks=r'^stages\.(\d+)' if coarse else [
- (r'^stages\.(\d+)\.blocks\.(\d+)', None),
- (r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
- (r'^stages\.(\d+)', (0,)),
- ]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- assert not enable, 'gradient checkpointing not supported'
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- self.head.reset(num_classes, global_pool)
- def forward_features(self, x):
- x = self.stem(x)
- x = self.stages(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def _init_weights(module, name, zero_init_last=False):
- if isinstance(module, nn.Conv2d):
- nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Linear):
- nn.init.normal_(module.weight, mean=0.0, std=0.01)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif zero_init_last and hasattr(module, 'zero_init_last'):
- module.zero_init_last()
- model_cfgs = dict(
- cspresnet50=CspModelCfg(
- stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
- stages=CspStagesCfg(
- depth=(3, 3, 5, 2),
- out_chs=(128, 256, 512, 1024),
- stride=(1, 2),
- expand_ratio=2.,
- bottle_ratio=0.5,
- cross_linear=True,
- ),
- ),
- cspresnet50d=CspModelCfg(
- stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
- stages=CspStagesCfg(
- depth=(3, 3, 5, 2),
- out_chs=(128, 256, 512, 1024),
- stride=(1,) + (2,),
- expand_ratio=2.,
- bottle_ratio=0.5,
- block_ratio=1.,
- cross_linear=True,
- ),
- ),
- cspresnet50w=CspModelCfg(
- stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
- stages=CspStagesCfg(
- depth=(3, 3, 5, 2),
- out_chs=(256, 512, 1024, 2048),
- stride=(1,) + (2,),
- expand_ratio=1.,
- bottle_ratio=0.25,
- block_ratio=0.5,
- cross_linear=True,
- ),
- ),
- cspresnext50=CspModelCfg(
- stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
- stages=CspStagesCfg(
- depth=(3, 3, 5, 2),
- out_chs=(256, 512, 1024, 2048),
- stride=(1,) + (2,),
- groups=32,
- expand_ratio=1.,
- bottle_ratio=1.,
- block_ratio=0.5,
- cross_linear=True,
- ),
- ),
- cspdarknet53=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 2, 8, 8, 4),
- out_chs=(64, 128, 256, 512, 1024),
- stride=2,
- expand_ratio=(2.,) + (1.,),
- bottle_ratio=(0.5,) + (1.,),
- block_ratio=(1.,) + (0.5,),
- down_growth=True,
- block_type='dark',
- ),
- ),
- darknet17=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1,) * 5,
- out_chs=(64, 128, 256, 512, 1024),
- stride=(2,),
- bottle_ratio=(0.5,),
- block_ratio=(1.,),
- stage_type='dark',
- block_type='dark',
- ),
- ),
- darknet21=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 1, 1, 2, 2),
- out_chs=(64, 128, 256, 512, 1024),
- stride=(2,),
- bottle_ratio=(0.5,),
- block_ratio=(1.,),
- stage_type='dark',
- block_type='dark',
- ),
- ),
- sedarknet21=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 1, 1, 2, 2),
- out_chs=(64, 128, 256, 512, 1024),
- stride=2,
- bottle_ratio=0.5,
- block_ratio=1.,
- attn_layer='se',
- stage_type='dark',
- block_type='dark',
- ),
- ),
- darknet53=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 2, 8, 8, 4),
- out_chs=(64, 128, 256, 512, 1024),
- stride=2,
- bottle_ratio=0.5,
- block_ratio=1.,
- stage_type='dark',
- block_type='dark',
- ),
- ),
- darknetaa53=CspModelCfg(
- stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
- stages=CspStagesCfg(
- depth=(1, 2, 8, 8, 4),
- out_chs=(64, 128, 256, 512, 1024),
- stride=2,
- bottle_ratio=0.5,
- block_ratio=1.,
- avg_down=True,
- stage_type='dark',
- block_type='dark',
- ),
- ),
- cs3darknet_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5),
- cs3darknet_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67),
- cs3darknet_l=_cs3_cfg(),
- cs3darknet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33),
- cs3darknet_focus_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
- cs3darknet_focus_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
- cs3darknet_focus_l=_cs3_cfg(focus=True),
- cs3darknet_focus_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
- cs3sedarknet_l=_cs3_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
- cs3sedarknet_x=_cs3_cfg(attn_layer='se', width_multiplier=1.25, depth_multiplier=1.33),
- cs3sedarknet_xdw=CspModelCfg(
- stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
- stages=CspStagesCfg(
- depth=(3, 6, 12, 4),
- out_chs=(256, 512, 1024, 2048),
- stride=2,
- groups=(1, 1, 256, 512),
- bottle_ratio=0.5,
- block_ratio=0.5,
- attn_layer='se',
- ),
- act_layer='silu',
- ),
- cs3edgenet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge'),
- cs3se_edgenet_x=_cs3_cfg(
- width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge',
- attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
- )
- def _create_cspnet(variant, pretrained=False, **kwargs):
- if variant.startswith('darknet') or variant.startswith('cspdarknet'):
- # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
- default_out_indices = (0, 1, 2, 3, 4, 5)
- else:
- default_out_indices = (0, 1, 2, 3, 4)
- out_indices = kwargs.pop('out_indices', default_out_indices)
- return build_model_with_cfg(
- CspNet, variant, pretrained,
- model_cfg=model_cfgs[variant],
- feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
- **kwargs)
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
- 'crop_pct': 0.887, 'interpolation': 'bilinear',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'cspresnet50.ra_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
- 'cspresnet50d.untrained': _cfg(),
- 'cspresnet50w.untrained': _cfg(),
- 'cspresnext50.ra_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
- ),
- 'cspdarknet53.ra_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
- 'darknet17.untrained': _cfg(),
- 'darknet21.untrained': _cfg(),
- 'sedarknet21.untrained': _cfg(),
- 'darknet53.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'darknetaa53.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
- test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'cs3darknet_s.untrained': _cfg(interpolation='bicubic'),
- 'cs3darknet_m.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95,
- ),
- 'cs3darknet_l.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'cs3darknet_x.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
- interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'cs3darknet_focus_s.ra4_e3600_r256_in1k': _cfg(
- hf_hub_id='timm/',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- interpolation='bicubic', test_input_size=(3, 320, 320), test_crop_pct=1.0),
- 'cs3darknet_focus_m.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'cs3darknet_focus_l.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'cs3darknet_focus_x.untrained': _cfg(interpolation='bicubic'),
- 'cs3sedarknet_l.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
- 'cs3sedarknet_x.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'cs3sedarknet_xdw.untrained': _cfg(interpolation='bicubic'),
- 'cs3edgenet_x.c2_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3edgenet_x_c2-2e1610a9.pth',
- interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
- 'cs3se_edgenet_x.c2ns_in1k': _cfg(
- hf_hub_id='timm/',
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3se_edgenet_x_c2ns-76f8e3ac.pth',
- interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
- })
- @register_model
- def cspresnet50(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs)
- @register_model
- def cspresnet50d(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs)
- @register_model
- def cspresnet50w(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs)
- @register_model
- def cspresnext50(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
- @register_model
- def cspdarknet53(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
- @register_model
- def darknet17(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
- @register_model
- def darknet21(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
- @register_model
- def sedarknet21(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
- @register_model
- def darknet53(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
- @register_model
- def darknetaa53(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_s(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_m(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_l(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_focus_s(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_focus_m(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_focus_l(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
- @register_model
- def cs3darknet_focus_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
- @register_model
- def cs3sedarknet_l(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3sedarknet_l', pretrained=pretrained, **kwargs)
- @register_model
- def cs3sedarknet_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3sedarknet_x', pretrained=pretrained, **kwargs)
- @register_model
- def cs3sedarknet_xdw(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
- @register_model
- def cs3edgenet_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3edgenet_x', pretrained=pretrained, **kwargs)
- @register_model
- def cs3se_edgenet_x(pretrained=False, **kwargs) -> CspNet:
- return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)
|