parser.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. """
  4. A functionally equivalent parser of the numpy.einsum input parser
  5. """
  6. import itertools
  7. from collections import OrderedDict
  8. import numpy as np
  9. __all__ = [
  10. "is_valid_einsum_char", "has_valid_einsum_chars_only", "get_symbol", "gen_unused_symbols",
  11. "convert_to_valid_einsum_chars", "alpha_canonicalize", "find_output_str", "find_output_shape",
  12. "possibly_convert_to_numpy", "parse_einsum_input"
  13. ]
  14. _einsum_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
  15. def is_valid_einsum_char(x):
  16. """Check if the character ``x`` is valid for numpy einsum.
  17. Examples
  18. --------
  19. >>> is_valid_einsum_char("a")
  20. True
  21. >>> is_valid_einsum_char("Ǵ")
  22. False
  23. """
  24. return (x in _einsum_symbols_base) or (x in ',->.')
  25. def has_valid_einsum_chars_only(einsum_str):
  26. """Check if ``einsum_str`` contains only valid characters for numpy einsum.
  27. Examples
  28. --------
  29. >>> has_valid_einsum_chars_only("abAZ")
  30. True
  31. >>> has_valid_einsum_chars_only("Över")
  32. False
  33. """
  34. return all(map(is_valid_einsum_char, einsum_str))
  35. def get_symbol(i):
  36. """Get the symbol corresponding to int ``i`` - runs through the usual 52
  37. letters before resorting to unicode characters, starting at ``chr(192)``.
  38. Examples
  39. --------
  40. >>> get_symbol(2)
  41. 'c'
  42. >>> get_symbol(200)
  43. 'Ŕ'
  44. >>> get_symbol(20000)
  45. '京'
  46. """
  47. if i < 52:
  48. return _einsum_symbols_base[i]
  49. return chr(i + 140)
  50. def gen_unused_symbols(used, n):
  51. """Generate ``n`` symbols that are not already in ``used``.
  52. Examples
  53. --------
  54. >>> list(oe.parser.gen_unused_symbols("abd", 2))
  55. ['c', 'e']
  56. """
  57. i = cnt = 0
  58. while cnt < n:
  59. s = get_symbol(i)
  60. i += 1
  61. if s in used:
  62. continue
  63. yield s
  64. cnt += 1
  65. def convert_to_valid_einsum_chars(einsum_str):
  66. """Convert the str ``einsum_str`` to contain only the alphabetic characters
  67. valid for numpy einsum. If there are too many symbols, let the backend
  68. throw an error.
  69. Examples
  70. --------
  71. >>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö")
  72. 'cbdda'
  73. """
  74. symbols = sorted(set(einsum_str) - set(',->'))
  75. replacer = {x: get_symbol(i) for i, x in enumerate(symbols)}
  76. return "".join(replacer.get(x, x) for x in einsum_str)
  77. def alpha_canonicalize(equation):
  78. """Alpha convert an equation in an order-independent canonical way.
  79. Examples
  80. --------
  81. >>> oe.parser.alpha_canonicalize("dcba")
  82. 'abcd'
  83. >>> oe.parser.alpha_canonicalize("Ĥěļļö")
  84. 'abccd'
  85. """
  86. rename = OrderedDict()
  87. for name in equation:
  88. if name in '.,->':
  89. continue
  90. if name not in rename:
  91. rename[name] = get_symbol(len(rename))
  92. return ''.join(rename.get(x, x) for x in equation)
  93. def find_output_str(subscripts):
  94. """
  95. Find the output string for the inputs ``subscripts`` under canonical einstein summation rules. That is, repeated indices are summed over by default.
  96. Examples
  97. --------
  98. >>> oe.parser.find_output_str("ab,bc")
  99. 'ac'
  100. >>> oe.parser.find_output_str("a,b")
  101. 'ab'
  102. >>> oe.parser.find_output_str("a,a,b,b")
  103. ''
  104. """
  105. tmp_subscripts = subscripts.replace(",", "")
  106. return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1)
  107. def find_output_shape(inputs, shapes, output):
  108. """Find the output shape for given inputs, shapes and output string, taking
  109. into account broadcasting.
  110. Examples
  111. --------
  112. >>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac")
  113. (2, 4)
  114. # Broadcasting is accounted for
  115. >>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a")
  116. (4,)
  117. """
  118. return tuple(
  119. max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output)
  120. def possibly_convert_to_numpy(x):
  121. """Convert things without a 'shape' to ndarrays, but leave everything else.
  122. Examples
  123. --------
  124. >>> oe.parser.possibly_convert_to_numpy(5)
  125. array(5)
  126. >>> oe.parser.possibly_convert_to_numpy([5, 3])
  127. array([5, 3])
  128. >>> oe.parser.possibly_convert_to_numpy(np.array([5, 3]))
  129. array([5, 3])
  130. # Any class with a shape is passed through
  131. >>> class Shape:
  132. ... def __init__(self, shape):
  133. ... self.shape = shape
  134. ...
  135. >>> myshape = Shape((5, 5))
  136. >>> oe.parser.possibly_convert_to_numpy(myshape)
  137. <__main__.Shape object at 0x10f850710>
  138. """
  139. if not hasattr(x, 'shape'):
  140. return np.asanyarray(x)
  141. else:
  142. return x
  143. def convert_subscripts(old_sub, symbol_map):
  144. """Convert user custom subscripts list to subscript string according to `symbol_map`.
  145. Examples
  146. --------
  147. >>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'})
  148. 'ab'
  149. >>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'})
  150. '...a'
  151. """
  152. new_sub = ""
  153. for s in old_sub:
  154. if s is Ellipsis:
  155. new_sub += "..."
  156. else:
  157. # no need to try/except here because symbol_map has already been checked
  158. new_sub += symbol_map[s]
  159. return new_sub
  160. def convert_interleaved_input(operands):
  161. """Convert 'interleaved' input to standard einsum input.
  162. """
  163. tmp_operands = list(operands)
  164. operand_list = []
  165. subscript_list = []
  166. for p in range(len(operands) // 2):
  167. operand_list.append(tmp_operands.pop(0))
  168. subscript_list.append(tmp_operands.pop(0))
  169. output_list = tmp_operands[-1] if len(tmp_operands) else None
  170. operands = [possibly_convert_to_numpy(x) for x in operand_list]
  171. # build a map from user symbols to single-character symbols based on `get_symbol`
  172. # The map retains the intrinsic order of user symbols
  173. try:
  174. # collect all user symbols
  175. symbol_set = set(itertools.chain.from_iterable(subscript_list))
  176. # remove Ellipsis because it can not be compared with other objects
  177. symbol_set.discard(Ellipsis)
  178. # build the map based on sorted user symbols, retaining the order we lost in the `set`
  179. symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))}
  180. except TypeError: # unhashable or uncomparable object
  181. raise TypeError("For this input type lists must contain either Ellipsis "
  182. "or hashable and comparable object (e.g. int, str).")
  183. subscripts = ','.join(convert_subscripts(sub, symbol_map) for sub in subscript_list)
  184. if output_list is not None:
  185. subscripts += "->"
  186. subscripts += convert_subscripts(output_list, symbol_map)
  187. return subscripts, operands
  188. def parse_einsum_input(operands):
  189. """
  190. A reproduction of einsum c side einsum parsing in python.
  191. Returns
  192. -------
  193. input_strings : str
  194. Parsed input strings
  195. output_string : str
  196. Parsed output string
  197. operands : list of array_like
  198. The operands to use in the numpy contraction
  199. Examples
  200. --------
  201. The operand list is simplified to reduce printing:
  202. >>> a = np.random.rand(4, 4)
  203. >>> b = np.random.rand(4, 4, 4)
  204. >>> parse_einsum_input(('...a,...a->...', a, b))
  205. ('za,xza', 'xz', [a, b])
  206. >>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
  207. ('za,xza', 'xz', [a, b])
  208. """
  209. if len(operands) == 0:
  210. raise ValueError("No input operands")
  211. if isinstance(operands[0], str):
  212. subscripts = operands[0].replace(" ", "")
  213. operands = [possibly_convert_to_numpy(x) for x in operands[1:]]
  214. else:
  215. subscripts, operands = convert_interleaved_input(operands)
  216. # Check for proper "->"
  217. if ("-" in subscripts) or (">" in subscripts):
  218. invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
  219. if invalid or (subscripts.count("->") != 1):
  220. raise ValueError("Subscripts can only contain one '->'.")
  221. # Parse ellipses
  222. if "." in subscripts:
  223. used = subscripts.replace(".", "").replace(",", "").replace("->", "")
  224. ellipse_inds = "".join(gen_unused_symbols(used, max(len(x.shape) for x in operands)))
  225. longest = 0
  226. # Do we have an output to account for?
  227. if "->" in subscripts:
  228. input_tmp, output_sub = subscripts.split("->")
  229. split_subscripts = input_tmp.split(",")
  230. out_sub = True
  231. else:
  232. split_subscripts = subscripts.split(',')
  233. out_sub = False
  234. for num, sub in enumerate(split_subscripts):
  235. if "." in sub:
  236. if (sub.count(".") != 3) or (sub.count("...") != 1):
  237. raise ValueError("Invalid Ellipses.")
  238. # Take into account numerical values
  239. if operands[num].shape == ():
  240. ellipse_count = 0
  241. else:
  242. ellipse_count = max(len(operands[num].shape), 1) - (len(sub) - 3)
  243. if ellipse_count > longest:
  244. longest = ellipse_count
  245. if ellipse_count < 0:
  246. raise ValueError("Ellipses lengths do not match.")
  247. elif ellipse_count == 0:
  248. split_subscripts[num] = sub.replace('...', '')
  249. else:
  250. split_subscripts[num] = sub.replace('...', ellipse_inds[-ellipse_count:])
  251. subscripts = ",".join(split_subscripts)
  252. # Figure out output ellipses
  253. if longest == 0:
  254. out_ellipse = ""
  255. else:
  256. out_ellipse = ellipse_inds[-longest:]
  257. if out_sub:
  258. subscripts += "->" + output_sub.replace("...", out_ellipse)
  259. else:
  260. # Special care for outputless ellipses
  261. output_subscript = find_output_str(subscripts)
  262. normal_inds = ''.join(sorted(set(output_subscript) - set(out_ellipse)))
  263. subscripts += "->" + out_ellipse + normal_inds
  264. # Build output string if does not exist
  265. if "->" in subscripts:
  266. input_subscripts, output_subscript = subscripts.split("->")
  267. else:
  268. input_subscripts, output_subscript = subscripts, find_output_str(subscripts)
  269. # Make sure output subscripts are in the input
  270. for char in output_subscript:
  271. if char not in input_subscripts:
  272. raise ValueError("Output character '{}' did not appear in the input".format(char))
  273. # Make sure number operands is equivalent to the number of terms
  274. if len(input_subscripts.split(',')) != len(operands):
  275. raise ValueError("Number of einsum subscripts must be equal to the " "number of operands.")
  276. return input_subscripts, output_subscript, operands