pool1d.py 670 B

1234567891011121314151617181920212223242526
  1. import torch
  2. def global_pool_nlc(
  3. x: torch.Tensor,
  4. pool_type: str = 'token',
  5. num_prefix_tokens: int = 1,
  6. reduce_include_prefix: bool = False,
  7. ):
  8. if not pool_type:
  9. return x
  10. if pool_type == 'token':
  11. x = x[:, 0] # class token
  12. else:
  13. x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
  14. if pool_type == 'avg':
  15. x = x.mean(dim=1)
  16. elif pool_type == 'avgmax':
  17. x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
  18. elif pool_type == 'max':
  19. x = x.amax(dim=1)
  20. else:
  21. assert not pool_type, f'Unknown pool type {pool_type}'
  22. return x