1 """
2 Line searches: find minimum of a multivariate function.
3
4 Optionnaly depends on scipy.optimize for some line searches.
5
6 Available:
7
8 - optimal step (exact minimum if Criterion is quadratic (only Norm2
9 norms))
10
11 - Backtracking : starts with optimal steps and reduces step until
12 criterion decreases.
13
14 if scipy.optimize is in PYTHONPATH:
15
16 - LineSearch, LineSearchArmijo, LineSearchWolfe1; LineSearchWolfe2
17 """
18 import numpy as np
19 import norms
20
22 """
23 Finds quadratic optimal step of a criterion.
24
25 Arguments
26 ----------
27
28 algo: Algoritm instance with the following attributes:
29 current_descent, current_gradient, criterion. The criterion
30 attribute should be a Criterion instance with the following
31 attributes: model, priors, hypers, norms.
32
33 Returns
34 -------
35 a: float
36 The optimal step.
37 """
38
39 d = algo.current_descent
40 g = algo.current_gradient
41 H = algo.criterion.model
42 Ds = algo.criterion.priors
43 hypers = algo.criterion.hypers
44 algo_norms = algo.criterion.norms
45
46
47 algo_norms = [n if isinstance(n, norms.Norm2) else norms.Norm2() for n in algo_norms]
48
49 a = -.5 * np.dot(d.T, g)
50 a /= algo_norms[0](H * d) + np.sum([h * n(D * d) for h, D, n in zip(hypers, Ds, algo_norms)])
51 return a
52
55 self.maxiter = maxiter
56 self.tau = tau
58 x = algo.current_solution
59 d = algo.current_descent
60 a = optimal_step(algo)
61 i = 0
62 f0 = algo.current_criterion
63 fi = 2 * f0
64 while (i < self.maxiter) and (fi > f0):
65 i += 1
66 a *= self.tau
67 xi = x + a * d
68 fi = algo.criterion(xi)
69 return a
70
71 default_backtracking = Backtracking()
72
73
74 try:
75 from scipy.optimize import linesearch
76 except ImportError:
77 pass
78
79 if 'linesearch' in locals():
81 """
82 Wraps scipy.optimize.linesearch.line_search
83 """
85 self.args = args
86 self.kwargs = kwargs
87 self.f = None
88 self.fprime = None
89 self.xk = None
90 self.pk = None
91 self.gfk = None
92 self.old_fval = None
93 self.old_old_fval = None
94 self.step = None
96 self.f = algo.criterion
97 self.fprime = algo.gradient
98 self.xk = algo.current_solution
99 self.pk = algo.current_descent
100 self.gfk = algo.current_gradient
101 self.old_fval = algo.current_criterion
102 self.old_old_fval = algo.last_criterion
104 line_search = linesearch.line_search
105 out = line_search(s.f, s.fprime, s.xk, s.pk, gfk=s.gfk,
106 old_fval=s.old_fval,
107 old_old_fval=s.old_old_fval,
108 args=s.args, **s.kwargs)
109 s.step = out[0]
119
121 """
122 Wraps scipy.optimize.linesearch.line_search_armijo.
123 """
125 armijo = linesearch.line_search_armijo
126 out = armijo(s.f, s.xk, s.pk, s.gfk, s.old_fval, args=s.args,
127 **s.kwargs)
128 s.step = out[0]
129
131 """
132 Wraps scipy.optimize.linesearch.line_search_wolfe1
133 """
135 wolfe1 = linesearch.line_search_wolfe1
136 out = wolfe1(s.f, s.fprime, s.xk, s.pk, s.gfk, s.old_fval,
137 s.old_old_fval, args=s.args, **s.kwargs)
138 s.step = out[0]
139
141 """
142 Wraps scipy.optimize.linesearch.line_search_wolfe2
143 """
145 wolfe2 = linesearch.line_search_wolfe2
146 out = wolfe2(s.f, s.fprime, s.xk, s.pk, s.gfk, s.old_fval,
147 s.old_old_fval, args=s.args, **s.kwargs)
148 s.step = out[0]
149