import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.Serializable;
public class RecommendMovie {
//创建一个得分的类并实现Serializable接口
public static class Rating implements Serializable {
private int userid;
private int movieid;
private float rating;
private long timestamp;
//无参构造方法
public Rating() {
}
//有参构造方法
public Rating(int userid, int movieid, float rating, long timestamp) {
this.userid = userid;
this.movieid = movieid;
this.rating = rating;
this.timestamp = timestamp;
}
//get方法获取userid和其他
public int getUserid() {
return userid;
}
public int getMovieid() {
return movieid;
}
public float getRating() {
return rating;
}
public long getTimestamp() {
return timestamp;
}
//重写parse方法来将字符串str转换成Rating类型
public static Rating parseRating(String str) {
//将传进来的数据进行切分获取其中的四个字段
String[] movieInfo = str.split(",");
//如果不是四个字段就抛出错误
if (movieInfo.length != 4) {
throw new IllegalArgumentException("Each line must contain 4 fields");
}
//将4个字符串字段分别进行转换
int userid = Integer.parseInt(movieInfo[0]);
int movieid = Integer.parseInt(movieInfo[1]);
float rating = Float.parseFloat(movieInfo[2]);
long timestamp = Long.parseLong(movieInfo[3]);
//返回一个Rating类型的类,供调用方使用
return new Rating(userid, movieid, rating, timestamp);
}
}
public static void main(String[] args) {
//调用spark的ml包进行协同过滤推荐算法
SparkSession spark = SparkSession.builder().master("local[*]").appName("RecommendMovie").getOrCreate();
//将测试数据转换成javaRDD并用Rating进行封装
JavaRDD<Rating> javaRDD = spark.read().textFile("C:\\Users\\13373\\Desktop\\test.data").javaRDD().map(Rating::parseRating);
//将类型转换成dataframe用dataFrame中的als进行计算
Dataset<Row> dataFrame = spark.createDataFrame(javaRDD, Rating.class);
//进行随机切分,0.8的训练数据和0.2的测试数据
Dataset<Row>[] split = dataFrame.randomSplit(new double[]{0.8, 0.2});
//训练数据
Dataset<Row> training = split[0];
//测试数据
Dataset<Row> test = split[1];
/**
*获取ALS的实例,设置最大的迭代次数和最小平方差,该对象用来训练已有数据得到模型
*
* 即数据建模
*/
ALS als = new ALS()
.setMaxIter(5)//最大迭代次数
.setRegParam(0.01)//最小平方差
.setUserCol("userid")
.setItemCol("movieid")
.setRatingCol("rating");
ALSModel fit = als.fit(training);
/**
* 对模型的测试评估
*/
fit.setColdStartStrategy("drop");
Dataset<Row> predictions = fit.transform(test);
/**
* 回归测试
* 均方根误差
*/
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = "+rmse);
//得出10个相同用户
Dataset<Row> userCF = fit.recommendForAllUsers(10);
//需要将dataset转换成javaRDD再进行存储工作
userCF.toJavaRDD().coalesce(1).saveAsTextFile("C:\\Users\\13373\\Desktop\\itemCF.txt");
//得出10个相同商品
Dataset<Row> itemCF = fit.recommendForAllItems(10);
spark.stop();
}
}
pom文件:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>Aiads</groupId>
<artifactId>morgan13</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<java.version>1.8</java.version>
<junit.version>4.12</junit.version>
<mysql.driver.version>5.1.38</mysql.driver.version>
<slf4j.version>1.7.21</slf4j.version>
<fastjson.version>1.2.11</fastjson.version>
<scala.version>2.11.11</scala.version>
<spark.version>2.2.0</spark.version>
</properties>
<dependencies>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.2.0</version>
<!--<scope>runtime</scope>-->
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.2.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.2.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.scala-lang/scala-library -->
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>2.11.11</version>
</dependency>
</dependencies>
<!--maven中pom文件的java设置-->
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.7.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
<profiles>
<profile>
<id>aiads</id>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<maven.compiler.compilerVersion>1.8</maven.compiler.compilerVersion>
<!-- <sonar.host.url>http://sonar.aiads.com</sonar.host.url>
<sonar.login>5e7a06adc9654b9ee9c4a114ed8b73e2f2da6489</sonar.login> -->
</properties>
<repositories>
<repository>
<id>nexus</id>
<name>local private nexus</name>
<url>http://nexus.aiads.com/repository/maven-public</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
<pluginRepositories>
<pluginRepository>
<id>nexus</id>
<name>local private nexus</name>
<url>http://nexus.aiads.com/repository/maven-public</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>true</enabled>
</snapshots>
</pluginRepository>
</pluginRepositories>
</profile>
</profiles>
</project>