aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoraokolnychyi <anton.okolnychyi@sap.com>2018-06-04 13:28:16 -0700
committerWenchen Fan <wenchen@databricks.com>2018-06-04 13:28:16 -0700
commit7297ae04d87b6e3d48b747a7c1d53687fcc3971c (patch)
tree2e05237759475e37754eafdd0ef2d39b55fe4cf9
parent0be5aa27460f87b5627f9de16ec25b09368d205a (diff)
[SPARK-21896][SQL] Fix StackOverflow caused by window functions inside aggregate functions
## What changes were proposed in this pull request? This PR explicitly prohibits window functions inside aggregates. Currently, this will cause StackOverflow during analysis. See PR #19193 for previous discussion. ## How was this patch tested? This PR comes with a dedicated unit test. Author: aokolnychyi <anton.okolnychyi@sap.com> Closes #21473 from aokolnychyi/fix-stackoverflow-window-funcs.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala34
2 files changed, 39 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3eaa9ecf5d..f9947d1fa6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1744,10 +1744,10 @@ class Analyzer(
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
- private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
- projectList.exists(hasWindowFunction)
+ private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
+ exprs.exists(hasWindowFunction)
- private def hasWindowFunction(expr: NamedExpression): Boolean = {
+ private def hasWindowFunction(expr: Expression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
@@ -1830,6 +1830,10 @@ class Analyzer(
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)
+ case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
+ failAnalysis("It is not allowed to use a window function inside an aggregate " +
+ "function. Please use the inner window function in a sub-query.")
+
// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
// we need to extract SUM(x).
case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 96c28961e5..f495a949eb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql
import scala.util.Random
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
-import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.scalatest.Matchers.the
+
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -687,4 +687,34 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ test("SPARK-21896: Window functions inside aggregate functions") {
+ def checkWindowError(df: => DataFrame): Unit = {
+ val thrownException = the [AnalysisException] thrownBy {
+ df.queryExecution.analyzed
+ }
+ assert(thrownException.message.contains("not allowed to use a window function"))
+ }
+
+ checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a)))))
+ checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a)))))
+ checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b)))))
+ checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a)))))
+ checkWindowError(
+ testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3))
+ checkAnswer(
+ testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3),
+ Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil)
+
+ checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2"))
+ checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2"))
+ checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a"))
+ checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a"))
+ checkWindowError(
+ sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3"))
+ checkAnswer(
+ sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"),
+ Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil)
+ }
+
}