lbfgs.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import _to_scalar, Optimizer, ParamsT
  6. __all__ = ["LBFGS"]
  7. def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
  8. # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
  9. # Compute bounds of interpolation area
  10. if bounds is not None:
  11. xmin_bound, xmax_bound = bounds
  12. else:
  13. xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
  14. # Code for most common case: cubic interpolation of 2 points
  15. # w/ function and derivative values for both
  16. # Solution in this case (where x2 is the farthest point):
  17. # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
  18. # d2 = sqrt(d1^2 - g1*g2);
  19. # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
  20. # t_new = min(max(min_pos,xmin_bound),xmax_bound);
  21. d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
  22. d2_square = d1**2 - g1 * g2
  23. if d2_square >= 0:
  24. d2 = d2_square.sqrt()
  25. if x1 <= x2:
  26. min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
  27. else:
  28. min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
  29. return min(max(min_pos, xmin_bound), xmax_bound)
  30. else:
  31. return (xmin_bound + xmax_bound) / 2.0
  32. def _strong_wolfe(
  33. obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25
  34. ):
  35. # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
  36. d_norm = d.abs().max()
  37. g = g.clone(memory_format=torch.contiguous_format)
  38. # evaluate objective and gradient using initial step
  39. f_new, g_new = obj_func(x, t, d)
  40. ls_func_evals = 1
  41. gtd_new = g_new.dot(d)
  42. # bracket an interval containing a point satisfying the Wolfe criteria
  43. t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
  44. done = False
  45. ls_iter = 0
  46. while ls_iter < max_ls:
  47. # check conditions
  48. if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
  49. bracket = [t_prev, t]
  50. bracket_f = [f_prev, f_new]
  51. bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
  52. bracket_gtd = [gtd_prev, gtd_new]
  53. break
  54. if abs(gtd_new) <= -c2 * gtd:
  55. bracket = [t]
  56. bracket_f = [f_new]
  57. bracket_g = [g_new]
  58. done = True
  59. break
  60. if gtd_new >= 0:
  61. bracket = [t_prev, t]
  62. bracket_f = [f_prev, f_new]
  63. bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
  64. bracket_gtd = [gtd_prev, gtd_new]
  65. break
  66. # interpolate
  67. min_step = t + 0.01 * (t - t_prev)
  68. max_step = t * 10
  69. tmp = t
  70. t = _cubic_interpolate(
  71. t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)
  72. )
  73. # next step
  74. t_prev = tmp
  75. f_prev = f_new
  76. g_prev = g_new.clone(memory_format=torch.contiguous_format)
  77. gtd_prev = gtd_new
  78. f_new, g_new = obj_func(x, t, d)
  79. ls_func_evals += 1
  80. gtd_new = g_new.dot(d)
  81. ls_iter += 1
  82. # reached max number of iterations?
  83. if ls_iter == max_ls:
  84. bracket = [0, t]
  85. bracket_f = [f, f_new]
  86. bracket_g = [g, g_new]
  87. # zoom phase: we now have a point satisfying the criteria, or
  88. # a bracket around it. We refine the bracket until we find the
  89. # exact point satisfying the criteria
  90. insuf_progress = False
  91. # find high and low points in bracket
  92. low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined]
  93. while not done and ls_iter < max_ls:
  94. # line-search bracket is so small
  95. if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined]
  96. break
  97. # compute new trial value
  98. t = _cubic_interpolate(
  99. bracket[0],
  100. bracket_f[0],
  101. bracket_gtd[0], # type: ignore[possibly-undefined]
  102. bracket[1],
  103. bracket_f[1],
  104. bracket_gtd[1],
  105. )
  106. # test that we are making sufficient progress:
  107. # in case `t` is so close to boundary, we mark that we are making
  108. # insufficient progress, and if
  109. # + we have made insufficient progress in the last step, or
  110. # + `t` is at one of the boundary,
  111. # we will move `t` to a position which is `0.1 * len(bracket)`
  112. # away from the nearest boundary point.
  113. eps = 0.1 * (max(bracket) - min(bracket))
  114. if min(max(bracket) - t, t - min(bracket)) < eps:
  115. # interpolation close to boundary
  116. if insuf_progress or t >= max(bracket) or t <= min(bracket):
  117. # evaluate at 0.1 away from boundary
  118. if abs(t - max(bracket)) < abs(t - min(bracket)):
  119. t = max(bracket) - eps
  120. else:
  121. t = min(bracket) + eps
  122. insuf_progress = False
  123. else:
  124. insuf_progress = True
  125. else:
  126. insuf_progress = False
  127. # Evaluate new point
  128. f_new, g_new = obj_func(x, t, d)
  129. ls_func_evals += 1
  130. gtd_new = g_new.dot(d)
  131. ls_iter += 1
  132. if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
  133. # Armijo condition not satisfied or not lower than lowest point
  134. bracket[high_pos] = t
  135. bracket_f[high_pos] = f_new
  136. bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
  137. bracket_gtd[high_pos] = gtd_new
  138. low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
  139. else:
  140. if abs(gtd_new) <= -c2 * gtd:
  141. # Wolfe conditions satisfied
  142. done = True
  143. elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
  144. # old high becomes new low
  145. bracket[high_pos] = bracket[low_pos]
  146. bracket_f[high_pos] = bracket_f[low_pos]
  147. bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
  148. bracket_gtd[high_pos] = bracket_gtd[low_pos]
  149. # new point becomes new low
  150. bracket[low_pos] = t
  151. bracket_f[low_pos] = f_new
  152. bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
  153. bracket_gtd[low_pos] = gtd_new
  154. # return stuff
  155. t = bracket[low_pos] # type: ignore[possibly-undefined]
  156. f_new = bracket_f[low_pos]
  157. g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
  158. return f_new, g_new, t, ls_func_evals
  159. class LBFGS(Optimizer):
  160. """Implements L-BFGS algorithm.
  161. Heavily inspired by `minFunc
  162. <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
  163. .. warning::
  164. This optimizer doesn't support per-parameter options and parameter
  165. groups (there can be only one).
  166. .. warning::
  167. Right now all parameters have to be on a single device. This will be
  168. improved in the future.
  169. .. note::
  170. This is a very memory intensive optimizer (it requires additional
  171. ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
  172. try reducing the history size, or use a different algorithm.
  173. Args:
  174. params (iterable): iterable of parameters to optimize. Parameters must be real.
  175. lr (float, optional): learning rate (default: 1)
  176. max_iter (int, optional): maximal number of iterations per optimization step
  177. (default: 20)
  178. max_eval (int, optional): maximal number of function evaluations per optimization
  179. step (default: max_iter * 1.25).
  180. tolerance_grad (float, optional): termination tolerance on first order optimality
  181. (default: 1e-7).
  182. tolerance_change (float, optional): termination tolerance on function
  183. value/parameter changes (default: 1e-9).
  184. history_size (int, optional): update history size (default: 100).
  185. line_search_fn (str, optional): either 'strong_wolfe' or None (default: None).
  186. """
  187. def __init__(
  188. self,
  189. params: ParamsT,
  190. lr: Union[float, Tensor] = 1,
  191. max_iter: int = 20,
  192. max_eval: Optional[int] = None,
  193. tolerance_grad: float = 1e-7,
  194. tolerance_change: float = 1e-9,
  195. history_size: int = 100,
  196. line_search_fn: Optional[str] = None,
  197. ):
  198. if isinstance(lr, Tensor) and lr.numel() != 1:
  199. raise ValueError("Tensor lr must be 1-element")
  200. if not 0.0 <= lr:
  201. raise ValueError(f"Invalid learning rate: {lr}")
  202. if max_eval is None:
  203. max_eval = max_iter * 5 // 4
  204. defaults = {
  205. "lr": lr,
  206. "max_iter": max_iter,
  207. "max_eval": max_eval,
  208. "tolerance_grad": tolerance_grad,
  209. "tolerance_change": tolerance_change,
  210. "history_size": history_size,
  211. "line_search_fn": line_search_fn,
  212. }
  213. super().__init__(params, defaults)
  214. if len(self.param_groups) != 1:
  215. raise ValueError(
  216. "LBFGS doesn't support per-parameter options (parameter groups)"
  217. )
  218. self._params = self.param_groups[0]["params"]
  219. self._numel_cache = None
  220. def _numel(self):
  221. if self._numel_cache is None:
  222. self._numel_cache = sum(
  223. 2 * p.numel() if torch.is_complex(p) else p.numel()
  224. for p in self._params
  225. )
  226. return self._numel_cache
  227. def _gather_flat_grad(self):
  228. views = []
  229. for p in self._params:
  230. if p.grad is None:
  231. view = p.new(p.numel()).zero_()
  232. elif p.grad.is_sparse:
  233. view = p.grad.to_dense().view(-1)
  234. else:
  235. view = p.grad.view(-1)
  236. if torch.is_complex(view):
  237. view = torch.view_as_real(view).view(-1)
  238. views.append(view)
  239. return torch.cat(views, 0)
  240. def _add_grad(self, step_size, update):
  241. offset = 0
  242. for p in self._params:
  243. if torch.is_complex(p):
  244. p = torch.view_as_real(p)
  245. numel = p.numel()
  246. # view as to avoid deprecated pointwise semantics
  247. p.add_(update[offset : offset + numel].view_as(p), alpha=step_size)
  248. offset += numel
  249. assert offset == self._numel()
  250. def _clone_param(self):
  251. return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
  252. def _set_param(self, params_data):
  253. for p, pdata in zip(self._params, params_data):
  254. p.copy_(pdata)
  255. def _directional_evaluate(self, closure, x, t, d):
  256. self._add_grad(t, d)
  257. loss = float(closure())
  258. flat_grad = self._gather_flat_grad()
  259. self._set_param(x)
  260. return loss, flat_grad
  261. @torch.no_grad()
  262. def step(self, closure): # type: ignore[override]
  263. """Perform a single optimization step.
  264. Args:
  265. closure (Callable): A closure that reevaluates the model
  266. and returns the loss.
  267. """
  268. assert len(self.param_groups) == 1
  269. # Make sure the closure is always called with grad enabled
  270. closure = torch.enable_grad()(closure)
  271. group = self.param_groups[0]
  272. lr = _to_scalar(group["lr"])
  273. max_iter = group["max_iter"]
  274. max_eval = group["max_eval"]
  275. tolerance_grad = group["tolerance_grad"]
  276. tolerance_change = group["tolerance_change"]
  277. line_search_fn = group["line_search_fn"]
  278. history_size = group["history_size"]
  279. # NOTE: LBFGS has only global state, but we register it as state for
  280. # the first param, because this helps with casting in load_state_dict
  281. state = self.state[self._params[0]]
  282. state.setdefault("func_evals", 0)
  283. state.setdefault("n_iter", 0)
  284. # evaluate initial f(x) and df/dx
  285. orig_loss = closure()
  286. loss = float(orig_loss)
  287. current_evals = 1
  288. state["func_evals"] += 1
  289. flat_grad = self._gather_flat_grad()
  290. opt_cond = flat_grad.abs().max() <= tolerance_grad
  291. # optimal condition
  292. if opt_cond:
  293. return orig_loss
  294. # tensors cached in state (for tracing)
  295. d = state.get("d")
  296. t = state.get("t")
  297. old_dirs = state.get("old_dirs")
  298. old_stps = state.get("old_stps")
  299. ro = state.get("ro")
  300. H_diag = state.get("H_diag")
  301. prev_flat_grad = state.get("prev_flat_grad")
  302. prev_loss = state.get("prev_loss")
  303. n_iter = 0
  304. # optimize for a max of max_iter iterations
  305. while n_iter < max_iter:
  306. # keep track of nb of iterations
  307. n_iter += 1
  308. state["n_iter"] += 1
  309. ############################################################
  310. # compute gradient descent direction
  311. ############################################################
  312. if state["n_iter"] == 1:
  313. d = flat_grad.neg()
  314. old_dirs = []
  315. old_stps = []
  316. ro = []
  317. H_diag = 1
  318. else:
  319. # do lbfgs update (update memory)
  320. y = flat_grad.sub(prev_flat_grad)
  321. s = d.mul(t)
  322. ys = y.dot(s) # y*s
  323. if ys > 1e-10:
  324. # updating memory
  325. if len(old_dirs) == history_size:
  326. # shift history by one (limited-memory)
  327. old_dirs.pop(0)
  328. old_stps.pop(0)
  329. ro.pop(0)
  330. # store new direction/step
  331. old_dirs.append(y)
  332. old_stps.append(s)
  333. ro.append(1.0 / ys)
  334. # update scale of initial Hessian approximation
  335. H_diag = ys / y.dot(y) # (y*y)
  336. # compute the approximate (L-BFGS) inverse Hessian
  337. # multiplied by the gradient
  338. num_old = len(old_dirs)
  339. if "al" not in state:
  340. state["al"] = [None] * history_size
  341. al = state["al"]
  342. # iteration in L-BFGS loop collapsed to use just one buffer
  343. q = flat_grad.neg()
  344. for i in range(num_old - 1, -1, -1):
  345. al[i] = old_stps[i].dot(q) * ro[i]
  346. q.add_(old_dirs[i], alpha=-al[i])
  347. # multiply by initial Hessian
  348. # r/d is the final direction
  349. d = r = torch.mul(q, H_diag)
  350. for i in range(num_old):
  351. be_i = old_dirs[i].dot(r) * ro[i]
  352. r.add_(old_stps[i], alpha=al[i] - be_i)
  353. if prev_flat_grad is None:
  354. prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
  355. else:
  356. prev_flat_grad.copy_(flat_grad)
  357. prev_loss = loss
  358. ############################################################
  359. # compute step length
  360. ############################################################
  361. # reset initial guess for step size
  362. if state["n_iter"] == 1:
  363. t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr
  364. else:
  365. t = lr
  366. # directional derivative
  367. gtd = flat_grad.dot(d) # g * d
  368. # directional derivative is below tolerance
  369. if gtd > -tolerance_change:
  370. break
  371. # optional line search: user function
  372. ls_func_evals = 0
  373. if line_search_fn is not None:
  374. # perform line search, using user function
  375. if line_search_fn != "strong_wolfe":
  376. raise RuntimeError("only 'strong_wolfe' is supported")
  377. else:
  378. x_init = self._clone_param()
  379. def obj_func(x, t, d):
  380. return self._directional_evaluate(closure, x, t, d)
  381. loss, flat_grad, t, ls_func_evals = _strong_wolfe(
  382. obj_func, x_init, t, d, loss, flat_grad, gtd
  383. )
  384. self._add_grad(t, d)
  385. opt_cond = flat_grad.abs().max() <= tolerance_grad
  386. else:
  387. # no line search, simply move with fixed-step
  388. self._add_grad(t, d)
  389. if n_iter != max_iter:
  390. # re-evaluate function only if not in last iteration
  391. # the reason we do this: in a stochastic setting,
  392. # no use to re-evaluate that function here
  393. with torch.enable_grad():
  394. loss = closure()
  395. loss = float(loss)
  396. flat_grad = self._gather_flat_grad()
  397. opt_cond = flat_grad.abs().max() <= tolerance_grad
  398. ls_func_evals = 1
  399. # update func eval
  400. current_evals += ls_func_evals
  401. state["func_evals"] += ls_func_evals
  402. ############################################################
  403. # check conditions
  404. ############################################################
  405. if n_iter == max_iter:
  406. break
  407. if current_evals >= max_eval:
  408. break
  409. # optimal condition
  410. if opt_cond:
  411. break
  412. # lack of progress
  413. if d.mul(t).abs().max() <= tolerance_change:
  414. break
  415. if abs(loss - prev_loss) < tolerance_change:
  416. break
  417. state["d"] = d
  418. state["t"] = t
  419. state["old_dirs"] = old_dirs
  420. state["old_stps"] = old_stps
  421. state["ro"] = ro
  422. state["H_diag"] = H_diag
  423. state["prev_flat_grad"] = prev_flat_grad
  424. state["prev_loss"] = prev_loss
  425. return orig_loss