00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef __SVD_H
00031 #define __SVD_H
00032
00033 #include <TooN/TooN.h>
00034 #include <TooN/lapack.h>
00035
00036 namespace TooN {
00037
00038
00039 static const double condition_no=1e9;
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087 template<int Rows=Dynamic, int Cols=Rows, typename Precision=DefaultPrecision>
00088 class SVD {
00089
00090
00091 static const int Min_Dim = Rows<Cols?Rows:Cols;
00092
00093 public:
00094
00095
00096 SVD() {}
00097
00098
00099 SVD(int rows, int cols)
00100 : my_copy(rows,cols),
00101 my_diagonal(std::min(rows,cols)),
00102 my_square(std::min(rows,cols), std::min(rows,cols))
00103 {}
00104
00105
00106
00107 template <int R2, int C2, typename P2, typename B2>
00108 SVD(const Matrix<R2,C2,P2,B2>& m)
00109 : my_copy(m),
00110 my_diagonal(std::min(m.num_rows(),m.num_cols())),
00111 my_square(std::min(m.num_rows(),m.num_cols()),std::min(m.num_rows(),m.num_cols()))
00112 {
00113 do_compute();
00114 }
00115
00116
00117 template <int R2, int C2, typename P2, typename B2>
00118 void compute(const Matrix<R2,C2,P2,B2>& m){
00119 my_copy=m;
00120 do_compute();
00121 }
00122
00123 private:
00124 void do_compute(){
00125 Precision* const a = my_copy.my_data;
00126 int lda = my_copy.num_cols();
00127 int m = my_copy.num_cols();
00128 int n = my_copy.num_rows();
00129 Precision* const uorvt = my_square.my_data;
00130 Precision* const s = my_diagonal.my_data;
00131 int ldu;
00132 int ldvt = lda;
00133 int LWORK;
00134 int INFO;
00135 char JOBU;
00136 char JOBVT;
00137
00138 if(is_vertical()){
00139 JOBU='O';
00140 JOBVT='S';
00141 ldu = lda;
00142 } else {
00143 JOBU='S';
00144 JOBVT='O';
00145 ldu = my_square.num_cols();
00146 }
00147
00148 Precision* wk;
00149
00150 Precision size;
00151 LWORK = -1;
00152
00153
00154
00155 dgesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
00156 &ldvt, uorvt, &ldu, &size, &LWORK, &INFO);
00157
00158 LWORK = (long int)(size);
00159 wk = new Precision[LWORK];
00160
00161 dgesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
00162 &ldvt, uorvt, &ldu, wk, &LWORK, &INFO);
00163
00164 delete[] wk;
00165 }
00166
00167 bool is_vertical(){
00168 return (my_copy.num_rows() >= my_copy.num_cols());
00169 }
00170
00171 int min_dim(){ return std::min(my_copy.num_rows(), my_copy.num_cols()); }
00172
00173 public:
00174
00175
00176
00177
00178
00179 template <int Rows2, int Cols2, typename P2, typename B2>
00180 Matrix<Cols,Cols2, typename Internal::MultiplyType<Precision,P2>::type >
00181 backsub(const Matrix<Rows2,Cols2,P2,B2>& rhs, const Precision condition=condition_no)
00182 {
00183 Vector<Min_Dim> inv_diag(min_dim());
00184 get_inv_diag(inv_diag,condition);
00185 return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
00186 }
00187
00188
00189
00190
00191
00192 template <int Size, typename P2, typename B2>
00193 Vector<Cols, typename Internal::MultiplyType<Precision,P2>::type >
00194 backsub(const Vector<Size,P2,B2>& rhs, const Precision condition=condition_no)
00195 {
00196 Vector<Min_Dim> inv_diag(min_dim());
00197 get_inv_diag(inv_diag,condition);
00198 return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
00199 }
00200
00201
00202
00203
00204
00205 Matrix<Cols,Rows> get_pinv(const Precision condition = condition_no){
00206 Vector<Min_Dim> inv_diag(min_dim());
00207 get_inv_diag(inv_diag,condition);
00208 return diagmult(get_VT().T(),inv_diag) * get_U().T();
00209 }
00210
00211
00212
00213 Precision determinant() {
00214 Precision result = my_diagonal[0];
00215 for(int i=1; i<my_diagonal.size(); i++){
00216 result *= my_diagonal[i];
00217 }
00218 return result;
00219 }
00220
00221
00222
00223 int rank(const Precision condition = condition_no) {
00224 if (my_diagonal[0] == 0) return 0;
00225 int result=1;
00226 for(int i=0; i<min_dim(); i++){
00227 if(my_diagonal[i] * condition <= my_diagonal[0]){
00228 result++;
00229 }
00230 }
00231 return result;
00232 }
00233
00234
00235
00236
00237 Matrix<Rows,Min_Dim,Precision,Reference::RowMajor> get_U(){
00238 if(is_vertical()){
00239 return Matrix<Rows,Min_Dim,Precision,Reference::RowMajor>
00240 (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
00241 } else {
00242 return Matrix<Rows,Min_Dim,Precision,Reference::RowMajor>
00243 (my_square.my_data, my_square.num_rows(), my_square.num_cols());
00244 }
00245 }
00246
00247
00248 Vector<Min_Dim,Precision>& get_diagonal(){ return my_diagonal; }
00249
00250
00251
00252
00253 Matrix<Min_Dim,Cols,Precision,Reference::RowMajor> get_VT(){
00254 if(is_vertical()){
00255 return Matrix<Min_Dim,Cols,Precision,Reference::RowMajor>
00256 (my_square.my_data, my_square.num_rows(), my_square.num_cols());
00257 } else {
00258 return Matrix<Min_Dim,Cols,Precision,Reference::RowMajor>
00259 (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
00260 }
00261 }
00262
00263
00264
00265
00266
00267
00268 void get_inv_diag(Vector<Min_Dim>& inv_diag, const Precision condition){
00269 for(int i=0; i<min_dim(); i++){
00270 if(my_diagonal[i] * condition <= my_diagonal[0]){
00271 inv_diag[i]=0;
00272 } else {
00273 inv_diag[i]=static_cast<Precision>(1)/my_diagonal[i];
00274 }
00275 }
00276 }
00277
00278 private:
00279 Matrix<Rows,Cols,Precision,RowMajor> my_copy;
00280 Vector<Min_Dim,Precision> my_diagonal;
00281 Matrix<Min_Dim,Min_Dim,Precision,RowMajor> my_square;
00282 };
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292 template<int Size, typename Precision>
00293 struct SQSVD : public SVD<Size, Size, Precision> {
00294
00295
00296
00297 SQSVD() {}
00298 SQSVD(int size) : SVD<Size,Size,Precision>(size, size) {}
00299
00300 template <int R2, int C2, typename P2, typename B2>
00301 SQSVD(const Matrix<R2,C2,P2,B2>& m) : SVD<Size,Size,Precision>(m) {}
00302
00303 };
00304
00305
00306 }
00307
00308
00309 #endif