xgboost参数说明在网上找了很多关于xgboost的文章,基本上90%都是以python在说明的,java的很少,
xgboost参数说明http://blog.csdn.net/zc02051126/article/details/46711047在这篇文章里面说明的很详细,
在java中使用的话,只要:
Map<String, Object> params = new HashMap<String, Object>(); params.put("eta", 1.0); //为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重 params.put("max_depth", 15);//叔最大深度 params.put("silent", 1); //为1的时候不会打印模型迭代的信息,为0可以看到打印的信息 params.put("lambda", 2);//用于逻辑回归的时候L2正则选项 params.put("min_child_weight", 6); // params.put("nthread", 6); //不使用的话系统会默认得到最大的线程数目 params.put("objective", "binary:logistic");//目标函数值关于xgboost数据训练格式,官网DMatrix提供的构造函数主要有三种:
第一种是采用的是l提供ibsvm格式文件所在磁盘路径,官网提供的数据也是这个例子,然后把libsvm格式数据文件转化为DMatrix类,
去看看这个类的源码,也是调用c++底层代码,核心代码还是c++,无论是python、java、scala都值一个外壳。
第二种采用的是LabeledPoint格式,这也是变种libsvm格式文件,用这个不大方便,会把数据缓存到一个目标里面去。
第三种采用的是DMatrix.SparseType,这个我还是比较喜欢,最后转化Dmatrix。
其中预测输入都是用的DMatrix类型参数。
说了这么多,关于模型训练、保存不上代码说明,看看模型预测,用写代码说明下,在git上提供的一个例子再加了两个方法,
这个方法作用是把一行文本转化为DMatrix类型,以提供模型预测:
package com.meituan.model.xgboost; import java.io.*; import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.Map.Entry; import java.util.stream.Collectors; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.XGBoostError; import org.ansj.splitWord.analysis.ToAnalysis; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang3.ArrayUtils; import com.meituan.model.libsvm.TFIDF; import com.meituan.model.libsvm.Terms; import com.meituan.nlp.util.TextUtil; import com.meituan.nlp.util.WordUtil; public class DataLoader { public static class DenseData { public float[] labels; public float[] data; public int nrow; public int ncol; } public static class CSRSparseData { public float[] labels; public float[] data; public long[] rowHeaders; public int[] colIndex; } public static DenseData loadCSVFile(String filePath) throws IOException { DenseData denseData = new DenseData(); File f = new File(filePath); FileInputStream in = new FileInputStream(f); BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); denseData.nrow = 0; denseData.ncol = -1; String line; List<Float> tlabels = new ArrayList<>(); List<Float> tdata = new ArrayList<>(); while ((line = reader.readLine()) != null) { String[] items = line.trim().split(","); if (items.length == 0) { continue; } denseData.nrow++; if (denseData.ncol == -1) { denseData.ncol = items.length - 1; } tlabels.add(Float.valueOf(items[items.length - 1])); for (int i = 0; i < items.length - 1; i++) { tdata.add(Float.valueOf(items[i])); } } reader.close(); in.close(); denseData.labels = ArrayUtils.toPrimitive(tlabels .toArray(new Float[tlabels.size()])); denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata .size()])); return denseData; } public static CSRSparseData loadSVMFile(String filePath) throws IOException { CSRSparseData spData = new CSRSparseData(); List<Float> tlabels = new ArrayList<>(); List<Float> tdata = new ArrayList<>(); List<Long> theaders = new ArrayList<>(); List<Integer> tindex = new ArrayList<>(); File f = new File(filePath); FileInputStream in = new FileInputStream(f); BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); String line; long rowheader = 0; theaders.add(rowheader); while ((line = reader.readLine()) != null) { String[] items = line.trim().split(" "); if (items.length == 0) { continue; } rowheader += items.length - 1; theaders.add(rowheader); tlabels.add(Float.valueOf(items[0])); for (int i = 1; i < items.length; i++) { String[] tup = items[i].split(":"); assert tup.length == 2; tdata.add(Float.valueOf(tup[1])); tindex.add(Integer.valueOf(tup[0])); } } spData.labels = ArrayUtils.toPrimitive(tlabels .toArray(new Float[tlabels.size()])); spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata .size()])); spData.colIndex = ArrayUtils.toPrimitive(tindex .toArray(new Integer[tindex.size()])); spData.rowHeaders = ArrayUtils.toPrimitive(theaders .toArray(new Long[theaders.size()])); return spData; } public static CSRSparseData getSparseData(String content,Map<String, Terms> termsmap){ if (StringUtils.isBlank(content)) { return null; } Map<String, Long> maps = ToAnalysis .parse(WordUtil.replaceAllSynonyms(TextUtil.fan2Jian(WordUtil .replaceAll(content.toLowerCase())))) .getTerms() .stream() .map(x -> x.getName()) .filter(x -> !WordUtil.isStopword(x) ) .collect(Collectors.groupingBy(p -> p, Collectors.counting())); if (maps == null || maps.size() == 0) { return null; } int sum = maps.values().stream() .reduce((result, element) -> result = result + element).get() .intValue(); Map<Integer, Double> treemap = new TreeMap<Integer, Double>(); for (Entry<String, Long> map : maps.entrySet()) { String key = map.getKey(); Terms keyword = termsmap.get(key); double tf = TFIDF.tf(map.getValue(), sum); if (keyword == null) { continue; } int id = keyword.getId(); double idf = 0; idf = TFIDF.idf(termsmap.get("documentTotal").getFreq(), keyword.getFreq()); double tfidf = TFIDF.tfidf(tf, idf); treemap.put(id, tfidf); } if (treemap.size() == 0) { return null; } CSRSparseData spData = new CSRSparseData(); List<Float> tlabels = new ArrayList<>(); List<Float> tdata = new ArrayList<>(); List<Long> theaders = new ArrayList<>(); List<Integer> tindex = new ArrayList<>(); theaders.add(0l); theaders.add((long) treemap.size()); for (Entry<Integer, Double> map : treemap.entrySet()) { BigDecimal b = new BigDecimal(Double.toString(map.getValue())); tdata.add(b.floatValue()); tindex.add(Integer.valueOf(map.getKey())); } spData.labels = ArrayUtils.toPrimitive(tlabels .toArray(new Float[tlabels.size()])); spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata .size()])); spData.colIndex = ArrayUtils.toPrimitive(tindex .toArray(new Integer[tindex.size()])); spData.rowHeaders = ArrayUtils.toPrimitive(theaders .toArray(new Long[theaders.size()])); return spData; } public static double getClassify(Booster booster,String content,Map<String, Terms> termsmap) throws XGBoostError{ CSRSparseData spData=getSparseData(content, termsmap); if(spData==null){ return 0.0; } DMatrix data = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR, 0); return booster.predict(data)[0][0]; } }
