00001 #include "WienerFilter.h"
00002 #include "miscfunc.h"
00003
00004 using namespace boost::numeric::ublas;
00005
00006
00007 WienerFilter::WienerFilter(const int inputsize):
00008 AdaptiveFilter(inputsize,inputsize),
00009 CorrMatrix(inputsize,inputsize),
00010 Weights(inputsize),
00011 lambda(0)
00012 {
00013 }
00014
00015 WienerFilter::~WienerFilter()
00016 {
00017 }
00018
00019 void WienerFilter::PrintWeights(std::ostream &output)
00020 {
00021 std::copy(Weights.begin(),Weights.end(),std::ostream_iterator<double>(output,"\n"));
00022 }
00023 void WienerFilter::AdaptFilter(const gplib::rvec &Input, const gplib::rvec &Desired)
00024 {
00025 const int inputsize = Input.size();
00026 vector<double> Cross(Input.size());
00027 vector<double> Auto(Input.size());
00028 Correl(Desired,Desired,Auto);
00029 Correl(Input,Desired,Cross);
00030 for (int i = 0; i < inputsize; ++i)
00031 for (int j = i; j < inputsize; ++j)
00032 {
00033 CorrMatrix(i,j) = Auto(j-i);
00034 CorrMatrix(j,i) = Auto(j-i);
00035 }
00036 CorrMatrix += lambda * identity_matrix<double>(inputsize);
00037 matrix<double> Inverse(inputsize,inputsize);
00038
00039
00040
00041
00042
00043
00044
00045
00046 axpy_prod(Inverse,Cross,Weights,true);
00047
00048 }
00049
00050 void WienerFilter::CalcOutput(const gplib::rvec &Input, gplib::rvec &Output)
00051 {
00052 vector<double> output(Input.size());
00053 Convolve(Input,Weights,Output);
00054 SetOutput(Output);
00055 }