Java 实现 BP 神经网络完成 Iris 数据分类

xiaoxiao2021-02-28  120

继了解了 BP 神经网络的原理后,笔者之前用 Java 实现三层的 BP 神经网络完成 Iris 鸢尾花数据集的分类预测,特此记录了实现过程,附源码。

1. Iris 鸢尾花数据集

Iris 也称鸢尾花卉数据集,是一类多重变量分析的数据集,来自 UCI 机器学习库,下载地址请戳这里。通过 sepal length(花萼长度),sepal width (花萼宽度),petal length (花瓣长度),petal width (花瓣宽度)4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

该数据集一共有150条记录,选取 Iris 数据集中的124条数据作为训练集,剩余的26条数据作为测试集。

注:选取训练集时尽量覆盖全面,不要出现只包含一类的情况。

测试集

训练集

2. BP 算法模型的建立

输入层和输出层节点数量分别为数据集的属性数量和类别数量,采用一个隐层,隐层节点数=√(输入节点数+输出节点数)+5求得;激活函数选择单极性S型函数;学习率 η η =0.5;初始权值随机生成,值在-0.5~0.5之间,初始阈值设为0;设置最大训练次数为2000次;误差允许范围:Iris:0.015;动量常数 α α =0.1;输入数据归一化处理:(0.1,0.9)范围内;输出层节点处理,进行one-hot编程:

3. Java 实现代码

一共包含三个类: BPNN.java 、DataUtil.java 、Test.java

BPNN.java

BP 神经网络核心代码以及预测处理代码,注释部分是附加动量项的处理代码:

import java.io.IOException; import java.util.ArrayList; import java.util.Random; class BPNN { // private static int LAYER = 3; // 三层神经网络 private static int NodeNum = 10; // 每层的最多节点数 private static final int ADJUST = 5; // 隐层节点数调节常数 private static final int MaxTrain = 2000; // 最大训练次数 private static final double ACCU = 0.015; // 每次迭代允许的误差 iris:0.015 private double ETA_W = 0.5; // 权值学习效率0.5 private double ETA_T = 0.5; // 阈值学习效率 private double accu; // 附加动量项 //private static final double ETA_A = 0.3; // 动量常数0.1 //private double[][] in_hd_last; // 上一次的权值调整量 //private double[][] hd_out_last; private int in_num; // 输入层节点数 private int hd_num; // 隐层节点数 private int out_num; // 输入出节点数 private ArrayList<ArrayList<Double>> list = new ArrayList<>(); // 输入输出数据 private double[][] in_hd_weight; // BP网络in-hidden突触权值 private double[][] hd_out_weight; // BP网络hidden_out突触权值 private double[] in_hd_th; // BP网络in-hidden阈值 private double[] hd_out_th; // BP网络hidden-out阈值 private double[][] out; // 每个神经元的值经S型函数转化后的输出值,输入层就为原值 private double[][] delta; // delta学习规则中的值 // 获得网络三层中神经元最多的数量 public int GetMaxNum() { return Math.max(Math.max(in_num, hd_num), out_num); } // 设置权值学习率 public void SetEtaW() { ETA_W = 0.5; } // 设置阈值学习率 public void SetEtaT() { ETA_T = 0.5; } // BPNN训练 public void Train(int in_number, int out_number, ArrayList<ArrayList<Double>> arraylist) throws IOException { list = arraylist; in_num = in_number; out_num = out_number; GetNums(in_num, out_num); // 获取输入层、隐层、输出层的节点数 // SetEtaW(); // 设置学习率 // SetEtaT(); InitNetWork(); // 初始化网络的权值和阈值 int datanum = list.size(); // 训练数据的组数 int createsize = GetMaxNum(); // 比较创建存储每一层输出数据的数组 out = new double[3][createsize]; for (int iter = 0; iter < MaxTrain; iter++) { for (int cnd = 0; cnd < datanum; cnd++) { // 第一层输入节点赋值 for (int i = 0; i < in_num; i++) { out[0][i] = list.get(cnd).get(i); // 为输入层节点赋值,其输入与输出相同 } Forward(); // 前向传播 Backward(cnd); // 误差反向传播 } System.out.println("This is the " + (iter + 1) + " th trainning NetWork !"); accu = GetAccu(); System.out.println("All Samples Accuracy is " + accu); if (accu < ACCU) break; } } // 获取输入层、隐层、输出层的节点数,in_number、out_number分别为输入层节点数和输出层节点数 public void GetNums(int in_number, int out_number) { in_num = in_number; out_num = out_number; hd_num = (int) Math.sqrt(in_num + out_num) + ADJUST; if (hd_num > NodeNum) hd_num = NodeNum; // 隐层节点数不能大于最大节点数 } // 初始化网络的权值和阈值 public void InitNetWork() { // 初始化上一次权值量,范围为-0.5-0.5之间 //in_hd_last = new double[in_num][hd_num]; //hd_out_last = new double[hd_num][out_num]; in_hd_weight = new double[in_num][hd_num]; for (int i = 0; i < in_num; i++) for (int j = 0; j < hd_num; j++) { int flag = 1; // 符号标志位(-1或者1) if ((new Random().nextInt(2)) == 1) flag = 1; else flag = -1; in_hd_weight[i][j] = (new Random().nextDouble() / 2) * flag; // 初始化in-hidden的权值 //in_hd_last[i][j] = 0; } hd_out_weight = new double[hd_num][out_num]; for (int i = 0; i < hd_num; i++) for (int j = 0; j < out_num; j++) { int flag = 1; // 符号标志位(-1或者1) if ((new Random().nextInt(2)) == 1) flag = 1; else flag = -1; hd_out_weight[i][j] = (new Random().nextDouble() / 2) * flag; // 初始化hidden-out的权值 //hd_out_last[i][j] = 0; } // 阈值均初始化为0 in_hd_th = new double[hd_num]; for (int k = 0; k < hd_num; k++) in_hd_th[k] = 0; hd_out_th = new double[out_num]; for (int k = 0; k < out_num; k++) hd_out_th[k] = 0; } // 计算单个样本的误差 public double GetError(int cnd) { double ans = 0; for (int i = 0; i < out_num; i++) ans += 0.5 * (out[2][i] - list.get(cnd).get(in_num + i)) * (out[2][i] - list.get(cnd).get(in_num + i)); return ans; } // 计算所有样本的平均精度 public double GetAccu() { double ans = 0; int num = list.size(); for (int i = 0; i < num; i++) { int m = in_num; for (int j = 0; j < m; j++) out[0][j] = list.get(i).get(j); Forward(); int n = out_num; for (int k = 0; k < n; k++) ans += 0.5 * (list.get(i).get(in_num + k) - out[2][k]) * (list.get(i).get(in_num + k) - out[2][k]); } return ans / num; } // 前向传播 public void Forward() { // 计算隐层节点的输出值 for (int j = 0; j < hd_num; j++) { double v = 0; for (int i = 0; i < in_num; i++) v += in_hd_weight[i][j] * out[0][i]; v += in_hd_th[j]; out[1][j] = Sigmoid(v); } // 计算输出层节点的输出值 for (int j = 0; j < out_num; j++) { double v = 0; for (int i = 0; i < hd_num; i++) v += hd_out_weight[i][j] * out[1][i]; v += hd_out_th[j]; out[2][j] = Sigmoid(v); } } // 误差反向传播 public void Backward(int cnd) { CalcDelta(cnd); // 计算权值调整量 UpdateNetWork(); // 更新BP神经网络的权值和阈值 } // 计算delta调整量 public void CalcDelta(int cnd) { int createsize = GetMaxNum(); // 比较创建数组 delta = new double[3][createsize]; // 计算输出层的delta值 for (int i = 0; i < out_num; i++) { delta[2][i] = (list.get(cnd).get(in_num + i) - out[2][i]) * SigmoidDerivative(out[2][i]); } // 计算隐层的delta值 for (int i = 0; i < hd_num; i++) { double t = 0; for (int j = 0; j < out_num; j++) t += hd_out_weight[i][j] * delta[2][j]; delta[1][i] = t * SigmoidDerivative(out[1][i]); } } // 更新BP神经网络的权值和阈值 public void UpdateNetWork() { // 隐含层和输出层之间权值和阀值调整 for (int i = 0; i < hd_num; i++) { for (int j = 0; j < out_num; j++) { hd_out_weight[i][j] += ETA_W * delta[2][j] * out[1][i]; // 未加权值动量项 /* 动量项 * hd_out_weight[i][j] += (ETA_A * hd_out_last[i][j] + ETA_W * delta[2][j] * out[1][i]); hd_out_last[i][j] = ETA_A * * hd_out_last[i][j] + ETA_W delta[2][j] * out[1][i]; */ } } for (int i = 0; i < out_num; i++) hd_out_th[i] += ETA_T * delta[2][i]; // 输入层和隐含层之间权值和阀值调整 for (int i = 0; i < in_num; i++) { for (int j = 0; j < hd_num; j++) { in_hd_weight[i][j] += ETA_W * delta[1][j] * out[0][i]; // 未加权值动量项 /* 动量项 * in_hd_weight[i][j] += (ETA_A * in_hd_last[i][j] + ETA_W * delta[1][j] * out[0][i]); in_hd_last[i][j] = ETA_A * * in_hd_last[i][j] + ETA_W delta[1][j] * out[0][i]; */ } } for (int i = 0; i < hd_num; i++) in_hd_th[i] += ETA_T * delta[1][i]; } // 符号函数sign public int Sign(double x) { if (x > 0) return 1; else if (x < 0) return -1; else return 0; } // 返回最大值 public double Maximum(double x, double y) { if (x >= y) return x; else return y; } // 返回最小值 public double Minimum(double x, double y) { if (x <= y) return x; else return y; } // log-sigmoid函数 public double Sigmoid(double x) { return (double) (1 / (1 + Math.exp(-x))); } // log-sigmoid函数的倒数 public double SigmoidDerivative(double y) { return (double) (y * (1 - y)); } // tan-sigmoid函数 public double TSigmoid(double x) { return (double) ((1 - Math.exp(-x)) / (1 + Math.exp(-x))); } // tan-sigmoid函数的倒数 public double TSigmoidDerivative(double y) { return (double) (1 - (y * y)); } // 分类预测函数 public ArrayList<ArrayList<Double>> ForeCast( ArrayList<ArrayList<Double>> arraylist) { ArrayList<ArrayList<Double>> alloutlist = new ArrayList<>(); ArrayList<Double> outlist = new ArrayList<Double>(); int datanum = arraylist.size(); for (int cnd = 0; cnd < datanum; cnd++) { for (int i = 0; i < in_num; i++) out[0][i] = arraylist.get(cnd).get(i); // 为输入节点赋值 Forward(); for (int i = 0; i < out_num; i++) { if (out[2][i] > 0 && out[2][i] < 0.5) out[2][i] = 0; else if (out[2][i] > 0.5 && out[2][i] < 1) { out[2][i] = 1; } outlist.add(out[2][i]); } alloutlist.add(outlist); outlist = new ArrayList<Double>(); outlist.clear(); } return alloutlist; } }

DataUtil.java

数据处理类,将训练数据和测试数据进行处理。

import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; class DataUtil { private ArrayList<ArrayList<Double>> alllist = new ArrayList<ArrayList<Double>>(); // 存放所有数据 private ArrayList<String> outlist = new ArrayList<String>(); // 存放输出数据,索引对应每个everylist的输出 private ArrayList<String> checklist = new ArrayList<String>(); //存放测试集的真实输出字符串 private int in_num = 0; private int out_num = 0; // 输入输出数据的个数 private int type_num = 0; // 输出的类型数量 private double[][] nom_data; //归一化输入数据中的最大值和最小值 private int in_data_num = 0; //提前获得输入数据的个数 // 获取输出类型的个数 public int GetTypeNum() { return type_num; } // 设置输出类型的个数 public void SetTypeNum(int type_num) { this.type_num = type_num; } // 获取输入数据的个数 public int GetInNum() { return in_num; } // 获取输出数据的个数 public int GetOutNum() { return out_num; } // 获取所有数据的数组 public ArrayList<ArrayList<Double>> GetList() { return alllist; } // 获取输出为字符串形式的数据 public ArrayList<String> GetOutList() { return outlist; } // 获取输出为字符串形式的数据 public ArrayList<String> GetCheckList() { return checklist; } //返回归一化数据所需最大最小值 public double[][] GetMaxMin(){ return nom_data; } // 读取文件初始化数据 public void ReadFile(String filepath, String sep, int flag) throws Exception { ArrayList<Double> everylist = new ArrayList<Double>(); // 存放每一组输入输出数据 int readflag = flag; // flag=0,train;flag=1,test String encoding = "GBK"; File file = new File(filepath); if (file.isFile() && file.exists()) { // 判断文件是否存在 InputStreamReader read = new InputStreamReader(new FileInputStream( file), encoding);// 考虑到编码格式 BufferedReader bufferedReader = new BufferedReader(read); String lineTxt = null; while ((lineTxt = bufferedReader.readLine()) != null) { int in_number = 0; String splits[] = lineTxt.split(sep); // 按','截取字符串 if (readflag == 0) { for (int i = 0; i < splits.length; i++) try { everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1])); in_number++; } catch (Exception e) { if (!outlist.contains(splits[i])) outlist.add(splits[i]); // 存放字符串形式的输出数据 for (int k = 0; k < type_num; k++) { everylist.add(0.0); } everylist .set(in_number + outlist.indexOf(splits[i]), 1.0); } } else if (readflag == 1) { for (int i = 0; i < splits.length; i++) try { everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1])); in_number++; } catch (Exception e) { checklist.add(splits[i]); // 存放字符串形式的输出数据 } } alllist.add(everylist); // 存放所有数据 in_num = in_number; out_num = type_num; everylist = new ArrayList<Double>(); everylist.clear(); } bufferedReader.close(); } } //向文件写入分类结果 public void WriteFile(String filepath, ArrayList<ArrayList<Double>> list, int in_number, ArrayList<String> resultlist) throws IOException{ File file = new File(filepath); FileWriter fw = null; BufferedWriter writer = null; try { fw = new FileWriter(file); writer = new BufferedWriter(fw); for(int i=0;i<list.size();i++){ for(int j=0;j<in_number;j++) writer.write(list.get(i).get(j)+","); writer.write(resultlist.get(i)); writer.newLine(); } writer.flush(); } catch (IOException e) { e.printStackTrace(); }finally{ writer.close(); fw.close(); } } //学习样本归一化,找到输入样本数据的最大值和最小值 public void NormalizeData(String filepath) throws IOException{ //提前获得输入数据的个数 GetBeforIn(filepath); int flag=1; nom_data = new double[in_data_num][2]; String encoding = "GBK"; File file = new File(filepath); if (file.isFile() && file.exists()) { // 判断文件是否存在 InputStreamReader read = new InputStreamReader(new FileInputStream( file), encoding);// 考虑到编码格式 BufferedReader bufferedReader = new BufferedReader(read); String lineTxt = null; while ((lineTxt = bufferedReader.readLine()) != null) { String splits[] = lineTxt.split(","); // 按','截取字符串 for (int i = 0; i < splits.length-1; i++){ if(flag==1){ nom_data[i][0]=Double.valueOf(splits[i]); nom_data[i][1]=Double.valueOf(splits[i]); } else{ if(Double.valueOf(splits[i])>nom_data[i][0]) nom_data[i][0]=Double.valueOf(splits[i]); if(Double.valueOf(splits[i])<nom_data[i][1]) nom_data[i][1]=Double.valueOf(splits[i]); } } flag=0; } bufferedReader.close(); } } //归一化前获得输入数据的个数 public void GetBeforIn(String filepath) throws IOException{ String encoding = "GBK"; File file = new File(filepath); if (file.isFile() && file.exists()) { // 判断文件是否存在 InputStreamReader read = new InputStreamReader(new FileInputStream( file), encoding);// 考虑到编码格式 //提前获得输入数据的个数 BufferedReader beforeReader = new BufferedReader(read); String beforetext = beforeReader.readLine(); String splits[] = beforetext.split(","); in_data_num = splits.length-1; beforeReader.close(); } } //归一化公式 public double Normalize(double x, double max, double min){ double y = 0.1+0.8*(x-min)/(max-min); return y; } }

Test.java

import java.util.ArrayList; public class Test { public static void main(String args[]) throws Exception { ArrayList<ArrayList<Double>> alllist = new ArrayList<ArrayList<Double>>(); // 存放所有数据 ArrayList<String> outlist = new ArrayList<String>(); // 存放分类的字符串 int in_num = 0, out_num = 0; // 输入输出数据的个数 DataUtil dataUtil = new DataUtil(); // 初始化数据 dataUtil.NormalizeData("E:\\BP_data\\train.txt"); dataUtil.SetTypeNum(3); // 设置输出类型的数量 dataUtil.ReadFile("E:\\BP_data\\train.txt", ",", 0); in_num = dataUtil.GetInNum(); // 获得输入数据的个数 out_num = dataUtil.GetOutNum(); // 获得输出数据的个数(个数代表类型个数) alllist = dataUtil.GetList(); // 获得初始化后的数据 outlist = dataUtil.GetOutList(); System.out.print("分类的类型:"); for(int i =0 ;i<outlist.size();i++) System.out.print(outlist.get(i)+" "); System.out.println(); System.out.println("训练集的数量:"+alllist.size()); BPNN bpnn = new BPNN(); // 训练 System.out.println("Train Start!"); System.out.println("............."); bpnn.Train(in_num, out_num, alllist); System.out.println("Train End!"); // 测试 DataUtil testUtil = new DataUtil(); testUtil.NormalizeData("E:\\BP_data\\test.txt"); testUtil.SetTypeNum(3); // 设置输出类型的数量 testUtil.ReadFile("E:\\BP_data\\test.txt", ",", 1); ArrayList<ArrayList<Double>> testList = new ArrayList<ArrayList<Double>>(); ArrayList<ArrayList<Double>> resultList = new ArrayList<ArrayList<Double>>(); ArrayList<String> normallist = new ArrayList<String>(); // 存放测试集标准的输出字符串 ArrayList<String> resultlist = new ArrayList<String>(); // 存放测试集计算后的输出字符串 double right = 0; // 分类正确的数量 int type_num = 0; // 类型的数量 double all_num = 0; //测试集的数量 type_num = outlist.size(); testList = testUtil.GetList(); // 获取测试数据 normallist = testUtil.GetCheckList(); int errorcount=0; // 分类错误的数量 resultList = bpnn.ForeCast(testList); // 测试 all_num=resultList.size(); for (int i = 0; i < resultList.size(); i++) { String checkString = "unknow"; for (int j = 0; j < type_num; j++) { if(resultList.get(i).get(j)==1.0){ checkString = outlist.get(j); resultlist.add(checkString); } /*else{ resultlist.add(checkString); }*/ } /* if(checkString.equals("unknow")) errorcount++; */ if(checkString.equals(normallist.get(i))) right++; } testUtil.WriteFile("E:\\BP_data\\result.txt",testList,in_num,resultlist); System.out.println("测试集的数量:"+ (new Double(all_num)).intValue()); System.out.println("分类正确的数量:"+(new Double(right)).intValue()); System.out.println("算法的分类正确率为:"+right/all_num); System.out.println("分类结果存储在:E:\\BP_data\\result.txt"); } }

在这里笔者只通过 Java 代码建立了 BP 神经网络的基本模型,实现 Iris 数据集的分类预测,效果如下:

….

其实,也可以用交叉预测去判断模型的分类性能。通过简单的代码可以对 BP 神经网络的数学原理有一个更好的巩固。

转载请注明原文地址: https://www.6miu.com/read-38851.html

最新回复(0)