aboutsummaryrefslogtreecommitdiff
path: root/exec/java-exec/src
diff options
context:
space:
mode:
authorAman Sinha <asinha@maprtech.com>2014-08-27 13:32:03 -0700
committerJacques Nadeau <jacques@apache.org>2014-08-31 10:27:26 -0700
commita0d3906b8ed6dc598ec23b55ca9542180111e910 (patch)
treeaa54b727196d2cf7dd352adf3c0bde0af9cab077 /exec/java-exec/src
parentb9e384dc9a8889e5cb1a9ac05b48c6b99096fe73 (diff)
DRILL-1171: Create Drill's implementation of ReduceAggregatesRule, including a new CastHigh function.
DRILL-1342: Fix nullability handling of aggregate functions.
Diffstat (limited to 'exec/java-exec/src')
-rw-r--r--exec/java-exec/src/main/codegen/config.fmpp3
-rw-r--r--exec/java-exec/src/main/codegen/data/AggrTypes1.tdd30
-rw-r--r--exec/java-exec/src/main/codegen/data/AggrTypes2.tdd8
-rw-r--r--exec/java-exec/src/main/codegen/data/AggrTypes3.tdd80
-rw-r--r--exec/java-exec/src/main/codegen/data/CastHigh.tdd28
-rw-r--r--exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd2
-rw-r--r--exec/java-exec/src/main/codegen/data/SumZero.tdd20
-rw-r--r--exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java33
-rw-r--r--exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java24
-rw-r--r--exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java33
-rw-r--r--exec/java-exec/src/main/codegen/templates/CastHigh.java64
-rw-r--r--exec/java-exec/src/main/codegen/templates/SumZeroAggr.java18
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java19
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java602
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java4
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java4
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java3
17 files changed, 873 insertions, 102 deletions
diff --git a/exec/java-exec/src/main/codegen/config.fmpp b/exec/java-exec/src/main/codegen/config.fmpp
index ff6135dff..aff6240c0 100644
--- a/exec/java-exec/src/main/codegen/config.fmpp
+++ b/exec/java-exec/src/main/codegen/config.fmpp
@@ -36,7 +36,8 @@ data: {
intervalNumericTypes: tdd(../data/IntervalNumericTypes.tdd),
extract: tdd(../data/ExtractTypes.tdd),
sumzero: tdd(../data/SumZero.tdd),
- numericTypes: tdd(../data/NumericTypes.tdd)
+ numericTypes: tdd(../data/NumericTypes.tdd),
+ casthigh: tdd(../data/CastHigh.tdd)
}
freemarkerLinks: {
includes: includes/
diff --git a/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd b/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd
index 1bac07ee3..812c289e6 100644
--- a/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd
+++ b/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd
@@ -20,13 +20,13 @@
{inputType: "Bit", outputType: "Bit", runningType: "Bit", major: "Numeric"},
{inputType: "Int", outputType: "Int", runningType: "Int", major: "Numeric"},
{inputType: "BigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
- {inputType: "NullableBit", outputType: "Bit", runningType: "Bit", major: "Numeric"},
- {inputType: "NullableInt", outputType: "Int", runningType: "Int", major: "Numeric"},
- {inputType: "NullableBigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"},
+ {inputType: "NullableInt", outputType: "NullableInt", runningType: "Int", major: "Numeric"},
+ {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"},
{inputType: "Float4", outputType: "Float4", runningType: "Float4", major: "Numeric"},
{inputType: "Float8", outputType: "Float8", runningType: "Float8", major: "Numeric"},
- {inputType: "NullableFloat4", outputType: "Float4", runningType: "Float4", major: "Numeric"},
- {inputType: "NullableFloat8", outputType: "Float8", runningType: "Float8", major: "Numeric"},
+ {inputType: "NullableFloat4", outputType: "NullableFloat4", runningType: "Float4", major: "Numeric"},
+ {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"},
{inputType: "Date", outputType: "Date", runningType: "Date", major: "Date", initialValue: "Long.MAX_VALUE"},
{inputType: "NullableDate", outputType: "Date", runningType: "Date", major: "Date", initialValue: "Long.MAX_VALUE"},
{inputType: "TimeStamp", outputType: "TimeStamp", runningType: "TimeStamp", major: "Date", initialValue: "Long.MAX_VALUE"},
@@ -51,13 +51,13 @@
{inputType: "Bit", outputType: "Bit", runningType: "Bit", major: "Numeric"},
{inputType: "Int", outputType: "Int", runningType: "Int", major: "Numeric"},
{inputType: "BigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
- {inputType: "NullableBit", outputType: "Bit", runningType: "Bit", major: "Numeric"},
- {inputType: "NullableInt", outputType: "Int", runningType: "Int", major: "Numeric"},
- {inputType: "NullableBigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"},
+ {inputType: "NullableInt", outputType: "NullableInt", runningType: "Int", major: "Numeric"},
+ {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"},
{inputType: "Float4", outputType: "Float4", runningType: "Float4", major: "Numeric"},
{inputType: "Float8", outputType: "Float8", runningType: "Float8", major: "Numeric"},
- {inputType: "NullableFloat4", outputType: "Float4", runningType: "Float4", major: "Numeric"},
- {inputType: "NullableFloat8", outputType: "Float8", runningType: "Float8", major: "Numeric"},
+ {inputType: "NullableFloat4", outputType: "NullableFloat4", runningType: "Float4", major: "Numeric"},
+ {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"},
{inputType: "Date", outputType: "Date", runningType: "Date", major: "Date", initialValue: "Long.MIN_VALUE"},
{inputType: "NullableDate", outputType: "Date", runningType: "Date", major: "Date", initialValue: "Long.MIN_VALUE"},
{inputType: "TimeStamp", outputType: "TimeStamp", runningType: "TimeStamp", major: "Date", initialValue: "Long.MIN_VALUE"},
@@ -82,13 +82,13 @@
{inputType: "Bit", outputType: "Bit", runningType: "Bit", major: "Numeric"},
{inputType: "Int", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
{inputType: "BigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
- {inputType: "NullableBit", outputType: "Bit", runningType: "Bit", major: "Numeric"},
- {inputType: "NullableInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
- {inputType: "NullableBigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"},
+ {inputType: "NullableInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"},
{inputType: "Float4", outputType: "Float8", runningType: "Float8", major: "Numeric"},
{inputType: "Float8", outputType: "Float8", runningType: "Float8", major: "Numeric"},
- {inputType: "NullableFloat4", outputType: "Float8", runningType: "Float8", major: "Numeric"},
- {inputType: "NullableFloat8", outputType: "Float8", runningType: "Float8", major: "Numeric"},
+ {inputType: "NullableFloat4", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"},
+ {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"},
{inputType: "IntervalDay", outputType: "IntervalDay", runningType: "IntervalDay", major: "Date", initialValue: "0"},
{inputType: "NullableIntervalDay", outputType: "IntervalDay", runningType: "IntervalDay", major: "Date", initialValue: "0"},
{inputType: "IntervalYear", outputType: "IntervalYear", runningType: "IntervalYear", major: "Date", initialValue: "0"},
diff --git a/exec/java-exec/src/main/codegen/data/AggrTypes2.tdd b/exec/java-exec/src/main/codegen/data/AggrTypes2.tdd
index c6655afd2..ee64dafa8 100644
--- a/exec/java-exec/src/main/codegen/data/AggrTypes2.tdd
+++ b/exec/java-exec/src/main/codegen/data/AggrTypes2.tdd
@@ -19,12 +19,12 @@
{className: "Avg", funcName: "avg", types: [
{inputType: "Int", outputType: "Float8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"},
{inputType: "BigInt", outputType: "Float8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"},
- {inputType: "NullableInt", outputType: "Float8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"},
- {inputType: "NullableBigInt", outputType: "Float8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableInt", outputType: "NullableFloat8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableBigInt", outputType: "NullableFloat8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"},
{inputType: "Float4", outputType: "Float8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"},
{inputType: "Float8", outputType: "Float8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"},
- {inputType: "NullableFloat4", outputType: "Float8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"},
- {inputType: "NullableFloat8", outputType: "Float8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableFloat4", outputType: "NullableFloat8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableFloat8", outputType: "NullableFloat8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"},
{inputType: "IntervalDay", outputType: "Interval", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Date"},
{inputType: "NullableIntervalDay", outputType: "Interval", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Date"},
{inputType: "IntervalYear", outputType: "Interval", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Date"},
diff --git a/exec/java-exec/src/main/codegen/data/AggrTypes3.tdd b/exec/java-exec/src/main/codegen/data/AggrTypes3.tdd
index 05acc0327..0c3a3588f 100644
--- a/exec/java-exec/src/main/codegen/data/AggrTypes3.tdd
+++ b/exec/java-exec/src/main/codegen/data/AggrTypes3.tdd
@@ -18,97 +18,97 @@
aggrtypes: [
{className: "StdDevPop", funcName: "stddev_pop", aliasName: "", types: [
{inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableBigInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableSmallInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableTinyInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt1", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt2", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableFloat4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}
+ {inputType: "NullableFloat8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}
]
},
{className: "VariancePop", funcName: "var_pop", aliasName: "", types: [
{inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableBigInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableSmallInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableTinyInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt1", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt2", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableFloat4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}
+ {inputType: "NullableFloat8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}
]
},
{className: "StdDevSample", funcName: "stddev_samp", aliasName: "stddev", types: [
{inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableBigInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableSmallInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableTinyInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt1", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt2", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableFloat4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}
+ {inputType: "NullableFloat8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}
]
},
{className: "VarianceSample", funcName: "var_samp", aliasName: "variance", types: [
{inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableBigInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableSmallInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableTinyInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt1", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt2", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableUInt8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
+ {inputType: "NullableFloat4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
{inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"},
- {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}
+ {inputType: "NullableFloat8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}
]
}
]
diff --git a/exec/java-exec/src/main/codegen/data/CastHigh.tdd b/exec/java-exec/src/main/codegen/data/CastHigh.tdd
new file mode 100644
index 000000000..54c337ddf
--- /dev/null
+++ b/exec/java-exec/src/main/codegen/data/CastHigh.tdd
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http:# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+{
+ types: [
+ {value: true, from: "Int", to: "Float8" },
+ {value: true, from: "BigInt", to: "Float8" },
+ {value: true, from: "Float4", to: "Float8" },
+ {value: true, from: "Float8", to: "Float8" },
+ {value: false, from: "Decimal9"},
+ {value: false, from: "Decimal18"},
+ {value: false, from: "Decimal28Sparse"},
+ {value: false, from: "Decimal38Sparse"},
+ ]
+}
diff --git a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd
index ed8d1334b..5aa8b7fba 100644
--- a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd
+++ b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd
@@ -28,4 +28,4 @@
]
}
]
-} \ No newline at end of file
+}
diff --git a/exec/java-exec/src/main/codegen/data/SumZero.tdd b/exec/java-exec/src/main/codegen/data/SumZero.tdd
index b270d928d..c6532b4fd 100644
--- a/exec/java-exec/src/main/codegen/data/SumZero.tdd
+++ b/exec/java-exec/src/main/codegen/data/SumZero.tdd
@@ -16,16 +16,16 @@
{
types: [
- {inputType: "Bit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"},
- {inputType: "Int", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"},
- {inputType: "BigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"},
- {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"},
- {inputType: "NullableInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"},
- {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"},
- {inputType: "Float4", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"},
- {inputType: "Float8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"},
- {inputType: "NullableFloat4", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"},
- {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}
+ {inputType: "Bit", outputType: "Bit", runningType: "Bit", major: "Numeric"},
+ {inputType: "Int", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
+ {inputType: "BigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableBit", outputType: "Bit", runningType: "Bit", major: "Numeric"},
+ {inputType: "NullableInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableBigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
+ {inputType: "Float4", outputType: "Float8", runningType: "Float8", major: "Numeric"},
+ {inputType: "Float8", outputType: "Float8", runningType: "Float8", major: "Numeric"},
+ {inputType: "NullableFloat4", outputType: "Float8", runningType: "Float8", major: "Numeric"},
+ {inputType: "NullableFloat8", outputType: "Float8", runningType: "Float8", major: "Numeric"}
]
}
\ No newline at end of file
diff --git a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java
index aa9aeab7a..e19def360 100644
--- a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java
+++ b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java
@@ -55,10 +55,17 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
@Param ${type.inputType}Holder in;
@Workspace ${type.runningType}Holder value;
+ <#if type.inputType?starts_with("Nullable") && type.outputType?starts_with("Nullable")>
+ @Workspace BigIntHolder nonNullCount;
+ </#if>
@Output ${type.outputType}Holder out;
public void setup(RecordBatch b) {
- value = new ${type.runningType}Holder();
+ value = new ${type.runningType}Holder();
+ <#if type.inputType?starts_with("Nullable") && type.outputType?starts_with("Nullable")>
+ nonNullCount = new BigIntHolder();
+ nonNullCount.value = 0;
+ </#if>
<#if aggrtype.funcName == "sum" || aggrtype.funcName == "count">
value.value = 0;
<#elseif aggrtype.funcName == "min">
@@ -96,7 +103,12 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
if (in.isSet == 0) {
// processing nullable input and the value is null, so don't do anything...
break sout;
- }
+ }
+ <#if type.outputType?starts_with("Nullable")>
+ else {
+ nonNullCount.value++;
+ }
+ </#if>
</#if>
<#if aggrtype.funcName == "min">
value.value = Math.min(value.value, in.value);
@@ -115,13 +127,24 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
}
@Override
- public void output() {
- out.value = value.value;
+ public void output() {
+ <#if type.inputType?starts_with("Nullable") && type.outputType?starts_with("Nullable")>
+ if (nonNullCount.value > 0) {
+ out.value = value.value;
+ out.isSet = 1;
+ } else {
+ out.isSet = 0;
+ }
+ <#else>
+ out.value = value.value;
+ </#if>
}
@Override
public void reset() {
-
+ <#if type.inputType?starts_with("Nullable") && type.outputType?starts_with("Nullable")>
+ nonNullCount.value = 0;
+ </#if>
<#if aggrtype.funcName == "sum" || aggrtype.funcName == "count">
value.value = 0;
<#elseif aggrtype.funcName == "min">
diff --git a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java
index 8606dd1f2..fda14571b 100644
--- a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java
+++ b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java
@@ -58,11 +58,16 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
@Param ${type.inputType}Holder in;
@Workspace ${type.sumRunningType}Holder sum;
@Workspace ${type.countRunningType}Holder count;
+ @Workspace BigIntHolder nonNullCount;
@Output ${type.outputType}Holder out;
public void setup(RecordBatch b) {
sum = new ${type.sumRunningType}Holder();
count = new ${type.countRunningType}Holder();
+ <#if type.inputType?starts_with("Nullable") >
+ nonNullCount = new BigIntHolder();
+ nonNullCount.value = 0;
+ </#if>
sum.value = 0;
count.value = 0;
}
@@ -74,7 +79,10 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
if (in.isSet == 0) {
// processing nullable input and the value is null, so don't do anything...
break sout;
- }
+ }
+ else {
+ nonNullCount.value++;
+ }
</#if>
<#if aggrtype.funcName == "avg">
sum.value += in.value;
@@ -89,11 +97,23 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
@Override
public void output() {
- out.value = sum.value / ((double) count.value);
+ <#if type.inputType?starts_with("Nullable") >
+ if (nonNullCount.value > 0) {
+ out.value = sum.value / ((double) count.value);
+ out.isSet = 1;
+ } else {
+ out.isSet = 0;
+ }
+ <#else>
+ out.value = sum.value / ((double) count.value);
+ </#if>
}
@Override
public void reset() {
+ <#if type.inputType?starts_with("Nullable") >
+ nonNullCount.value = 0;
+ </#if>
sum.value = 0;
count.value = 0;
}
diff --git a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java
index 1b276e7c8..acf877a34 100644
--- a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java
+++ b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java
@@ -61,13 +61,17 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
@Workspace ${type.movingAverageType}Holder avg;
@Workspace ${type.movingDeviationType}Holder dev;
@Workspace ${type.countRunningType}Holder count;
+ @Workspace BigIntHolder nonNullCount;
@Output ${type.outputType}Holder out;
public void setup(RecordBatch b) {
avg = new ${type.movingAverageType}Holder();
dev = new ${type.movingDeviationType}Holder();
count = new ${type.countRunningType}Holder();
-
+ <#if type.inputType?starts_with("Nullable") >
+ nonNullCount = new BigIntHolder();
+ nonNullCount.value = 0;
+ </#if>
// Initialize the workspace variables
avg.value = 0;
dev.value = 0;
@@ -82,6 +86,9 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
// processing nullable input and the value is null, so don't do anything...
break sout;
}
+ else {
+ nonNullCount.value++;
+ }
</#if>
// Welford's approach to compute standard deviation
@@ -97,6 +104,26 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
@Override
public void output() {
+ <#if type.inputType?starts_with("Nullable") >
+ if (nonNullCount.value > 0) {
+ out.isSet = 1;
+ <#if aggrtype.funcName == "stddev_pop">
+ if (count.value > 1)
+ out.value = Math.sqrt((dev.value / (count.value - 1)));
+ <#elseif aggrtype.funcName == "var_pop">
+ if (count.value > 1)
+ out.value = (dev.value / (count.value - 1));
+ <#elseif aggrtype.funcName == "stddev_samp">
+ if (count.value > 2)
+ out.value = Math.sqrt((dev.value / (count.value - 2)));
+ <#elseif aggrtype.funcName == "var_samp">
+ if (count.value > 2)
+ out.value = (dev.value / (count.value - 2));
+ </#if>
+ } else {
+ out.isSet = 0;
+ }
+ <#else>
<#if aggrtype.funcName == "stddev_pop">
if (count.value > 1)
out.value = Math.sqrt((dev.value / (count.value - 1)));
@@ -110,10 +137,14 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
if (count.value > 2)
out.value = (dev.value / (count.value - 2));
</#if>
+ </#if>
}
@Override
public void reset() {
+ <#if type.inputType?starts_with("Nullable") >
+ nonNullCount.value = 0;
+ </#if>
avg.value = 0;
dev.value = 0;
count.value = 1;
diff --git a/exec/java-exec/src/main/codegen/templates/CastHigh.java b/exec/java-exec/src/main/codegen/templates/CastHigh.java
new file mode 100644
index 000000000..934b60b88
--- /dev/null
+++ b/exec/java-exec/src/main/codegen/templates/CastHigh.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+<@pp.dropOutputFile />
+
+<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/gcast/CastHighFunctions.java" />
+
+<#include "/@includes/license.ftl" />
+
+package org.apache.drill.exec.expr.fn.impl.gcast;
+
+import org.apache.drill.exec.expr.DrillSimpleFunc;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate.NullHandling;
+import org.apache.drill.exec.expr.annotations.Output;
+import org.apache.drill.exec.expr.annotations.Param;
+import org.apache.drill.exec.expr.holders.*;
+import javax.inject.Inject;
+import io.netty.buffer.DrillBuf;
+import org.apache.drill.exec.record.RecordBatch;
+
+public class CastHighFunctions {
+ static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(CastHighFunctions.class);
+
+ <#list casthigh.types as type>
+
+ @SuppressWarnings("unused")
+ @FunctionTemplate(name = "casthigh", scope = FunctionTemplate.FunctionScope.SIMPLE, nulls=NullHandling.NULL_IF_NULL)
+ public static class CastHigh${type.from} implements DrillSimpleFunc {
+
+ @Param ${type.from}Holder in;
+ <#if type.from.contains("Decimal")>
+ @Output ${type.from}Holder out;
+ <#else>
+ @Output ${type.to}Holder out;
+ </#if>
+
+ public void setup(RecordBatch incoming) {}
+
+ public void eval() {
+ <#if type.value >
+ out.value = (double) in.value;
+ <#else>
+ out = in;
+ </#if>
+ }
+ }
+</#list>
+}
+
diff --git a/exec/java-exec/src/main/codegen/templates/SumZeroAggr.java b/exec/java-exec/src/main/codegen/templates/SumZeroAggr.java
index 0eab23d99..5b0c4a0cf 100644
--- a/exec/java-exec/src/main/codegen/templates/SumZeroAggr.java
+++ b/exec/java-exec/src/main/codegen/templates/SumZeroAggr.java
@@ -51,42 +51,26 @@ public class SumZeroFunctions {
@Param ${type.inputType}Holder in;
@Workspace ${type.runningType}Holder value;
- @Workspace BigIntHolder callCount;
@Output ${type.outputType}Holder out;
public void setup(RecordBatch b) {
value.value = 0;
- callCount.value = 0;
}
@Override
public void add() {
- callCount.value++;
- <#if type.inputType?starts_with("Nullable") >
- if(in.isSet == 1){
- value.value += in.value;
- }
- <#else>
value.value += in.value;
- </#if>
}
@Override
public void output() {
- if(callCount.value > 0){
- out.value = value.value;
- out.isSet = 1;
- }else{
- out.isSet = 0;
- }
-
+ out.value = value.value;
}
@Override
public void reset() {
value.value = 0;
- callCount.value = 0;
}
}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java
index 4854307d3..98f6bd5cd 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java
@@ -20,12 +20,14 @@ package org.apache.drill.exec.planner.common;
import java.util.BitSet;
import java.util.List;
-
+import org.apache.drill.exec.planner.cost.DrillCostBase.DrillCostFactory;
import org.eigenbase.rel.AggregateCall;
import org.eigenbase.rel.AggregateRelBase;
import org.eigenbase.rel.InvalidRelException;
import org.eigenbase.rel.RelNode;
import org.eigenbase.relopt.RelOptCluster;
+import org.eigenbase.relopt.RelOptCost;
+import org.eigenbase.relopt.RelOptPlanner;
import org.eigenbase.relopt.RelTraitSet;
@@ -38,5 +40,20 @@ public abstract class DrillAggregateRelBase extends AggregateRelBase implements
List<AggregateCall> aggCalls) throws InvalidRelException {
super(cluster, traits, child, groupSet, aggCalls);
}
+
+ @Override
+ public RelOptCost computeSelfCost(RelOptPlanner planner) {
+ for (AggregateCall aggCall : getAggCallList()) {
+ String name = aggCall.getAggregation().getName();
+ // For avg, stddev_pop, stddev_samp, var_pop and var_samp, the ReduceAggregatesRule is supposed
+ // to convert them to use sum and count. Here, we make the cost of the original functions high
+ // enough such that the planner does not choose them and instead chooses the rewritten functions.
+ if (name.equals("AVG") || name.equals("STDDEV_POP") || name.equals("STDDEV_SAMP")
+ || name.equals("VAR_POP") || name.equals("VAR_SAMP")) {
+ return ((DrillCostFactory)planner.getCostFactory()).makeHugeCost();
+ }
+ }
+ return ((DrillCostFactory)planner.getCostFactory()).makeTinyCost();
+ }
}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
new file mode 100644
index 000000000..8305dd8a6
--- /dev/null
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
@@ -0,0 +1,602 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.drill.exec.planner.logical;
+
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.drill.exec.planner.sql.DrillSqlOperator;
+import org.eigenbase.rel.AggregateCall;
+import org.eigenbase.rel.AggregateRel;
+import org.eigenbase.rel.AggregateRelBase;
+import org.eigenbase.rel.CalcRel;
+import org.eigenbase.rel.RelNode;
+import org.eigenbase.relopt.RelOptRule;
+import org.eigenbase.relopt.RelOptRuleCall;
+import org.eigenbase.relopt.RelOptRuleOperand;
+import org.eigenbase.reltype.RelDataType;
+import org.eigenbase.reltype.RelDataTypeFactory;
+import org.eigenbase.reltype.RelDataTypeField;
+import org.eigenbase.rex.RexBuilder;
+import org.eigenbase.rex.RexCall;
+import org.eigenbase.rex.RexLiteral;
+import org.eigenbase.rex.RexNode;
+import org.eigenbase.sql.SqlAggFunction;
+import org.eigenbase.sql.fun.SqlAvgAggFunction;
+import org.eigenbase.sql.fun.SqlStdOperatorTable;
+import org.eigenbase.sql.fun.SqlSumAggFunction;
+import org.eigenbase.sql.fun.SqlSumEmptyIsZeroAggFunction;
+import org.eigenbase.sql.type.SqlTypeUtil;
+import org.eigenbase.util.CompositeList;
+import org.eigenbase.util.ImmutableIntList;
+import org.eigenbase.util.Util;
+
+import com.google.common.collect.ImmutableList;
+
+/**
+ * Rule to reduce aggregates to simpler forms. Currently only AVG(x) to
+ * SUM(x)/COUNT(x), but eventually will handle others such as STDDEV.
+ */
+public class DrillReduceAggregatesRule extends RelOptRule {
+ //~ Static fields/initializers ---------------------------------------------
+
+ /**
+ * The singleton.
+ */
+ public static final DrillReduceAggregatesRule INSTANCE =
+ new DrillReduceAggregatesRule(operand(AggregateRel.class, any()));
+
+ private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1);
+
+ //~ Constructors -----------------------------------------------------------
+
+ protected DrillReduceAggregatesRule(RelOptRuleOperand operand) {
+ super(operand);
+ }
+
+ //~ Methods ----------------------------------------------------------------
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ if (!super.matches(call)) {
+ return false;
+ }
+ AggregateRelBase oldAggRel = (AggregateRelBase) call.rels[0];
+ return containsAvgStddevVarCall(oldAggRel.getAggCallList());
+ }
+
+ public void onMatch(RelOptRuleCall ruleCall) {
+ AggregateRelBase oldAggRel = (AggregateRelBase) ruleCall.rels[0];
+ reduceAggs(ruleCall, oldAggRel);
+ }
+
+ /**
+ * Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*.
+ *
+ * @param aggCallList List of aggregate calls
+ */
+ private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
+ for (AggregateCall call : aggCallList) {
+ if (call.getAggregation() instanceof SqlAvgAggFunction
+ || call.getAggregation() instanceof SqlSumAggFunction) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /*
+ private boolean isMatch(AggregateCall call) {
+ if (call.getAggregation() instanceof SqlAvgAggFunction) {
+ final SqlAvgAggFunction.Subtype subtype =
+ ((SqlAvgAggFunction) call.getAggregation()).getSubtype();
+ return (subtype == SqlAvgAggFunction.Subtype.AVG);
+ }
+ return false;
+ }
+ */
+
+ /**
+ * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
+ * the aggregates list to.
+ *
+ * <p>It handles newly generated common subexpressions since this was done
+ * at the sql2rel stage.
+ */
+ private void reduceAggs(
+ RelOptRuleCall ruleCall,
+ AggregateRelBase oldAggRel) {
+ RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+
+ List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
+ final int nGroups = oldAggRel.getGroupCount();
+
+ List<AggregateCall> newCalls = new ArrayList<AggregateCall>();
+ Map<AggregateCall, RexNode> aggCallMapping =
+ new HashMap<AggregateCall, RexNode>();
+
+ List<RexNode> projList = new ArrayList<RexNode>();
+
+ // pass through group key
+ for (int i = 0; i < nGroups; ++i) {
+ projList.add(
+ rexBuilder.makeInputRef(
+ getFieldType(oldAggRel, i),
+ i));
+ }
+
+ // List of input expressions. If a particular aggregate needs more, it
+ // will add an expression to the end, and we will create an extra
+ // project.
+ RelNode input = oldAggRel.getChild();
+ List<RexNode> inputExprs = new ArrayList<RexNode>();
+ for (RelDataTypeField field : input.getRowType().getFieldList()) {
+ inputExprs.add(
+ rexBuilder.makeInputRef(
+ field.getType(), inputExprs.size()));
+ }
+
+ // create new agg function calls and rest of project list together
+ for (AggregateCall oldCall : oldCalls) {
+ projList.add(
+ reduceAgg(
+ oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
+ }
+
+ final int extraArgCount =
+ inputExprs.size() - input.getRowType().getFieldCount();
+ if (extraArgCount > 0) {
+ input =
+ CalcRel.createProject(
+ input,
+ inputExprs,
+ CompositeList.of(
+ input.getRowType().getFieldNames(),
+ Collections.<String>nCopies(
+ extraArgCount,
+ null)));
+ }
+ AggregateRelBase newAggRel =
+ newAggregateRel(
+ oldAggRel, input, newCalls);
+
+ RelNode projectRel =
+ CalcRel.createProject(
+ newAggRel,
+ projList,
+ oldAggRel.getRowType().getFieldNames());
+
+ ruleCall.transformTo(projectRel);
+ }
+
+ private RexNode reduceAgg(
+ AggregateRelBase oldAggRel,
+ AggregateCall oldCall,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping,
+ List<RexNode> inputExprs) {
+ if (oldCall.getAggregation() instanceof SqlSumAggFunction) {
+ // replace original SUM(x) with
+ // case COUNT(x) when 0 then null else SUM0(x) end
+ return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
+ }
+ if (oldCall.getAggregation() instanceof SqlAvgAggFunction) {
+ final SqlAvgAggFunction.Subtype subtype =
+ ((SqlAvgAggFunction) oldCall.getAggregation()).getSubtype();
+
+ switch (subtype) {
+ case AVG:
+ // replace original AVG(x) with SUM(x) / COUNT(x)
+ return reduceAvg(
+ oldAggRel, oldCall, newCalls, aggCallMapping);
+ case STDDEV_POP:
+ // replace original STDDEV_POP(x) with
+ // SQRT(
+ // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // / COUNT(x))
+ return reduceStddev(
+ oldAggRel, oldCall, true, true, newCalls, aggCallMapping,
+ inputExprs);
+ case STDDEV_SAMP:
+ // replace original STDDEV_POP(x) with
+ // SQRT(
+ // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
+ return reduceStddev(
+ oldAggRel, oldCall, false, true, newCalls, aggCallMapping,
+ inputExprs);
+ case VAR_POP:
+ // replace original VAR_POP(x) with
+ // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // / COUNT(x)
+ return reduceStddev(
+ oldAggRel, oldCall, true, false, newCalls, aggCallMapping,
+ inputExprs);
+ case VAR_SAMP:
+ // replace original VAR_POP(x) with
+ // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
+ return reduceStddev(
+ oldAggRel, oldCall, false, false, newCalls, aggCallMapping,
+ inputExprs);
+ default:
+ throw Util.unexpected(subtype);
+ }
+ } else {
+ // anything else: preserve original call
+ RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+ final int nGroups = oldAggRel.getGroupCount();
+ List<RelDataType> oldArgTypes = SqlTypeUtil
+ .projectTypes(oldAggRel.getRowType(), oldCall.getArgList());
+ return rexBuilder.addAggCall(
+ oldCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldArgTypes);
+ }
+ }
+
+ private RexNode reduceAvg(
+ AggregateRelBase oldAggRel,
+ AggregateCall oldCall,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping) {
+ final int nGroups = oldAggRel.getGroupCount();
+ RelDataTypeFactory typeFactory =
+ oldAggRel.getCluster().getTypeFactory();
+ RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+ int iAvgInput = oldCall.getArgList().get(0);
+ RelDataType avgInputType =
+ getFieldType(
+ oldAggRel.getChild(),
+ iAvgInput);
+ RelDataType sumType =
+ typeFactory.createTypeWithNullability(
+ avgInputType,
+ avgInputType.isNullable() || nGroups == 0);
+ // SqlAggFunction sumAgg = new SqlSumAggFunction(sumType);
+ SqlAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction(sumType);
+ AggregateCall sumCall =
+ new AggregateCall(
+ sumAgg,
+ oldCall.isDistinct(),
+ oldCall.getArgList(),
+ sumType,
+ null);
+ SqlAggFunction countAgg = SqlStdOperatorTable.COUNT;
+ RelDataType countType = countAgg.getReturnType(typeFactory);
+ AggregateCall countCall =
+ new AggregateCall(
+ countAgg,
+ oldCall.isDistinct(),
+ oldCall.getArgList(),
+ countType,
+ null);
+
+ RexNode tmpsumRef =
+ rexBuilder.addAggCall(
+ sumCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(avgInputType));
+
+ RexNode tmpcountRef =
+ rexBuilder.addAggCall(
+ countCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(avgInputType));
+
+ RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE,
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
+ rexBuilder.constantNull(),
+ tmpsumRef);
+
+ // NOTE: these references are with respect to the output
+ // of newAggRel
+ /*
+ RexNode numeratorRef =
+ rexBuilder.makeCall(CastHighOp,
+ rexBuilder.addAggCall(
+ sumCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(avgInputType))
+ );
+ */
+ RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
+
+ RexNode denominatorRef =
+ rexBuilder.addAggCall(
+ countCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(avgInputType));
+ final RexNode divideRef =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.DIVIDE,
+ numeratorRef,
+ denominatorRef);
+ return rexBuilder.makeCast(
+ oldCall.getType(), divideRef);
+ }
+
+ private RexNode reduceSum(
+ AggregateRelBase oldAggRel,
+ AggregateCall oldCall,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping) {
+ final int nGroups = oldAggRel.getGroupCount();
+ RelDataTypeFactory typeFactory =
+ oldAggRel.getCluster().getTypeFactory();
+ RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+ int arg = oldCall.getArgList().get(0);
+ RelDataType argType =
+ getFieldType(
+ oldAggRel.getChild(),
+ arg);
+ RelDataType sumType =
+ typeFactory.createTypeWithNullability(
+ argType, argType.isNullable());
+ SqlAggFunction sumZeroAgg = new SqlSumEmptyIsZeroAggFunction(sumType);
+ AggregateCall sumZeroCall =
+ new AggregateCall(
+ sumZeroAgg,
+ oldCall.isDistinct(),
+ oldCall.getArgList(),
+ sumType,
+ null);
+ SqlAggFunction countAgg = SqlStdOperatorTable.COUNT;
+ RelDataType countType = countAgg.getReturnType(typeFactory);
+ AggregateCall countCall =
+ new AggregateCall(
+ countAgg,
+ oldCall.isDistinct(),
+ oldCall.getArgList(),
+ countType,
+ null);
+
+ // NOTE: these references are with respect to the output
+ // of newAggRel
+ RexNode sumZeroRef =
+ rexBuilder.addAggCall(
+ sumZeroCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(argType));
+ if (!oldCall.getType().isNullable()) {
+ // If SUM(x) is not nullable, the validator must have determined that
+ // nulls are impossible (because the group is never empty and x is never
+ // null). Therefore we translate to SUM0(x).
+ return sumZeroRef;
+ }
+ RexNode countRef =
+ rexBuilder.addAggCall(
+ countCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(argType));
+ return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
+ rexBuilder.constantNull(),
+ sumZeroRef);
+ }
+
+ private RexNode reduceStddev(
+ AggregateRelBase oldAggRel,
+ AggregateCall oldCall,
+ boolean biased,
+ boolean sqrt,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping,
+ List<RexNode> inputExprs) {
+ // stddev_pop(x) ==>
+ // power(
+ // (sum(x * x) - sum(x) * sum(x) / count(x))
+ // / count(x),
+ // .5)
+ //
+ // stddev_samp(x) ==>
+ // power(
+ // (sum(x * x) - sum(x) * sum(x) / count(x))
+ // / nullif(count(x) - 1, 0),
+ // .5)
+ final int nGroups = oldAggRel.getGroupCount();
+ RelDataTypeFactory typeFactory =
+ oldAggRel.getCluster().getTypeFactory();
+ final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+
+ assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
+ final int argOrdinal = oldCall.getArgList().get(0);
+ final RelDataType argType =
+ getFieldType(
+ oldAggRel.getChild(),
+ argOrdinal);
+
+ // final RexNode argRef = inputExprs.get(argOrdinal);
+ RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
+ inputExprs.set(argOrdinal, argRef);
+
+ final RexNode argSquared =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.MULTIPLY, argRef, argRef);
+ final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
+
+ final RelDataType sumType =
+ typeFactory.createTypeWithNullability(
+ argType,
+ true);
+ final AggregateCall sumArgSquaredAggCall =
+ new AggregateCall(
+ new SqlSumAggFunction(sumType),
+ oldCall.isDistinct(),
+ ImmutableIntList.of(argSquaredOrdinal),
+ sumType,
+ null);
+ final RexNode sumArgSquared =
+ rexBuilder.addAggCall(
+ sumArgSquaredAggCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(argType));
+
+ final AggregateCall sumArgAggCall =
+ new AggregateCall(
+ new SqlSumAggFunction(sumType),
+ oldCall.isDistinct(),
+ ImmutableIntList.of(argOrdinal),
+ sumType,
+ null);
+ final RexNode sumArg =
+ rexBuilder.addAggCall(
+ sumArgAggCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(argType));
+
+ final RexNode sumSquaredArg =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
+
+ final SqlAggFunction countAgg = SqlStdOperatorTable.COUNT;
+ final RelDataType countType = countAgg.getReturnType(typeFactory);
+ final AggregateCall countArgAggCall =
+ new AggregateCall(
+ countAgg,
+ oldCall.isDistinct(),
+ oldCall.getArgList(),
+ countType,
+ null);
+ final RexNode countArg =
+ rexBuilder.addAggCall(
+ countArgAggCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(argType));
+
+ final RexNode avgSumSquaredArg =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.DIVIDE,
+ sumSquaredArg, countArg);
+
+ final RexNode diff =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.MINUS,
+ sumArgSquared, avgSumSquaredArg);
+
+ final RexNode denominator;
+ if (biased) {
+ denominator = countArg;
+ } else {
+ final RexLiteral one =
+ rexBuilder.makeExactLiteral(BigDecimal.ONE);
+ final RexNode nul =
+ rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
+ final RexNode countMinusOne =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.MINUS, countArg, one);
+ final RexNode countEqOne =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.EQUALS, countArg, one);
+ denominator =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.CASE,
+ countEqOne, nul, countMinusOne);
+ }
+
+ final RexNode div =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.DIVIDE, diff, denominator);
+
+ RexNode result = div;
+ if (sqrt) {
+ final RexNode half =
+ rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
+ result =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.POWER, div, half);
+ }
+
+ return rexBuilder.makeCast(
+ oldCall.getType(), result);
+ }
+
+ /**
+ * Finds the ordinal of an element in a list, or adds it.
+ *
+ * @param list List
+ * @param element Element to lookup or add
+ * @param <T> Element type
+ * @return Ordinal of element in list
+ */
+ private static <T> int lookupOrAdd(List<T> list, T element) {
+ int ordinal = list.indexOf(element);
+ if (ordinal == -1) {
+ ordinal = list.size();
+ list.add(element);
+ }
+ return ordinal;
+ }
+
+ /**
+ * Do a shallow clone of oldAggRel and update aggCalls. Could be refactored
+ * into AggregateRelBase and subclasses - but it's only needed for some
+ * subclasses.
+ *
+ * @param oldAggRel AggregateRel to clone.
+ * @param inputRel Input relational expression
+ * @param newCalls New list of AggregateCalls
+ * @return shallow clone with new list of AggregateCalls.
+ */
+ protected AggregateRelBase newAggregateRel(
+ AggregateRelBase oldAggRel,
+ RelNode inputRel,
+ List<AggregateCall> newCalls) {
+ return new AggregateRel(
+ oldAggRel.getCluster(),
+ inputRel,
+ oldAggRel.getGroupSet(),
+ newCalls);
+ }
+
+ private RelDataType getFieldType(RelNode relNode, int i) {
+ final RelDataTypeField inputField =
+ relNode.getRowType().getFieldList().get(i);
+ return inputField.getType();
+ }
+
+}
+
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java
index cf92121be..65fa2d7d1 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java
@@ -46,7 +46,6 @@ import org.eigenbase.rel.rules.PushFilterPastProjectRule;
import org.eigenbase.rel.rules.PushJoinThroughJoinRule;
import org.eigenbase.rel.rules.PushProjectPastFilterRule;
import org.eigenbase.rel.rules.PushProjectPastJoinRule;
-import org.eigenbase.rel.rules.ReduceAggregatesRule;
import org.eigenbase.rel.rules.RemoveDistinctAggregateRule;
import org.eigenbase.rel.rules.RemoveDistinctRule;
import org.eigenbase.rel.rules.RemoveSortRule;
@@ -87,7 +86,7 @@ public class DrillRuleSets {
//MergeProjectRule.INSTANCE, //
DrillMergeProjectRule.getInstance(true, RelFactories.DEFAULT_PROJECT_FACTORY, context.getFunctionRegistry()),
RemoveDistinctAggregateRule.INSTANCE, //
- ReduceAggregatesRule.INSTANCE, //
+ // ReduceAggregatesRule.INSTANCE, // replaced by DrillReduceAggregatesRule
PushProjectPastJoinRule.INSTANCE,
// PushProjectPastFilterRule.INSTANCE,
DrillPushProjectPastFilterRule.INSTANCE,
@@ -108,6 +107,7 @@ public class DrillRuleSets {
DrillSortRule.INSTANCE,
DrillJoinRule.INSTANCE,
DrillUnionRule.INSTANCE
+ ,DrillReduceAggregatesRule.INSTANCE
));
}
return DRILL_BASIC_RULES;
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
index 3a164149a..3dedb552a 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
@@ -73,7 +73,7 @@ public abstract class AggPrelBase extends AggregateRelBase implements Prel{
private final RelDataType type;
public SqlSumCountAggFunction(RelDataType type) {
- super("SUM",
+ super("$SUM0",
SqlKind.OTHER_FUNCTION,
ReturnTypes.BIGINT, // use the inferred return type of SqlCountAggFunction
null,
@@ -185,7 +185,7 @@ public abstract class AggPrelBase extends AggregateRelBase implements Prel{
public <T, X, E extends Throwable> T accept(PrelVisitor<T, X, E> logicalVisitor, X value) throws E {
return logicalVisitor.visitPrel(this, value);
}
-
+
@Override
public boolean needsFinalColumnReordering() {
return true;
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java
index 7b7e3b72e..a61bfd6fe 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java
@@ -69,7 +69,8 @@ public abstract class AggPruleBase extends Prule {
for (AggregateCall aggCall : aggregate.getAggCallList()) {
String name = aggCall.getAggregation().getName();
- if ( ! (name.equals("SUM") || name.equals("MIN") || name.equals("MAX") || name.equals("COUNT"))) {
+ if ( ! (name.equals("SUM") || name.equals("MIN") || name.equals("MAX") || name.equals("COUNT")
+ || name.equals("$SUM0"))) {
return false;
}
}