helpers.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. """
  2. Contains helper functions for opt_einsum testing scripts
  3. """
  4. from collections import OrderedDict
  5. import numpy as np
  6. from .parser import get_symbol
  7. __all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"]
  8. _valid_chars = "abcdefghijklmopqABC"
  9. _sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4])
  10. _default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)}
  11. def build_views(string, dimension_dict=None):
  12. """
  13. Builds random numpy arrays for testing.
  14. Parameters
  15. ----------
  16. string : list of str
  17. List of tensor strings to build
  18. dimension_dict : dictionary
  19. Dictionary of index _sizes
  20. Returns
  21. -------
  22. ret : list of np.ndarry's
  23. The resulting views.
  24. Examples
  25. --------
  26. >>> view = build_views(['abbc'], {'a': 2, 'b':3, 'c':5})
  27. >>> view[0].shape
  28. (2, 3, 3, 5)
  29. """
  30. if dimension_dict is None:
  31. dimension_dict = _default_dim_dict
  32. views = []
  33. terms = string.split('->')[0].split(',')
  34. for term in terms:
  35. dims = [dimension_dict[x] for x in term]
  36. views.append(np.random.rand(*dims))
  37. return views
  38. def compute_size_by_dict(indices, idx_dict):
  39. """
  40. Computes the product of the elements in indices based on the dictionary
  41. idx_dict.
  42. Parameters
  43. ----------
  44. indices : iterable
  45. Indices to base the product on.
  46. idx_dict : dictionary
  47. Dictionary of index _sizes
  48. Returns
  49. -------
  50. ret : int
  51. The resulting product.
  52. Examples
  53. --------
  54. >>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
  55. 90
  56. """
  57. ret = 1
  58. for i in indices: # lgtm [py/iteration-string-and-sequence]
  59. ret *= idx_dict[i]
  60. return ret
  61. def find_contraction(positions, input_sets, output_set):
  62. """
  63. Finds the contraction for a given set of input and output sets.
  64. Parameters
  65. ----------
  66. positions : iterable
  67. Integer positions of terms used in the contraction.
  68. input_sets : list
  69. List of sets that represent the lhs side of the einsum subscript
  70. output_set : set
  71. Set that represents the rhs side of the overall einsum subscript
  72. Returns
  73. -------
  74. new_result : set
  75. The indices of the resulting contraction
  76. remaining : list
  77. List of sets that have not been contracted, the new set is appended to
  78. the end of this list
  79. idx_removed : set
  80. Indices removed from the entire contraction
  81. idx_contraction : set
  82. The indices used in the current contraction
  83. Examples
  84. --------
  85. # A simple dot product test case
  86. >>> pos = (0, 1)
  87. >>> isets = [set('ab'), set('bc')]
  88. >>> oset = set('ac')
  89. >>> find_contraction(pos, isets, oset)
  90. ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
  91. # A more complex case with additional terms in the contraction
  92. >>> pos = (0, 2)
  93. >>> isets = [set('abd'), set('ac'), set('bdc')]
  94. >>> oset = set('ac')
  95. >>> find_contraction(pos, isets, oset)
  96. ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
  97. """
  98. remaining = list(input_sets)
  99. inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
  100. idx_contract = set.union(*inputs)
  101. idx_remain = output_set.union(*remaining)
  102. new_result = idx_remain & idx_contract
  103. idx_removed = (idx_contract - new_result)
  104. remaining.append(new_result)
  105. return new_result, remaining, idx_removed, idx_contract
  106. def flop_count(idx_contraction, inner, num_terms, size_dictionary):
  107. """
  108. Computes the number of FLOPS in the contraction.
  109. Parameters
  110. ----------
  111. idx_contraction : iterable
  112. The indices involved in the contraction
  113. inner : bool
  114. Does this contraction require an inner product?
  115. num_terms : int
  116. The number of terms in a contraction
  117. size_dictionary : dict
  118. The size of each of the indices in idx_contraction
  119. Returns
  120. -------
  121. flop_count : int
  122. The total number of FLOPS required for the contraction.
  123. Examples
  124. --------
  125. >>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
  126. 90
  127. >>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
  128. 270
  129. """
  130. overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
  131. op_factor = max(1, num_terms - 1)
  132. if inner:
  133. op_factor += 1
  134. return overall_size * op_factor
  135. def rand_equation(n, reg, n_out=0, d_min=2, d_max=9, seed=None, global_dim=False, return_size_dict=False):
  136. """Generate a random contraction and shapes.
  137. Parameters
  138. ----------
  139. n : int
  140. Number of array arguments.
  141. reg : int
  142. 'Regularity' of the contraction graph. This essentially determines how
  143. many indices each tensor shares with others on average.
  144. n_out : int, optional
  145. Number of output indices (i.e. the number of non-contracted indices).
  146. Defaults to 0, i.e., a contraction resulting in a scalar.
  147. d_min : int, optional
  148. Minimum dimension size.
  149. d_max : int, optional
  150. Maximum dimension size.
  151. seed: int, optional
  152. If not None, seed numpy's random generator with this.
  153. global_dim : bool, optional
  154. Add a global, 'broadcast', dimension to every operand.
  155. return_size_dict : bool, optional
  156. Return the mapping of indices to sizes.
  157. Returns
  158. -------
  159. eq : str
  160. The equation string.
  161. shapes : list[tuple[int]]
  162. The array shapes.
  163. size_dict : dict[str, int]
  164. The dict of index sizes, only returned if ``return_size_dict=True``.
  165. Examples
  166. --------
  167. >>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42)
  168. >>> eq
  169. 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda'
  170. >>> shapes
  171. [(9, 5, 4, 5, 4),
  172. (4, 4, 8, 5),
  173. (9, 4, 6, 9),
  174. (6, 6),
  175. (6, 9, 7, 8),
  176. (4,),
  177. (9, 3, 9, 4, 9),
  178. (6, 8, 4, 6, 8, 6, 3),
  179. (4, 7, 8, 8, 6, 9, 6),
  180. (9, 5, 3, 3, 9, 5)]
  181. """
  182. if seed is not None:
  183. np.random.seed(seed)
  184. # total number of indices
  185. num_inds = n * reg // 2 + n_out
  186. inputs = ["" for _ in range(n)]
  187. output = []
  188. size_dict = OrderedDict((get_symbol(i), np.random.randint(d_min, d_max + 1)) for i in range(num_inds))
  189. # generate a list of indices to place either once or twice
  190. def gen():
  191. for i, ix in enumerate(size_dict):
  192. # generate an outer index
  193. if i < n_out:
  194. output.append(ix)
  195. yield ix
  196. # generate a bond
  197. else:
  198. yield ix
  199. yield ix
  200. # add the indices randomly to the inputs
  201. for i, ix in enumerate(np.random.permutation(list(gen()))):
  202. # make sure all inputs have at least one index
  203. if i < n:
  204. inputs[i] += ix
  205. else:
  206. # don't add any traces on same op
  207. where = np.random.randint(0, n)
  208. while ix in inputs[where]:
  209. where = np.random.randint(0, n)
  210. inputs[where] += ix
  211. # possibly add the same global dim to every arg
  212. if global_dim:
  213. gdim = get_symbol(num_inds)
  214. size_dict[gdim] = np.random.randint(d_min, d_max + 1)
  215. for i in range(n):
  216. inputs[i] += gdim
  217. output += gdim
  218. # randomly transpose the output indices and form equation
  219. output = "".join(np.random.permutation(output))
  220. eq = "{}->{}".format(",".join(inputs), output)
  221. # make the shapes
  222. shapes = [tuple(size_dict[ix] for ix in op) for op in inputs]
  223. ret = (eq, shapes)
  224. if return_size_dict:
  225. ret += (size_dict, )
  226. return ret