lbfgs.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import defaultdict
  15. from functools import reduce
  16. import paddle
  17. from ..base import framework
  18. from .optimizer import Optimizer
  19. __all__ = []
  20. def dot(x, y):
  21. r"""
  22. NOTE: This is a temporary workaround for unstable result computed by `paddle.dot`,
  23. which will be reverted when the problem is fixed."
  24. """
  25. return (x * y).sum(axis=-1)
  26. def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
  27. r"""Cubic interpolation between (x1, f1, g1) and (x2, f2, g2).
  28. Use two points and their gradient to determine a cubic function and get the minimum point
  29. between them in the cubic curve.
  30. Reference:
  31. Jorge Nocedal, Stephen J. Wright, Numerical Optimization, Second Edition, 2006.
  32. pp59: formula 3.59
  33. Args:
  34. x1, f1, g1: point1's position, value and gradient.
  35. x2, f2, g2: point2's position, value and gradient.
  36. bounds: bounds of interpolation area
  37. Returns:
  38. min_pos: the minimum point between the specified points in the cubic curve.
  39. """
  40. # Compute bounds of interpolation area
  41. if bounds is not None:
  42. xmin_bound, xmax_bound = bounds
  43. else:
  44. xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
  45. d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
  46. d2_square = d1**2 - g1 * g2
  47. if d2_square >= 0:
  48. d2 = d2_square.sqrt()
  49. if x1 <= x2:
  50. min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
  51. else:
  52. min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
  53. return min(max(min_pos, xmin_bound), xmax_bound)
  54. else:
  55. return (xmin_bound + xmax_bound) / 2.0
  56. def _strong_wolfe(
  57. obj_func,
  58. xk,
  59. alpha,
  60. d,
  61. loss,
  62. grad,
  63. gtd,
  64. c1=1e-4,
  65. c2=0.9,
  66. tolerance_change=1e-9,
  67. max_ls=25,
  68. ):
  69. r"""Implements of line search algorithm that satisfies the strong Wolfe conditions using double zoom.
  70. Reference:
  71. Jorge Nocedal, Stephen J. Wright, Numerical Optimization, Second Edition, 2006.
  72. pp60: Algorithm 3.5 (Line Search Algorithm).
  73. Args:
  74. obj_func: the objective function to minimize. ```` accepts a multivariate input and returns a scalar.
  75. xk (Tensor): the starting point of the iterates.
  76. alpha (Scalar): the initial step size.
  77. d (Tensor): search direction.
  78. loss (scalar): the initial loss
  79. grad (Tensor): the initial grad
  80. c1 (Scalar): parameter for sufficient decrease condition.
  81. c2 (Scalar): parameter for curvature condition.
  82. tolerance_change (Scalar): terminates if the change of function value/position/parameter between
  83. two iterations is smaller than this value.
  84. max_ls(int): max iteration of line search.
  85. alpha_max (float): max step length.
  86. Returns:
  87. loss_new (Scaler): loss of obj_func at final alpha.
  88. grad_new, (Tensor): derivative of obj_func at final alpha.
  89. alpha(Tensor): optimal step length, or 0. if the line search algorithm did not converge.
  90. ls_func_evals (Scaler): number of objective function called in line search process.
  91. Following summarizes the essentials of the strong Wolfe line search algorithm.
  92. Some notations used in the description:
  93. - `func` denotes the objective function.
  94. - `obi_func` is a function of step size alpha, restricting `obj_func` on a line.
  95. obi_func = func(xk + alpha * d),
  96. where xk is the position of k'th iterate, d is the line search direction(decent direction),
  97. and a is the step size.
  98. - alpha : substitute of alpha
  99. - a1 is alpha of last iteration, which is alpha_(i-1).
  100. - a2 is alpha of current iteration, which is alpha_i.
  101. - a_lo is alpha in left position when calls zoom, which is alpha_low.
  102. - a_hi is alpha in right position when calls zoom, which is alpha_high.
  103. Line Search Algorithm:
  104. repeat
  105. Compute obi_func(a2) and derphi(a2).
  106. 1. If obi_func(a2) > obi_func(0) + c_1 * a2 * obi_func'(0) or [obi_func(a2) >= obi_func(a1) and i > 1],
  107. alpha= zoom(a1, a2) and stop;
  108. 2. If |obi_func'(a2)| <= -c_2 * obi_func'(0),
  109. alpha= a2 and stop;
  110. 3. If obi_func'(a2) >= 0,
  111. alpha= zoom(a2, a1) and stop;
  112. a1 = a2
  113. a2 = min(2 * a2, a2)
  114. i = i + 1
  115. end(repeat)
  116. zoom(a_lo, a_hi) Algorithm:
  117. repeat
  118. aj = cubic_interpolation(a_lo, a_hi)
  119. Compute obi_func(aj) and derphi(aj).
  120. 1. If obi_func(aj) > obi_func(0) + c_1 * aj * obi_func'(0) or obi_func(aj) >= obi_func(a_lo),
  121. then a_hi <- aj;
  122. 2.
  123. 2.1. If |obi_func'(aj)| <= -c_2 * obi_func'(0), then alpha= a2 and stop;
  124. 2.2. If obi_func'(aj) * (a2 - a1) >= 0, then a_hi = a_lo
  125. a_lo = aj;
  126. end(repeat)
  127. reference: https://github.com/pytorch/pytorch
  128. """
  129. d_norm = d.abs().max()
  130. grad = grad.clone()
  131. # evaluate objective and gradient using initial step
  132. loss_new, grad_new = obj_func(xk, alpha, d)
  133. ls_func_evals = 1
  134. gtd_new = dot(grad_new, d)
  135. # bracket an interval containing a point satisfying the Wolfe criteria
  136. t_prev, f_prev, g_prev, gtd_prev = (0, loss, grad, gtd)
  137. done = False
  138. ls_iter = 0
  139. while ls_iter < max_ls:
  140. # check conditions
  141. if loss_new > (loss + c1 * alpha * gtd) or (
  142. ls_iter > 1 and loss_new >= f_prev
  143. ):
  144. bracket = [t_prev, alpha]
  145. bracket_f = [f_prev, loss_new]
  146. bracket_g = [g_prev, grad_new.clone()]
  147. bracket_gtd = [gtd_prev, gtd_new]
  148. break
  149. if paddle.abs(gtd_new) <= -c2 * gtd:
  150. bracket = [alpha]
  151. bracket_f = [loss_new]
  152. bracket_g = [grad_new]
  153. done = True
  154. break
  155. if gtd_new >= 0:
  156. bracket = [t_prev, alpha]
  157. bracket_f = [f_prev, loss_new]
  158. bracket_g = [g_prev, grad_new.clone()]
  159. bracket_gtd = [gtd_prev, gtd_new]
  160. break
  161. # interpolate
  162. min_step = alpha + 0.01 * (alpha - t_prev)
  163. max_step = alpha * 10
  164. tmp = alpha
  165. alpha = _cubic_interpolate(
  166. t_prev,
  167. f_prev,
  168. gtd_prev,
  169. alpha,
  170. loss_new,
  171. gtd_new,
  172. bounds=(min_step, max_step),
  173. )
  174. # next step
  175. t_prev = tmp
  176. f_prev = loss_new
  177. g_prev = grad_new.clone()
  178. gtd_prev = gtd_new
  179. loss_new, grad_new = obj_func(xk, alpha, d)
  180. ls_func_evals += 1
  181. gtd_new = dot(grad_new, d)
  182. ls_iter += 1
  183. # reached max number of iterations?
  184. if ls_iter == max_ls:
  185. bracket = [0, alpha]
  186. bracket_f = [loss, loss_new]
  187. bracket_g = [grad, grad_new]
  188. # zoom phase: we now have a point satisfying the criteria, or
  189. # a bracket around it. We refine the bracket until we find the
  190. # exact point satisfying the criteria
  191. insuf_progress = False
  192. # find high and low points in bracket
  193. low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
  194. while not done and ls_iter < max_ls:
  195. # line-search bracket is so small
  196. bracket_ls = bracket[1] - bracket[0]
  197. if not isinstance(bracket_ls, paddle.Tensor):
  198. bracket_ls = paddle.to_tensor(bracket_ls, dtype=gtd_new.dtype)
  199. if paddle.abs(bracket_ls) * d_norm < tolerance_change:
  200. break
  201. # compute new trial value
  202. alpha = _cubic_interpolate(
  203. bracket[0],
  204. bracket_f[0],
  205. bracket_gtd[0],
  206. bracket[1],
  207. bracket_f[1],
  208. bracket_gtd[1],
  209. )
  210. # test that we are making sufficient progress:
  211. # in case `alpha` is so close to boundary, we mark that we are making
  212. # insufficient progress, and if
  213. # + we have made insufficient progress in the last step, or
  214. # + `alpha` is at one of the boundary,
  215. # we will move `alpha` to a position which is `0.1 * len(bracket)`
  216. # away from the nearest boundary point.
  217. eps = 0.1 * (max(bracket) - min(bracket))
  218. if min(max(bracket) - alpha, alpha - min(bracket)) < eps:
  219. # interpolation close to boundary
  220. if insuf_progress or alpha >= max(bracket) or alpha <= min(bracket):
  221. # evaluate at 0.1 away from boundary
  222. if paddle.abs(alpha - max(bracket)) < paddle.abs(
  223. alpha - min(bracket)
  224. ):
  225. alpha = max(bracket) - eps
  226. else:
  227. alpha = min(bracket) + eps
  228. insuf_progress = False
  229. else:
  230. insuf_progress = True
  231. else:
  232. insuf_progress = False
  233. # Evaluate new point
  234. loss_new, grad_new = obj_func(xk, alpha, d)
  235. ls_func_evals += 1
  236. gtd_new = dot(grad_new, d)
  237. ls_iter += 1
  238. if (
  239. loss_new > (loss + c1 * alpha * gtd)
  240. or loss_new >= bracket_f[low_pos]
  241. ):
  242. # Armijo condition not satisfied or not lower than lowest point
  243. bracket[high_pos] = alpha
  244. bracket_f[high_pos] = loss_new
  245. bracket_g[high_pos] = grad_new.clone()
  246. bracket_gtd[high_pos] = gtd_new
  247. low_pos, high_pos = (
  248. (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
  249. )
  250. else:
  251. if paddle.abs(gtd_new) <= -c2 * gtd:
  252. # Wolfe conditions satisfied
  253. done = True
  254. elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
  255. # old high becomes new low
  256. bracket[high_pos] = bracket[low_pos]
  257. bracket_f[high_pos] = bracket_f[low_pos]
  258. bracket_g[high_pos] = bracket_g[low_pos]
  259. bracket_gtd[high_pos] = bracket_gtd[low_pos]
  260. # new point becomes new low
  261. bracket[low_pos] = alpha
  262. bracket_f[low_pos] = loss_new
  263. bracket_g[low_pos] = grad_new.clone()
  264. bracket_gtd[low_pos] = gtd_new
  265. # return stuff
  266. alpha = bracket[low_pos]
  267. loss_new = bracket_f[low_pos]
  268. grad_new = bracket_g[low_pos]
  269. return loss_new, grad_new, alpha, ls_func_evals
  270. class LBFGS(Optimizer):
  271. r"""
  272. The L-BFGS is a quasi-Newton method for solving an unconstrained optimization problem over a differentiable function.
  273. Closely related is the Newton method for minimization. Consider the iterate update formula:
  274. .. math::
  275. x_{k+1} = x_{k} + H_k \nabla{f_k}
  276. If :math:`H_k` is the inverse Hessian of :math:`f` at :math:`x_k`, then it's the Newton method.
  277. If :math:`H_k` is symmetric and positive definite, used as an approximation of the inverse Hessian, then
  278. it's a quasi-Newton. In practice, the approximated Hessians are obtained
  279. by only using the gradients, over either whole or part of the search
  280. history, the former is BFGS, the latter is L-BFGS.
  281. Reference:
  282. Jorge Nocedal, Stephen J. Wright, Numerical Optimization, Second Edition, 2006. pp179: Algorithm 7.5 (L-BFGS).
  283. Args:
  284. learning_rate (float, optional): learning rate .The default value is 1.
  285. max_iter (int, optional): maximal number of iterations per optimization step.
  286. The default value is 20.
  287. max_eval (int, optional): maximal number of function evaluations per optimization
  288. step. The default value is max_iter * 1.25.
  289. tolerance_grad (float, optional): termination tolerance on first order optimality
  290. The default value is 1e-5.
  291. tolerance_change (float, optional): termination tolerance on function
  292. value/parameter changes. The default value is 1e-9.
  293. history_size (int, optional): update history size. The default value is 100.
  294. line_search_fn (string, optional): either 'strong_wolfe' or None. The default value is strong_wolfe.
  295. parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \
  296. This parameter is required in dygraph mode. The default value is None.
  297. weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
  298. It canbe a float value as coeff of L2 regularization or \
  299. :ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`.
  300. If a parameter has set regularizer using :ref:`api_paddle_ParamAttr` already, \
  301. the regularization setting here in optimizer will be ignored for this parameter. \
  302. Otherwise, the regularization setting here in optimizer will take effect. \
  303. Default None, meaning there is no regularization.
  304. grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of \
  305. some derived class of ``GradientClipBase`` . There are three clipping strategies \
  306. ( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` , \
  307. :ref:`api_paddle_nn_ClipGradByValue` ). Default None, meaning there is no gradient clipping.
  308. name (str, optional): Normally there is no need for user to set this property.
  309. For more information, please refer to :ref:`api_guide_Name`.
  310. The default value is None.
  311. Return:
  312. loss (Tensor): the final loss of closure.
  313. Examples:
  314. .. code-block:: python
  315. >>> import paddle
  316. >>> import numpy as np
  317. >>> paddle.disable_static()
  318. >>> np.random.seed(0)
  319. >>> np_w = np.random.rand(1).astype(np.float32)
  320. >>> np_x = np.random.rand(1).astype(np.float32)
  321. >>> inputs = [np.random.rand(1).astype(np.float32) for i in range(10)]
  322. >>> # y = 2x
  323. >>> targets = [2 * x for x in inputs]
  324. >>> class Net(paddle.nn.Layer):
  325. ... def __init__(self):
  326. ... super().__init__()
  327. ... w = paddle.to_tensor(np_w)
  328. ... self.w = paddle.create_parameter(shape=w.shape, dtype=w.dtype, default_initializer=paddle.nn.initializer.Assign(w))
  329. ...
  330. ... def forward(self, x):
  331. ... return self.w * x
  332. ...
  333. >>> net = Net()
  334. >>> opt = paddle.optimizer.LBFGS(learning_rate=1, max_iter=1, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe', parameters=net.parameters())
  335. >>> def train_step(inputs, targets):
  336. ... def closure():
  337. ... outputs = net(inputs)
  338. ... loss = paddle.nn.functional.mse_loss(outputs, targets)
  339. ... print('loss: ', loss.item())
  340. ... opt.clear_grad()
  341. ... loss.backward()
  342. ... return loss
  343. ... opt.step(closure)
  344. ...
  345. >>> for input, target in zip(inputs, targets):
  346. ... input = paddle.to_tensor(input)
  347. ... target = paddle.to_tensor(target)
  348. ... train_step(input, target)
  349. """
  350. def __init__(
  351. self,
  352. learning_rate=1.0,
  353. max_iter=20,
  354. max_eval=None,
  355. tolerance_grad=1e-7,
  356. tolerance_change=1e-9,
  357. history_size=100,
  358. line_search_fn=None,
  359. parameters=None,
  360. weight_decay=None,
  361. grad_clip=None,
  362. name=None,
  363. ):
  364. if max_eval is None:
  365. max_eval = max_iter * 5 // 4
  366. self.learning_rate = learning_rate
  367. self.max_iter = max_iter
  368. self.max_eval = max_eval
  369. self.tolerance_grad = tolerance_grad
  370. self.tolerance_change = tolerance_change
  371. self.history_size = history_size
  372. self.line_search_fn = line_search_fn
  373. if isinstance(parameters, paddle.Tensor):
  374. raise TypeError(
  375. "parameters argument given to the optimizer should be "
  376. "an iterable of Tensors or dicts, but got " + type(parameters)
  377. )
  378. self.state = defaultdict(dict)
  379. super().__init__(
  380. learning_rate=1.0,
  381. parameters=parameters,
  382. weight_decay=weight_decay,
  383. grad_clip=grad_clip,
  384. name=name,
  385. )
  386. if not isinstance(self._parameter_list[0], dict):
  387. self._params = self._parameter_list
  388. else:
  389. for idx, param_group in enumerate(self._param_groups):
  390. self._params = param_group['params']
  391. self._numel_cache = None
  392. def state_dict(self):
  393. r"""Returns the state of the optimizer as a :class:`dict`.
  394. Return:
  395. state, a dict holding current optimization state. Its content
  396. differs between optimizer classes.
  397. Examples:
  398. .. code-block:: python
  399. >>> import paddle
  400. >>> paddle.disable_static()
  401. >>> net = paddle.nn.Linear(10, 10)
  402. >>> opt = paddle.optimizer.LBFGS(
  403. ... learning_rate=1,
  404. ... max_iter=1,
  405. ... max_eval=None,
  406. ... tolerance_grad=1e-07,
  407. ... tolerance_change=1e-09,
  408. ... history_size=100,
  409. ... line_search_fn='strong_wolfe',
  410. ... parameters=net.parameters(),
  411. >>> )
  412. >>> def train_step(inputs, targets):
  413. ... def closure():
  414. ... outputs = net(inputs)
  415. ... loss = paddle.nn.functional.mse_loss(outputs, targets)
  416. ... opt.clear_grad()
  417. ... loss.backward()
  418. ... return loss
  419. ...
  420. ... opt.step(closure)
  421. ...
  422. >>> inputs = paddle.rand([10, 10], dtype="float32")
  423. >>> targets = paddle.to_tensor([2 * x for x in inputs])
  424. >>> n_iter = 0
  425. >>> while n_iter < 20:
  426. ... loss = train_step(inputs, targets)
  427. ... n_iter = opt.state_dict()["state"]["func_evals"]
  428. ... print("n_iter:", n_iter)
  429. ...
  430. """
  431. packed_state = {}
  432. for k, v in self.state.items():
  433. packed_state.update({k: v})
  434. return {'state': packed_state}
  435. def _numel(self):
  436. # compute the number of all parameters
  437. if self._numel_cache is None:
  438. self._numel_cache = reduce(
  439. lambda total, p: total + p.numel(), self._params, 0
  440. )
  441. return self._numel_cache
  442. # flatten grad of all parameters
  443. def _gather_flat_grad(self):
  444. views = []
  445. for p in self._params:
  446. if p.grad is None:
  447. view = paddle.zeros_like(p).reshape([-1])
  448. else:
  449. view = p.grad.reshape([-1])
  450. views.append(view)
  451. return paddle.concat(views, axis=0)
  452. # compute xk = xk + alpha * direction
  453. def _add_grad(self, alpha, direction):
  454. offset = 0
  455. for p in self._params:
  456. numel = reduce(lambda x, y: x * y, p.shape) if p.shape != [] else 1
  457. p = paddle.assign(
  458. p.add(
  459. direction[offset : offset + numel].reshape(p.shape) * alpha
  460. ),
  461. p,
  462. )
  463. offset += numel
  464. assert offset == self._numel()
  465. def _clone_param(self):
  466. return [p.clone() for p in self._params]
  467. def _set_param(self, params_data):
  468. for p, pdata in zip(self._params, params_data):
  469. paddle.assign(pdata, p)
  470. def _directional_evaluate(self, closure, x, alpha, d):
  471. self._add_grad(alpha, d)
  472. loss = float(closure())
  473. flat_grad = self._gather_flat_grad()
  474. self._set_param(x)
  475. return loss, flat_grad
  476. @framework.non_static_only
  477. def step(self, closure):
  478. """Performs a single optimization step.
  479. Args:
  480. closure (callable): A closure that reevaluates the model
  481. and returns the loss.
  482. Examples:
  483. .. code-block:: python
  484. >>> import paddle
  485. >>> paddle.disable_static()
  486. >>> inputs = paddle.rand([10, 10], dtype="float32")
  487. >>> targets = paddle.to_tensor([2 * x for x in inputs])
  488. >>> net = paddle.nn.Linear(10, 10)
  489. >>> opt = paddle.optimizer.LBFGS(
  490. ... learning_rate=1,
  491. ... max_iter=1,
  492. ... max_eval=None,
  493. ... tolerance_grad=1e-07,
  494. ... tolerance_change=1e-09,
  495. ... history_size=100,
  496. ... line_search_fn='strong_wolfe',
  497. ... parameters=net.parameters(),
  498. >>> )
  499. >>> def closure():
  500. ... outputs = net(inputs)
  501. ... loss = paddle.nn.functional.mse_loss(outputs, targets)
  502. ... print("loss:", loss.item())
  503. ... opt.clear_grad()
  504. ... loss.backward()
  505. ... return loss
  506. ...
  507. >>> opt.step(closure)
  508. """
  509. with paddle.no_grad():
  510. # Make sure the closure is always called with grad enabled
  511. closure = paddle.enable_grad()(closure)
  512. learning_rate = self.learning_rate
  513. max_iter = self.max_iter
  514. max_eval = self.max_eval
  515. tolerance_grad = self.tolerance_grad
  516. tolerance_change = self.tolerance_change
  517. line_search_fn = self.line_search_fn
  518. history_size = self.history_size
  519. state = self.state
  520. state.setdefault('func_evals', 0)
  521. state.setdefault('n_iter', 0)
  522. # evaluate initial f(x) and df/dx
  523. orig_loss = closure()
  524. loss = float(orig_loss)
  525. current_evals = 1
  526. state['func_evals'] += 1
  527. flat_grad = self._gather_flat_grad()
  528. opt_cond = flat_grad.abs().max() <= tolerance_grad
  529. # optimal condition
  530. if opt_cond:
  531. return orig_loss
  532. # tensors cached in state (for tracing)
  533. d = state.get('d')
  534. alpha = state.get('alpha')
  535. old_yk = state.get('old_yk')
  536. old_sk = state.get('old_sk')
  537. ro = state.get('ro')
  538. H_diag = state.get('H_diag')
  539. prev_flat_grad = state.get('prev_flat_grad')
  540. prev_loss = state.get('prev_loss')
  541. n_iter = 0
  542. # optimize for a max of max_iter iterations
  543. while n_iter < max_iter:
  544. # keep track of nb of iterations
  545. n_iter += 1
  546. state['n_iter'] += 1
  547. ############################################################
  548. # compute gradient descent direction
  549. ############################################################
  550. if state['n_iter'] == 1:
  551. d = flat_grad.neg()
  552. old_yk = []
  553. old_sk = []
  554. ro = []
  555. H_diag = paddle.to_tensor(1.0, dtype=orig_loss.dtype)
  556. else:
  557. # do lbfgs update (update memory)
  558. y = flat_grad.subtract(prev_flat_grad)
  559. s = d.multiply(paddle.to_tensor(alpha, dtype=d.dtype))
  560. ys = dot(y, s)
  561. if ys > 1e-10:
  562. # updating memory
  563. if len(old_yk) == history_size:
  564. # shift history by one (limited-memory)
  565. old_yk.pop(0)
  566. old_sk.pop(0)
  567. ro.pop(0)
  568. # store new direction/step
  569. old_yk.append(y)
  570. old_sk.append(s)
  571. ro.append(1.0 / ys)
  572. # update scale of initial Hessian approximation
  573. H_diag = ys / dot(y, y) # (y*y)
  574. # compute the approximate (L-BFGS) inverse Hessian
  575. # multiplied by the gradient
  576. num_old = len(old_yk)
  577. if 'al' not in state:
  578. state['al'] = [None] * history_size
  579. al = state['al']
  580. # iteration in L-BFGS loop collapsed to use just one buffer
  581. q = flat_grad.neg()
  582. for i in range(num_old - 1, -1, -1):
  583. al[i] = dot(old_sk[i], q) * ro[i]
  584. paddle.assign(q.add(old_yk[i] * (-al[i])), q)
  585. # multiply by initial Hessian
  586. # r/d is the final direction
  587. d = r = paddle.multiply(q, H_diag)
  588. for i in range(num_old):
  589. be_i = dot(old_yk[i], r) * ro[i]
  590. paddle.assign(r.add(old_sk[i] * (al[i] - be_i)), r)
  591. if prev_flat_grad is None:
  592. prev_flat_grad = flat_grad.clone()
  593. else:
  594. paddle.assign(flat_grad, prev_flat_grad)
  595. prev_loss = loss
  596. ############################################################
  597. # compute step length
  598. ############################################################
  599. # reset initial guess for step size
  600. if state['n_iter'] == 1:
  601. alpha = (
  602. min(1.0, 1.0 / flat_grad.abs().sum()) * learning_rate
  603. )
  604. else:
  605. alpha = learning_rate
  606. # directional derivative
  607. gtd = dot(flat_grad, d)
  608. # directional derivative is below tolerance
  609. if gtd > -tolerance_change:
  610. break
  611. # optional line search: user function
  612. ls_func_evals = 0
  613. if line_search_fn is not None:
  614. # perform line search, using user function
  615. if line_search_fn != "strong_wolfe":
  616. raise RuntimeError("only 'strong_wolfe' is supported")
  617. else:
  618. x_init = self._clone_param()
  619. def obj_func(x, alpha, d):
  620. return self._directional_evaluate(
  621. closure, x, alpha, d
  622. )
  623. loss, flat_grad, alpha, ls_func_evals = _strong_wolfe(
  624. obj_func, x_init, alpha, d, loss, flat_grad, gtd
  625. )
  626. self._add_grad(alpha, d)
  627. opt_cond = flat_grad.abs().max() <= tolerance_grad
  628. else:
  629. # no line search, simply move with fixed-step
  630. self._add_grad(alpha, d)
  631. if n_iter != max_iter:
  632. with paddle.enable_grad():
  633. loss = float(closure())
  634. flat_grad = self._gather_flat_grad()
  635. opt_cond = flat_grad.abs().max() <= tolerance_grad
  636. ls_func_evals = 1
  637. # update func eval
  638. current_evals += ls_func_evals
  639. state['func_evals'] += ls_func_evals
  640. # optimal condition
  641. if opt_cond:
  642. break
  643. # lack of progress
  644. if (d * alpha).abs().max() <= tolerance_change:
  645. break
  646. if abs(loss - prev_loss) < tolerance_change:
  647. break
  648. # check conditions
  649. if current_evals >= max_eval:
  650. break
  651. if n_iter == max_iter:
  652. break
  653. state['d'] = d
  654. state['alpha'] = alpha
  655. state['old_yk'] = old_yk
  656. state['old_sk'] = old_sk
  657. state['ro'] = ro
  658. state['H_diag'] = H_diag
  659. state['prev_flat_grad'] = prev_flat_grad
  660. state['prev_loss'] = prev_loss
  661. return orig_loss
  662. def minimize(
  663. self, loss, startup_program=None, parameters=None, no_grad_set=None
  664. ):
  665. """Empty method. LBFGS optimizer does not use this way to minimize ``loss``. Please refer 'Examples' of LBFGS() above for usage."""
  666. raise NotImplementedError(
  667. "LBFGS optimizer does not use this way to minimize loss. Please refer 'Examples' of LBFGS() for usage."
  668. )