Point Cloud Library (PCL)  1.8.0
regression_variance_stats_estimator.h
1 /*
2  * Software License Agreement (BSD License)
3  *
4  * Point Cloud Library (PCL) - www.pointclouds.org
5  * Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * * Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * * Redistributions in binary form must reproduce the above
16  * copyright notice, this list of conditions and the following
17  * disclaimer in the documentation and/or other materials provided
18  * with the distribution.
19  * * Neither the name of Willow Garage, Inc. nor the names of its
20  * contributors may be used to endorse or promote products derived
21  * from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  * POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #ifndef PCL_ML_REGRESSION_VARIANCE_STATS_ESTIMATOR_H_
39 #define PCL_ML_REGRESSION_VARIANCE_STATS_ESTIMATOR_H_
40 
41 #include <pcl/common/common.h>
42 #include <pcl/ml/stats_estimator.h>
43 #include <pcl/ml/branch_estimator.h>
44 
45 #include <istream>
46 #include <ostream>
47 
48 namespace pcl
49 {
50 
51  /** \brief Node for a regression trees which optimizes variance. */
52  template <class FeatureType, class LabelType>
53  class PCL_EXPORTS RegressionVarianceNode
54  {
55  public:
56  /** \brief Constructor. */
57  RegressionVarianceNode () : value(0), variance(0), threshold(0), sub_nodes() {}
58  /** \brief Destructor. */
60 
61  /** \brief Serializes the node to the specified stream.
62  * \param[out] stream The destination for the serialization.
63  */
64  inline void
65  serialize (std::ostream & stream) const
66  {
67  feature.serialize (stream);
68 
69  stream.write (reinterpret_cast<const char*> (&threshold), sizeof (threshold));
70 
71  stream.write (reinterpret_cast<const char*> (&value), sizeof (value));
72  stream.write (reinterpret_cast<const char*> (&variance), sizeof (variance));
73 
74  const int num_of_sub_nodes = static_cast<int> (sub_nodes.size ());
75  stream.write (reinterpret_cast<const char*> (&num_of_sub_nodes), sizeof (num_of_sub_nodes));
76  for (int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index)
77  {
78  sub_nodes[sub_node_index].serialize (stream);
79  }
80  }
81 
82  /** \brief Deserializes a node from the specified stream.
83  * \param[in] stream The source for the deserialization.
84  */
85  inline void
86  deserialize (std::istream & stream)
87  {
88  feature.deserialize (stream);
89 
90  stream.read (reinterpret_cast<char*> (&threshold), sizeof (threshold));
91 
92  stream.read (reinterpret_cast<char*> (&value), sizeof (value));
93  stream.read (reinterpret_cast<char*> (&variance), sizeof (variance));
94 
95  int num_of_sub_nodes;
96  stream.read (reinterpret_cast<char*> (&num_of_sub_nodes), sizeof (num_of_sub_nodes));
97  sub_nodes.resize (num_of_sub_nodes);
98 
99  if (num_of_sub_nodes > 0)
100  {
101  for (int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index)
102  {
103  sub_nodes[sub_node_index].deserialize (stream);
104  }
105  }
106  }
107 
108  public:
109  /** \brief The feature associated with the node. */
110  FeatureType feature;
111  /** \brief The threshold applied on the feature response. */
112  float threshold;
113 
114  /** \brief The label value of this node. */
115  LabelType value;
116  /** \brief The variance of the labels that ended up at this node during training. */
117  LabelType variance;
118 
119  /** \brief The child nodes. */
120  std::vector<RegressionVarianceNode> sub_nodes;
121  };
122 
123  /** \brief Statistics estimator for regression trees which optimizes variance. */
124  template <class LabelDataType, class NodeType, class DataSet, class ExampleIndex>
126  : public pcl::StatsEstimator<LabelDataType, NodeType, DataSet, ExampleIndex>
127  {
128 
129  public:
130  /** \brief Constructor. */
132  : branch_estimator_ (branch_estimator)
133  {}
134  /** \brief Destructor. */
136 
137  /** \brief Returns the number of branches the corresponding tree has. */
138  inline size_t
140  {
141  //return 2;
142  return branch_estimator_->getNumOfBranches ();
143  }
144 
145  /** \brief Returns the label of the specified node.
146  * \param[in] node The node which label is returned.
147  */
148  inline LabelDataType
150  NodeType & node) const
151  {
152  return node.value;
153  }
154 
155  /** \brief Computes the information gain obtained by the specified threshold.
156  * \param[in] data_set The data set corresponding to the supplied result data.
157  * \param[in] examples The examples used for extracting the supplied result data.
158  * \param[in] label_data The label data corresponding to the specified examples.
159  * \param[in] results The results computed using the specifed examples.
160  * \param[in] flags The flags corresponding to the results.
161  * \param[in] threshold The threshold for which the information gain is computed.
162  */
163  float
165  DataSet & data_set,
166  std::vector<ExampleIndex> & examples,
167  std::vector<LabelDataType> & label_data,
168  std::vector<float> & results,
169  std::vector<unsigned char> & flags,
170  const float threshold) const
171  {
172  const size_t num_of_examples = examples.size ();
173  const size_t num_of_branches = getNumOfBranches();
174 
175  // compute variance
176  std::vector<LabelDataType> sums (num_of_branches+1, 0);
177  std::vector<LabelDataType> sqr_sums (num_of_branches+1, 0);
178  std::vector<size_t> branch_element_count (num_of_branches+1, 0);
179 
180  for (size_t branch_index = 0; branch_index < num_of_branches; ++branch_index)
181  {
182  branch_element_count[branch_index] = 1;
183  ++branch_element_count[num_of_branches];
184  }
185 
186  for (size_t example_index = 0; example_index < num_of_examples; ++example_index)
187  {
188  unsigned char branch_index;
189  computeBranchIndex (results[example_index], flags[example_index], threshold, branch_index);
190 
191  LabelDataType label = label_data[example_index];
192 
193  sums[branch_index] += label;
194  sums[num_of_branches] += label;
195 
196  sqr_sums[branch_index] += label*label;
197  sqr_sums[num_of_branches] += label*label;
198 
199  ++branch_element_count[branch_index];
200  ++branch_element_count[num_of_branches];
201  }
202 
203  std::vector<float> variances (num_of_branches+1, 0);
204  for (size_t branch_index = 0; branch_index < num_of_branches+1; ++branch_index)
205  {
206  const float mean_sum = static_cast<float>(sums[branch_index]) / branch_element_count[branch_index];
207  const float mean_sqr_sum = static_cast<float>(sqr_sums[branch_index]) / branch_element_count[branch_index];
208  variances[branch_index] = mean_sqr_sum - mean_sum*mean_sum;
209  }
210 
211  float information_gain = variances[num_of_branches];
212  for (size_t branch_index = 0; branch_index < num_of_branches; ++branch_index)
213  {
214  //const float weight = static_cast<float>(sums[branchIndex]) / sums[numOfBranches];
215  const float weight = static_cast<float>(branch_element_count[branch_index]) / static_cast<float>(branch_element_count[num_of_branches]);
216  information_gain -= weight*variances[branch_index];
217  }
218 
219  return information_gain;
220  }
221 
222  /** \brief Computes the branch indices for all supplied results.
223  * \param[in] results The results the branch indices will be computed for.
224  * \param[in] flags The flags corresponding to the specified results.
225  * \param[in] threshold The threshold used to compute the branch indices.
226  * \param[out] branch_indices The destination for the computed branch indices.
227  */
228  void
230  std::vector<float> & results,
231  std::vector<unsigned char> & flags,
232  const float threshold,
233  std::vector<unsigned char> & branch_indices) const
234  {
235  const size_t num_of_results = results.size ();
236  const size_t num_of_branches = getNumOfBranches();
237 
238  branch_indices.resize (num_of_results);
239  for (size_t result_index = 0; result_index < num_of_results; ++result_index)
240  {
241  unsigned char branch_index;
242  computeBranchIndex (results[result_index], flags[result_index], threshold, branch_index);
243  branch_indices[result_index] = branch_index;
244  }
245  }
246 
247  /** \brief Computes the branch index for the specified result.
248  * \param[in] result The result the branch index will be computed for.
249  * \param[in] flag The flag corresponding to the specified result.
250  * \param[in] threshold The threshold used to compute the branch index.
251  * \param[out] branch_index The destination for the computed branch index.
252  */
253  inline void
255  const float result,
256  const unsigned char flag,
257  const float threshold,
258  unsigned char & branch_index) const
259  {
260  branch_estimator_->computeBranchIndex (result, flag, threshold, branch_index);
261  //branch_index = (result > threshold) ? 1 : 0;
262  }
263 
264  /** \brief Computes and sets the statistics for a node.
265  * \param[in] data_set The data set which is evaluated.
266  * \param[in] examples The examples which define which parts of the data set are used for evaluation.
267  * \param[in] label_data The label_data corresponding to the examples.
268  * \param[out] node The destination node for the statistics.
269  */
270  void
272  DataSet & data_set,
273  std::vector<ExampleIndex> & examples,
274  std::vector<LabelDataType> & label_data,
275  NodeType & node) const
276  {
277  const size_t num_of_examples = examples.size ();
278 
279  LabelDataType sum = 0.0f;
280  LabelDataType sqr_sum = 0.0f;
281  for (size_t example_index = 0; example_index < num_of_examples; ++example_index)
282  {
283  const LabelDataType label = label_data[example_index];
284 
285  sum += label;
286  sqr_sum += label*label;
287  }
288 
289  sum /= num_of_examples;
290  sqr_sum /= num_of_examples;
291 
292  const float variance = sqr_sum - sum*sum;
293 
294  node.value = sum;
295  node.variance = variance;
296  }
297 
298  /** \brief Generates code for branch index computation.
299  * \param[in] node The node for which code is generated.
300  * \param[out] stream The destination for the generated code.
301  */
302  void
304  NodeType & node,
305  std::ostream & stream) const
306  {
307  stream << "ERROR: RegressionVarianceStatsEstimator does not implement generateCodeForBranchIndex(...)";
308  }
309 
310  /** \brief Generates code for label output.
311  * \param[in] node The node for which code is generated.
312  * \param[out] stream The destination for the generated code.
313  */
314  void
316  NodeType & node,
317  std::ostream & stream) const
318  {
319  stream << "ERROR: RegressionVarianceStatsEstimator does not implement generateCodeForBranchIndex(...)";
320  }
321 
322  private:
323  /** \brief The branch estimator. */
324  pcl::BranchEstimator * branch_estimator_;
325  };
326 
327 }
328 
329 #endif
Statistics estimator for regression trees which optimizes variance.
LabelType value
The label value of this node.
void computeBranchIndex(const float result, const unsigned char flag, const float threshold, unsigned char &branch_index) const
Computes the branch index for the specified result.
void generateCodeForOutput(NodeType &node, std::ostream &stream) const
Generates code for label output.
Node for a regression trees which optimizes variance.
void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const
Computes and sets the statistics for a node.
LabelType variance
The variance of the labels that ended up at this node during training.
size_t getNumOfBranches() const
Returns the number of branches the corresponding tree has.
Define standard C methods and C++ classes that are common to all methods.
FeatureType feature
The feature associated with the node.
float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const
Computes the information gain obtained by the specified threshold.
LabelDataType getLabelOfNode(NodeType &node) const
Returns the label of the specified node.
void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const
Computes the branch indices for all supplied results.
void serialize(std::ostream &stream) const
Serializes the node to the specified stream.
float threshold
The threshold applied on the feature response.
Class interface for gathering statistics for decision tree learning.
RegressionVarianceStatsEstimator(BranchEstimator *branch_estimator)
Constructor.
std::vector< RegressionVarianceNode > sub_nodes
The child nodes.
void deserialize(std::istream &stream)
Deserializes a node from the specified stream.
void generateCodeForBranchIndexComputation(NodeType &node, std::ostream &stream) const
Generates code for branch index computation.
Interface for branch estimators.