Caffe源码 - SegAccuracyLayer

xiaoxiao2021-02-28  139

SegAccuracyLayer

语义分割

seg_accuracy_layer.hpp

#ifndef CAFFE_SEG_ACCURACY_LAYER_HPP_ #define CAFFE_SEG_ACCURACY_LAYER_HPP_ #include <vector> #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/util/confusion_matrix.hpp" #include "caffe/proto/caffe.pb.h" namespace caffe { template <typename Dtype> class SegAccuracyLayer : public Layer<Dtype> { public: explicit SegAccuracyLayer(const LayerParameter& param) : Layer<Dtype>(param) {} virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top); virtual void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top); virtual inline const char* type() const { return "SegAccuracy"; } virtual inline int ExactNumBottomBlobs() const { return 2; } virtual inline int ExactNumTopBlobs() const { return 1; } protected: /** * @param bottom input Blob vector (length 2) * -# @f$ (N \times C \times H \times W) @f$ * the predictions @f$ x @f$, a Blob with values in * @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of * the @f$ K = CHW @f$ classes. Each @f$ x_n @f$ is mapped to a predicted * label @f$ \hat{l}_n @f$ given by its maximal index: * @f$ \hat{l}_n = \arg\max\limits_k x_{nk} @f$ * -# @f$ (N \times 1 \times 1 \times 1) @f$ * the labels @f$ l @f$, an integer-valued Blob with values * @f$ l_n \in [0, 1, 2, ..., K - 1] @f$ * indicating the correct class label among the @f$ K @f$ classes * @param top output Blob vector (length 1) * -# @f$ (1 \times 1 \times 1 \times 1) @f$ * the computed accuracy: @f$ * \frac{1}{N} \sum\limits_{n=1}^N \delta\{ \hat{l}_n = l_n \} * @f$, where @f$ * \delta\{\mathrm{condition}\} = \left\{ * \begin{array}{lr} * 1 & \mbox{if condition} \\ * 0 & \mbox{otherwise} * \end{array} \right. * @f$ */ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top); /// @brief Not implemented -- AccuracyLayer cannot be used as a loss. virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { for (int i = 0; i < propagate_down.size(); ++i) { if (propagate_down[i]) { NOT_IMPLEMENTED; } } } ConfusionMatrix confusion_matrix_; // set of ignore labels std::set<int> ignore_label_; }; } // namespace caffe #endif // CAFFE_SEG_ACCURACY_HPP_

seg_accuracy_layer.cpp

#include <algorithm> #include <functional> #include <utility> #include <vector> #include "caffe/layer.hpp" #include "caffe/util/io.hpp" #include "caffe/util/math_functions.hpp" #include "caffe/layers/seg_accuracy_layer.hpp" namespace caffe { template <typename Dtype> void SegAccuracyLayer<Dtype>::LayerSetUp( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { confusion_matrix_.clear(); //清空混淆矩阵 confusion_matrix_.resize(bottom[0]->channels()); SegAccuracyParameter seg_accuracy_param = this->layer_param_.seg_accuracy_param(); for (int c = 0; c < seg_accuracy_param.ignore_label_size(); ++c){ ignore_label_.insert(seg_accuracy_param.ignore_label(c)); //忽略labels } } template <typename Dtype> void SegAccuracyLayer<Dtype>::Reshape( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { CHECK_LE(1, bottom[0]->channels()) << "top_k must be less than or equal to the number of channels (classes)."; // top_k 的值必须小于等于 channels数(类别数) CHECK_EQ(bottom[0]->num(), bottom[1]->num()) << "The data and label should have the same number."; // data 和 label 需数目相同 CHECK_EQ(bottom[1]->channels(), 1) << "The label should have one channel."; // label需是 1D 形式 CHECK_EQ(bottom[0]->height(), bottom[1]->height()) << "The data should have the same height as label."; // data 与 label 需有相同的 height CHECK_EQ(bottom[0]->width(), bottom[1]->width()) << "The data should have the same width as label."; // data 与 label 需有相同的 width //confusion_matrix_.clear(); //清空混淆矩阵 top[0]->Reshape(1, 1, 1, 3); } template <typename Dtype> void SegAccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { const Dtype* bottom_data = bottom[0]->cpu_data(); const Dtype* bottom_label = bottom[1]->cpu_data(); int num = bottom[0]->num(); int channels = bottom[0]->channels(); int height = bottom[0]->height(); int width = bottom[0]->width(); int data_index, label_index; int top_k = 1; // only support for top_k = 1 // remove old predictions if reset() flag is true // 如果 reset() == true,移除旧的预测 if (this->layer_param_.seg_accuracy_param().reset()) { confusion_matrix_.clear(); } for (int i = 0; i < num; ++i) { for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { // Top-k accuracy std::vector<std::pair<Dtype, int> > bottom_data_vector; for (int c = 0; c < channels; ++c) { data_index = (c * height + h) * width + w; bottom_data_vector.push_back(std::make_pair(bottom_data[data_index], c)); } std::partial_sort( bottom_data_vector.begin(), bottom_data_vector.begin() + top_k, bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >()); // check if true label is in top k predictions label_index = h * width + w; const int gt_label = static_cast<int>(bottom_label[label_index]); if (ignore_label_.count(gt_label) != 0) { // ignore the pixel with this gt_label continue; } else if (gt_label >= 0 && gt_label < channels) { // current position is not "255", indicating ambiguous position confusion_matrix_.accumulate(gt_label, bottom_data_vector[0].second); } else { LOG(FATAL) << "Unexpected label " << gt_label << ". num: " << i << ". row: " << h << ". col: " << w; } } } bottom_data += bottom[0]->offset(1); bottom_label += bottom[1]->offset(1); } /* for debug LOG(INFO) << "confusion matrix info:" << confusion_matrix_.numRows() << "," << confusion_matrix_.numCols(); confusion_matrix_.printCounts(); */ // we report all the resuls top[0]->mutable_cpu_data()[0] = (Dtype)confusion_matrix_.accuracy(); // accuracy 精度 top[0]->mutable_cpu_data()[1] = (Dtype)confusion_matrix_.avgRecall(false); // 平均 Recall top[0]->mutable_cpu_data()[2] = (Dtype)confusion_matrix_.avgJaccard(); // 平均 Jaccard } INSTANTIATE_CLASS(SegAccuracyLayer); REGISTER_LAYER_CLASS(SegAccuracy); } // namespace caffe
转载请注明原文地址: https://www.6miu.com/read-19152.html

最新回复(0)