使用GraphFrame 的shortestPaths API 求最短路径

xiaoxiao2021-02-27  181

GraphFrame 的shortestPaths 是可以计算节点到节点的最短路径,但是不能计算带权重的最短路径。然后利用BFS方法和find 方法求出路径节点。

代码如下

import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.graphframes.GraphFrame; import scala.Option; import scala.collection.Map; /**  *   */ public class GraphFrameShorhPaths { public static void main( String[] args ) { SparkConf conf = new SparkConf( ).setAppName( "Short Paths" ).setMaster( "local" ); JavaSparkContext ctx = new JavaSparkContext( conf ); SQLContext sqlCtx = SQLContext.getOrCreate( ctx.sc( ) ); List<StructField> vList = new ArrayList<StructField>( ); vList.add( DataTypes.createStructField( "id", DataTypes.LongType, false ) ); vList.add( DataTypes.createStructField( "name", DataTypes.StringType, true ) ); StructType vType = DataTypes.createStructType( vList ); List<StructField> eList = new ArrayList<StructField>(); eList.add( DataTypes.createStructField( "src", DataTypes.LongType, false ) ); eList.add( DataTypes.createStructField( "dst", DataTypes.LongType, false ) ); eList.add( DataTypes.createStructField( "weight", DataTypes.DoubleType, true ) ); StructType eType = DataTypes.createStructType( eList ); JavaRDD<Row> verticeRow = ctx.parallelize( Arrays.asList(  RowFactory.create( 1L,"a" ), RowFactory.create( 2L,"b" ), RowFactory.create( 3L,"c" ), RowFactory.create( 4L,"d" ), RowFactory.create( 5L,"e" ) ) ); JavaRDD<Row> edgeRow = ctx.parallelize( Arrays.asList(  RowFactory.create( 1L,2L,10.0 ), RowFactory.create( 2L,3L,20.0 ), RowFactory.create( 2L,4L,30.0 ), RowFactory.create( 4L,5L,90.0 ), RowFactory.create( 1L,4L,15.0 )) ); GraphFrame frame = new GraphFrame( sqlCtx.createDataFrame( verticeRow, vType ), sqlCtx.createDataFrame( edgeRow, eType ) ); ArrayList<Object> lamd = new ArrayList<Object>(); lamd.addAll( Arrays.asList( 1L,2L,3L,4L,5L ) ); DataFrame shortPathData = frame.shortestPaths( ).landmarks( lamd ).run( ); List<Long> ids = BFS( frame, shortPathData, 1L, 5L ); System.out.println( ids ); ctx.stop( ); } private static int getShortPathLenght(DataFrame shortPathData, long from, long to) { Row row = shortPathData.filter( "id = " + from ).collectAsList( ).get( 0 ); Map map  = row.getMap( 2 ); Option option = map.get( to ); if (!option.isDefined( )) { return -1; } return (int)option.get( ); } private static List<Long> BFS(GraphFrame frame, DataFrame shortPathData, long  from, long to) { List<Long> retValue = new ArrayList<Long>(); int lenght = (int)getShortPathLenght( shortPathData, from, to ); if (lenght <= 0 ) { return retValue; } DataFrame pathData = frame.bfs( ).fromExpr( "id = " + from ).toExpr( "id = " + to ).maxPathLength( lenght ).run( ); long count = pathData.columns( ).length; Row row = pathData.collectAsList( ).get( 0 ); for (int i=0; i<count; i=i+2) { retValue.add( ((GenericRow)row.getAs( i )).getLong( 0 ) ); } return retValue; } }

转载请注明原文地址: https://www.6miu.com/read-12572.html

最新回复(0)