aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjinxing <jinxing6042@126.com>2018-06-05 11:32:42 -0700
committerWenchen Fan <wenchen@databricks.com>2018-06-05 11:32:42 -0700
commit93df3cd03503fca7745141fbd2676b8bf70fe92f (patch)
tree185233b104d48b37cfacd839ad8748748d0883e1
parent2c2a86b5d5be6f77ee72d16f990b39ae59f479b9 (diff)
[SPARK-22384][SQL] Refine partition pruning when attribute is wrapped in Cast
## What changes were proposed in this pull request? Sql below will get all partitions from metastore, which put much burden on metastore; ``` CREATE TABLE `partition_test`(`col` int) PARTITIONED BY (`pt` byte) SELECT * FROM partition_test WHERE CAST(pt AS INT)=1 ``` The reason is that the the analyzed attribute `dt` is wrapped in `Cast` and `HiveShim` fails to generate a proper partition filter. This pr proposes to take `Cast` into consideration when generate partition filter. ## How was this patch tested? Test added. This pr proposes to use analyzed expressions in `HiveClientSuite` Author: jinxing <jinxing6042@126.com> Closes #19602 from jinxing64/SPARK-22384.
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala23
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala102
2 files changed, 86 insertions, 39 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 948ba542b5..130e258e78 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
-import scala.util.Try
import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
@@ -657,17 +656,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled
+ object ExtractAttribute {
+ def unapply(expr: Expression): Option[Attribute] = {
+ expr match {
+ case attr: Attribute => Some(attr)
+ case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child)
+ case _ => None
+ }
+ }
+ }
+
def convert(expr: Expression): Option[String] = expr match {
- case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced =>
+ case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values))
+ if useAdvanced =>
Some(convertInToOr(name, values))
- case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced =>
+ case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values))
+ if useAdvanced =>
Some(convertInToOr(name, values))
- case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) =>
+ case op @ SpecialBinaryComparison(
+ ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) =>
Some(s"$name ${op.symbol} $value")
- case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) =>
+ case op @ SpecialBinaryComparison(
+ ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) =>
Some(s"$value ${op.symbol} $name")
case And(expr1, expr2) if useAdvanced =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala
index f991352b20..55275f6b37 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala
@@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet}
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.LongType
// TODO: Refactor this to `HivePartitionFilteringSuite`
class HiveClientSuite(version: String)
extends HiveVersionSuite(version) with BeforeAndAfterAll {
- import CatalystSqlParser._
private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname
@@ -46,8 +46,7 @@ class HiveClientSuite(version: String)
val hadoopConf = new Configuration()
hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql)
val client = buildClient(hadoopConf)
- client
- .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)")
+ client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)")
val partitions =
for {
@@ -66,6 +65,15 @@ class HiveClientSuite(version: String)
client
}
+ private def attr(name: String): Attribute = {
+ client.getTable("default", "test").partitionSchema.fields
+ .find(field => field.name.equals(name)) match {
+ case Some(field) => AttributeReference(field.name, field.dataType)()
+ case None =>
+ fail(s"Illegal name of partition attribute: $name")
+ }
+ }
+
override def beforeAll() {
super.beforeAll()
client = init(true)
@@ -74,7 +82,7 @@ class HiveClientSuite(version: String)
test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") {
val client = init(false)
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
- Seq(parseExpression("ds=20170101")))
+ Seq(attr("ds") === 20170101))
assert(filteredPartitions.size == testPartitionCount)
}
@@ -82,7 +90,7 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: ds<=>20170101") {
// Should return all partitions where <=> is not supported
testMetastorePartitionFiltering(
- "ds<=>20170101",
+ attr("ds") <=> 20170101,
20170101 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -90,7 +98,7 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: ds=20170101") {
testMetastorePartitionFiltering(
- "ds=20170101",
+ attr("ds") === 20170101,
20170101 to 20170101,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -100,7 +108,7 @@ class HiveClientSuite(version: String)
// Should return all partitions where h=0 because getPartitionsByFilter does not support
// comparisons to non-literal values
testMetastorePartitionFiltering(
- "ds=(20170101 + 1) and h=0",
+ attr("ds") === (Literal(20170101) + 1) && attr("h") === 0,
20170101 to 20170103,
0 to 0,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -108,7 +116,7 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: chunk='aa'") {
testMetastorePartitionFiltering(
- "chunk='aa'",
+ attr("chunk") === "aa",
20170101 to 20170103,
0 to 23,
"aa" :: Nil)
@@ -116,7 +124,7 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: 20170101=ds") {
testMetastorePartitionFiltering(
- "20170101=ds",
+ Literal(20170101) === attr("ds"),
20170101 to 20170101,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -124,7 +132,15 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: ds=20170101 and h=10") {
testMetastorePartitionFiltering(
- "ds=20170101 and h=10",
+ attr("ds") === 20170101 && attr("h") === 10,
+ 20170101 to 20170101,
+ 10 to 10,
+ "aa" :: "ab" :: "ba" :: "bb" :: Nil)
+ }
+
+ test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") {
+ testMetastorePartitionFiltering(
+ attr("ds").cast(LongType) === 20170101L && attr("h") === 10,
20170101 to 20170101,
10 to 10,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -132,7 +148,7 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: ds=20170101 or ds=20170102") {
testMetastorePartitionFiltering(
- "ds=20170101 or ds=20170102",
+ attr("ds") === 20170101 || attr("ds") === 20170102,
20170101 to 20170102,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -140,7 +156,15 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") {
testMetastorePartitionFiltering(
- "ds in (20170102, 20170103)",
+ attr("ds").in(20170102, 20170103),
+ 20170102 to 20170103,
+ 0 to 23,
+ "aa" :: "ab" :: "ba" :: "bb" :: Nil)
+ }
+
+ test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") {
+ testMetastorePartitionFiltering(
+ attr("ds").cast(LongType).in(20170102L, 20170103L),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -148,7 +172,19 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") {
testMetastorePartitionFiltering(
- "ds in (20170102, 20170103)",
+ attr("ds").in(20170102, 20170103),
+ 20170102 to 20170103,
+ 0 to 23,
+ "aa" :: "ab" :: "ba" :: "bb" :: Nil, {
+ case expr @ In(v, list) if expr.inSetConvertible =>
+ InSet(v, list.map(_.eval(EmptyRow)).toSet)
+ })
+ }
+
+ test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)")
+ {
+ testMetastorePartitionFiltering(
+ attr("ds").cast(LongType).in(20170102L, 20170103L),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil, {
@@ -159,7 +195,7 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") {
testMetastorePartitionFiltering(
- "chunk in ('ab', 'ba')",
+ attr("chunk").in("ab", "ba"),
20170101 to 20170103,
0 to 23,
"ab" :: "ba" :: Nil)
@@ -167,7 +203,7 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") {
testMetastorePartitionFiltering(
- "chunk in ('ab', 'ba')",
+ attr("chunk").in("ab", "ba"),
20170101 to 20170103,
0 to 23,
"ab" :: "ba" :: Nil, {
@@ -179,26 +215,24 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb"))
val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb"))
- testMetastorePartitionFiltering(
- "(ds=20170101 and h>=8) or (ds=20170102 and h<8)",
- day1 :: day2 :: Nil)
+ testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) ||
+ (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil)
}
test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb"))
// Day 2 should include all hours because we can't build a filter for h<(7+1)
val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb"))
- testMetastorePartitionFiltering(
- "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))",
- day1 :: day2 :: Nil)
+ testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) ||
+ (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil)
}
test("getPartitionsByFilter: " +
"chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba"))
val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba"))
- testMetastorePartitionFiltering(
- "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))",
+ testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") &&
+ ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)),
day1 :: day2 :: Nil)
}
@@ -207,41 +241,41 @@ class HiveClientSuite(version: String)
}
private def testMetastorePartitionFiltering(
- filterString: String,
+ filterExpr: Expression,
expectedDs: Seq[Int],
expectedH: Seq[Int],
expectedChunks: Seq[String]): Unit = {
testMetastorePartitionFiltering(
- filterString,
+ filterExpr,
(expectedDs, expectedH, expectedChunks) :: Nil,
identity)
}
private def testMetastorePartitionFiltering(
- filterString: String,
+ filterExpr: Expression,
expectedDs: Seq[Int],
expectedH: Seq[Int],
expectedChunks: Seq[String],
transform: Expression => Expression): Unit = {
testMetastorePartitionFiltering(
- filterString,
+ filterExpr,
(expectedDs, expectedH, expectedChunks) :: Nil,
- identity)
+ transform)
}
private def testMetastorePartitionFiltering(
- filterString: String,
+ filterExpr: Expression,
expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = {
- testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity)
+ testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity)
}
private def testMetastorePartitionFiltering(
- filterString: String,
+ filterExpr: Expression,
expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])],
transform: Expression => Expression): Unit = {
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
Seq(
- transform(parseExpression(filterString))
+ transform(filterExpr)
))
val expectedPartitionCount = expectedPartitionCubes.map {