GPLIB++
mtunn.cpp
Go to the documentation of this file.
1 #include <fstream>
2 #include <iostream>
3 #include <string>
4 #include <algorithm>
5 #include <boost/bind.hpp>
6 #include <boost/date_time/posix_time/posix_time.hpp>
7 #include <boost/program_options.hpp>
8 #include "ApplyFilter.h"
9 #include "Adaptors.h"
10 #include "NetCDFTools.h"
11 #include "NeuralNetwork.h"
12 #include "TimeSeriesData.h"
13 
14 using namespace std;
15 using namespace gplib;
16 
17 string version = "$Id: mtunn.cpp 1838 2010-03-04 16:19:34Z mmoorkamp $";
18 
19 void Restore(const TimeSeriesComponent &Input, TimeSeriesComponent &Output,
20  const double factor)
21  {
22  transform(Input.GetData().begin(), Input.GetData().end(),
23  Output.GetData().begin(),
24  boost::bind(multiplies<double> (), _1, factor));
25  }
26 
27 double Normalize(TimeSeriesComponent &Component)
28  {
29  double themax = *max_element(Component.GetData().begin(),
30  Component.GetData().end(), gplib::absLess<double, double>());
31  double factor = 1. / themax;
32  transform(Component.GetData().begin(), Component.GetData().end(),
33  Component.GetData().begin(), boost::bind(multiplies<double> (), _1,
34  factor));
35  return themax;
36  }
37 
38 void GetNNSetup(const size_t filterlength, const size_t hiddenlayers,
39  const size_t ntimeseries, NeuralNetwork::ttypeArray &NNLayers,
40  double &NNmaxinit)
41  {
42  NeuralNetwork::ttypeVector typeVector(filterlength,
43  SigmoidalNeuron::bipolar); // we want filterlength number of bipolar neurons per hidden layer
44  for (size_t i = 0; i < hiddenlayers; ++i) //intialize the type array for the hidden layers
45  {
46  NNLayers.push_back(typeVector); // all layers are the same, so we copy the same vector there
47  }
48  typeVector.assign(1, SigmoidalNeuron::identity);
49  NNLayers.push_back(typeVector); // and then we add it to the type Array
50  }
51 
52 namespace po = boost::program_options;
53 
54 int main(int argc, char *argv[])
55  {
56  cout << "This is mtunn: Perform neural network filtering on MT time-series"
57  << endl << endl;
58  cout
59  << " The program will ask for reference and input filename, further settings are read from 'mtuadaptive.conf' "
60  << endl;
61  cout
62  << " Output will be 1 Phoenix format file with ending '.clean' where all channels are overwritten"
63  << endl;
64  cout
65  << " Network weights are stored in a file with ending '.weights.nc' and network topology in '.dot"
66  << endl << endl;
67  cout << " This is Version: " << version << endl << endl;
68 
69  int filterlength = 0, shift = 0, hiddenlayers;
70  double mu, alpha;
71 
72  po::options_description desc("General options");
73  desc.add_options()("help", "produce help message")("filterlength",
74  po::value<int>(&filterlength)->default_value(10),
75  "The length of the adaptive filter")("shift",
76  po::value<int>(&shift)->default_value(0),
77  "The shift in samples between the time series")("mu",
78  po::value<double>(&mu)->default_value(1.0),
79  "Stepsize for LMS adaptive filter")("alpha", po::value<
80  double>(&alpha)->default_value(1.0), "")("hiddenlayers", po::value<int>(
81  &hiddenlayers)->default_value(1),
82  "The number of hiddenlayers for the neural network");
83 
84  po::variables_map vm;
85  po::store(po::parse_command_line(argc, argv, desc), vm);
86  po::notify(vm);
87 
88  if (vm.count("help"))
89  {
90  std::cout << desc << "\n";
91  return 1;
92  }
93 
94  TimeSeriesData InputData, ReferenceData;
95  string tsfilename, noisefilename;
96 
97  if (argc == 3)
98  {
99  noisefilename = argv[1];
100  tsfilename = argv[2];
101  }
102  else
103  {
104  cout << "Reference Data: ";
105  cin >> noisefilename;
106  cout << "Input Time Series Filename: ";
107  cin >> tsfilename;
108  }
109 
110  ReferenceData.GetData(noisefilename);
111  InputData.GetData(tsfilename);
112 
113  cout << "Input Start time: " << InputData.GetData().GetTime().front()
114  << endl;
115  cout << "Reference Start time: "
116  << ReferenceData.GetData().GetTime().front() << endl;
117  if (InputData.GetData().GetTime().front()
118  != ReferenceData.GetData().GetTime().front())
119  {
120  cerr << "Time series not synchronized !" << endl;
121  return 100;
122  }
123  int lengthdiff = ReferenceData.GetData().Size()
124  - InputData.GetData().Size();
125  if (lengthdiff > 0)
126  {
127  cout << "Removing " << lengthdiff
128  << " datapoints from Reference time-series." << endl;
129  ReferenceData.GetData().erase(ReferenceData.GetData().Size()
130  - lengthdiff, ReferenceData.GetData().Size());
131  }
132  if (lengthdiff < 0)
133  {
134  cout << "Removing " << lengthdiff
135  << " datapoints from Input time-series." << endl;
136  InputData.GetData().erase(InputData.GetData().Size() + lengthdiff,
137  InputData.GetData().Size());
138  }
139  cout << "Input End time: " << InputData.GetData().GetTime().back() << endl;
140  cout << "Reference End time: " << ReferenceData.GetData().GetTime().back()
141  << endl;
142 
143  NeuralNetwork::ttypeArray NNLayers;
144  double NNmaxinit = 1.0;
145  const int ntimeseries = 4;
146  GetNNSetup(filterlength, hiddenlayers, ntimeseries, NNLayers, NNmaxinit);
147  NeuralNetwork NN(filterlength, ntimeseries, mu, NNLayers, NNmaxinit, true);
148  NN.SetAlpha(alpha);
149 
150  ApplyFilter Canceller(NN, true);
151 
152  double rexmax = Normalize(ReferenceData.GetData().GetEx());
153  double reymax = Normalize(ReferenceData.GetData().GetEy());
154  double rhxmax = Normalize(ReferenceData.GetData().GetHx());
155  double rhymax = Normalize(ReferenceData.GetData().GetHy());
156 
157  Canceller.AddReferenceChannel(ReferenceData.GetData().GetHx());
158  Canceller.AddReferenceChannel(ReferenceData.GetData().GetHy());
159  Canceller.AddReferenceChannel(ReferenceData.GetData().GetEx());
160  Canceller.AddReferenceChannel(ReferenceData.GetData().GetEy());
161 
162  double ihxmax = Normalize(InputData.GetData().GetHx());
163  double ihymax = Normalize(InputData.GetData().GetHy());
164  Canceller.AddInputChannel(InputData.GetData().GetHx());
165  Canceller.AddInputChannel(InputData.GetData().GetHy());
166 
167  Canceller.SetWeightSaveIntervall(1000);
168  Canceller.SetShift(shift);
169  Canceller.ShowProgress(true);
170  ofstream weightfile((noisefilename + "weights").c_str());
171  NN.PrintWeights(weightfile);
172  cout << " First iteration: " << endl << endl;
173 
174  Canceller.FilterData();
175  NN.PrintWeights(weightfile);
176  cout << endl << endl << " Second iteration: " << endl << endl;
177 
178  Canceller.FilterData();
179  NN.PrintWeights(weightfile);
180 
181  Restore(*Canceller.GetOutChannels().at(0).get(),
182  ReferenceData.GetData().GetHx(), rhxmax);
183  Restore(*Canceller.GetOutChannels().at(1).get(),
184  ReferenceData.GetData().GetHy(), rhymax);
185  Restore(*Canceller.GetOutChannels().at(2).get(),
186  ReferenceData.GetData().GetEx(), rexmax);
187  Restore(*Canceller.GetOutChannels().at(3).get(),
188  ReferenceData.GetData().GetEy(), reymax);
189  ReferenceData.WriteBack(noisefilename + ".clean");
190 
191  //ofstream epsfile((noisefilename+".eps").c_str());
192  //copy(Canceller.GetEpsValues().front().begin(),Canceller.GetEpsValues().front().end(),ostream_iterator<double>(epsfile,"\n"));
193 
194  //WriteMatrixAsNetCDF(noisefilename+".weights.nc",Canceller.GetWeightHistory());
195  //NN.PrintTopology(noisefilename+".dot");
196  }
std::vector< double > & GetData()
Access for data vector, for ease of use and efficiency we return a reference.
void AddReferenceChannel(TimeSeriesComponent &Channel)
Add a reference channel to the filter, some AdaptiveFilter objects require only one reference...
Definition: ApplyFilter.cpp:45
Apply an adaptive filter to a time-series.
Definition: ApplyFilter.h:15
TimeSeries & GetData()
return a reference to the actual object stored in the pointer
TimeSeriesComponent & GetEx()
Definition: TimeSeries.h:47
virtual void PrintWeights(std::ostream &output)
Print the weights of the network to the specified output stream.
void ShowProgress(const bool what)
Do we want visual feedback of the progess on the screen, if yes draw a simple progress indicator in t...
Definition: ApplyFilter.h:30
void SetAlpha(const double a)
Set the momentum multiplier.
Definition: NeuralNetwork.h:45
int main(int argc, char *argv[])
Definition: mtunn.cpp:54
void Restore(const TimeSeriesComponent &Input, TimeSeriesComponent &Output, const double factor)
Definition: mtunn.cpp:19
void erase(const int startindex, const int endindex)
Erase data between startindex and endindex.
Definition: TimeSeries.cpp:59
void AddInputChannel(TimeSeriesComponent &Channel)
Add an input channel to the filter.
Definition: ApplyFilter.cpp:36
std::vector< ttypeVector > ttypeArray
Definition: NeuralNetwork.h:23
TimeSeriesComponent & GetHy()
Definition: TimeSeries.h:39
void SetWeightSaveIntervall(const int intervall)
Set the distance between iterations at which the weights are saved.
Definition: ApplyFilter.h:37
TimeSeriesComponent is the base storage class for all types of time series data.
std::vector< SigmoidalNeuron::tneurontype > ttypeVector
Definition: NeuralNetwork.h:22
void GetNNSetup(const size_t filterlength, const size_t hiddenlayers, const size_t ntimeseries, NeuralNetwork::ttypeArray &NNLayers, double &NNmaxinit)
Definition: mtunn.cpp:38
string version
Definition: mtunn.cpp:17
void FilterData()
Filter the input channels with the current settings.
Definition: ApplyFilter.cpp:57
void SetShift(const int theshift)
Set the shift between the input time series and the reference time series.
Definition: ApplyFilter.h:52
void WriteBack(std::string filename_base)
Write in the format it was originally read in.
TimeSeriesData stores a pointer to the different components of magnetotelluric data and provides func...
double Normalize(TimeSeriesComponent &Component)
Definition: mtunn.cpp:27
ttimedata & GetTime()
Definition: TimeSeries.h:55
size_t Size()
Return the size of the time series, throws if one of the components has a different size...
Definition: TimeSeries.cpp:74
const std::vector< boost::shared_ptr< TimeSeriesComponent > > & GetOutChannels()
Return the vector of output channels.
Definition: ApplyFilter.h:42
TimeSeriesComponent & GetEy()
Definition: TimeSeries.h:51
TimeSeriesComponent & GetHx()
Access function for Hx, returns reference for efficiency.
Definition: TimeSeries.h:35