Point Cloud Library (PCL)  1.10.1-dev
fern_trainer.hpp
1 /*
2  * Software License Agreement (BSD License)
3  *
4  * Point Cloud Library (PCL) - www.pointclouds.org
5  * Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * * Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * * Redistributions in binary form must reproduce the above
16  * copyright notice, this list of conditions and the following
17  * disclaimer in the documentation and/or other materials provided
18  * with the distribution.
19  * * Neither the name of Willow Garage, Inc. nor the names of its
20  * contributors may be used to endorse or promote products derived
21  * from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  * POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #pragma once
39 
40 template <class FeatureType,
41  class DataSet,
42  class LabelType,
43  class ExampleIndex,
44  class NodeType>
46 : fern_depth_(10)
47 , num_of_features_(1000)
48 , num_of_thresholds_(10)
49 , feature_handler_(nullptr)
50 , stats_estimator_(nullptr)
51 , data_set_()
52 , label_data_()
53 , examples_()
54 {}
55 
56 template <class FeatureType,
57  class DataSet,
58  class LabelType,
59  class ExampleIndex,
60  class NodeType>
63 {}
64 
65 template <class FeatureType,
66  class DataSet,
67  class LabelType,
68  class ExampleIndex,
69  class NodeType>
70 void
73 {
74  const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
75  const std::size_t num_of_examples = examples_.size();
76 
77  // create random features
78  std::vector<FeatureType> features;
79  feature_handler_->createRandomFeatures(num_of_features_, features);
80 
81  // setup fern
82  fern.initialize(fern_depth_);
83 
84  // evaluate all features
85  std::vector<std::vector<float>> feature_results(num_of_features_);
86  std::vector<std::vector<unsigned char>> flags(num_of_features_);
87 
88  for (std::size_t feature_index = 0; feature_index < num_of_features_;
89  ++feature_index) {
90  feature_results[feature_index].reserve(num_of_examples);
91  flags[feature_index].reserve(num_of_examples);
92 
93  feature_handler_->evaluateFeature(features[feature_index],
94  data_set_,
95  examples_,
96  feature_results[feature_index],
97  flags[feature_index]);
98  }
99 
100  // iteratively select features and thresholds
101  std::vector<std::vector<std::vector<float>>> branch_feature_results(
102  num_of_features_); // [feature_index][branch_index][result_index]
103  std::vector<std::vector<std::vector<unsigned char>>> branch_flags(
104  num_of_features_); // [feature_index][branch_index][flag_index]
105  std::vector<std::vector<std::vector<ExampleIndex>>> branch_examples(
106  num_of_features_); // [feature_index][branch_index][result_index]
107  std::vector<std::vector<std::vector<LabelType>>> branch_label_data(
108  num_of_features_); // [feature_index][branch_index][flag_index]
109 
110  // - initialize branch feature results and flags
111  for (std::size_t feature_index = 0; feature_index < num_of_features_;
112  ++feature_index) {
113  branch_feature_results[feature_index].resize(1);
114  branch_flags[feature_index].resize(1);
115  branch_examples[feature_index].resize(1);
116  branch_label_data[feature_index].resize(1);
117 
118  branch_feature_results[feature_index][0] = feature_results[feature_index];
119  branch_flags[feature_index][0] = flags[feature_index];
120  branch_examples[feature_index][0] = examples_;
121  branch_label_data[feature_index][0] = label_data_;
122  }
123 
124  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
125  // get thresholds
126  std::vector<std::vector<float>> thresholds(num_of_features_);
127 
128  for (std::size_t feature_index = 0; feature_index < num_of_features_;
129  ++feature_index) {
130  thresholds.reserve(num_of_thresholds_);
131  createThresholdsUniform(num_of_thresholds_,
132  feature_results[feature_index],
133  thresholds[feature_index]);
134  }
135 
136  // compute information gain
137  int best_feature_index = -1;
138  float best_feature_threshold = 0.0f;
139  float best_feature_information_gain = 0.0f;
140 
141  for (std::size_t feature_index = 0; feature_index < num_of_features_;
142  ++feature_index) {
143  for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
144  ++threshold_index) {
145  float information_gain = 0.0f;
146  for (std::size_t branch_index = 0;
147  branch_index < branch_feature_results[feature_index].size();
148  ++branch_index) {
149  const float branch_information_gain =
150  stats_estimator_->computeInformationGain(
151  data_set_,
152  branch_examples[feature_index][branch_index],
153  branch_label_data[feature_index][branch_index],
154  branch_feature_results[feature_index][branch_index],
155  branch_flags[feature_index][branch_index],
156  thresholds[feature_index][threshold_index]);
157 
158  information_gain +=
159  branch_information_gain *
160  branch_feature_results[feature_index][branch_index].size();
161  }
162 
163  if (information_gain > best_feature_information_gain) {
164  best_feature_information_gain = information_gain;
165  best_feature_index = static_cast<int>(feature_index);
166  best_feature_threshold = thresholds[feature_index][threshold_index];
167  }
168  }
169  }
170 
171  // add feature to the feature list of the fern
172  fern.accessFeature(depth_index) = features[best_feature_index];
173  fern.accessThreshold(depth_index) = best_feature_threshold;
174 
175  // update branch feature results and flags
176  for (std::size_t feature_index = 0; feature_index < num_of_features_;
177  ++feature_index) {
178  std::vector<std::vector<float>>& cur_branch_feature_results =
179  branch_feature_results[feature_index];
180  std::vector<std::vector<unsigned char>>& cur_branch_flags =
181  branch_flags[feature_index];
182  std::vector<std::vector<ExampleIndex>>& cur_branch_examples =
183  branch_examples[feature_index];
184  std::vector<std::vector<LabelType>>& cur_branch_label_data =
185  branch_label_data[feature_index];
186 
187  const std::size_t total_num_of_new_branches =
188  num_of_branches * cur_branch_feature_results.size();
189 
190  std::vector<std::vector<float>> new_branch_feature_results(
191  total_num_of_new_branches); // [branch_index][example_index]
192  std::vector<std::vector<unsigned char>> new_branch_flags(
193  total_num_of_new_branches); // [branch_index][example_index]
194  std::vector<std::vector<ExampleIndex>> new_branch_examples(
195  total_num_of_new_branches); // [branch_index][example_index]
196  std::vector<std::vector<LabelType>> new_branch_label_data(
197  total_num_of_new_branches); // [branch_index][example_index]
198 
199  for (std::size_t branch_index = 0;
200  branch_index < cur_branch_feature_results.size();
201  ++branch_index) {
202  const std::size_t num_of_examples_in_this_branch =
203  cur_branch_feature_results[branch_index].size();
204 
205  std::vector<unsigned char> branch_indices;
206  branch_indices.reserve(num_of_examples_in_this_branch);
207 
208  stats_estimator_->computeBranchIndices(cur_branch_feature_results[branch_index],
209  cur_branch_flags[branch_index],
210  best_feature_threshold,
211  branch_indices);
212 
213  // split results into different branches
214  const std::size_t base_branch_index = branch_index * num_of_branches;
215  for (std::size_t example_index = 0;
216  example_index < num_of_examples_in_this_branch;
217  ++example_index) {
218  const std::size_t combined_branch_index =
219  base_branch_index + branch_indices[example_index];
220 
221  new_branch_feature_results[combined_branch_index].push_back(
222  cur_branch_feature_results[branch_index][example_index]);
223  new_branch_flags[combined_branch_index].push_back(
224  cur_branch_flags[branch_index][example_index]);
225  new_branch_examples[combined_branch_index].push_back(
226  cur_branch_examples[branch_index][example_index]);
227  new_branch_label_data[combined_branch_index].push_back(
228  cur_branch_label_data[branch_index][example_index]);
229  }
230  }
231 
232  branch_feature_results[feature_index] = new_branch_feature_results;
233  branch_flags[feature_index] = new_branch_flags;
234  branch_examples[feature_index] = new_branch_examples;
235  branch_label_data[feature_index] = new_branch_label_data;
236  }
237  }
238 
239  // set node statistics
240  // - re-evaluate selected features
241  std::vector<std::vector<float>> final_feature_results(
242  fern_depth_); // [feature_index][example_index]
243  std::vector<std::vector<unsigned char>> final_flags(
244  fern_depth_); // [feature_index][example_index]
245  std::vector<std::vector<unsigned char>> final_branch_indices(
246  fern_depth_); // [feature_index][example_index]
247  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
248  final_feature_results[depth_index].reserve(num_of_examples);
249  final_flags[depth_index].reserve(num_of_examples);
250  final_branch_indices[depth_index].reserve(num_of_examples);
251 
252  feature_handler_->evaluateFeature(fern.accessFeature(depth_index),
253  data_set_,
254  examples_,
255  final_feature_results[depth_index],
256  final_flags[depth_index]);
257 
258  stats_estimator_->computeBranchIndices(final_feature_results[depth_index],
259  final_flags[depth_index],
260  fern.accessThreshold(depth_index),
261  final_branch_indices[depth_index]);
262  }
263 
264  // - distribute examples to nodes
265  std::vector<std::vector<LabelType>> node_labels(
266  0x1 << fern_depth_); // [node_index][example_index]
267  std::vector<std::vector<ExampleIndex>> node_examples(
268  0x1 << fern_depth_); // [node_index][example_index]
269 
270  for (std::size_t example_index = 0; example_index < num_of_examples;
271  ++example_index) {
272  std::size_t node_index = 0;
273  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
274  node_index *= num_of_branches;
275  node_index += final_branch_indices[depth_index][example_index];
276  }
277 
278  node_labels[node_index].push_back(label_data_[example_index]);
279  node_examples[node_index].push_back(examples_[example_index]);
280  }
281 
282  // - compute and set statistics for every node
283  const std::size_t num_of_nodes = 0x1 << fern_depth_;
284  for (std::size_t node_index = 0; node_index < num_of_nodes; ++node_index) {
285  stats_estimator_->computeAndSetNodeStats(data_set_,
286  node_examples[node_index],
287  node_labels[node_index],
288  fern[node_index]);
289  }
290 }
291 
292 template <class FeatureType,
293  class DataSet,
294  class LabelType,
295  class ExampleIndex,
296  class NodeType>
297 void
299  createThresholdsUniform(const std::size_t num_of_thresholds,
300  std::vector<float>& values,
301  std::vector<float>& thresholds)
302 {
303  // estimate range of values
304  float min_value = ::std::numeric_limits<float>::max();
305  float max_value = -::std::numeric_limits<float>::max();
306 
307  const std::size_t num_of_values = values.size();
308  for (int value_index = 0; value_index < num_of_values; ++value_index) {
309  const float value = values[value_index];
310 
311  if (value < min_value)
312  min_value = value;
313  if (value > max_value)
314  max_value = value;
315  }
316 
317  const float range = max_value - min_value;
318  const float step = range / (num_of_thresholds + 2);
319 
320  // compute thresholds
321  thresholds.resize(num_of_thresholds);
322 
323  for (int threshold_index = 0; threshold_index < num_of_thresholds;
324  ++threshold_index) {
325  thresholds[threshold_index] = min_value + step * (threshold_index + 1);
326  }
327 }
void initialize(const std::size_t num_of_decisions)
Initializes the fern.
Definition: fern.h:62
virtual void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const =0
Computes the branch indices obtained by the specified threshold on the supplied feature evaluation re...
FeatureType & accessFeature(const std::size_t feature_index)
Access operator for features.
Definition: fern.h:164
virtual float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const =0
Computes the information gain obtained by the specified threshold on the supplied feature evaluation ...
virtual std::size_t getNumOfBranches() const =0
Returns the number of brances a node can have (e.g.
Class representing a Fern.
Definition: fern.h:49
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
float & accessThreshold(const std::size_t threshold_index)
Access operator for thresholds.
Definition: fern.h:184
virtual ~FernTrainer()
Destructor.
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
FernTrainer()
Constructor.
virtual void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const =0
Computes and sets the statistics for a node.