model.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # This code is borrowed and modified from Human Motion Diffusion Model,
  2. # made publicly available under MIT license at https://github.com/GuyTevet/motion-diffusion-model
  3. from .modules import gaussian_diffusion as gd
  4. from .modules.mdm import MDM
  5. from .modules.respace import SpacedDiffusion, space_timesteps
  6. def load_model_wo_clip(model, state_dict):
  7. missing_keys, unexpected_keys = model.load_state_dict(
  8. state_dict, strict=False)
  9. assert len(unexpected_keys) == 0
  10. assert all([k.startswith('clip_model.') for k in missing_keys])
  11. def create_model(cfg):
  12. model = MDM(
  13. '',
  14. njoints=263,
  15. nfeats=1,
  16. num_actions=1,
  17. translation=True,
  18. pose_rep='rot6d',
  19. glob=True,
  20. glob_rot=True,
  21. latent_dim=512,
  22. ff_size=1024,
  23. smpl_data_path=cfg.smpl_data_path,
  24. data_rep='hml_vec',
  25. dataset='humanml',
  26. clip_version='ViT-B/32',
  27. **{
  28. 'cond_mode': 'text',
  29. 'cond_mask_prob': 0.1,
  30. 'action_emb': 'tensor'
  31. })
  32. predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal!
  33. steps = cfg.sample_steps
  34. scale_beta = 1. # no scaling
  35. timestep_respacing = '' # can be used for ddim sampling, we don't use it.
  36. learn_sigma = False
  37. rescale_timesteps = False
  38. betas = gd.get_named_beta_schedule('cosine', steps, scale_beta)
  39. loss_type = gd.LossType.MSE
  40. if not timestep_respacing:
  41. timestep_respacing = [steps]
  42. diffusion = SpacedDiffusion(
  43. use_timesteps=space_timesteps(steps, timestep_respacing),
  44. betas=betas,
  45. model_mean_type=(gd.ModelMeanType.EPSILON
  46. if not predict_xstart else gd.ModelMeanType.START_X),
  47. model_var_type=((gd.ModelVarType.FIXED_LARGE
  48. if not True else gd.ModelVarType.FIXED_SMALL)
  49. if not learn_sigma else gd.ModelVarType.LEARNED_RANGE),
  50. loss_type=loss_type,
  51. rescale_timesteps=rescale_timesteps,
  52. lambda_vel=0.0,
  53. lambda_rcxyz=0.0,
  54. lambda_fc=0.0,
  55. )
  56. return model, diffusion