sdafnet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import random
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from modelscope.metainfo import Models
  7. from modelscope.models import MODELS
  8. from modelscope.utils.constant import ModelFile, Tasks
  9. def apply_offset(offset):
  10. sizes = list(offset.size()[2:])
  11. grid_list = torch.meshgrid(
  12. [torch.arange(size, device=offset.device) for size in sizes])
  13. grid_list = reversed(grid_list)
  14. # apply offset
  15. grid_list = [
  16. grid.float().unsqueeze(0) + offset[:, dim, ...]
  17. for dim, grid in enumerate(grid_list)
  18. ]
  19. # normalize
  20. grid_list = [
  21. grid / ((size - 1.0) / 2.0) - 1.0
  22. for grid, size in zip(grid_list, reversed(sizes))
  23. ]
  24. return torch.stack(grid_list, dim=-1)
  25. # backbone
  26. class ResBlock(nn.Module):
  27. def __init__(self, in_channels):
  28. super(ResBlock, self).__init__()
  29. self.block = nn.Sequential(
  30. nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
  31. nn.Conv2d(
  32. in_channels, in_channels, kernel_size=3,
  33. padding=1, bias=False), nn.BatchNorm2d(in_channels),
  34. nn.ReLU(inplace=True),
  35. nn.Conv2d(
  36. in_channels, in_channels, kernel_size=3, padding=1,
  37. bias=False))
  38. def forward(self, x):
  39. return self.block(x) + x
  40. class Downsample(nn.Module):
  41. def __init__(self, in_channels, out_channels):
  42. super(Downsample, self).__init__()
  43. self.block = nn.Sequential(
  44. nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
  45. nn.Conv2d(
  46. in_channels,
  47. out_channels,
  48. kernel_size=3,
  49. stride=2,
  50. padding=1,
  51. bias=False))
  52. def forward(self, x):
  53. return self.block(x)
  54. class FeatureEncoder(nn.Module):
  55. def __init__(self, in_channels, chns=[64, 128, 256, 256, 256]):
  56. # in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnostic representation
  57. super(FeatureEncoder, self).__init__()
  58. self.encoders = []
  59. for i, out_chns in enumerate(chns):
  60. if i == 0:
  61. encoder = nn.Sequential(
  62. Downsample(in_channels, out_chns), ResBlock(out_chns),
  63. ResBlock(out_chns))
  64. else:
  65. encoder = nn.Sequential(
  66. Downsample(chns[i - 1], out_chns), ResBlock(out_chns),
  67. ResBlock(out_chns))
  68. self.encoders.append(encoder)
  69. self.encoders = nn.ModuleList(self.encoders)
  70. def forward(self, x):
  71. encoder_features = []
  72. for encoder in self.encoders:
  73. x = encoder(x)
  74. encoder_features.append(x)
  75. return encoder_features
  76. class RefinePyramid(nn.Module):
  77. def __init__(self, chns=[64, 128, 256, 256, 256], fpn_dim=256):
  78. super(RefinePyramid, self).__init__()
  79. self.chns = chns
  80. # adaptive
  81. self.adaptive = []
  82. for in_chns in list(reversed(chns)):
  83. adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
  84. self.adaptive.append(adaptive_layer)
  85. self.adaptive = nn.ModuleList(self.adaptive)
  86. # output conv
  87. self.smooth = []
  88. for i in range(len(chns)):
  89. smooth_layer = nn.Conv2d(
  90. fpn_dim, fpn_dim, kernel_size=3, padding=1)
  91. self.smooth.append(smooth_layer)
  92. self.smooth = nn.ModuleList(self.smooth)
  93. def forward(self, x):
  94. conv_ftr_list = x
  95. feature_list = []
  96. last_feature = None
  97. for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
  98. # adaptive
  99. feature = self.adaptive[i](conv_ftr)
  100. # fuse
  101. if last_feature is not None:
  102. feature = feature + F.interpolate(
  103. last_feature, scale_factor=2, mode='nearest')
  104. # smooth
  105. feature = self.smooth[i](feature)
  106. last_feature = feature
  107. feature_list.append(feature)
  108. return tuple(reversed(feature_list))
  109. def DAWarp(feat, offsets, att_maps, sample_k, out_ch):
  110. att_maps = torch.repeat_interleave(att_maps, out_ch, 1)
  111. B, C, H, W = feat.size()
  112. multi_feat = torch.repeat_interleave(feat, sample_k, 0)
  113. multi_warp_feat = F.grid_sample(
  114. multi_feat,
  115. offsets.detach().permute(0, 2, 3, 1),
  116. mode='bilinear',
  117. padding_mode='border')
  118. multi_att_warp_feat = multi_warp_feat.reshape(B, -1, H, W) * att_maps
  119. att_warp_feat = sum(torch.split(multi_att_warp_feat, out_ch, 1))
  120. return att_warp_feat
  121. class MFEBlock(nn.Module):
  122. def __init__(self,
  123. in_channels,
  124. out_channels,
  125. kernel_size=3,
  126. num_filters=[128, 64, 32]):
  127. super(MFEBlock, self).__init__()
  128. layers = []
  129. for i in range(len(num_filters)):
  130. if i == 0:
  131. layers.append(
  132. torch.nn.Conv2d(
  133. in_channels=in_channels,
  134. out_channels=num_filters[i],
  135. kernel_size=3,
  136. stride=1,
  137. padding=1))
  138. else:
  139. layers.append(
  140. torch.nn.Conv2d(
  141. in_channels=num_filters[i - 1],
  142. out_channels=num_filters[i],
  143. kernel_size=kernel_size,
  144. stride=1,
  145. padding=kernel_size // 2))
  146. layers.append(
  147. torch.nn.LeakyReLU(inplace=False, negative_slope=0.1))
  148. layers.append(
  149. torch.nn.Conv2d(
  150. in_channels=num_filters[-1],
  151. out_channels=out_channels,
  152. kernel_size=kernel_size,
  153. stride=1,
  154. padding=kernel_size // 2))
  155. self.layers = torch.nn.Sequential(*layers)
  156. def forward(self, input):
  157. return self.layers(input)
  158. class DAFlowNet(nn.Module):
  159. def __init__(self, num_pyramid, fpn_dim=256, head_nums=1):
  160. super(DAFlowNet, self).__init__()
  161. self.Self_MFEs = []
  162. self.Cross_MFEs = []
  163. self.Refine_MFEs = []
  164. self.k = head_nums
  165. self.out_ch = fpn_dim
  166. for i in range(num_pyramid):
  167. # self-MFE for model img 2k:flow 1k:att_map
  168. Self_MFE_layer = MFEBlock(
  169. in_channels=2 * fpn_dim,
  170. out_channels=self.k * 3,
  171. kernel_size=7)
  172. # cross-MFE for cloth img
  173. Cross_MFE_layer = MFEBlock(
  174. in_channels=2 * fpn_dim, out_channels=self.k * 3)
  175. # refine-MFE for cloth and model imgs
  176. Refine_MFE_layer = MFEBlock(
  177. in_channels=2 * fpn_dim, out_channels=self.k * 6)
  178. self.Self_MFEs.append(Self_MFE_layer)
  179. self.Cross_MFEs.append(Cross_MFE_layer)
  180. self.Refine_MFEs.append(Refine_MFE_layer)
  181. self.Self_MFEs = nn.ModuleList(self.Self_MFEs)
  182. self.Cross_MFEs = nn.ModuleList(self.Cross_MFEs)
  183. self.Refine_MFEs = nn.ModuleList(self.Refine_MFEs)
  184. self.lights_decoder = torch.nn.Sequential(
  185. torch.nn.Conv2d(64, out_channels=32, kernel_size=1, stride=1),
  186. torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
  187. torch.nn.Conv2d(
  188. in_channels=32,
  189. out_channels=3,
  190. kernel_size=3,
  191. stride=1,
  192. padding=1))
  193. self.lights_encoder = torch.nn.Sequential(
  194. torch.nn.Conv2d(
  195. 3, out_channels=32, kernel_size=3, stride=1, padding=1),
  196. torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
  197. torch.nn.Conv2d(
  198. in_channels=32, out_channels=64, kernel_size=1, stride=1))
  199. def forward(self,
  200. source_image,
  201. reference_image,
  202. source_feats,
  203. reference_feats,
  204. return_all=False,
  205. warp_feature=True,
  206. use_light_en_de=True):
  207. r"""
  208. Args:
  209. source_image: cloth rgb image for tryon
  210. reference_image: model rgb image for try on
  211. source_feats: cloth FPN features
  212. reference_feats: model and pose features
  213. return_all: bool return all intermediate try-on results in training phase
  214. warp_feature: use DAFlow for both features and images
  215. use_light_en_de: use shallow encoder and decoder to project the images from RGB to high dimensional space
  216. """
  217. # reference branch inputs model img using self-DAFlow
  218. last_multi_self_offsets = None
  219. # source branch inputs cloth img using cross-DAFlow
  220. last_multi_cross_offsets = None
  221. if return_all:
  222. results_all = []
  223. for i in range(len(source_feats)):
  224. feat_source = source_feats[len(source_feats) - 1 - i]
  225. feat_ref = reference_feats[len(reference_feats) - 1 - i]
  226. B, C, H, W = feat_source.size()
  227. # Pre-DAWarp for Pyramid feature
  228. if last_multi_cross_offsets is not None and warp_feature:
  229. att_source_feat = DAWarp(feat_source, last_multi_cross_offsets,
  230. cross_att_maps, self.k, self.out_ch)
  231. att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets,
  232. self_att_maps, self.k, self.out_ch)
  233. else:
  234. att_source_feat = feat_source
  235. att_reference_feat = feat_ref
  236. # Cross-MFE
  237. input_feat = torch.cat([att_source_feat, feat_ref], 1)
  238. offsets_att = self.Cross_MFEs[i](input_feat)
  239. cross_att_maps = F.softmax(
  240. offsets_att[:, self.k * 2:, :, :], dim=1)
  241. offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape(
  242. -1, 2, H, W))
  243. if last_multi_cross_offsets is not None:
  244. offsets = F.grid_sample(
  245. last_multi_cross_offsets,
  246. offsets,
  247. mode='bilinear',
  248. padding_mode='border')
  249. else:
  250. offsets = offsets.permute(0, 3, 1, 2)
  251. last_multi_cross_offsets = offsets
  252. att_source_feat = DAWarp(feat_source, last_multi_cross_offsets,
  253. cross_att_maps, self.k, self.out_ch)
  254. # Self-MFE
  255. input_feat = torch.cat([att_source_feat, att_reference_feat], 1)
  256. offsets_att = self.Self_MFEs[i](input_feat)
  257. self_att_maps = F.softmax(offsets_att[:, self.k * 2:, :, :], dim=1)
  258. offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape(
  259. -1, 2, H, W))
  260. if last_multi_self_offsets is not None:
  261. offsets = F.grid_sample(
  262. last_multi_self_offsets,
  263. offsets,
  264. mode='bilinear',
  265. padding_mode='border')
  266. else:
  267. offsets = offsets.permute(0, 3, 1, 2)
  268. last_multi_self_offsets = offsets
  269. att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets,
  270. self_att_maps, self.k, self.out_ch)
  271. # Refine-MFE
  272. input_feat = torch.cat([att_source_feat, att_reference_feat], 1)
  273. offsets_att = self.Refine_MFEs[i](input_feat)
  274. att_maps = F.softmax(offsets_att[:, self.k * 4:, :, :], dim=1)
  275. cross_offsets = apply_offset(
  276. offsets_att[:, :self.k * 2, :, :].reshape(-1, 2, H, W))
  277. self_offsets = apply_offset(
  278. offsets_att[:,
  279. self.k * 2:self.k * 4, :, :].reshape(-1, 2, H, W))
  280. last_multi_cross_offsets = F.grid_sample(
  281. last_multi_cross_offsets,
  282. cross_offsets,
  283. mode='bilinear',
  284. padding_mode='border')
  285. last_multi_self_offsets = F.grid_sample(
  286. last_multi_self_offsets,
  287. self_offsets,
  288. mode='bilinear',
  289. padding_mode='border')
  290. # Upsampling
  291. last_multi_cross_offsets = F.interpolate(
  292. last_multi_cross_offsets, scale_factor=2, mode='bilinear')
  293. last_multi_self_offsets = F.interpolate(
  294. last_multi_self_offsets, scale_factor=2, mode='bilinear')
  295. self_att_maps = F.interpolate(
  296. att_maps[:, :self.k, :, :], scale_factor=2, mode='bilinear')
  297. cross_att_maps = F.interpolate(
  298. att_maps[:, self.k:, :, :], scale_factor=2, mode='bilinear')
  299. # Post-DAWarp for source and reference images
  300. if return_all:
  301. cur_source_image = F.interpolate(
  302. source_image, (H * 2, W * 2), mode='bilinear')
  303. cur_reference_image = F.interpolate(
  304. reference_image, (H * 2, W * 2), mode='bilinear')
  305. if use_light_en_de:
  306. cur_source_image = self.lights_encoder(cur_source_image)
  307. cur_reference_image = self.lights_encoder(
  308. cur_reference_image)
  309. # the feat dim in light encoder is 64
  310. warp_att_source_image = DAWarp(cur_source_image,
  311. last_multi_cross_offsets,
  312. cross_att_maps, self.k, 64)
  313. warp_att_reference_image = DAWarp(cur_reference_image,
  314. last_multi_self_offsets,
  315. self_att_maps, self.k,
  316. 64)
  317. result_tryon = self.lights_decoder(
  318. warp_att_source_image + warp_att_reference_image)
  319. else:
  320. warp_att_source_image = DAWarp(cur_source_image,
  321. last_multi_cross_offsets,
  322. cross_att_maps, self.k, 3)
  323. warp_att_reference_image = DAWarp(cur_reference_image,
  324. last_multi_self_offsets,
  325. self_att_maps, self.k, 3)
  326. result_tryon = warp_att_source_image + warp_att_reference_image
  327. results_all.append(result_tryon)
  328. last_multi_self_offsets = F.interpolate(
  329. last_multi_self_offsets,
  330. reference_image.size()[2:],
  331. mode='bilinear')
  332. last_multi_cross_offsets = F.interpolate(
  333. last_multi_cross_offsets, source_image.size()[2:], mode='bilinear')
  334. self_att_maps = F.interpolate(
  335. self_att_maps, reference_image.size()[2:], mode='bilinear')
  336. cross_att_maps = F.interpolate(
  337. cross_att_maps, source_image.size()[2:], mode='bilinear')
  338. if use_light_en_de:
  339. source_image = self.lights_encoder(source_image)
  340. reference_image = self.lights_encoder(reference_image)
  341. warp_att_source_image = DAWarp(source_image,
  342. last_multi_cross_offsets,
  343. cross_att_maps, self.k, 64)
  344. warp_att_reference_image = DAWarp(reference_image,
  345. last_multi_self_offsets,
  346. self_att_maps, self.k, 64)
  347. result_tryon = self.lights_decoder(warp_att_source_image
  348. + warp_att_reference_image)
  349. else:
  350. warp_att_source_image = DAWarp(source_image,
  351. last_multi_cross_offsets,
  352. cross_att_maps, self.k, 3)
  353. warp_att_reference_image = DAWarp(reference_image,
  354. last_multi_self_offsets,
  355. self_att_maps, self.k, 3)
  356. result_tryon = warp_att_source_image + warp_att_reference_image
  357. if return_all:
  358. return result_tryon, return_all
  359. return result_tryon
  360. class SDAFNet_Tryon(nn.Module):
  361. def __init__(self, ref_in_channel, source_in_channel=3, head_nums=6):
  362. super(SDAFNet_Tryon, self).__init__()
  363. num_filters = [64, 128, 256, 256, 256]
  364. self.source_features = FeatureEncoder(source_in_channel, num_filters)
  365. self.reference_features = FeatureEncoder(ref_in_channel, num_filters)
  366. self.source_FPN = RefinePyramid(num_filters)
  367. self.reference_FPN = RefinePyramid(num_filters)
  368. self.dafnet = DAFlowNet(len(num_filters), head_nums=head_nums)
  369. def forward(self,
  370. ref_input,
  371. source_image,
  372. ref_image,
  373. use_light_en_de=True,
  374. return_all=False,
  375. warp_feature=True):
  376. reference_feats = self.reference_FPN(
  377. self.reference_features(ref_input))
  378. source_feats = self.source_FPN(self.source_features(source_image))
  379. result = self.dafnet(
  380. source_image,
  381. ref_image,
  382. source_feats,
  383. reference_feats,
  384. use_light_en_de=use_light_en_de,
  385. return_all=return_all,
  386. warp_feature=warp_feature)
  387. return result