记录一次用spark java写文件到本地(java推荐算法)

xiaoxiao2025-05-27  20

import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.recommendation.ALS; import org.apache.spark.ml.recommendation.ALSModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import java.io.Serializable; public class RecommendMovie { //创建一个得分的类并实现Serializable接口 public static class Rating implements Serializable { private int userid; private int movieid; private float rating; private long timestamp; //无参构造方法 public Rating() { } //有参构造方法 public Rating(int userid, int movieid, float rating, long timestamp) { this.userid = userid; this.movieid = movieid; this.rating = rating; this.timestamp = timestamp; } //get方法获取userid和其他 public int getUserid() { return userid; } public int getMovieid() { return movieid; } public float getRating() { return rating; } public long getTimestamp() { return timestamp; } //重写parse方法来将字符串str转换成Rating类型 public static Rating parseRating(String str) { //将传进来的数据进行切分获取其中的四个字段 String[] movieInfo = str.split(","); //如果不是四个字段就抛出错误 if (movieInfo.length != 4) { throw new IllegalArgumentException("Each line must contain 4 fields"); } //将4个字符串字段分别进行转换 int userid = Integer.parseInt(movieInfo[0]); int movieid = Integer.parseInt(movieInfo[1]); float rating = Float.parseFloat(movieInfo[2]); long timestamp = Long.parseLong(movieInfo[3]); //返回一个Rating类型的类,供调用方使用 return new Rating(userid, movieid, rating, timestamp); } } public static void main(String[] args) { //调用spark的ml包进行协同过滤推荐算法 SparkSession spark = SparkSession.builder().master("local[*]").appName("RecommendMovie").getOrCreate(); //将测试数据转换成javaRDD并用Rating进行封装 JavaRDD<Rating> javaRDD = spark.read().textFile("C:\\Users\\13373\\Desktop\\test.data").javaRDD().map(Rating::parseRating); //将类型转换成dataframe用dataFrame中的als进行计算 Dataset<Row> dataFrame = spark.createDataFrame(javaRDD, Rating.class); //进行随机切分,0.8的训练数据和0.2的测试数据 Dataset<Row>[] split = dataFrame.randomSplit(new double[]{0.8, 0.2}); //训练数据 Dataset<Row> training = split[0]; //测试数据 Dataset<Row> test = split[1]; /** *获取ALS的实例,设置最大的迭代次数和最小平方差,该对象用来训练已有数据得到模型 * * 即数据建模 */ ALS als = new ALS() .setMaxIter(5)//最大迭代次数 .setRegParam(0.01)//最小平方差 .setUserCol("userid") .setItemCol("movieid") .setRatingCol("rating"); ALSModel fit = als.fit(training); /** * 对模型的测试评估 */ fit.setColdStartStrategy("drop"); Dataset<Row> predictions = fit.transform(test); /** * 回归测试 * 均方根误差 */ RegressionEvaluator evaluator = new RegressionEvaluator() .setMetricName("rmse") .setLabelCol("rating") .setPredictionCol("prediction"); double rmse = evaluator.evaluate(predictions); System.out.println("Root-mean-square error = "+rmse); //得出10个相同用户 Dataset<Row> userCF = fit.recommendForAllUsers(10); //需要将dataset转换成javaRDD再进行存储工作 userCF.toJavaRDD().coalesce(1).saveAsTextFile("C:\\Users\\13373\\Desktop\\itemCF.txt"); //得出10个相同商品 Dataset<Row> itemCF = fit.recommendForAllItems(10); spark.stop(); } }

pom文件:

<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>Aiads</groupId> <artifactId>morgan13</artifactId> <version>1.0-SNAPSHOT</version> <properties> <java.version>1.8</java.version> <junit.version>4.12</junit.version> <mysql.driver.version>5.1.38</mysql.driver.version> <slf4j.version>1.7.21</slf4j.version> <fastjson.version>1.2.11</fastjson.version> <scala.version>2.11.11</scala.version> <spark.version>2.2.0</spark.version> </properties> <dependencies> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.11</artifactId> <version>2.2.0</version> <!--<scope>runtime</scope>--> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.11</artifactId> <version>2.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.11</artifactId> <version>2.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.scala-lang/scala-library --> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> <version>2.11.11</version> </dependency> </dependencies> <!--maven中pom文件的java设置--> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <version>3.7.0</version> <configuration> <source>1.8</source> <target>1.8</target> </configuration> </plugin> </plugins> </build> <profiles> <profile> <id>aiads</id> <properties> <maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.compilerVersion>1.8</maven.compiler.compilerVersion> <!-- <sonar.host.url>http://sonar.aiads.com</sonar.host.url> <sonar.login>5e7a06adc9654b9ee9c4a114ed8b73e2f2da6489</sonar.login> --> </properties> <repositories> <repository> <id>nexus</id> <name>local private nexus</name> <url>http://nexus.aiads.com/repository/maven-public</url> <releases> <enabled>true</enabled> </releases> <snapshots> <enabled>true</enabled> </snapshots> </repository> </repositories> <pluginRepositories> <pluginRepository> <id>nexus</id> <name>local private nexus</name> <url>http://nexus.aiads.com/repository/maven-public</url> <releases> <enabled>true</enabled> </releases> <snapshots> <enabled>true</enabled> </snapshots> </pluginRepository> </pluginRepositories> </profile> </profiles> </project>

 

 

转载请注明原文地址: https://www.6miu.com/read-5030811.html

最新回复(0)