00001 #ifndef TOON_DOWNHILL_SIMPLEX_H
00002 #define TOON_DOWNHILL_SIMPLEX_H
00003 #include <TooN/TooN.h>
00004 #include <TooN/helpers.h>
00005 #include <algorithm>
00006 #include <cstdlib>
00007
00008 namespace TooN
00009 {
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077 template<int N=-1, typename Precision=double> class DownhillSimplex
00078 {
00079 static const int Vertices = (N==-1?-1:N+1);
00080 typedef Matrix<Vertices, N, Precision> Simplex;
00081 typedef Vector<Vertices, Precision> Values;
00082
00083 public:
00084
00085
00086
00087
00088
00089
00090
00091
00092 template<class Function> DownhillSimplex(const Function& func, const Vector<N>& c, Precision spread=1)
00093 :simplex(c.size()+1, c.size()),values(c.size()+1)
00094 {
00095 alpha = 1.0;
00096 rho = 2.0;
00097 gamma = 0.5;
00098 sigma = 0.5;
00099
00100 epsilon = sqrt(numeric_limits<Precision>::epsilon());
00101 zero_epsilon = 1e-20;
00102
00103 restart(func, c, spread);
00104 }
00105
00106
00107
00108
00109
00110
00111
00112 template<class Function> void restart(const Function& func, const Vector<N>& c, Precision spread)
00113 {
00114 for(int i=0; i < simplex.num_rows(); i++)
00115 simplex[i] = c;
00116
00117 for(int i=0; i < simplex.num_cols(); i++)
00118 simplex[i][i] += spread;
00119
00120 for(int i=0; i < values.size(); i++)
00121 values[i] = func(simplex[i]);
00122 }
00123
00124
00125
00126
00127
00128
00129 bool finished()
00130 {
00131 Precision span = norm(simplex[get_best()] - simplex[get_worst()]);
00132 Precision scale = norm(simplex[get_best()]);
00133
00134 if(span/scale < epsilon || span < zero_epsilon)
00135 return 1;
00136 else
00137 return 0;
00138 }
00139
00140
00141
00142
00143
00144 template<class Function> void restart(const Function& func, Precision spread)
00145 {
00146 restart(func, simplex[get_best()], spread);
00147 }
00148
00149
00150 const Simplex& get_simplex() const
00151 {
00152 return simplex;
00153 }
00154
00155
00156 const Values& get_values() const
00157 {
00158 return values;
00159 }
00160
00161
00162 int get_best() const
00163 {
00164 return std::min_element(&values[0], &values[0] + values.size()) - &values[0];
00165 }
00166
00167
00168 int get_worst() const
00169 {
00170 return std::max_element(&values[0], &values[0] + values.size()) - &values[0];
00171 }
00172
00173
00174
00175 template<class Function> void find_next_point(const Function& func)
00176 {
00177
00178
00179
00180
00181
00182 int worst = get_worst();
00183 Precision second_worst_val=-HUGE_VAL, bestval = HUGE_VAL, worst_val = values[worst];
00184 int best=0;
00185 Vector<N> x0 = Zeros(simplex.num_cols());
00186
00187
00188 for(int i=0; i < simplex.num_rows(); i++)
00189 {
00190 if(values[i] < bestval)
00191 {
00192 bestval = values[i];
00193 best = i;
00194 }
00195
00196 if(i != worst)
00197 {
00198 if(values[i] > second_worst_val)
00199 second_worst_val = values[i];
00200
00201
00202 x0 += simplex[i];
00203 }
00204 }
00205 x0 *= 1.0 / simplex.num_cols();
00206
00207
00208
00209 Vector<N> xr = (1 + alpha) * x0 - alpha * simplex[worst];
00210 Precision fr = func(xr);
00211
00212 if(fr < bestval)
00213 {
00214
00215 Vector<N> xe = rho * xr + (1-rho) * x0;
00216 Precision fe = func(xe);
00217
00218
00219 if(fe < fr)
00220 {
00221 simplex[worst] = xe;
00222 values[worst] = fe;
00223 }
00224 else
00225 {
00226 simplex[worst] = xr;
00227 values[worst] = fr;
00228 }
00229
00230 return;
00231 }
00232
00233
00234
00235 if(fr < second_worst_val)
00236 {
00237 simplex[worst] = xr;
00238 values[worst] = fr;
00239 return;
00240 }
00241
00242
00243
00244
00245
00246 if(fr < worst_val)
00247 {
00248 Vector<N> xc = (1 + gamma) * x0 - gamma * simplex[worst];
00249 Precision fc = func(xc);
00250
00251
00252 if(fc <= fr)
00253 {
00254 simplex[worst] = xc;
00255 values[worst] = fc;
00256 return;
00257 }
00258 }
00259
00260
00261
00262 for(int i=0; i < simplex.num_rows(); i++)
00263 if(i != best)
00264 {
00265 simplex[i] = simplex[best] + sigma * (simplex[i] - simplex[best]);
00266 values[i] = func(simplex[i]);
00267 }
00268 }
00269
00270
00271
00272
00273 template<class Function> bool iterate(const Function& func)
00274 {
00275 find_next_point(func);
00276 return !finished();
00277 }
00278
00279 Precision alpha;
00280 Precision rho;
00281 Precision gamma;
00282 Precision sigma;
00283 Precision epsilon;
00284 Precision zero_epsilon;
00285
00286 private:
00287
00288
00289 Simplex simplex;
00290
00291
00292 Values values;
00293
00294
00295 };
00296 }
00297 #endif