00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #ifndef _NNLS_H
00020 #define _NNLS_H
00021
00022 #include <lsp/least_squares.h>
00023 #include <lsp/utils.h>
00024
00025 #include <boost/numeric/ublas/vector.hpp>
00026 #include <boost/numeric/ublas/matrix.hpp>
00027
00028 #include <algorithm>
00029 #include <list>
00030 #include <limits>
00031
00032 using namespace boost::numeric::ublas;
00033
00034 namespace lsp {
00035
00050 template<class M, class V> class nnls {
00051 public:
00052 typedef M matrix_type;
00053 typedef V vector_type;
00054 typedef typename matrix_type::value_type value_type;
00055 typedef typename matrix_type::size_type size_type;
00056 private:
00057 const matrix_type& m_matrix;
00058 const vector_type& m_vector;
00059 public:
00070 nnls( const matrix_type& matrix, const vector_type& vector ):
00071 m_matrix( matrix ),
00072 m_vector( vector ) {
00073
00074 assert( vector.size() == matrix.size1() );
00075
00076 }
00077
00088 template<class sV, class sM> void solve( sV& ret, sM& cov ) const {
00089 typedef std::list< size_type > index_space_type;
00090 typedef vector< value_type > vector_type;
00091 typedef least_squares< matrix_type, vector_type > least_squares_type;
00092
00093 value_type lim;
00094 vector_type w,z;
00095 index_space_type positive,zero;
00096
00097 for( size_type i = 0; i < m_matrix.size2(); ++i ) zero.push_back( i );
00098 ret = zero_vector< value_type >( m_matrix.size2() );
00099 w = prod( trans( m_matrix ), m_vector - prod( m_matrix, ret ) );
00100 lim = std::numeric_limits< value_type >::epsilon() * ( norm_2(m_vector) * ( 2*m_matrix.size1()*m_matrix.size2() - m_matrix.size2() )+ norm_2(ret) * ( 4*m_matrix.size1()*m_matrix.size2() - m_matrix.size1() - m_matrix.size2() ) );
00101 for( typename vector_type::iterator it = w.begin(); it != w.end(); ++it )
00102 if( std::abs(*it) < lim ) *it=0;
00103
00104 while( ! is_vector_elem< vector_type, index_space_type, std::less_equal<value_type> >( w, zero ) ){
00105 size_type max_w = *(std::max_element( zero.begin(), zero.end(), vector_less< vector_type, std::less< typename vector_type::value_type > >( w ) ));
00106 swap_indexes(zero,positive,max_w);
00107
00108 do {
00109 vector_type f = m_vector;
00110 matrix< value_type > Ep( m_matrix.size1(), m_matrix.size2() );
00111 least_squares_type least_squares(Ep,f);
00112 for( typename index_space_type::const_iterator it = positive.begin();it != positive.end(); ++it )
00113 column(Ep, (*it)) = column(m_matrix, (*it));
00114 for( typename index_space_type::const_iterator it = zero.begin();it != zero.end(); ++it )
00115 column(Ep, (*it)) = zero_vector< value_type >( m_matrix.size1() );
00116
00117 least_squares.solve( z, cov );
00118 for( typename index_space_type::const_iterator it = zero.begin();it != zero.end(); ++it )
00119 z( *it ) = 0;
00120
00121 if( is_vector_elem< vector_type, index_space_type, std::greater<value_type> >( z, positive ) ) {
00122 ret = z;
00123 w = prod( trans( m_matrix ), m_vector - prod( m_matrix, ret ) );
00124 lim = std::numeric_limits< value_type >::epsilon() * ( norm_2(m_vector) * ( 2*m_matrix.size1()*m_matrix.size2() - m_matrix.size2() )+ norm_2(ret) * ( 4*m_matrix.size1()*m_matrix.size2() - m_matrix.size1() - m_matrix.size2() ) );
00125 for( typename vector_type::iterator it = w.begin(); it != w.end(); ++it )
00126 if( std::abs(*it) < lim ) *it=0;
00127 break;
00128 }
00129
00130 size_type min_1 = *(std::min_element( positive.begin(), positive.end(), vector_less_nnls1< vector_type, std::less< typename vector_type::value_type > >(ret,z) ));
00131 value_type min_1_value = ret(min_1) / (ret(min_1)-z(min_1));
00132 ret = ret + min_1_value * ( z - ret );
00133
00134 ret(min_1) = 0;
00135 swap_indexes(positive,zero,min_1);
00136
00137 for( typename index_space_type::const_iterator it = positive.begin();it != positive.end(); ++it ) {
00138 if( ret(*it) <= 0 ){
00139 ret(*it) = 0;
00140 swap_indexes(positive,zero,*it);
00141 }
00142 }
00143 } while( true );
00144 }
00145 }
00146 template<class sV> void solve( sV& ret ) const {
00147 solve( ret, null_type::s_null );
00148 }
00149
00150 };
00151
00152 };
00153
00154 #endif // _NNLS_H
00155