config.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import logging
  2. import os
  3. import os.path as osp
  4. from datetime import datetime
  5. import torch
  6. from easydict import EasyDict
  7. cfg = EasyDict(__name__='Config: VideoComposer')
  8. pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
  9. gpus_per_machine = torch.cuda.device_count()
  10. world_size = pmi_world_size * gpus_per_machine
  11. cfg.video_compositions = [
  12. 'text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image',
  13. 'single_sketch'
  14. ]
  15. # dataset
  16. cfg.root_dir = 'webvid10m/'
  17. cfg.alpha = 0.7
  18. cfg.misc_size = 384
  19. cfg.depth_std = 20.0
  20. cfg.depth_clamp = 10.0
  21. cfg.hist_sigma = 10.0
  22. cfg.use_image_dataset = False
  23. cfg.alpha_img = 0.7
  24. cfg.resolution = 256
  25. cfg.mean = [0.5, 0.5, 0.5]
  26. cfg.std = [0.5, 0.5, 0.5]
  27. # sketch
  28. cfg.sketch_mean = [0.485, 0.456, 0.406]
  29. cfg.sketch_std = [0.229, 0.224, 0.225]
  30. # dataloader
  31. cfg.max_words = 1000
  32. cfg.frame_lens = [
  33. 16,
  34. 16,
  35. 16,
  36. 16,
  37. ]
  38. cfg.feature_framerates = [
  39. 4,
  40. ]
  41. cfg.feature_framerate = 4
  42. cfg.batch_sizes = {
  43. str(1): 1,
  44. str(4): 1,
  45. str(8): 1,
  46. str(16): 1,
  47. }
  48. cfg.chunk_size = 64
  49. cfg.num_workers = 8
  50. cfg.prefetch_factor = 2
  51. cfg.seed = 8888
  52. # diffusion
  53. cfg.num_timesteps = 1000
  54. cfg.mean_type = 'eps'
  55. cfg.var_type = 'fixed_small'
  56. cfg.loss_type = 'mse'
  57. cfg.ddim_timesteps = 50
  58. cfg.ddim_eta = 0.0
  59. cfg.clamp = 1.0
  60. cfg.share_noise = False
  61. cfg.use_div_loss = False
  62. # classifier-free guidance
  63. cfg.p_zero = 0.9
  64. cfg.guide_scale = 6.0
  65. # stable diffusion
  66. cfg.sd_checkpoint = 'v2-1_512-ema-pruned.ckpt'
  67. # clip vision encoder
  68. cfg.vit_image_size = 336
  69. cfg.vit_patch_size = 14
  70. cfg.vit_dim = 1024
  71. cfg.vit_out_dim = 768
  72. cfg.vit_heads = 16
  73. cfg.vit_layers = 24
  74. cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
  75. cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
  76. cfg.clip_checkpoint = 'open_clip_pytorch_model.bin'
  77. cfg.mvs_visual = False
  78. # unet
  79. cfg.unet_in_dim = 4
  80. cfg.unet_concat_dim = 8
  81. cfg.unet_y_dim = cfg.vit_out_dim
  82. cfg.unet_context_dim = 1024
  83. cfg.unet_out_dim = 8 if cfg.var_type.startswith('learned') else 4
  84. cfg.unet_dim = 320
  85. cfg.unet_dim_mult = [1, 2, 4, 4]
  86. cfg.unet_res_blocks = 2
  87. cfg.unet_num_heads = 8
  88. cfg.unet_head_dim = 64
  89. cfg.unet_attn_scales = [1 / 1, 1 / 2, 1 / 4]
  90. cfg.unet_dropout = 0.1
  91. cfg.misc_dropout = 0.5
  92. cfg.p_all_zero = 0.1
  93. cfg.p_all_keep = 0.1
  94. cfg.temporal_conv = False
  95. cfg.temporal_attn_times = 1
  96. cfg.temporal_attention = True
  97. cfg.use_fps_condition = False
  98. cfg.use_sim_mask = False
  99. # Default: load 2d pretrain
  100. cfg.pretrained = False
  101. cfg.fix_weight = False
  102. # Default resume
  103. cfg.resume = True
  104. cfg.resume_step = 148000
  105. cfg.resume_check_dir = '.'
  106. cfg.resume_checkpoint = os.path.join(
  107. cfg.resume_check_dir,
  108. f'step_{cfg.resume_step}/non_ema_{cfg.resume_step}.pth')
  109. cfg.resume_optimizer = False
  110. if cfg.resume_optimizer:
  111. cfg.resume_optimizer = os.path.join(
  112. cfg.resume_check_dir, f'optimizer_step_{cfg.resume_step}.pt')
  113. # acceleration
  114. cfg.use_ema = True
  115. # for debug, no ema
  116. if world_size < 2:
  117. cfg.use_ema = False
  118. cfg.load_from = None
  119. cfg.use_checkpoint = True
  120. cfg.use_sharded_ddp = False
  121. cfg.use_fsdp = False
  122. cfg.use_fp16 = True
  123. # training
  124. cfg.ema_decay = 0.9999
  125. cfg.viz_interval = 1000
  126. cfg.save_ckp_interval = 1000
  127. # logging
  128. cfg.log_interval = 100
  129. composition_strings = '_'.join(cfg.video_compositions)
  130. # Default log_dir
  131. cfg.log_dir = 'outputs/'