Neural Network(神经网络)实例--手写数字识别

xiaoxiao2021-02-28  109

[**本实例整理自斯坦福机器学习课程课后练习ex3**](http://download.csdn.net/detail/the_lastest/9893091)

本例是对一个手写体的数据集(0-9)进行分类,其最终实现的效果同上一个实例相同。只是两者在实现方式上有所不同。

In the previous part of this exercise, you implemented multi-class logistic regression to recognize handwritten digits. However, logistic regression cannot form more complex hypotheses as it is only a linear classifier. You could add more features (such as polynomial features) to logistic regression, but that can be very expensive to train.

1.Model Representation

本例中的手写体图片采用的是 20 p i x e l × 20 p i x e l 20pixel \times 20pixel 20pixel×20pixel的格式,所有一张图片就有 20 × 20 = 400 20\times20=400 20×20=400个feature,即输入层有400个活化单元(activation unit),加上额外的偏置单元(bias unit)一共就是401个。整个神经网络的模型图如下所示:

2.Predict

本例中,第一层与第二层的权重(weights)即参数 Θ 1 ( 25 × 401 ) , Θ 2 ( 10 × 26 ) \Theta_1(25\times401),\Theta_2(10\times26) Θ1(25×401),Θ2(10×26)已经给出,我们只需要根据向前传播(Forward Propagation)的方法进行计算即可。

训练数据X是一个 5000 × 400 5000 \times 400 5000×400的矩阵,为了更加容易明白计算过程,我们先取其中的任意一行 x ( 1 × 400 ) x_{(1\times400)} x(1×400)为例。根据神经网络的数学定义(第三点)知: \begin{align*} X & = [ones(m, 1) X];\ a_1 & =x’;\ z_2 & =\Theta_1a_1;\ a_2 & =sigmoid(z_2);\ \ a_2 & = [1;a_2];加上a_2对应的偏置项,\color{red}{注意不是a(0,0)=1,也不是a(1,1)=0}\ z_3 & =\Theta_2a_2;\ a_3 & =sigmoid(z_3); \end{align*}

此时相当于一张图片,经过3层的神经网络模型的计算,就可以得出这张图片对应输出层的十个输出的概率了,然后选择概率值最大的输出,我们就可以知道该图片所对应的手写体数字了。

下面是为了更加形象化计算过程:

同之前One-vs-all中一样, g ( z i 3 ) g(z_i^3) g(zi3)代表的是该手写体对应为数字几的概率(其中0映射为10)。

%循环5000次,即可预测出所有图片所对应的手写体 for i = 1:m; a1 = X(i,:)'; % 401 by 1 z2 = Theta1*a1; % 25 by 401 * 401 by 1 a2 = sigmoid(z2);% 25 by 1 a2 = [1;a2]; % column vector , 26 by 1 z3 = Theta2 * a2; % 10 by 26 * 26 by 1 a3 = sigmoid(z3); % 10 by 1 [temp p(i)] = max(a3); end

其中,[temp p(i)] = max(a3),temp 用来保存最大的概率值,p(i)保存为其对应的数字。 如 a 3 a^3 a3的可能值为: a 3 = [ 0.21 , 0.11 , 0.04 , 0.51 , 0.34 , 0.66 , 0.71 , 0.88 , 0.17 , 0.32 ] T ,则  t e m p = 0.88 , p ( i ) = 8 a^3=[0.21,0.11,0.04,0.51,0.34,0.66,0.71,0.88,0.17,0.32]^T\text{,则 }temp = 0.88,p(i)=8 a3=[0.21,0.11,0.04,0.51,0.34,0.66,0.71,0.88,0.17,0.32]T,则 temp=0.88,p(i)=8 max的用法戳此处

test = X(3454,:); [temp pp] = max(predict(Theta1, Theta2, test)) y(3454,1) %与已知标记进行对比验证 %%以下是输出结果 temp = 6 pp = 1%说明该图片为数字6的概率接近100%了 ans = 6 %这是矢量化的形式,即同时一起计算,不用循环 a1 = X'; z2 = Theta1*a1; a2 = sigmoid(z2); a2 = [ones(1,m);a2]; z3 = Theta2 * a2; a3 = sigmoid(z3); [temp p] = max(a3); p = p(:);
转载请注明原文地址: https://www.6miu.com/read-27488.html

最新回复(0)