00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef TOON_INCLUDE_LU_H
00032 #define TOON_INCLUDE_LU_H
00033
00034 #include <iostream>
00035
00036 #include <TooN/lapack.h>
00037
00038 #include <TooN/TooN.h>
00039
00040 namespace TooN {
00066 template <int Size=-1, class Precision=double>
00067 class LU {
00068 public:
00069
00072 template<int S1, int S2, class Base>
00073 LU(const Matrix<S1,S2,Precision, Base>& m)
00074 :my_lu(m.num_rows(),m.num_cols()),my_IPIV(m.num_rows()){
00075 compute(m);
00076 }
00077
00079 template<int S1, int S2, class Base>
00080 void compute(const Matrix<S1,S2,Precision,Base>& m){
00081
00082 SizeMismatch<Size, S1>::test(my_lu.num_rows(),m.num_rows());
00083 SizeMismatch<Size, S2>::test(my_lu.num_rows(),m.num_cols());
00084
00085
00086 my_lu=m;
00087 int lda = m.num_rows();
00088 int M = m.num_rows();
00089 int N = m.num_rows();
00090
00091 getrf_(&M,&N,&my_lu[0][0],&lda,&my_IPIV[0],&my_info);
00092
00093 if(my_info < 0){
00094 std::cerr << "error in LU, INFO was " << my_info << std::endl;
00095 }
00096 }
00097
00100 template <int Rows, int NRHS, class Base>
00101 Matrix<Size,NRHS,Precision> backsub(const Matrix<Rows,NRHS,Precision,Base>& rhs){
00102
00103 SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.num_rows());
00104
00105 Matrix<Size, NRHS, Precision> result(rhs);
00106
00107 int M=rhs.num_cols();
00108 int N=my_lu.num_rows();
00109 double alpha=1;
00110 int lda=my_lu.num_rows();
00111 int ldb=rhs.num_cols();
00112 trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb);
00113 trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb);
00114
00115
00116 for(int i=N-1; i>=0; i--){
00117 const int swaprow = my_IPIV[i]-1;
00118 for(int j=0; j<NRHS; j++){
00119 Precision temp = result[i][j];
00120 result[i][j] = result[swaprow][j];
00121 result[swaprow][j] = temp;
00122 }
00123 }
00124 return result;
00125 }
00126
00129 template <int Rows, class Base>
00130 Vector<Size,Precision> backsub(const Vector<Rows,Precision,Base>& rhs){
00131
00132 SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.size());
00133
00134 Vector<Size, Precision> result(rhs);
00135
00136 int M=1;
00137 int N=my_lu.num_rows();
00138 double alpha=1;
00139 int lda=my_lu.num_rows();
00140 int ldb=1;
00141 trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb);
00142 trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb);
00143
00144
00145 for(int i=N-1; i>=0; i--){
00146 const int swaprow = my_IPIV[i]-1;
00147 Precision temp = result[i];
00148 result[i] = result[swaprow];
00149 result[swaprow] = temp;
00150 }
00151 return result;
00152 }
00153
00156 Matrix<Size,Size,Precision> get_inverse(){
00157 Matrix<Size,Size,Precision> Inverse(my_lu);
00158 int N = my_lu.num_rows();
00159 int lda=my_lu.num_rows();
00160 int lwork=-1;
00161 Precision size;
00162 getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], &size, &lwork, &my_info);
00163 lwork=int(size);
00164 Precision* WORK = new Precision[lwork];
00165 getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], WORK, &lwork, &my_info);
00166 delete [] WORK;
00167 return Inverse;
00168 }
00169
00175 const Matrix<Size,Size,Precision>& get_lu()const {return my_lu;}
00176
00177 inline int get_sign() const {
00178 int result=1;
00179 for(int i=0; i<my_lu.num_rows()-1; i++){
00180 if(my_IPIV[i] > i+1){
00181 result=-result;
00182 }
00183 }
00184 return result;
00185 }
00186
00188 inline Precision determinant() const {
00189 Precision result = get_sign();
00190 for (int i=0; i<my_lu.num_rows(); i++){
00191 result*=my_lu(i,i);
00192 }
00193 return result;
00194 }
00195
00197 int get_info() const { return my_info; }
00198
00199 private:
00200
00201 Matrix<Size,Size,Precision> my_lu;
00202 int my_info;
00203 Vector<Size, int> my_IPIV;
00204
00205 };
00206 }
00207
00208
00209 #endif