functions.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. def unsqueeze(input, dims):
  6. """ Implement multi-dimension unsqueeze function. """
  7. if isinstance(dims, (list, tuple)):
  8. dims = [
  9. dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims
  10. ]
  11. dims = sorted(dims, reverse=True)
  12. shape = list(input.shape)
  13. for dim in dims:
  14. shape.insert(dim, 1)
  15. return torch.reshape(input, shape)
  16. elif isinstance(dims, int):
  17. return input.unsqueeze(dims)
  18. else:
  19. raise ValueError('Warning: type(dims) must in (list, tuple, int)!')
  20. def gumbel_softmax(input, tau=1, eps=1e-10):
  21. """ Basic implement of gumbel_softmax. """
  22. U = torch.tensor(np.random.rand(*input.shape))
  23. gumbel = 0.0 - torch.log(eps - torch.log(U + eps))
  24. y = input + gumbel
  25. return F.softmax(y / tau)
  26. def equal(x, y, dtype=None):
  27. """ Implement equal in dygraph mode. (paddle) """
  28. if dtype is None:
  29. dtype = 'float32'
  30. if isinstance(x, torch.Tensor):
  31. x = x.numpy()
  32. if isinstance(y, torch.Tensor):
  33. y = y.numpy()
  34. out = np.equal(x, y).astype(dtype)
  35. return torch.tensor(out)
  36. def not_equal(x, y, dtype=None):
  37. """ Implement not_equal in dygraph mode. (paddle) """
  38. return 1 - equal(x, y, dtype)
  39. if __name__ == '__main__':
  40. a = torch.tensor([[1, 1], [3, 4]])
  41. b = torch.tensor([[1, 1], [3, 4]])
  42. c = torch.equal(a, a)
  43. c1 = equal(a, 3)
  44. d = 1 - torch.not_equal(a, 3).float()
  45. print(c)
  46. print(c1)
  47. print(d)
  48. e = F.gumbel_softmax(a)
  49. f = a.unsqueeze(a)
  50. g = unsqueeze(a, dims=[0, 0, 1])
  51. print(g, g.shape)