blas.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """
  2. Determines if a contraction can use BLAS or not
  3. """
  4. import numpy as np
  5. from . import helpers
  6. __all__ = ["can_blas", "tensor_blas"]
  7. def can_blas(inputs, result, idx_removed, shapes=None):
  8. """
  9. Checks if we can use a BLAS call.
  10. Parameters
  11. ----------
  12. inputs : list of str
  13. Specifies the subscripts for summation.
  14. result : str
  15. Resulting summation.
  16. idx_removed : set
  17. Indices that are removed in the summation
  18. shapes : sequence of tuple[int], optional
  19. If given, check also that none of the indices are broadcast dimensions.
  20. Returns
  21. -------
  22. type : str or bool
  23. The type of BLAS call to be used or False if none.
  24. Notes
  25. -----
  26. We assume several operations are not efficient such as a transposed
  27. DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas
  28. type appended with "/EINSUM" to differentiate when they can still be done
  29. with tensordot if required, e.g. when a backend has no einsum.
  30. Examples
  31. --------
  32. >>> can_blas(['ij', 'jk'], 'ik', set('j'))
  33. 'GEMM'
  34. >>> can_blas(['ijj', 'jk'], 'ik', set('j'))
  35. False
  36. >>> can_blas(['ab', 'cd'], 'abcd', set())
  37. 'OUTER/EINSUM'
  38. >>> # looks like GEMM but actually 'j' is broadcast:
  39. >>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)])
  40. False
  41. """
  42. # Can only do two
  43. if len(inputs) != 2:
  44. return False
  45. input_left, input_right = inputs
  46. for c in set(input_left + input_right):
  47. # can't deal with repeated indices on same input or more than 2 total
  48. nl, nr = input_left.count(c), input_right.count(c)
  49. if (nl > 1) or (nr > 1) or (nl + nr > 2):
  50. return False
  51. # can't do implicit summation or dimension collapse e.g.
  52. # "ab,bc->c" (implicitly sum over 'a')
  53. # "ab,ca->ca" (take diagonal of 'a')
  54. if nl + nr - 1 == int(c in result):
  55. return False
  56. # check for broadcast indices e.g:
  57. # "ij,jk->ik" (but one of the 'j' dimensions is broadcast up)
  58. if shapes is not None:
  59. for c in idx_removed:
  60. if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]:
  61. return False
  62. # Prefer einsum if not removing indices
  63. # (N.B. tensordot outer faster for large arrays?)
  64. if len(idx_removed) == 0:
  65. return 'OUTER/EINSUM'
  66. # Build a few temporaries
  67. sets = [set(x) for x in inputs]
  68. keep_left = sets[0] - idx_removed
  69. keep_right = sets[1] - idx_removed
  70. rs = len(idx_removed)
  71. # DDOT
  72. if inputs[0] == inputs[1]:
  73. return 'DOT'
  74. # DDOT doesnt make sense if you have to tranpose - prefer einsum
  75. elif sets[0] == sets[1]:
  76. return 'DOT/EINSUM'
  77. # GEMM no transpose
  78. if input_left[-rs:] == input_right[:rs]:
  79. return 'GEMM'
  80. # GEMM transpose both
  81. elif input_left[:rs] == input_right[-rs:]:
  82. return 'GEMM'
  83. # GEMM transpose right
  84. elif input_left[-rs:] == input_right[-rs:]:
  85. return 'GEMM'
  86. # GEMM tranpose left
  87. elif input_left[:rs] == input_right[:rs]:
  88. return 'GEMM'
  89. # Einsum is faster than vectordot if we have to copy
  90. elif (len(keep_left) == 0) or (len(keep_right) == 0):
  91. return 'GEMV/EINSUM'
  92. # Conventional tensordot
  93. else:
  94. return 'TDOT'
  95. def tensor_blas(view_left, input_left, view_right, input_right, index_result, idx_removed):
  96. """
  97. Computes the dot product between two tensors, attempts to use np.dot and
  98. then tensordot if that fails.
  99. Parameters
  100. ----------
  101. view_left : array_like
  102. The left hand view
  103. input_left : str
  104. Indices of the left view
  105. view_right : array_like
  106. The right hand view
  107. input_right : str
  108. Indices of the right view
  109. index_result : str
  110. The resulting indices
  111. idx_removed : set
  112. Indices removed in the contraction
  113. Returns
  114. -------
  115. type : array
  116. The resulting BLAS operation.
  117. Notes
  118. -----
  119. Interior function for tensor BLAS.
  120. This function will attempt to use `np.dot` by the iterating through the
  121. four possible transpose cases. If this fails all inner and matrix-vector
  122. operations will be handed off to einsum while all matrix-matrix operations will
  123. first copy the data, perform the DGEMM, and then copy the data to the required
  124. order.
  125. Examples
  126. --------
  127. >>> a = np.random.rand(4, 4)
  128. >>> b = np.random.rand(4, 4)
  129. >>> tmp = tensor_blas(a, 'ij', b, 'jk', 'ik', set('j'))
  130. >>> np.allclose(tmp, np.dot(a, b))
  131. """
  132. idx_removed = set(idx_removed)
  133. keep_left = set(input_left) - idx_removed
  134. keep_right = set(input_right) - idx_removed
  135. # We trust this must be called correctly
  136. dimension_dict = {}
  137. for i, s in zip(input_left, view_left.shape):
  138. dimension_dict[i] = s
  139. for i, s in zip(input_right, view_right.shape):
  140. dimension_dict[i] = s
  141. # Do we want to be able to do this?
  142. # Check for duplicate indices, cannot do einsum('iij,jkk->ik') operations here
  143. # if (len(set(input_left)) != len(input_left)):
  144. # new_inds = ''.join(keep_left) + ''.join(idx_removed)
  145. # view_left = np.einsum(input_left + '->' + new_inds, view_left, order='C')
  146. # input_left = new_inds
  147. # if (len(set(input_right)) != len(input_right)):
  148. # new_inds = ''.join(idx_removed) + ''.join(keep_right)
  149. # view_right = np.einsum(input_right + '->' + new_inds, view_right, order='C')
  150. # input_right = new_inds
  151. # Tensordot guarantees a copy for ndim > 2, should avoid skip if possible
  152. rs = len(idx_removed)
  153. dim_left = helpers.compute_size_by_dict(keep_left, dimension_dict)
  154. dim_right = helpers.compute_size_by_dict(keep_right, dimension_dict)
  155. dim_removed = helpers.compute_size_by_dict(idx_removed, dimension_dict)
  156. tensor_result = input_left + input_right
  157. for s in idx_removed:
  158. tensor_result = tensor_result.replace(s, "")
  159. # This is ugly, but can vastly speed up certain operations
  160. # Vectordot
  161. if input_left == input_right:
  162. new_view = np.dot(view_left.ravel(), view_right.ravel())
  163. # Matrix multiply
  164. # No transpose needed
  165. elif input_left[-rs:] == input_right[:rs]:
  166. new_view = np.dot(view_left.reshape(dim_left, dim_removed), view_right.reshape(dim_removed, dim_right))
  167. # Transpose both
  168. elif input_left[:rs] == input_right[-rs:]:
  169. new_view = np.dot(view_left.reshape(dim_removed, dim_left).T, view_right.reshape(dim_right, dim_removed).T)
  170. # Transpose right
  171. elif input_left[-rs:] == input_right[-rs:]:
  172. new_view = np.dot(view_left.reshape(dim_left, dim_removed), view_right.reshape(dim_right, dim_removed).T)
  173. # Tranpose left
  174. elif input_left[:rs] == input_right[:rs]:
  175. new_view = np.dot(view_left.reshape(dim_removed, dim_left).T, view_right.reshape(dim_removed, dim_right))
  176. # Conventional tensordot
  177. else:
  178. # Find indices to contract over
  179. left_pos, right_pos = (), ()
  180. for s in idx_removed:
  181. left_pos += (input_left.find(s), )
  182. right_pos += (input_right.find(s), )
  183. new_view = np.tensordot(view_left, view_right, axes=(left_pos, right_pos))
  184. # Make sure the resulting shape is correct
  185. tensor_shape = tuple(dimension_dict[x] for x in tensor_result)
  186. if new_view.shape != tensor_shape:
  187. if len(tensor_result) > 0:
  188. new_view.shape = tensor_shape
  189. else:
  190. new_view = np.squeeze(new_view)
  191. if tensor_result != index_result:
  192. new_view = np.einsum(tensor_result + '->' + index_result, new_view)
  193. return new_view