diff options
author | Aman Sinha <asinha@maprtech.com> | 2014-08-27 13:32:03 -0700 |
---|---|---|
committer | Jacques Nadeau <jacques@apache.org> | 2014-08-31 10:27:26 -0700 |
commit | a0d3906b8ed6dc598ec23b55ca9542180111e910 (patch) | |
tree | aa54b727196d2cf7dd352adf3c0bde0af9cab077 /exec | |
parent | b9e384dc9a8889e5cb1a9ac05b48c6b99096fe73 (diff) |
DRILL-1171: Create Drill's implementation of ReduceAggregatesRule, including a new CastHigh function.
DRILL-1342: Fix nullability handling of aggregate functions.
Diffstat (limited to 'exec')
17 files changed, 873 insertions, 102 deletions
diff --git a/exec/java-exec/src/main/codegen/config.fmpp b/exec/java-exec/src/main/codegen/config.fmpp index ff6135dff..aff6240c0 100644 --- a/exec/java-exec/src/main/codegen/config.fmpp +++ b/exec/java-exec/src/main/codegen/config.fmpp @@ -36,7 +36,8 @@ data: { intervalNumericTypes: tdd(../data/IntervalNumericTypes.tdd), extract: tdd(../data/ExtractTypes.tdd), sumzero: tdd(../data/SumZero.tdd), - numericTypes: tdd(../data/NumericTypes.tdd) + numericTypes: tdd(../data/NumericTypes.tdd), + casthigh: tdd(../data/CastHigh.tdd) } freemarkerLinks: { includes: includes/ diff --git a/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd b/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd index 1bac07ee3..812c289e6 100644 --- a/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd +++ b/exec/java-exec/src/main/codegen/data/AggrTypes1.tdd @@ -20,13 +20,13 @@ {inputType: "Bit", outputType: "Bit", runningType: "Bit", major: "Numeric"}, {inputType: "Int", outputType: "Int", runningType: "Int", major: "Numeric"}, {inputType: "BigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, - {inputType: "NullableBit", outputType: "Bit", runningType: "Bit", major: "Numeric"}, - {inputType: "NullableInt", outputType: "Int", runningType: "Int", major: "Numeric"}, - {inputType: "NullableBigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"}, + {inputType: "NullableInt", outputType: "NullableInt", runningType: "Int", major: "Numeric"}, + {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, {inputType: "Float4", outputType: "Float4", runningType: "Float4", major: "Numeric"}, {inputType: "Float8", outputType: "Float8", runningType: "Float8", major: "Numeric"}, - {inputType: "NullableFloat4", outputType: "Float4", runningType: "Float4", major: "Numeric"}, - {inputType: "NullableFloat8", outputType: "Float8", runningType: "Float8", major: "Numeric"}, + {inputType: "NullableFloat4", outputType: "NullableFloat4", runningType: "Float4", major: "Numeric"}, + {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, {inputType: "Date", outputType: "Date", runningType: "Date", major: "Date", initialValue: "Long.MAX_VALUE"}, {inputType: "NullableDate", outputType: "Date", runningType: "Date", major: "Date", initialValue: "Long.MAX_VALUE"}, {inputType: "TimeStamp", outputType: "TimeStamp", runningType: "TimeStamp", major: "Date", initialValue: "Long.MAX_VALUE"}, @@ -51,13 +51,13 @@ {inputType: "Bit", outputType: "Bit", runningType: "Bit", major: "Numeric"}, {inputType: "Int", outputType: "Int", runningType: "Int", major: "Numeric"}, {inputType: "BigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, - {inputType: "NullableBit", outputType: "Bit", runningType: "Bit", major: "Numeric"}, - {inputType: "NullableInt", outputType: "Int", runningType: "Int", major: "Numeric"}, - {inputType: "NullableBigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"}, + {inputType: "NullableInt", outputType: "NullableInt", runningType: "Int", major: "Numeric"}, + {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, {inputType: "Float4", outputType: "Float4", runningType: "Float4", major: "Numeric"}, {inputType: "Float8", outputType: "Float8", runningType: "Float8", major: "Numeric"}, - {inputType: "NullableFloat4", outputType: "Float4", runningType: "Float4", major: "Numeric"}, - {inputType: "NullableFloat8", outputType: "Float8", runningType: "Float8", major: "Numeric"}, + {inputType: "NullableFloat4", outputType: "NullableFloat4", runningType: "Float4", major: "Numeric"}, + {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, {inputType: "Date", outputType: "Date", runningType: "Date", major: "Date", initialValue: "Long.MIN_VALUE"}, {inputType: "NullableDate", outputType: "Date", runningType: "Date", major: "Date", initialValue: "Long.MIN_VALUE"}, {inputType: "TimeStamp", outputType: "TimeStamp", runningType: "TimeStamp", major: "Date", initialValue: "Long.MIN_VALUE"}, @@ -82,13 +82,13 @@ {inputType: "Bit", outputType: "Bit", runningType: "Bit", major: "Numeric"}, {inputType: "Int", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, {inputType: "BigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, - {inputType: "NullableBit", outputType: "Bit", runningType: "Bit", major: "Numeric"}, - {inputType: "NullableInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, - {inputType: "NullableBigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"}, + {inputType: "NullableInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, {inputType: "Float4", outputType: "Float8", runningType: "Float8", major: "Numeric"}, {inputType: "Float8", outputType: "Float8", runningType: "Float8", major: "Numeric"}, - {inputType: "NullableFloat4", outputType: "Float8", runningType: "Float8", major: "Numeric"}, - {inputType: "NullableFloat8", outputType: "Float8", runningType: "Float8", major: "Numeric"}, + {inputType: "NullableFloat4", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, + {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, {inputType: "IntervalDay", outputType: "IntervalDay", runningType: "IntervalDay", major: "Date", initialValue: "0"}, {inputType: "NullableIntervalDay", outputType: "IntervalDay", runningType: "IntervalDay", major: "Date", initialValue: "0"}, {inputType: "IntervalYear", outputType: "IntervalYear", runningType: "IntervalYear", major: "Date", initialValue: "0"}, diff --git a/exec/java-exec/src/main/codegen/data/AggrTypes2.tdd b/exec/java-exec/src/main/codegen/data/AggrTypes2.tdd index c6655afd2..ee64dafa8 100644 --- a/exec/java-exec/src/main/codegen/data/AggrTypes2.tdd +++ b/exec/java-exec/src/main/codegen/data/AggrTypes2.tdd @@ -19,12 +19,12 @@ {className: "Avg", funcName: "avg", types: [ {inputType: "Int", outputType: "Float8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"}, {inputType: "BigInt", outputType: "Float8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"}, - {inputType: "NullableInt", outputType: "Float8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"}, - {inputType: "NullableBigInt", outputType: "Float8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "NullableInt", outputType: "NullableFloat8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "NullableBigInt", outputType: "NullableFloat8", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Numeric"}, {inputType: "Float4", outputType: "Float8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"}, {inputType: "Float8", outputType: "Float8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"}, - {inputType: "NullableFloat4", outputType: "Float8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"}, - {inputType: "NullableFloat8", outputType: "Float8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "NullableFloat4", outputType: "NullableFloat8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "NullableFloat8", outputType: "NullableFloat8", sumRunningType: "Float8", countRunningType: "BigInt", major: "Numeric"}, {inputType: "IntervalDay", outputType: "Interval", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Date"}, {inputType: "NullableIntervalDay", outputType: "Interval", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Date"}, {inputType: "IntervalYear", outputType: "Interval", sumRunningType: "BigInt", countRunningType: "BigInt", major: "Date"}, diff --git a/exec/java-exec/src/main/codegen/data/AggrTypes3.tdd b/exec/java-exec/src/main/codegen/data/AggrTypes3.tdd index 05acc0327..0c3a3588f 100644 --- a/exec/java-exec/src/main/codegen/data/AggrTypes3.tdd +++ b/exec/java-exec/src/main/codegen/data/AggrTypes3.tdd @@ -18,97 +18,97 @@ aggrtypes: [ {className: "StdDevPop", funcName: "stddev_pop", aliasName: "", types: [ {inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableBigInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableSmallInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableTinyInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt1", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt2", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} + {inputType: "NullableFloat8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} ] }, {className: "VariancePop", funcName: "var_pop", aliasName: "", types: [ {inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableBigInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableSmallInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableTinyInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt1", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt2", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} + {inputType: "NullableFloat8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} ] }, {className: "StdDevSample", funcName: "stddev_samp", aliasName: "stddev", types: [ {inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableBigInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableSmallInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableTinyInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt1", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt2", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} + {inputType: "NullableFloat8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} ] }, {className: "VarianceSample", funcName: "var_samp", aliasName: "variance", types: [ {inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableBigInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableSmallInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableTinyInt", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt1", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt2", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat4", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, {inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, - {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} + {inputType: "NullableFloat8", outputType: "NullableFloat8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} ] } ] diff --git a/exec/java-exec/src/main/codegen/data/CastHigh.tdd b/exec/java-exec/src/main/codegen/data/CastHigh.tdd new file mode 100644 index 000000000..54c337ddf --- /dev/null +++ b/exec/java-exec/src/main/codegen/data/CastHigh.tdd @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http:# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +{ + types: [ + {value: true, from: "Int", to: "Float8" }, + {value: true, from: "BigInt", to: "Float8" }, + {value: true, from: "Float4", to: "Float8" }, + {value: true, from: "Float8", to: "Float8" }, + {value: false, from: "Decimal9"}, + {value: false, from: "Decimal18"}, + {value: false, from: "Decimal28Sparse"}, + {value: false, from: "Decimal38Sparse"}, + ] +} diff --git a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd index ed8d1334b..5aa8b7fba 100644 --- a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd +++ b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd @@ -28,4 +28,4 @@ ] } ] -}
\ No newline at end of file +} diff --git a/exec/java-exec/src/main/codegen/data/SumZero.tdd b/exec/java-exec/src/main/codegen/data/SumZero.tdd index b270d928d..c6532b4fd 100644 --- a/exec/java-exec/src/main/codegen/data/SumZero.tdd +++ b/exec/java-exec/src/main/codegen/data/SumZero.tdd @@ -16,16 +16,16 @@ { types: [ - {inputType: "Bit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"}, - {inputType: "Int", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, - {inputType: "BigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, - {inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "Numeric"}, - {inputType: "NullableInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, - {inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "Numeric"}, - {inputType: "Float4", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, - {inputType: "Float8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, - {inputType: "NullableFloat4", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"}, - {inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "Numeric"} + {inputType: "Bit", outputType: "Bit", runningType: "Bit", major: "Numeric"}, + {inputType: "Int", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "BigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "NullableBit", outputType: "Bit", runningType: "Bit", major: "Numeric"}, + {inputType: "NullableInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "NullableBigInt", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, + {inputType: "Float4", outputType: "Float8", runningType: "Float8", major: "Numeric"}, + {inputType: "Float8", outputType: "Float8", runningType: "Float8", major: "Numeric"}, + {inputType: "NullableFloat4", outputType: "Float8", runningType: "Float8", major: "Numeric"}, + {inputType: "NullableFloat8", outputType: "Float8", runningType: "Float8", major: "Numeric"} ] }
\ No newline at end of file diff --git a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java index aa9aeab7a..e19def360 100644 --- a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java +++ b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions1.java @@ -55,10 +55,17 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu @Param ${type.inputType}Holder in; @Workspace ${type.runningType}Holder value; + <#if type.inputType?starts_with("Nullable") && type.outputType?starts_with("Nullable")> + @Workspace BigIntHolder nonNullCount; + </#if> @Output ${type.outputType}Holder out; public void setup(RecordBatch b) { - value = new ${type.runningType}Holder(); + value = new ${type.runningType}Holder(); + <#if type.inputType?starts_with("Nullable") && type.outputType?starts_with("Nullable")> + nonNullCount = new BigIntHolder(); + nonNullCount.value = 0; + </#if> <#if aggrtype.funcName == "sum" || aggrtype.funcName == "count"> value.value = 0; <#elseif aggrtype.funcName == "min"> @@ -96,7 +103,12 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu if (in.isSet == 0) { // processing nullable input and the value is null, so don't do anything... break sout; - } + } + <#if type.outputType?starts_with("Nullable")> + else { + nonNullCount.value++; + } + </#if> </#if> <#if aggrtype.funcName == "min"> value.value = Math.min(value.value, in.value); @@ -115,13 +127,24 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu } @Override - public void output() { - out.value = value.value; + public void output() { + <#if type.inputType?starts_with("Nullable") && type.outputType?starts_with("Nullable")> + if (nonNullCount.value > 0) { + out.value = value.value; + out.isSet = 1; + } else { + out.isSet = 0; + } + <#else> + out.value = value.value; + </#if> } @Override public void reset() { - + <#if type.inputType?starts_with("Nullable") && type.outputType?starts_with("Nullable")> + nonNullCount.value = 0; + </#if> <#if aggrtype.funcName == "sum" || aggrtype.funcName == "count"> value.value = 0; <#elseif aggrtype.funcName == "min"> diff --git a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java index 8606dd1f2..fda14571b 100644 --- a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java +++ b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions2.java @@ -58,11 +58,16 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu @Param ${type.inputType}Holder in; @Workspace ${type.sumRunningType}Holder sum; @Workspace ${type.countRunningType}Holder count; + @Workspace BigIntHolder nonNullCount; @Output ${type.outputType}Holder out; public void setup(RecordBatch b) { sum = new ${type.sumRunningType}Holder(); count = new ${type.countRunningType}Holder(); + <#if type.inputType?starts_with("Nullable") > + nonNullCount = new BigIntHolder(); + nonNullCount.value = 0; + </#if> sum.value = 0; count.value = 0; } @@ -74,7 +79,10 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu if (in.isSet == 0) { // processing nullable input and the value is null, so don't do anything... break sout; - } + } + else { + nonNullCount.value++; + } </#if> <#if aggrtype.funcName == "avg"> sum.value += in.value; @@ -89,11 +97,23 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu @Override public void output() { - out.value = sum.value / ((double) count.value); + <#if type.inputType?starts_with("Nullable") > + if (nonNullCount.value > 0) { + out.value = sum.value / ((double) count.value); + out.isSet = 1; + } else { + out.isSet = 0; + } + <#else> + out.value = sum.value / ((double) count.value); + </#if> } @Override public void reset() { + <#if type.inputType?starts_with("Nullable") > + nonNullCount.value = 0; + </#if> sum.value = 0; count.value = 0; } diff --git a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java index 1b276e7c8..acf877a34 100644 --- a/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java +++ b/exec/java-exec/src/main/codegen/templates/AggrTypeFunctions3.java @@ -61,13 +61,17 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu @Workspace ${type.movingAverageType}Holder avg; @Workspace ${type.movingDeviationType}Holder dev; @Workspace ${type.countRunningType}Holder count; + @Workspace BigIntHolder nonNullCount; @Output ${type.outputType}Holder out; public void setup(RecordBatch b) { avg = new ${type.movingAverageType}Holder(); dev = new ${type.movingDeviationType}Holder(); count = new ${type.countRunningType}Holder(); - + <#if type.inputType?starts_with("Nullable") > + nonNullCount = new BigIntHolder(); + nonNullCount.value = 0; + </#if> // Initialize the workspace variables avg.value = 0; dev.value = 0; @@ -82,6 +86,9 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu // processing nullable input and the value is null, so don't do anything... break sout; } + else { + nonNullCount.value++; + } </#if> // Welford's approach to compute standard deviation @@ -97,6 +104,26 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu @Override public void output() { + <#if type.inputType?starts_with("Nullable") > + if (nonNullCount.value > 0) { + out.isSet = 1; + <#if aggrtype.funcName == "stddev_pop"> + if (count.value > 1) + out.value = Math.sqrt((dev.value / (count.value - 1))); + <#elseif aggrtype.funcName == "var_pop"> + if (count.value > 1) + out.value = (dev.value / (count.value - 1)); + <#elseif aggrtype.funcName == "stddev_samp"> + if (count.value > 2) + out.value = Math.sqrt((dev.value / (count.value - 2))); + <#elseif aggrtype.funcName == "var_samp"> + if (count.value > 2) + out.value = (dev.value / (count.value - 2)); + </#if> + } else { + out.isSet = 0; + } + <#else> <#if aggrtype.funcName == "stddev_pop"> if (count.value > 1) out.value = Math.sqrt((dev.value / (count.value - 1))); @@ -110,10 +137,14 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu if (count.value > 2) out.value = (dev.value / (count.value - 2)); </#if> + </#if> } @Override public void reset() { + <#if type.inputType?starts_with("Nullable") > + nonNullCount.value = 0; + </#if> avg.value = 0; dev.value = 0; count.value = 1; diff --git a/exec/java-exec/src/main/codegen/templates/CastHigh.java b/exec/java-exec/src/main/codegen/templates/CastHigh.java new file mode 100644 index 000000000..934b60b88 --- /dev/null +++ b/exec/java-exec/src/main/codegen/templates/CastHigh.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +<@pp.dropOutputFile /> + +<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/gcast/CastHighFunctions.java" /> + +<#include "/@includes/license.ftl" /> + +package org.apache.drill.exec.expr.fn.impl.gcast; + +import org.apache.drill.exec.expr.DrillSimpleFunc; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; +import org.apache.drill.exec.expr.annotations.FunctionTemplate.NullHandling; +import org.apache.drill.exec.expr.annotations.Output; +import org.apache.drill.exec.expr.annotations.Param; +import org.apache.drill.exec.expr.holders.*; +import javax.inject.Inject; +import io.netty.buffer.DrillBuf; +import org.apache.drill.exec.record.RecordBatch; + +public class CastHighFunctions { + static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(CastHighFunctions.class); + + <#list casthigh.types as type> + + @SuppressWarnings("unused") + @FunctionTemplate(name = "casthigh", scope = FunctionTemplate.FunctionScope.SIMPLE, nulls=NullHandling.NULL_IF_NULL) + public static class CastHigh${type.from} implements DrillSimpleFunc { + + @Param ${type.from}Holder in; + <#if type.from.contains("Decimal")> + @Output ${type.from}Holder out; + <#else> + @Output ${type.to}Holder out; + </#if> + + public void setup(RecordBatch incoming) {} + + public void eval() { + <#if type.value > + out.value = (double) in.value; + <#else> + out = in; + </#if> + } + } +</#list> +} + diff --git a/exec/java-exec/src/main/codegen/templates/SumZeroAggr.java b/exec/java-exec/src/main/codegen/templates/SumZeroAggr.java index 0eab23d99..5b0c4a0cf 100644 --- a/exec/java-exec/src/main/codegen/templates/SumZeroAggr.java +++ b/exec/java-exec/src/main/codegen/templates/SumZeroAggr.java @@ -51,42 +51,26 @@ public class SumZeroFunctions { @Param ${type.inputType}Holder in; @Workspace ${type.runningType}Holder value; - @Workspace BigIntHolder callCount; @Output ${type.outputType}Holder out; public void setup(RecordBatch b) { value.value = 0; - callCount.value = 0; } @Override public void add() { - callCount.value++; - <#if type.inputType?starts_with("Nullable") > - if(in.isSet == 1){ - value.value += in.value; - } - <#else> value.value += in.value; - </#if> } @Override public void output() { - if(callCount.value > 0){ - out.value = value.value; - out.isSet = 1; - }else{ - out.isSet = 0; - } - + out.value = value.value; } @Override public void reset() { value.value = 0; - callCount.value = 0; } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java index 4854307d3..98f6bd5cd 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillAggregateRelBase.java @@ -20,12 +20,14 @@ package org.apache.drill.exec.planner.common; import java.util.BitSet; import java.util.List; - +import org.apache.drill.exec.planner.cost.DrillCostBase.DrillCostFactory; import org.eigenbase.rel.AggregateCall; import org.eigenbase.rel.AggregateRelBase; import org.eigenbase.rel.InvalidRelException; import org.eigenbase.rel.RelNode; import org.eigenbase.relopt.RelOptCluster; +import org.eigenbase.relopt.RelOptCost; +import org.eigenbase.relopt.RelOptPlanner; import org.eigenbase.relopt.RelTraitSet; @@ -38,5 +40,20 @@ public abstract class DrillAggregateRelBase extends AggregateRelBase implements List<AggregateCall> aggCalls) throws InvalidRelException { super(cluster, traits, child, groupSet, aggCalls); } + + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner) { + for (AggregateCall aggCall : getAggCallList()) { + String name = aggCall.getAggregation().getName(); + // For avg, stddev_pop, stddev_samp, var_pop and var_samp, the ReduceAggregatesRule is supposed + // to convert them to use sum and count. Here, we make the cost of the original functions high + // enough such that the planner does not choose them and instead chooses the rewritten functions. + if (name.equals("AVG") || name.equals("STDDEV_POP") || name.equals("STDDEV_SAMP") + || name.equals("VAR_POP") || name.equals("VAR_SAMP")) { + return ((DrillCostFactory)planner.getCostFactory()).makeHugeCost(); + } + } + return ((DrillCostFactory)planner.getCostFactory()).makeTinyCost(); + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java new file mode 100644 index 000000000..8305dd8a6 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java @@ -0,0 +1,602 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.drill.exec.planner.logical; + + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.drill.exec.planner.sql.DrillSqlOperator; +import org.eigenbase.rel.AggregateCall; +import org.eigenbase.rel.AggregateRel; +import org.eigenbase.rel.AggregateRelBase; +import org.eigenbase.rel.CalcRel; +import org.eigenbase.rel.RelNode; +import org.eigenbase.relopt.RelOptRule; +import org.eigenbase.relopt.RelOptRuleCall; +import org.eigenbase.relopt.RelOptRuleOperand; +import org.eigenbase.reltype.RelDataType; +import org.eigenbase.reltype.RelDataTypeFactory; +import org.eigenbase.reltype.RelDataTypeField; +import org.eigenbase.rex.RexBuilder; +import org.eigenbase.rex.RexCall; +import org.eigenbase.rex.RexLiteral; +import org.eigenbase.rex.RexNode; +import org.eigenbase.sql.SqlAggFunction; +import org.eigenbase.sql.fun.SqlAvgAggFunction; +import org.eigenbase.sql.fun.SqlStdOperatorTable; +import org.eigenbase.sql.fun.SqlSumAggFunction; +import org.eigenbase.sql.fun.SqlSumEmptyIsZeroAggFunction; +import org.eigenbase.sql.type.SqlTypeUtil; +import org.eigenbase.util.CompositeList; +import org.eigenbase.util.ImmutableIntList; +import org.eigenbase.util.Util; + +import com.google.common.collect.ImmutableList; + +/** + * Rule to reduce aggregates to simpler forms. Currently only AVG(x) to + * SUM(x)/COUNT(x), but eventually will handle others such as STDDEV. + */ +public class DrillReduceAggregatesRule extends RelOptRule { + //~ Static fields/initializers --------------------------------------------- + + /** + * The singleton. + */ + public static final DrillReduceAggregatesRule INSTANCE = + new DrillReduceAggregatesRule(operand(AggregateRel.class, any())); + + private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1); + + //~ Constructors ----------------------------------------------------------- + + protected DrillReduceAggregatesRule(RelOptRuleOperand operand) { + super(operand); + } + + //~ Methods ---------------------------------------------------------------- + + @Override + public boolean matches(RelOptRuleCall call) { + if (!super.matches(call)) { + return false; + } + AggregateRelBase oldAggRel = (AggregateRelBase) call.rels[0]; + return containsAvgStddevVarCall(oldAggRel.getAggCallList()); + } + + public void onMatch(RelOptRuleCall ruleCall) { + AggregateRelBase oldAggRel = (AggregateRelBase) ruleCall.rels[0]; + reduceAggs(ruleCall, oldAggRel); + } + + /** + * Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*. + * + * @param aggCallList List of aggregate calls + */ + private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) { + for (AggregateCall call : aggCallList) { + if (call.getAggregation() instanceof SqlAvgAggFunction + || call.getAggregation() instanceof SqlSumAggFunction) { + return true; + } + } + return false; + } + + /* + private boolean isMatch(AggregateCall call) { + if (call.getAggregation() instanceof SqlAvgAggFunction) { + final SqlAvgAggFunction.Subtype subtype = + ((SqlAvgAggFunction) call.getAggregation()).getSubtype(); + return (subtype == SqlAvgAggFunction.Subtype.AVG); + } + return false; + } + */ + + /** + * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in + * the aggregates list to. + * + * <p>It handles newly generated common subexpressions since this was done + * at the sql2rel stage. + */ + private void reduceAggs( + RelOptRuleCall ruleCall, + AggregateRelBase oldAggRel) { + RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + + List<AggregateCall> oldCalls = oldAggRel.getAggCallList(); + final int nGroups = oldAggRel.getGroupCount(); + + List<AggregateCall> newCalls = new ArrayList<AggregateCall>(); + Map<AggregateCall, RexNode> aggCallMapping = + new HashMap<AggregateCall, RexNode>(); + + List<RexNode> projList = new ArrayList<RexNode>(); + + // pass through group key + for (int i = 0; i < nGroups; ++i) { + projList.add( + rexBuilder.makeInputRef( + getFieldType(oldAggRel, i), + i)); + } + + // List of input expressions. If a particular aggregate needs more, it + // will add an expression to the end, and we will create an extra + // project. + RelNode input = oldAggRel.getChild(); + List<RexNode> inputExprs = new ArrayList<RexNode>(); + for (RelDataTypeField field : input.getRowType().getFieldList()) { + inputExprs.add( + rexBuilder.makeInputRef( + field.getType(), inputExprs.size())); + } + + // create new agg function calls and rest of project list together + for (AggregateCall oldCall : oldCalls) { + projList.add( + reduceAgg( + oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs)); + } + + final int extraArgCount = + inputExprs.size() - input.getRowType().getFieldCount(); + if (extraArgCount > 0) { + input = + CalcRel.createProject( + input, + inputExprs, + CompositeList.of( + input.getRowType().getFieldNames(), + Collections.<String>nCopies( + extraArgCount, + null))); + } + AggregateRelBase newAggRel = + newAggregateRel( + oldAggRel, input, newCalls); + + RelNode projectRel = + CalcRel.createProject( + newAggRel, + projList, + oldAggRel.getRowType().getFieldNames()); + + ruleCall.transformTo(projectRel); + } + + private RexNode reduceAgg( + AggregateRelBase oldAggRel, + AggregateCall oldCall, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping, + List<RexNode> inputExprs) { + if (oldCall.getAggregation() instanceof SqlSumAggFunction) { + // replace original SUM(x) with + // case COUNT(x) when 0 then null else SUM0(x) end + return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping); + } + if (oldCall.getAggregation() instanceof SqlAvgAggFunction) { + final SqlAvgAggFunction.Subtype subtype = + ((SqlAvgAggFunction) oldCall.getAggregation()).getSubtype(); + + switch (subtype) { + case AVG: + // replace original AVG(x) with SUM(x) / COUNT(x) + return reduceAvg( + oldAggRel, oldCall, newCalls, aggCallMapping); + case STDDEV_POP: + // replace original STDDEV_POP(x) with + // SQRT( + // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) + // / COUNT(x)) + return reduceStddev( + oldAggRel, oldCall, true, true, newCalls, aggCallMapping, + inputExprs); + case STDDEV_SAMP: + // replace original STDDEV_POP(x) with + // SQRT( + // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) + // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END) + return reduceStddev( + oldAggRel, oldCall, false, true, newCalls, aggCallMapping, + inputExprs); + case VAR_POP: + // replace original VAR_POP(x) with + // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) + // / COUNT(x) + return reduceStddev( + oldAggRel, oldCall, true, false, newCalls, aggCallMapping, + inputExprs); + case VAR_SAMP: + // replace original VAR_POP(x) with + // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) + // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END + return reduceStddev( + oldAggRel, oldCall, false, false, newCalls, aggCallMapping, + inputExprs); + default: + throw Util.unexpected(subtype); + } + } else { + // anything else: preserve original call + RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + final int nGroups = oldAggRel.getGroupCount(); + List<RelDataType> oldArgTypes = SqlTypeUtil + .projectTypes(oldAggRel.getRowType(), oldCall.getArgList()); + return rexBuilder.addAggCall( + oldCall, + nGroups, + newCalls, + aggCallMapping, + oldArgTypes); + } + } + + private RexNode reduceAvg( + AggregateRelBase oldAggRel, + AggregateCall oldCall, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping) { + final int nGroups = oldAggRel.getGroupCount(); + RelDataTypeFactory typeFactory = + oldAggRel.getCluster().getTypeFactory(); + RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + int iAvgInput = oldCall.getArgList().get(0); + RelDataType avgInputType = + getFieldType( + oldAggRel.getChild(), + iAvgInput); + RelDataType sumType = + typeFactory.createTypeWithNullability( + avgInputType, + avgInputType.isNullable() || nGroups == 0); + // SqlAggFunction sumAgg = new SqlSumAggFunction(sumType); + SqlAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction(sumType); + AggregateCall sumCall = + new AggregateCall( + sumAgg, + oldCall.isDistinct(), + oldCall.getArgList(), + sumType, + null); + SqlAggFunction countAgg = SqlStdOperatorTable.COUNT; + RelDataType countType = countAgg.getReturnType(typeFactory); + AggregateCall countCall = + new AggregateCall( + countAgg, + oldCall.isDistinct(), + oldCall.getArgList(), + countType, + null); + + RexNode tmpsumRef = + rexBuilder.addAggCall( + sumCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(avgInputType)); + + RexNode tmpcountRef = + rexBuilder.addAggCall( + countCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(avgInputType)); + + RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, + rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), + rexBuilder.constantNull(), + tmpsumRef); + + // NOTE: these references are with respect to the output + // of newAggRel + /* + RexNode numeratorRef = + rexBuilder.makeCall(CastHighOp, + rexBuilder.addAggCall( + sumCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(avgInputType)) + ); + */ + RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n); + + RexNode denominatorRef = + rexBuilder.addAggCall( + countCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(avgInputType)); + final RexNode divideRef = + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, + numeratorRef, + denominatorRef); + return rexBuilder.makeCast( + oldCall.getType(), divideRef); + } + + private RexNode reduceSum( + AggregateRelBase oldAggRel, + AggregateCall oldCall, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping) { + final int nGroups = oldAggRel.getGroupCount(); + RelDataTypeFactory typeFactory = + oldAggRel.getCluster().getTypeFactory(); + RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + int arg = oldCall.getArgList().get(0); + RelDataType argType = + getFieldType( + oldAggRel.getChild(), + arg); + RelDataType sumType = + typeFactory.createTypeWithNullability( + argType, argType.isNullable()); + SqlAggFunction sumZeroAgg = new SqlSumEmptyIsZeroAggFunction(sumType); + AggregateCall sumZeroCall = + new AggregateCall( + sumZeroAgg, + oldCall.isDistinct(), + oldCall.getArgList(), + sumType, + null); + SqlAggFunction countAgg = SqlStdOperatorTable.COUNT; + RelDataType countType = countAgg.getReturnType(typeFactory); + AggregateCall countCall = + new AggregateCall( + countAgg, + oldCall.isDistinct(), + oldCall.getArgList(), + countType, + null); + + // NOTE: these references are with respect to the output + // of newAggRel + RexNode sumZeroRef = + rexBuilder.addAggCall( + sumZeroCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(argType)); + if (!oldCall.getType().isNullable()) { + // If SUM(x) is not nullable, the validator must have determined that + // nulls are impossible (because the group is never empty and x is never + // null). Therefore we translate to SUM0(x). + return sumZeroRef; + } + RexNode countRef = + rexBuilder.addAggCall( + countCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(argType)); + return rexBuilder.makeCall(SqlStdOperatorTable.CASE, + rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, + countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), + rexBuilder.constantNull(), + sumZeroRef); + } + + private RexNode reduceStddev( + AggregateRelBase oldAggRel, + AggregateCall oldCall, + boolean biased, + boolean sqrt, + List<AggregateCall> newCalls, + Map<AggregateCall, RexNode> aggCallMapping, + List<RexNode> inputExprs) { + // stddev_pop(x) ==> + // power( + // (sum(x * x) - sum(x) * sum(x) / count(x)) + // / count(x), + // .5) + // + // stddev_samp(x) ==> + // power( + // (sum(x * x) - sum(x) * sum(x) / count(x)) + // / nullif(count(x) - 1, 0), + // .5) + final int nGroups = oldAggRel.getGroupCount(); + RelDataTypeFactory typeFactory = + oldAggRel.getCluster().getTypeFactory(); + final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); + + assert oldCall.getArgList().size() == 1 : oldCall.getArgList(); + final int argOrdinal = oldCall.getArgList().get(0); + final RelDataType argType = + getFieldType( + oldAggRel.getChild(), + argOrdinal); + + // final RexNode argRef = inputExprs.get(argOrdinal); + RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal)); + inputExprs.set(argOrdinal, argRef); + + final RexNode argSquared = + rexBuilder.makeCall( + SqlStdOperatorTable.MULTIPLY, argRef, argRef); + final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared); + + final RelDataType sumType = + typeFactory.createTypeWithNullability( + argType, + true); + final AggregateCall sumArgSquaredAggCall = + new AggregateCall( + new SqlSumAggFunction(sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argSquaredOrdinal), + sumType, + null); + final RexNode sumArgSquared = + rexBuilder.addAggCall( + sumArgSquaredAggCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(argType)); + + final AggregateCall sumArgAggCall = + new AggregateCall( + new SqlSumAggFunction(sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argOrdinal), + sumType, + null); + final RexNode sumArg = + rexBuilder.addAggCall( + sumArgAggCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(argType)); + + final RexNode sumSquaredArg = + rexBuilder.makeCall( + SqlStdOperatorTable.MULTIPLY, sumArg, sumArg); + + final SqlAggFunction countAgg = SqlStdOperatorTable.COUNT; + final RelDataType countType = countAgg.getReturnType(typeFactory); + final AggregateCall countArgAggCall = + new AggregateCall( + countAgg, + oldCall.isDistinct(), + oldCall.getArgList(), + countType, + null); + final RexNode countArg = + rexBuilder.addAggCall( + countArgAggCall, + nGroups, + newCalls, + aggCallMapping, + ImmutableList.of(argType)); + + final RexNode avgSumSquaredArg = + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, + sumSquaredArg, countArg); + + final RexNode diff = + rexBuilder.makeCall( + SqlStdOperatorTable.MINUS, + sumArgSquared, avgSumSquaredArg); + + final RexNode denominator; + if (biased) { + denominator = countArg; + } else { + final RexLiteral one = + rexBuilder.makeExactLiteral(BigDecimal.ONE); + final RexNode nul = + rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName()); + final RexNode countMinusOne = + rexBuilder.makeCall( + SqlStdOperatorTable.MINUS, countArg, one); + final RexNode countEqOne = + rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, countArg, one); + denominator = + rexBuilder.makeCall( + SqlStdOperatorTable.CASE, + countEqOne, nul, countMinusOne); + } + + final RexNode div = + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, diff, denominator); + + RexNode result = div; + if (sqrt) { + final RexNode half = + rexBuilder.makeExactLiteral(new BigDecimal("0.5")); + result = + rexBuilder.makeCall( + SqlStdOperatorTable.POWER, div, half); + } + + return rexBuilder.makeCast( + oldCall.getType(), result); + } + + /** + * Finds the ordinal of an element in a list, or adds it. + * + * @param list List + * @param element Element to lookup or add + * @param <T> Element type + * @return Ordinal of element in list + */ + private static <T> int lookupOrAdd(List<T> list, T element) { + int ordinal = list.indexOf(element); + if (ordinal == -1) { + ordinal = list.size(); + list.add(element); + } + return ordinal; + } + + /** + * Do a shallow clone of oldAggRel and update aggCalls. Could be refactored + * into AggregateRelBase and subclasses - but it's only needed for some + * subclasses. + * + * @param oldAggRel AggregateRel to clone. + * @param inputRel Input relational expression + * @param newCalls New list of AggregateCalls + * @return shallow clone with new list of AggregateCalls. + */ + protected AggregateRelBase newAggregateRel( + AggregateRelBase oldAggRel, + RelNode inputRel, + List<AggregateCall> newCalls) { + return new AggregateRel( + oldAggRel.getCluster(), + inputRel, + oldAggRel.getGroupSet(), + newCalls); + } + + private RelDataType getFieldType(RelNode relNode, int i) { + final RelDataTypeField inputField = + relNode.getRowType().getFieldList().get(i); + return inputField.getType(); + } + +} + diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java index cf92121be..65fa2d7d1 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillRuleSets.java @@ -46,7 +46,6 @@ import org.eigenbase.rel.rules.PushFilterPastProjectRule; import org.eigenbase.rel.rules.PushJoinThroughJoinRule; import org.eigenbase.rel.rules.PushProjectPastFilterRule; import org.eigenbase.rel.rules.PushProjectPastJoinRule; -import org.eigenbase.rel.rules.ReduceAggregatesRule; import org.eigenbase.rel.rules.RemoveDistinctAggregateRule; import org.eigenbase.rel.rules.RemoveDistinctRule; import org.eigenbase.rel.rules.RemoveSortRule; @@ -87,7 +86,7 @@ public class DrillRuleSets { //MergeProjectRule.INSTANCE, // DrillMergeProjectRule.getInstance(true, RelFactories.DEFAULT_PROJECT_FACTORY, context.getFunctionRegistry()), RemoveDistinctAggregateRule.INSTANCE, // - ReduceAggregatesRule.INSTANCE, // + // ReduceAggregatesRule.INSTANCE, // replaced by DrillReduceAggregatesRule PushProjectPastJoinRule.INSTANCE, // PushProjectPastFilterRule.INSTANCE, DrillPushProjectPastFilterRule.INSTANCE, @@ -108,6 +107,7 @@ public class DrillRuleSets { DrillSortRule.INSTANCE, DrillJoinRule.INSTANCE, DrillUnionRule.INSTANCE + ,DrillReduceAggregatesRule.INSTANCE )); } return DRILL_BASIC_RULES; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java index 3a164149a..3dedb552a 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java @@ -73,7 +73,7 @@ public abstract class AggPrelBase extends AggregateRelBase implements Prel{ private final RelDataType type; public SqlSumCountAggFunction(RelDataType type) { - super("SUM", + super("$SUM0", SqlKind.OTHER_FUNCTION, ReturnTypes.BIGINT, // use the inferred return type of SqlCountAggFunction null, @@ -185,7 +185,7 @@ public abstract class AggPrelBase extends AggregateRelBase implements Prel{ public <T, X, E extends Throwable> T accept(PrelVisitor<T, X, E> logicalVisitor, X value) throws E { return logicalVisitor.visitPrel(this, value); } - + @Override public boolean needsFinalColumnReordering() { return true; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java index 7b7e3b72e..a61bfd6fe 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java @@ -69,7 +69,8 @@ public abstract class AggPruleBase extends Prule { for (AggregateCall aggCall : aggregate.getAggCallList()) { String name = aggCall.getAggregation().getName(); - if ( ! (name.equals("SUM") || name.equals("MIN") || name.equals("MAX") || name.equals("COUNT"))) { + if ( ! (name.equals("SUM") || name.equals("MIN") || name.equals("MAX") || name.equals("COUNT") + || name.equals("$SUM0"))) { return false; } } |