| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289 |
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
- """
- Common modules
- """
- import json
- import math
- import platform
- import warnings
- from collections import OrderedDict, namedtuple
- from copy import copy
- from pathlib import Path
- import cv2
- import numpy as np
- import requests
- import torch
- import torch.nn as nn
- from PIL import Image
- from torch.cuda import amp
- from utils.yolov5_utils import make_divisible, initialize_weights, check_anchor_order, check_version, fuse_conv_and_bn
- def autopad(k, p=None): # kernel, padding
- # Pad to 'same'
- if p is None:
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
- return p
- class Conv(nn.Module):
- # Standard convolution
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
- super().__init__()
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
- self.bn = nn.BatchNorm2d(c2)
- if isinstance(act, bool):
- self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
- elif isinstance(act, str):
- if act == 'leaky':
- self.act = nn.LeakyReLU(0.1, inplace=True)
- elif act == 'relu':
- self.act = nn.ReLU(inplace=True)
- else:
- self.act = None
- def forward(self, x):
- return self.act(self.bn(self.conv(x)))
- def forward_fuse(self, x):
- return self.act(self.conv(x))
- class DWConv(Conv):
- # Depth-wise convolution class
- def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
- super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
- class TransformerLayer(nn.Module):
- # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
- def __init__(self, c, num_heads):
- super().__init__()
- self.q = nn.Linear(c, c, bias=False)
- self.k = nn.Linear(c, c, bias=False)
- self.v = nn.Linear(c, c, bias=False)
- self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
- self.fc1 = nn.Linear(c, c, bias=False)
- self.fc2 = nn.Linear(c, c, bias=False)
- def forward(self, x):
- x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
- x = self.fc2(self.fc1(x)) + x
- return x
- class TransformerBlock(nn.Module):
- # Vision Transformer https://arxiv.org/abs/2010.11929
- def __init__(self, c1, c2, num_heads, num_layers):
- super().__init__()
- self.conv = None
- if c1 != c2:
- self.conv = Conv(c1, c2)
- self.linear = nn.Linear(c2, c2) # learnable position embedding
- self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
- self.c2 = c2
- def forward(self, x):
- if self.conv is not None:
- x = self.conv(x)
- b, _, w, h = x.shape
- p = x.flatten(2).permute(2, 0, 1)
- return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
- class Bottleneck(nn.Module):
- # Standard bottleneck
- def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, act=True): # ch_in, ch_out, shortcut, groups, expansion
- super().__init__()
- c_ = int(c2 * e) # hidden channels
- self.cv1 = Conv(c1, c_, 1, 1, act=act)
- self.cv2 = Conv(c_, c2, 3, 1, g=g, act=act)
- self.add = shortcut and c1 == c2
- def forward(self, x):
- return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
- class BottleneckCSP(nn.Module):
- # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
- super().__init__()
- c_ = int(c2 * e) # hidden channels
- self.cv1 = Conv(c1, c_, 1, 1)
- self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
- self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
- self.cv4 = Conv(2 * c_, c2, 1, 1)
- self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
- self.act = nn.SiLU()
- self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
- def forward(self, x):
- y1 = self.cv3(self.m(self.cv1(x)))
- y2 = self.cv2(x)
- return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
- class C3(nn.Module):
- # CSP Bottleneck with 3 convolutions
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, act=True): # ch_in, ch_out, number, shortcut, groups, expansion
- super().__init__()
- c_ = int(c2 * e) # hidden channels
- self.cv1 = Conv(c1, c_, 1, 1, act=act)
- self.cv2 = Conv(c1, c_, 1, 1, act=act)
- self.cv3 = Conv(2 * c_, c2, 1, act=act) # act=FReLU(c2)
- self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0, act=act) for _ in range(n)))
- # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
- def forward(self, x):
- return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
- class C3TR(C3):
- # C3 module with TransformerBlock()
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
- super().__init__(c1, c2, n, shortcut, g, e)
- c_ = int(c2 * e)
- self.m = TransformerBlock(c_, c_, 4, n)
- class C3SPP(C3):
- # C3 module with SPP()
- def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
- super().__init__(c1, c2, n, shortcut, g, e)
- c_ = int(c2 * e)
- self.m = SPP(c_, c_, k)
- class C3Ghost(C3):
- # C3 module with GhostBottleneck()
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
- super().__init__(c1, c2, n, shortcut, g, e)
- c_ = int(c2 * e) # hidden channels
- self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
- class SPP(nn.Module):
- # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
- def __init__(self, c1, c2, k=(5, 9, 13)):
- super().__init__()
- c_ = c1 // 2 # hidden channels
- self.cv1 = Conv(c1, c_, 1, 1)
- self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
- self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
- def forward(self, x):
- x = self.cv1(x)
- with warnings.catch_warnings():
- warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
- return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
- class SPPF(nn.Module):
- # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
- def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
- super().__init__()
- c_ = c1 // 2 # hidden channels
- self.cv1 = Conv(c1, c_, 1, 1)
- self.cv2 = Conv(c_ * 4, c2, 1, 1)
- self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
- def forward(self, x):
- x = self.cv1(x)
- with warnings.catch_warnings():
- warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
- y1 = self.m(x)
- y2 = self.m(y1)
- return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
- class Focus(nn.Module):
- # Focus wh information into c-space
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
- super().__init__()
- self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
- # self.contract = Contract(gain=2)
- def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
- return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
- # return self.conv(self.contract(x))
- class GhostConv(nn.Module):
- # Ghost Convolution https://github.com/huawei-noah/ghostnet
- def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
- super().__init__()
- c_ = c2 // 2 # hidden channels
- self.cv1 = Conv(c1, c_, k, s, None, g, act)
- self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
- def forward(self, x):
- y = self.cv1(x)
- return torch.cat([y, self.cv2(y)], 1)
- class GhostBottleneck(nn.Module):
- # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
- def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
- super().__init__()
- c_ = c2 // 2
- self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
- DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
- GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
- self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
- Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
- def forward(self, x):
- return self.conv(x) + self.shortcut(x)
- class Contract(nn.Module):
- # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
- def __init__(self, gain=2):
- super().__init__()
- self.gain = gain
- def forward(self, x):
- b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
- s = self.gain
- x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
- x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
- return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
- class Expand(nn.Module):
- # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
- def __init__(self, gain=2):
- super().__init__()
- self.gain = gain
- def forward(self, x):
- b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
- s = self.gain
- x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
- x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
- return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
- class Concat(nn.Module):
- # Concatenate a list of tensors along dimension
- def __init__(self, dimension=1):
- super().__init__()
- self.d = dimension
- def forward(self, x):
- return torch.cat(x, self.d)
- class Classify(nn.Module):
- # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
- super().__init__()
- self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
- self.flat = nn.Flatten()
- def forward(self, x):
- z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
- return self.flat(self.conv(z)) # flatten to x(b,c2)
|