00001 #include <TooN/optimization/brent.h>
00002 #include <utility>
00003 #include <cmath>
00004 #include <cassert>
00005 #include <cstdlib>
00006
00007 namespace TooN{
00008 namespace Internal{
00009
00010
00011
00012
00013
00014
00015
00016
00017 template<int Size, typename Precision, typename Func> struct LineSearch
00018 {
00019 const Vector<Size, Precision>& start;
00020 const Vector<Size, Precision>& direction;
00021
00022 const Func& f;
00023
00024
00025
00026
00027
00028 LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func)
00029 :start(s),direction(d),f(func)
00030 {}
00031
00032
00033
00034 Precision operator()(Precision x) const
00035 {
00036 return f(start + x * direction);
00037 }
00038 };
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051 template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps)
00052 {
00053
00054 Precision a, b, c, b_val, c_val;
00055
00056 a=0;
00057
00058
00059 Precision lambda=initial_lambda;
00060 b = lambda;
00061 b_val = func(b);
00062
00063 while(std::isnan(b_val))
00064 {
00065
00066
00067
00068 lambda*=.5;
00069 b = lambda;
00070 b_val = func(b);
00071
00072 }
00073
00074
00075 if(b_val < a_val)
00076 {
00077 double last_good_lambda = lambda;
00078
00079 for(;;)
00080 {
00081 lambda *= 2;
00082 c = lambda;
00083 c_val = func(c);
00084
00085 if(std::isnan(c_val))
00086 break;
00087 last_good_lambda = lambda;
00088 if(c_val > b_val)
00089 break;
00090 else
00091 {
00092 a = b;
00093 a_val = b_val;
00094 b=c;
00095 b_val=c_val;
00096
00097 }
00098 }
00099
00100
00101
00102 if(std::isnan(c_val))
00103 {
00104 double bad_lambda=lambda;
00105 double l=1;
00106
00107 for(;;)
00108 {
00109 l*=.5;
00110 c = last_good_lambda + (bad_lambda - last_good_lambda)*l;
00111 c_val = func(c);
00112
00113 if(!std::isnan(c_val))
00114 break;
00115 }
00116
00117
00118 }
00119
00120 }
00121 else
00122 {
00123 c = b;
00124 c_val = b_val;
00125
00126
00127 for(;;)
00128 {
00129 lambda *= .5;
00130 b = lambda;
00131 b_val = func(b);
00132
00133 if(b_val < a_val)
00134 break;
00135 else if(lambda < zeps)
00136 return Zeros;
00137 else
00138 {
00139 c = b;
00140 c_val = b_val;
00141 }
00142 }
00143 }
00144
00145 Matrix<3,2> ret;
00146 ret[0] = makeVector(a, a_val);
00147 ret[1] = makeVector(b, b_val);
00148 ret[2] = makeVector(c, c_val);
00149
00150 return ret;
00151 }
00152
00153 }
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200 template<int Size, class Precision=double> struct ConjugateGradient
00201 {
00202 const int size;
00203 Vector<Size> g;
00204 Vector<Size> h;
00205 Vector<Size> minus_h;
00206 Vector<Size> old_g;
00207 Vector<Size> old_h;
00208 Vector<Size> x;
00209 Vector<Size> old_x;
00210 Precision y;
00211 Precision old_y;
00212
00213 Precision tolerance;
00214 Precision epsilon;
00215 int max_iterations;
00216
00217 Precision bracket_initial_lambda;
00218 Precision linesearch_tolerance;
00219 Precision linesearch_epsilon;
00220 int linesearch_max_iterations;
00221
00222 Precision bracket_epsilon;
00223
00224 int iterations;
00225
00226
00227
00228
00229
00230 template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv)
00231 : size(start.size()),
00232 g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
00233 {
00234 init(start, func(start), deriv(start));
00235 }
00236
00237
00238
00239
00240
00241 template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv)
00242 : size(start.size()),
00243 g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
00244 {
00245 init(start, func(start), deriv);
00246 }
00247
00248
00249
00250
00251
00252 void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv)
00253 {
00254
00255 using std::numeric_limits;
00256 x = start;
00257
00258
00259
00260 g = deriv;
00261 h = g;
00262 minus_h=-h;
00263
00264 y = func;
00265 old_y = y;
00266
00267 tolerance = sqrt(numeric_limits<Precision>::epsilon());
00268 epsilon = 1e-20;
00269 max_iterations = size * 100;
00270
00271 bracket_initial_lambda = 1;
00272
00273 linesearch_tolerance = sqrt(numeric_limits<Precision>::epsilon());
00274 linesearch_epsilon = 1e-20;
00275 linesearch_max_iterations=100;
00276
00277 bracket_epsilon=1e-20;
00278
00279 iterations=0;
00280 }
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296 template<class Func> void find_next_point(const Func& func)
00297 {
00298 Internal::LineSearch<Size, Precision, Func> line(x, minus_h, func);
00299
00300
00301
00302 Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon);
00303
00304 double a = bracket[0][0];
00305 double b = bracket[1][0];
00306 double c = bracket[2][0];
00307
00308 double a_val = bracket[0][1];
00309 double b_val = bracket[1][1];
00310 double c_val = bracket[2][1];
00311
00312 old_y = y;
00313 old_x = x;
00314 iterations++;
00315
00316
00317 if(a==0 && b== 0 && c == 0)
00318 return;
00319
00320
00321
00322 if(c < b)
00323 {
00324
00325
00326 x-=h * c;
00327 y=c_val;
00328
00329 }
00330 else
00331 {
00332 assert(a < b && b < c);
00333 assert(a_val > b_val && b_val < c_val);
00334
00335
00336 Vector<2, Precision> m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon);
00337
00338 assert(m[0] >= a && m[0] <= c);
00339 assert(m[1] <= b_val);
00340
00341
00342 x -= m[0] * h;
00343 y = m[1];
00344 }
00345 }
00346
00347
00348
00349 bool finished()
00350 {
00351 using std::abs;
00352 return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon);
00353 }
00354
00355
00356
00357
00358
00359
00360
00361
00362
00363 void update_vectors_PR(const Vector<Size>& grad)
00364 {
00365
00366 old_g = g;
00367 old_h = h;
00368
00369 g = grad;
00370
00371 Precision gamma = (g * g - old_g*g)/(old_g * old_g);
00372 h = g + gamma * old_h;
00373 minus_h=-h;
00374 }
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393 template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv)
00394 {
00395 find_next_point(func);
00396
00397 if(!finished())
00398 {
00399 update_vectors_PR(deriv(x));
00400 return 1;
00401 }
00402 else
00403 return 0;
00404 }
00405 };
00406
00407 }