aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWeichenXu <weichen.xu@databricks.com>2018-06-04 21:24:35 -0700
committerXiangrui Meng <meng@databricks.com>2018-06-04 21:24:35 -0700
commite8c1a0c2fdb09a628d9cc925676af870d5a7a946 (patch)
tree59c77bf7481466ee6007deb7d4ce75b928a5de37
parentb3417b731d4e323398a0d7ec6e86405f4464f4f9 (diff)
[SPARK-15784] Add Power Iteration Clustering to spark.ml
## What changes were proposed in this pull request? According to the discussion on JIRA. I rewrite the Power Iteration Clustering API in `spark.ml`. ## How was this patch tested? Unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu <weichen.xu@databricks.com> Closes #21493 from WeichenXu123/pic_api.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala157
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala179
2 files changed, 125 insertions, 211 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index 2c30a1d9aa..1b9a349994 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -18,21 +18,20 @@
package org.apache.spark.ml.clustering
import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._
/**
* Common params for PowerIterationClustering
*/
private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter
- with HasPredictionCol {
+ with HasWeightCol {
/**
* The number of clusters to create (k). Must be &gt; 1. Default: 2.
@@ -66,62 +65,33 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
def getInitMode: String = $(initMode)
/**
- * Param for the name of the input column for vertex IDs.
- * Default: "id"
+ * Param for the name of the input column for source vertex IDs.
+ * Default: "src"
* @group param
*/
@Since("2.4.0")
- val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.",
+ val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.",
(value: String) => value.nonEmpty)
- setDefault(idCol, "id")
-
- /** @group getParam */
- @Since("2.4.0")
- def getIdCol: String = getOrDefault(idCol)
-
- /**
- * Param for the name of the input column for neighbors in the adjacency list representation.
- * Default: "neighbors"
- * @group param
- */
- @Since("2.4.0")
- val neighborsCol = new Param[String](this, "neighborsCol",
- "Name of the input column for neighbors in the adjacency list representation.",
- (value: String) => value.nonEmpty)
-
- setDefault(neighborsCol, "neighbors")
-
/** @group getParam */
@Since("2.4.0")
- def getNeighborsCol: String = $(neighborsCol)
+ def getSrcCol: String = getOrDefault(srcCol)
/**
- * Param for the name of the input column for neighbors in the adjacency list representation.
- * Default: "similarities"
+ * Name of the input column for destination vertex IDs.
+ * Default: "dst"
* @group param
*/
@Since("2.4.0")
- val similaritiesCol = new Param[String](this, "similaritiesCol",
- "Name of the input column for neighbors in the adjacency list representation.",
+ val dstCol = new Param[String](this, "dstCol",
+ "Name of the input column for destination vertex IDs.",
(value: String) => value.nonEmpty)
- setDefault(similaritiesCol, "similarities")
-
/** @group getParam */
@Since("2.4.0")
- def getSimilaritiesCol: String = $(similaritiesCol)
+ def getDstCol: String = $(dstCol)
- protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType))
- SchemaUtils.checkColumnTypes(schema, $(neighborsCol),
- Seq(ArrayType(IntegerType, containsNull = false),
- ArrayType(LongType, containsNull = false)))
- SchemaUtils.checkColumnTypes(schema, $(similaritiesCol),
- Seq(ArrayType(FloatType, containsNull = false),
- ArrayType(DoubleType, containsNull = false)))
- SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
- }
+ setDefault(srcCol -> "src", dstCol -> "dst")
}
/**
@@ -131,21 +101,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
* PIC finds a very low-dimensional embedding of a dataset using truncated power
* iteration on a normalized pair-wise similarity matrix of the data.
*
- * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix
- * is a symmetric matrix whose entries are non-negative similarities between items.
- * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes:
- * - `idCol`: vertex ID
- * - `neighborsCol`: neighbors of vertex in `idCol`
- * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex
- * in `idCol` and each neighbor in `neighborsCol`
- * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol`
- * containing the cluster assignment in `[0,k)` for each row (vertex).
- *
- * Notes:
- * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation.
- * Transform runs the iterative PIC algorithm to cluster the whole input dataset.
- * - Input validation: This validates that similarities are non-negative but does NOT validate
- * that the input matrix is symmetric.
+ * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the
+ * PowerIterationClustering algorithm.
*
* @see <a href=http://en.wikipedia.org/wiki/Spectral_clustering>
* Spectral clustering (Wikipedia)</a>
@@ -154,7 +111,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
@Experimental
class PowerIterationClustering private[clustering] (
@Since("2.4.0") override val uid: String)
- extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable {
+ extends PowerIterationClusteringParams with DefaultParamsWritable {
setDefault(
k -> 2,
@@ -166,10 +123,6 @@ class PowerIterationClustering private[clustering] (
/** @group setParam */
@Since("2.4.0")
- def setPredictionCol(value: String): this.type = set(predictionCol, value)
-
- /** @group setParam */
- @Since("2.4.0")
def setK(value: Int): this.type = set(k, value)
/** @group expertSetParam */
@@ -182,66 +135,56 @@ class PowerIterationClustering private[clustering] (
/** @group setParam */
@Since("2.4.0")
- def setIdCol(value: String): this.type = set(idCol, value)
+ def setSrcCol(value: String): this.type = set(srcCol, value)
/** @group setParam */
@Since("2.4.0")
- def setNeighborsCol(value: String): this.type = set(neighborsCol, value)
+ def setDstCol(value: String): this.type = set(dstCol, value)
/** @group setParam */
@Since("2.4.0")
- def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value)
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+ /**
+ * Run the PIC algorithm and returns a cluster assignment for each input vertex.
+ *
+ * @param dataset A dataset with columns src, dst, weight representing the affinity matrix,
+ * which is the matrix A in the PIC paper. Suppose the src column value is i,
+ * the dst column value is j, the weight column value is similarity s,,ij,,
+ * which must be nonnegative. This is a symmetric matrix and hence
+ * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be
+ * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are
+ * ignored, because we assume s,,ij,, = 0.0.
+ *
+ * @return A dataset that contains columns of vertex id and the corresponding cluster for the id.
+ * The schema of it will be:
+ * - id: Long
+ * - cluster: Int
+ */
@Since("2.4.0")
- override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ def assignClusters(dataset: Dataset[_]): DataFrame = {
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
+ lit(1.0)
+ } else {
+ col($(weightCol)).cast(DoubleType)
+ }
- val sparkSession = dataset.sparkSession
- val idColValue = $(idCol)
- val rdd: RDD[(Long, Long, Double)] =
- dataset.select(
- col($(idCol)).cast(LongType),
- col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)),
- col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false))
- ).rdd.flatMap {
- case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) =>
- require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " +
- s"equal to the the length of the neighbor similarity list. Row for ID " +
- s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " +
- s"of length ${sims.length}.")
- nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map {
- case (nbr, similarity) => (id, nbr, similarity)
- }
- }
+ SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType))
+ SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType))
+ val rdd: RDD[(Long, Long, Double)] = dataset.select(
+ col($(srcCol)).cast(LongType),
+ col($(dstCol)).cast(LongType),
+ w).rdd.map {
+ case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight)
+ }
val algorithm = new MLlibPowerIterationClustering()
.setK($(k))
.setInitializationMode($(initMode))
.setMaxIterations($(maxIter))
val model = algorithm.run(rdd)
- val predictionsRDD: RDD[Row] = model.assignments.map { assignment =>
- Row(assignment.id, assignment.cluster)
- }
-
- val predictionsSchema = StructType(Seq(
- StructField($(idCol), LongType, nullable = false),
- StructField($(predictionCol), IntegerType, nullable = false)))
- val predictions = {
- val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema)
- dataset.schema($(idCol)).dataType match {
- case _: LongType =>
- uncastPredictions
- case otherType =>
- uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol)))
- }
- }
-
- dataset.join(predictions, $(idCol))
- }
-
- @Since("2.4.0")
- override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ import dataset.sparkSession.implicits._
+ model.assignments.toDF
}
@Since("2.4.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
index 65328df17b..b7072728d4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
@@ -17,19 +17,19 @@
package org.apache.spark.ml.clustering
-import scala.collection.mutable
-
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._
class PowerIterationClusteringSuite extends SparkFunSuite
with MLlibTestSparkContext with DefaultReadWriteTest {
+ import testImplicits._
+
@transient var data: Dataset[_] = _
final val r1 = 1.0
final val n1 = 10
@@ -48,10 +48,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite
assert(pic.getK === 2)
assert(pic.getMaxIter === 20)
assert(pic.getInitMode === "random")
- assert(pic.getPredictionCol === "prediction")
- assert(pic.getIdCol === "id")
- assert(pic.getNeighborsCol === "neighbors")
- assert(pic.getSimilaritiesCol === "similarities")
+ assert(pic.getSrcCol === "src")
+ assert(pic.getDstCol === "dst")
+ assert(!pic.isDefined(pic.weightCol))
}
test("parameter validation") {
@@ -62,125 +61,102 @@ class PowerIterationClusteringSuite extends SparkFunSuite
new PowerIterationClustering().setInitMode("no_such_a_mode")
}
intercept[IllegalArgumentException] {
- new PowerIterationClustering().setIdCol("")
+ new PowerIterationClustering().setSrcCol("")
}
intercept[IllegalArgumentException] {
- new PowerIterationClustering().setNeighborsCol("")
- }
- intercept[IllegalArgumentException] {
- new PowerIterationClustering().setSimilaritiesCol("")
+ new PowerIterationClustering().setDstCol("")
}
}
test("power iteration clustering") {
val n = n1 + n2
- val model = new PowerIterationClustering()
+ val assignments = new PowerIterationClustering()
.setK(2)
.setMaxIter(40)
- val result = model.transform(data)
-
- val predictions = Array.fill(2)(mutable.Set.empty[Long])
- result.select("id", "prediction").collect().foreach {
- case Row(id: Long, cluster: Integer) => predictions(cluster) += id
- }
- assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet))
-
- val result2 = new PowerIterationClustering()
+ .setWeightCol("weight")
+ .assignClusters(data)
+ val localAssignments = assignments
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+ val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++
+ (n1 until n).map(x => (x, 0)).toSet
+ assert(localAssignments === expectedResult)
+
+ val assignments2 = new PowerIterationClustering()
.setK(2)
.setMaxIter(10)
.setInitMode("degree")
- .transform(data)
- val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
- result2.select("id", "prediction").collect().foreach {
- case Row(id: Long, cluster: Integer) => predictions2(cluster) += id
- }
- assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet))
+ .setWeightCol("weight")
+ .assignClusters(data)
+ val localAssignments2 = assignments2
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+ assert(localAssignments2 === expectedResult)
}
test("supported input types") {
- val model = new PowerIterationClustering()
+ val pic = new PowerIterationClustering()
.setK(2)
.setMaxIter(1)
+ .setWeightCol("weight")
- def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = {
+ def runTest(srcType: DataType, dstType: DataType, weightType: DataType): Unit = {
val typedData = data.select(
- col("id").cast(idType).alias("id"),
- col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"),
- col("similarities").cast(ArrayType(similarityType, containsNull = false))
- .alias("similarities")
+ col("src").cast(srcType).alias("src"),
+ col("dst").cast(dstType).alias("dst"),
+ col("weight").cast(weightType).alias("weight")
)
- model.transform(typedData).collect()
- }
-
- for (idType <- Seq(IntegerType, LongType)) {
- runTest(idType, LongType, DoubleType)
- }
- for (neighborType <- Seq(IntegerType, LongType)) {
- runTest(LongType, neighborType, DoubleType)
- }
- for (similarityType <- Seq(FloatType, DoubleType)) {
- runTest(LongType, LongType, similarityType)
+ pic.assignClusters(typedData).collect()
}
- }
- test("invalid input: wrong types") {
- val model = new PowerIterationClustering()
- .setK(2)
- .setMaxIter(1)
- intercept[IllegalArgumentException] {
- val typedData = data.select(
- col("id").cast(DoubleType).alias("id"),
- col("neighbors"),
- col("similarities")
- )
- model.transform(typedData)
+ for (srcType <- Seq(IntegerType, LongType)) {
+ runTest(srcType, LongType, DoubleType)
}
- intercept[IllegalArgumentException] {
- val typedData = data.select(
- col("id"),
- col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"),
- col("similarities")
- )
- model.transform(typedData)
+ for (dstType <- Seq(IntegerType, LongType)) {
+ runTest(LongType, dstType, DoubleType)
}
- intercept[IllegalArgumentException] {
- val typedData = data.select(
- col("id"),
- col("neighbors"),
- col("neighbors").alias("similarities")
- )
- model.transform(typedData)
+ for (weightType <- Seq(FloatType, DoubleType)) {
+ runTest(LongType, LongType, weightType)
}
}
test("invalid input: negative similarity") {
- val model = new PowerIterationClustering()
+ val pic = new PowerIterationClustering()
.setMaxIter(1)
+ .setWeightCol("weight")
val badData = spark.createDataFrame(Seq(
- (0, Array(1), Array(-1.0)),
- (1, Array(0), Array(-1.0))
- )).toDF("id", "neighbors", "similarities")
+ (0, 1, -1.0),
+ (1, 0, -1.0)
+ )).toDF("src", "dst", "weight")
val msg = intercept[SparkException] {
- model.transform(badData)
+ pic.assignClusters(badData)
}.getCause.getMessage
assert(msg.contains("Similarity must be nonnegative"))
}
- test("invalid input: mismatched lengths for neighbor and similarity arrays") {
- val model = new PowerIterationClustering()
- .setMaxIter(1)
- val badData = spark.createDataFrame(Seq(
- (0, Array(1), Array(0.5)),
- (1, Array(0, 2), Array(0.5)),
- (2, Array(1), Array(0.5))
- )).toDF("id", "neighbors", "similarities")
- val msg = intercept[SparkException] {
- model.transform(badData)
- }.getCause.getMessage
- assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " +
- "the neighbor similarity list."))
- assert(msg.contains(s"Row for ID ${model.getIdCol}=1"))
+ test("test default weight") {
+ val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst)
+
+ val assignments = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(40)
+ .assignClusters(dataWithoutWeight)
+ val localAssignments = assignments
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+
+ val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0))
+
+ val assignments2 = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(40)
+ .assignClusters(dataWithWeightOne)
+ val localAssignments2 = assignments2
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+
+ assert(localAssignments === localAssignments2)
}
test("read/write") {
@@ -188,10 +164,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite
.setK(4)
.setMaxIter(100)
.setInitMode("degree")
- .setIdCol("test_id")
- .setNeighborsCol("myNeighborsCol")
- .setSimilaritiesCol("mySimilaritiesCol")
- .setPredictionCol("test_prediction")
+ .setSrcCol("src1")
+ .setDstCol("dst1")
+ .setWeightCol("weight")
testDefaultReadWrite(t)
}
}
@@ -222,17 +197,13 @@ object PowerIterationClusteringSuite {
val n = n1 + n2
val points = genCircle(r1, n1) ++ genCircle(r2, n2)
- val rows = for (i <- 1 until n) yield {
- val neighbors = for (j <- 0 until i) yield {
- j.toLong
+ val rows = (for (i <- 1 until n) yield {
+ for (j <- 0 until i) yield {
+ (i.toLong, j.toLong, sim(points(i), points(j)))
}
- val similarities = for (j <- 0 until i) yield {
- sim(points(i), points(j))
- }
- (i.toLong, neighbors.toArray, similarities.toArray)
- }
+ }).flatMap(_.iterator)
- spark.createDataFrame(rows).toDF("id", "neighbors", "similarities")
+ spark.createDataFrame(rows).toDF("src", "dst", "weight")
}
}