xgboost用于文本分类预测,java接口

xiaoxiao2021-02-28  49

周末花了两天时间从安装xgboost到用于文本预测,记录下,首先是把文本分词,去停顿词,计算tf-idf值,然后模型训练,模型保存,加载模型,模型预测:

训练模型代码:

package com.meituan.model.xgboost; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.ArrayList; import java.util.List; import java.util.Arrays; import java.util.Map; import java.util.Map.Entry; import java.io.File; import java.io.IOException; import org.apache.commons.io.FileUtils; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoostError; public class TrainXgboost { private static String path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"; private static String trainString = "agaricus.txt.train"; private static String testString = "agaricus.txt.test"; public static void main(String[] args) throws XGBoostError, IOException { DMatrix trainMat = new DMatrix("file/train.txt"); DMatrix testMat = new DMatrix("file/test.txt"); // specify parameters Map<String, Object> params = new HashMap<String, Object>(); params.put("booster", "gbtree"); params.put("eta", 0.6); // 为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重 params.put("max_depth", 22);// 树最大深度 params.put("silent", 0); // 为1的时候不会打印模型迭代的信息,为0可以看到打印的信息 params.put("lambda", 2.5);// 用于逻辑回归的时候L2正则选项 params.put("min_child_weight", 6); // params.put("nthread", 6); //不使用的话系统会默认得到最大的线程数目 params.put("objective", "binary:logistic");// 目标函数值 // specify watchList HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>(); watches.put("train", trainMat); // watches.put("test", testMat); // train a booster int round = 100; Booster booster = XGBoost.train(trainMat, params, round, watches, null, null); // booster.saveModel("xgboost/xgboost.model"); Map<String, Integer> map = booster.getFeatureScore(null); List<Map.Entry<String, Integer>> list = new ArrayList<Map.Entry<String, Integer>>( map.entrySet()); Collections.sort(list, new Comparator<Map.Entry<String, Integer>>() { @Override public int compare(Entry<String, Integer> o1, Entry<String, Integer> o2) { double result = o1.getValue() - o2.getValue(); if (result > 0) { return -1; } else { return 1; } } }); FileUtils.writeLines(new File("xgboost/keyword.txt"), list); float[][] result = booster.predict(testMat); /* * for(int i=0;i<result.length;i++){ * * for(int j=0;j<result[i].length;j++){ * System.out.print(result[i][j]+"\t"); } System.out.println(); } */ System.out.println("length is:" + result.length); } }

模型训练比较简单,先看看模型预测写的代码,准备的两个方法,把文本转化为libsvm的形式,再转化DMatrix:

public static CSRSparseData getSparseData(String content,Map<String, Terms> termsmap){ if (StringUtils.isBlank(content)) { return null; } Map<String, Long> maps = ToAnalysis .parse(WordUtil.replaceAllSynonyms(TextUtil.fan2Jian(WordUtil .replaceAll(content.toLowerCase())))) .getTerms() .stream() .map(x -> x.getName()) .filter(x -> !WordUtil.isStopword(x) ) .collect(Collectors.groupingBy(p -> p, Collectors.counting())); if (maps == null || maps.size() == 0) { return null; } int sum = maps.values().stream() .reduce((result, element) -> result = result + element).get() .intValue(); Map<Integer, Double> treemap = new TreeMap<Integer, Double>(); for (Entry<String, Long> map : maps.entrySet()) { String key = map.getKey(); Terms keyword = termsmap.get(key); double tf = TFIDF.tf(map.getValue(), sum); if (keyword == null) { continue; } int id = keyword.getId(); double idf = 0; idf = TFIDF.idf(termsmap.get("documentTotal").getFreq(), keyword.getFreq()); double tfidf = TFIDF.tfidf(tf, idf); treemap.put(id, tfidf); } if (treemap.size() == 0) { return null; } CSRSparseData spData = new CSRSparseData(); List<Float> tlabels = new ArrayList<>(); List<Float> tdata = new ArrayList<>(); List<Long> theaders = new ArrayList<>(); List<Integer> tindex = new ArrayList<>(); theaders.add(0l); theaders.add((long) treemap.size()); for (Entry<Integer, Double> map : treemap.entrySet()) { BigDecimal b = new BigDecimal(Double.toString(map.getValue())); tdata.add(b.floatValue()); tindex.add(Integer.valueOf(map.getKey())); } spData.labels = ArrayUtils.toPrimitive(tlabels .toArray(new Float[tlabels.size()])); spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata .size()])); spData.colIndex = ArrayUtils.toPrimitive(tindex .toArray(new Integer[tindex.size()])); spData.rowHeaders = ArrayUtils.toPrimitive(theaders .toArray(new Long[theaders.size()])); return spData; } public static double getClassification(Booster booster,String content,Map<String, Terms> termsmap) throws XGBoostError{ CSRSparseData spData=getSparseData(content, termsmap); if(spData==null){ return 0.0; } DMatrix data = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR, 0); return booster.predict(data)[0][0]; }

用于文本预测:

package com.meituan.model.xgboost; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; import java.util.Map; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoostError; import com.meituan.model.libsvm.DocumentTransForm; import com.meituan.model.libsvm.Terms; public class Prediction { public static void main(String[] args) throws XGBoostError, IOException { Map<String, Terms> termsmap = DocumentTransForm.readmap("file/model"); Booster booster = XGBoost.loadModel("xgboost/xgboost.model"); System.out.println(DataLoader.getClassification(booster,"我们在吃饭",termsmap)); test(termsmap, booster, "/Users/shuubiasahi/Desktop/测试文件.csv"); } public static void test(Map<String, Terms> termsmap, Booster booster, String path) throws IOException, XGBoostError { BufferedReader buffer = new BufferedReader(new InputStreamReader( new FileInputStream(path))); BufferedWriter bufferwrite = new BufferedWriter(new OutputStreamWriter( new FileOutputStream("xgboost/merge.txt"))); BufferedWriter bufferwriteresult = new BufferedWriter( new OutputStreamWriter(new FileOutputStream( "xgboost/result.txt"))); String label = null; String line = buffer.readLine(); while (line != null) { String[] lines = line.split("\t"); if ("1".equalsIgnoreCase(lines[0]) && "美食".equalsIgnoreCase(lines[2])) { label = "1"; } else { label = "0"; } String content = lines[3]; double p = DataLoader.getClassification(booster, content, termsmap); if (p > 0) { // if (p > 0.86) { // if (WordUtil.isNumberMain(content) // ) { // p = 0.001; // } // } bufferwrite.write(label + "," + p + "\n"); String prString = p > 0.5 ? "1" : "0"; if (!label.equals(prString)) bufferwriteresult.write(label + "\t" + prString + "\t" + p + "\t" + lines[3] + "\n"); } line = buffer.readLine(); } bufferwriteresult.close(); buffer.close(); bufferwrite.close(); } }

转载请注明原文地址: https://www.6miu.com/read-83332.html

最新回复(0)