path_random.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. """
  2. Support for random optimizers, including the random-greedy path.
  3. """
  4. import functools
  5. import heapq
  6. import math
  7. import numbers
  8. import time
  9. from collections import deque
  10. from . import helpers, paths
  11. # random.choices was introduced in python 3.6
  12. try:
  13. from random import choices as random_choices
  14. from random import seed as random_seed
  15. except ImportError:
  16. import numpy as np
  17. def random_choices(population, weights):
  18. norm = sum(weights)
  19. return np.random.choice(population, p=[w / norm for w in weights], size=1)
  20. random_seed = np.random.seed
  21. __all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"]
  22. class RandomOptimizer(paths.PathOptimizer):
  23. """Base class for running any random path finder that benefits
  24. from repeated calling, possibly in a parallel fashion. Custom random
  25. optimizers should subclass this, and the ``setup`` method should be
  26. implemented with the following signature::
  27. def setup(self, inputs, output, size_dict):
  28. # custom preparation here ...
  29. return trial_fn, trial_args
  30. Where ``trial_fn`` itself should have the signature::
  31. def trial_fn(r, *trial_args):
  32. # custom computation of path here
  33. return ssa_path, cost, size
  34. Where ``r`` is the run number and could for example be used to seed a
  35. random number generator. See ``RandomGreedy`` for an example.
  36. Parameters
  37. ----------
  38. max_repeats : int, optional
  39. The maximum number of repeat trials to have.
  40. max_time : float, optional
  41. The maximum amount of time to run the algorithm for.
  42. minimize : {'flops', 'size'}, optional
  43. Whether to favour paths that minimize the total estimated flop-count or
  44. the size of the largest intermediate created.
  45. parallel : {bool, int, or executor-pool like}, optional
  46. Whether to parallelize the random trials, by default ``False``. If
  47. ``True``, use a ``concurrent.futures.ProcessPoolExecutor`` with the same
  48. number of processes as cores. If an integer is specified, use that many
  49. processes instead. Finally, you can supply a custom executor-pool which
  50. should have an API matching that of the python 3 standard library
  51. module ``concurrent.futures``. Namely, a ``submit`` method that returns
  52. ``Future`` objects, themselves with ``result`` and ``cancel`` methods.
  53. pre_dispatch : int, optional
  54. If running in parallel, how many jobs to pre-dispatch so as to avoid
  55. submitting all jobs at once. Should also be more than twice the number
  56. of workers to avoid under-subscription. Default: 128.
  57. Attributes
  58. ----------
  59. path : list[tuple[int]]
  60. The best path found so far.
  61. costs : list[int]
  62. The list of each trial's costs found so far.
  63. sizes : list[int]
  64. The list of each trial's largest intermediate size so far.
  65. See Also
  66. --------
  67. RandomGreedy
  68. """
  69. def __init__(self, max_repeats=32, max_time=None, minimize='flops', parallel=False, pre_dispatch=128):
  70. if minimize not in ('flops', 'size'):
  71. raise ValueError("`minimize` should be one of {'flops', 'size'}.")
  72. self.max_repeats = max_repeats
  73. self.max_time = max_time
  74. self.minimize = minimize
  75. self.better = paths.get_better_fn(minimize)
  76. self.parallel = parallel
  77. self.pre_dispatch = pre_dispatch
  78. self.costs = []
  79. self.sizes = []
  80. self.best = {'flops': float('inf'), 'size': float('inf')}
  81. self._repeats_start = 0
  82. @property
  83. def path(self):
  84. """The best path found so far.
  85. """
  86. return paths.ssa_to_linear(self.best['ssa_path'])
  87. @property
  88. def parallel(self):
  89. return self._parallel
  90. @parallel.setter
  91. def parallel(self, parallel):
  92. # shutdown any previous executor if we are managing it
  93. if getattr(self, '_managing_executor', False):
  94. self._executor.shutdown()
  95. self._parallel = parallel
  96. self._managing_executor = False
  97. if parallel is False:
  98. self._executor = None
  99. return
  100. if parallel is True:
  101. from concurrent.futures import ProcessPoolExecutor
  102. self._executor = ProcessPoolExecutor()
  103. self._managing_executor = True
  104. return
  105. if isinstance(parallel, numbers.Number):
  106. from concurrent.futures import ProcessPoolExecutor
  107. self._executor = ProcessPoolExecutor(parallel)
  108. self._managing_executor = True
  109. return
  110. # assume a pool-executor has been supplied
  111. self._executor = parallel
  112. def _gen_results_parallel(self, repeats, trial_fn, args):
  113. """Lazily generate results from an executor without submitting all jobs at once.
  114. """
  115. self._futures = deque()
  116. # the idea here is to submit at least ``pre_dispatch`` jobs *before* we
  117. # yield any results, then do both in tandem, before draining the queue
  118. for r in repeats:
  119. if len(self._futures) < self.pre_dispatch:
  120. self._futures.append(self._executor.submit(trial_fn, r, *args))
  121. continue
  122. yield self._futures.popleft().result()
  123. while self._futures:
  124. yield self._futures.popleft().result()
  125. def _cancel_futures(self):
  126. if self._executor is not None:
  127. for f in self._futures:
  128. f.cancel()
  129. def setup(self, inputs, output, size_dict):
  130. raise NotImplementedError
  131. def __call__(self, inputs, output, size_dict, memory_limit):
  132. self._check_args_against_first_call(inputs, output, size_dict)
  133. # start a timer?
  134. if self.max_time is not None:
  135. t0 = time.time()
  136. trial_fn, trial_args = self.setup(inputs, output, size_dict)
  137. r_start = self._repeats_start + len(self.costs)
  138. r_stop = r_start + self.max_repeats
  139. repeats = range(r_start, r_stop)
  140. # create the trials lazily
  141. if self._executor is not None:
  142. trials = self._gen_results_parallel(repeats, trial_fn, trial_args)
  143. else:
  144. trials = (trial_fn(r, *trial_args) for r in repeats)
  145. # assess the trials
  146. for ssa_path, cost, size in trials:
  147. # keep track of all costs and sizes
  148. self.costs.append(cost)
  149. self.sizes.append(size)
  150. # check if we have found a new best
  151. found_new_best = self.better(cost, size, self.best['flops'], self.best['size'])
  152. if found_new_best:
  153. self.best['flops'] = cost
  154. self.best['size'] = size
  155. self.best['ssa_path'] = ssa_path
  156. # check if we have run out of time
  157. if (self.max_time is not None) and (time.time() > t0 + self.max_time):
  158. break
  159. self._cancel_futures()
  160. return self.path
  161. def __del__(self):
  162. # if we created the parallel pool-executor, shut it down
  163. if getattr(self, '_managing_executor', False):
  164. self._executor.shutdown()
  165. def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=True):
  166. """A contraction 'chooser' that weights possible contractions using a
  167. Boltzmann distribution. Explicitly, given costs ``c_i`` (with ``c_0`` the
  168. smallest), the relative weights, ``w_i``, are computed as:
  169. w_i = exp( -(c_i - c_0) / temperature)
  170. Additionally, if ``rel_temperature`` is set, scale ``temperature`` by
  171. ``abs(c_0)`` to account for likely fluctuating cost magnitudes during the
  172. course of a contraction.
  173. Parameters
  174. ----------
  175. queue : list
  176. The heapified list of candidate contractions.
  177. remaining : dict[str, int]
  178. Mapping of remaining inputs' indices to the ssa id.
  179. temperature : float, optional
  180. When choosing a possible contraction, its relative probability will be
  181. proportional to ``exp(-cost / temperature)``. Thus the larger
  182. ``temperature`` is, the further random paths will stray from the normal
  183. 'greedy' path. Conversely, if set to zero, only paths with exactly the
  184. same cost as the best at each step will be explored.
  185. rel_temperature : bool, optional
  186. Whether to normalize the ``temperature`` at each step to the scale of
  187. the best cost. This is generally beneficial as the magnitude of costs
  188. can vary significantly throughout a contraction.
  189. nbranch : int, optional
  190. How many potential paths to calculate probability for and choose from
  191. at each step.
  192. Returns
  193. -------
  194. cost, k1, k2, k12
  195. """
  196. n = 0
  197. choices = []
  198. while queue and n < nbranch:
  199. cost, k1, k2, k12 = heapq.heappop(queue)
  200. if k1 not in remaining or k2 not in remaining:
  201. continue # candidate is obsolete
  202. choices.append((cost, k1, k2, k12))
  203. n += 1
  204. if n == 0:
  205. return None
  206. if n == 1:
  207. return choices[0]
  208. costs = [choice[0][0] for choice in choices]
  209. cmin = costs[0]
  210. # adjust by the overall scale to account for fluctuating absolute costs
  211. if rel_temperature:
  212. temperature *= max(1, abs(cmin))
  213. # compute relative probability for each potential contraction
  214. if temperature == 0.0:
  215. energies = [1 if c == cmin else 0 for c in costs]
  216. else:
  217. # shift by cmin for numerical reasons
  218. energies = [math.exp(-(c - cmin) / temperature) for c in costs]
  219. # randomly choose a contraction based on energies
  220. chosen, = random_choices(range(n), weights=energies)
  221. cost, k1, k2, k12 = choices.pop(chosen)
  222. # put the other choise back in the heap
  223. for other in choices:
  224. heapq.heappush(queue, other)
  225. return cost, k1, k2, k12
  226. def ssa_path_compute_cost(ssa_path, inputs, output, size_dict):
  227. """Compute the flops and max size of an ssa path.
  228. """
  229. inputs = list(map(frozenset, inputs))
  230. output = frozenset(output)
  231. remaining = set(range(len(inputs)))
  232. total_cost = 0
  233. max_size = 0
  234. for i, j in ssa_path:
  235. k12, flops12 = paths.calc_k12_flops(inputs, output, remaining, i, j, size_dict)
  236. remaining.discard(i)
  237. remaining.discard(j)
  238. remaining.add(len(inputs))
  239. inputs.append(k12)
  240. total_cost += flops12
  241. max_size = max(max_size, helpers.compute_size_by_dict(k12, size_dict))
  242. return total_cost, max_size
  243. def _trial_greedy_ssa_path_and_cost(r, inputs, output, size_dict, choose_fn, cost_fn):
  244. """A single, repeatable, greedy trial run. Returns ``ssa_path`` and cost.
  245. """
  246. if r == 0:
  247. # always start with the standard greedy approach
  248. choose_fn = None
  249. random_seed(r)
  250. ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn)
  251. cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
  252. return ssa_path, cost, size
  253. class RandomGreedy(RandomOptimizer):
  254. """
  255. Parameters
  256. ----------
  257. cost_fn : callable, optional
  258. A function that returns a heuristic 'cost' of a potential contraction
  259. with which to sort candidates. Should have signature
  260. ``cost_fn(size12, size1, size2, k12, k1, k2)``.
  261. temperature : float, optional
  262. When choosing a possible contraction, its relative probability will be
  263. proportional to ``exp(-cost / temperature)``. Thus the larger
  264. ``temperature`` is, the further random paths will stray from the normal
  265. 'greedy' path. Conversely, if set to zero, only paths with exactly the
  266. same cost as the best at each step will be explored.
  267. rel_temperature : bool, optional
  268. Whether to normalize the ``temperature`` at each step to the scale of
  269. the best cost. This is generally beneficial as the magnitude of costs
  270. can vary significantly throughout a contraction. If False, the
  271. algorithm will end up branching when the absolute cost is low, but
  272. stick to the 'greedy' path when the cost is high - this can also be
  273. beneficial.
  274. nbranch : int, optional
  275. How many potential paths to calculate probability for and choose from
  276. at each step.
  277. kwargs
  278. Supplied to RandomOptimizer.
  279. See Also
  280. --------
  281. RandomOptimizer
  282. """
  283. def __init__(self, cost_fn='memory-removed-jitter', temperature=1.0, rel_temperature=True, nbranch=8, **kwargs):
  284. self.cost_fn = cost_fn
  285. self.temperature = temperature
  286. self.rel_temperature = rel_temperature
  287. self.nbranch = nbranch
  288. super().__init__(**kwargs)
  289. @property
  290. def choose_fn(self):
  291. """The function that chooses which contraction to take - make this a
  292. property so that ``temperature`` and ``nbranch`` etc. can be updated
  293. between runs.
  294. """
  295. if self.nbranch == 1:
  296. return None
  297. return functools.partial(thermal_chooser,
  298. temperature=self.temperature,
  299. nbranch=self.nbranch,
  300. rel_temperature=self.rel_temperature)
  301. def setup(self, inputs, output, size_dict):
  302. fn = _trial_greedy_ssa_path_and_cost
  303. args = (inputs, output, size_dict, self.choose_fn, self.cost_fn)
  304. return fn, args
  305. def random_greedy(inputs, output, idx_dict, memory_limit=None, **optimizer_kwargs):
  306. """
  307. """
  308. optimizer = RandomGreedy(**optimizer_kwargs)
  309. return optimizer(inputs, output, idx_dict, memory_limit)
  310. random_greedy_128 = functools.partial(random_greedy, max_repeats=128)