00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 #ifndef NEURAL_NET_H
00010 #define NEURAL_NET_H
00011 
00012 #include <string>
00013 #include <vector>
00014 #include "neuron.h"
00015 #include "input_file_buffer.h"
00016 
00017 namespace tesseract {
00018 
00019 
00020 static const float kMinInputRange = 1e-6f;
00021 
00022 class NeuralNet {
00023   public:
00024     NeuralNet();
00025     virtual ~NeuralNet();
00026     
00027     static NeuralNet *FromFile(const string file_name);
00028     
00029     static NeuralNet *FromInputBuffer(InputFileBuffer *ib);
00030     
00031     template <typename Type> bool FeedForward(const Type *inputs,
00032                                               Type *outputs);
00033     
00034     
00035     
00036     template <typename Type> bool GetNetOutput(const Type *inputs,
00037                                                int output_id,
00038                                                Type *output);
00039     
00040     int in_cnt() const { return in_cnt_; }
00041     int out_cnt() const { return out_cnt_; }
00042 
00043   protected:
00044     struct Node;
00045     
00046     struct WeightedNode {
00047       Node *input_node;
00048       float input_weight;
00049     };
00050     
00051     
00052     struct Node {
00053       float out;
00054       float bias;
00055       int fan_in_cnt;
00056       WeightedNode *inputs;
00057     };
00058     
00059     
00060     
00061     bool read_only_;
00062     
00063     int in_cnt_;
00064     
00065     int out_cnt_;
00066     
00067     int neuron_cnt_;
00068     
00069     int  wts_cnt_;
00070     
00071     Neuron *neurons_;
00072     
00073     
00074     
00075     
00076     static const int kWgtChunkSize = 0x10000;
00077     
00078     
00079     static const unsigned int kNetSignature = 0xFEFEABD0;
00080     
00081     int alloc_wgt_cnt_;
00082     
00083     vector<vector<float> *>wts_vec_;
00084     
00085     bool auto_encoder_;
00086     
00087     vector<float> inputs_max_;
00088     
00089     vector<float> inputs_min_;
00090     
00091     vector<float> inputs_mean_;
00092     
00093     vector<float> inputs_std_dev_;
00094     
00095     
00096     vector<Node> fast_nodes_;
00097     
00098     void Init();
00099     
00100     void Clear() {
00101       for (int node = 0; node < neuron_cnt_; node++) {
00102         neurons_[node].Clear();
00103       }
00104     }
00105     
00106     template<class ReadBuffType> bool ReadBinary(ReadBuffType *input_buff) {
00107       
00108       Init();
00109       
00110       unsigned int read_val;
00111       unsigned int auto_encode;
00112       
00113       if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00114         return false;
00115       }
00116       if (read_val != kNetSignature) {
00117         return false;
00118       }
00119       if (input_buff->Read(&auto_encode, sizeof(auto_encode)) !=
00120           sizeof(auto_encode)) {
00121         return false;
00122       }
00123       auto_encoder_ = auto_encode;
00124       
00125       if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00126         return false;
00127       }
00128       neuron_cnt_ = read_val;
00129       if (neuron_cnt_ <= 0) {
00130         return false;
00131       }
00132       
00133       neurons_ = new Neuron[neuron_cnt_];
00134       if (neurons_ == NULL) {
00135         return false;
00136       }
00137       
00138       if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00139         return false;
00140       }
00141       in_cnt_ = read_val;
00142       if (in_cnt_ <= 0) {
00143         return false;
00144       }
00145       
00146       if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00147         return false;
00148       }
00149       out_cnt_ = read_val;
00150       if (out_cnt_ <= 0) {
00151         return false;
00152       }
00153       
00154       for (int idx = 0; idx < neuron_cnt_; idx++) {
00155         neurons_[idx].set_id(idx);
00156         
00157         if (idx < in_cnt_) {
00158           neurons_[idx].set_node_type(Neuron::Input);
00159         } else if (idx >= (neuron_cnt_ - out_cnt_)) {
00160           neurons_[idx].set_node_type(Neuron::Output);
00161         } else {
00162           neurons_[idx].set_node_type(Neuron::Hidden);
00163         }
00164       }
00165       
00166       for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
00167         
00168         if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00169           return false;
00170         }
00171         
00172         int fan_out_cnt = read_val;
00173         for (int fan_out_idx = 0; fan_out_idx < fan_out_cnt; fan_out_idx++) {
00174           
00175           if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00176             return false;
00177           }
00178           
00179           if (!SetConnection(node_idx, read_val)) {
00180             return false;
00181           }
00182         }
00183       }
00184       
00185       for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
00186         
00187         if (!neurons_[node_idx].ReadBinary(input_buff)) {
00188           return false;
00189         }
00190       }
00191       
00192       inputs_mean_.resize(in_cnt_);
00193       inputs_std_dev_.resize(in_cnt_);
00194       inputs_min_.resize(in_cnt_);
00195       inputs_max_.resize(in_cnt_);
00196       
00197       if (input_buff->Read(&(inputs_mean_.front()),
00198           sizeof(inputs_mean_[0]) * in_cnt_) !=
00199           sizeof(inputs_mean_[0]) * in_cnt_) {
00200         return false;
00201       }
00202       if (input_buff->Read(&(inputs_std_dev_.front()),
00203           sizeof(inputs_std_dev_[0]) * in_cnt_) !=
00204           sizeof(inputs_std_dev_[0]) * in_cnt_) {
00205         return false;
00206       }
00207       if (input_buff->Read(&(inputs_min_.front()),
00208           sizeof(inputs_min_[0]) * in_cnt_) !=
00209           sizeof(inputs_min_[0]) * in_cnt_) {
00210         return false;
00211       }
00212       if (input_buff->Read(&(inputs_max_.front()),
00213           sizeof(inputs_max_[0]) * in_cnt_) !=
00214           sizeof(inputs_max_[0]) * in_cnt_) {
00215         return false;
00216       }
00217       
00218       if (read_only_) {
00219         return CreateFastNet();
00220       }
00221       return true;
00222     }
00223 
00224     
00225     bool SetConnection(int from, int to);
00226     
00227     
00228     bool CreateFastNet();
00229     
00230     
00231     
00232     float *AllocWgt(int wgt_cnt);
00233     
00234     template <typename Type> bool FastFeedForward(const Type *inputs,
00235                                                   Type *outputs);
00236     
00237     
00238     
00239     
00240     template <typename Type> bool FastGetNetOutput(const Type *inputs,
00241                                                    int output_id,
00242                                                    Type *output);
00243 };
00244 }
00245 
00246 #endif  // NEURAL_NET_H__