aboutsummaryrefslogtreecommitdiff
path: root/exec/java-exec/src/main
diff options
context:
space:
mode:
authorMehant Baid <mehantr@gmail.com>2014-07-25 19:39:30 -0700
committerMehant Baid <mehantr@gmail.com>2014-07-26 02:15:15 -0700
commit846e291d5f6966a9fdf2b590c597b79b3205bb68 (patch)
tree72da38bd2d2cf28bd757181e9a11163bba951430 /exec/java-exec/src/main
parentb6410ff846217ebffda23528bb46619248b1675a (diff)
DRILL-1180: Add casts to case expression to ensure all branches have same output type
Diffstat (limited to 'exec/java-exec/src/main')
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java41
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java130
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java2
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java45
4 files changed, 132 insertions, 86 deletions
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java
index 8bfa14959..73ce45e38 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java
@@ -136,32 +136,31 @@ public class EvaluationVisitor {
JConditional jc = null;
JBlock conditionalBlock = new JBlock(false, false);
- for (IfCondition c : ifExpr.conditions) {
- HoldingContainer holdingContainer = c.condition.accept(this, generator);
- if (jc == null) {
- if (holdingContainer.isOptional()) {
- jc = conditionalBlock._if(holdingContainer.getIsSet().eq(JExpr.lit(1)).cand(holdingContainer.getValue().eq(JExpr.lit(1))));
- } else {
- jc = conditionalBlock._if(holdingContainer.getValue().eq(JExpr.lit(1)));
- }
+ IfCondition c = ifExpr.ifCondition;
+
+ HoldingContainer holdingContainer = c.condition.accept(this, generator);
+ if (jc == null) {
+ if (holdingContainer.isOptional()) {
+ jc = conditionalBlock._if(holdingContainer.getIsSet().eq(JExpr.lit(1)).cand(holdingContainer.getValue().eq(JExpr.lit(1))));
} else {
- if (holdingContainer.isOptional()) {
- jc = jc._else()._if(holdingContainer.getIsSet().eq(JExpr.lit(1)).cand(holdingContainer.getValue().eq(JExpr.lit(1))));
- } else {
- jc = jc._else()._if(holdingContainer.getValue().eq(JExpr.lit(1)));
- }
+ jc = conditionalBlock._if(holdingContainer.getValue().eq(JExpr.lit(1)));
}
-
- HoldingContainer thenExpr = c.expression.accept(this, generator);
- if (thenExpr.isOptional()) {
- JConditional newCond = jc._then()._if(thenExpr.getIsSet().ne(JExpr.lit(0)));
- JBlock b = newCond._then();
- b.assign(output.getHolder(), thenExpr.getHolder());
- //b.assign(output.getIsSet(), thenExpr.getIsSet());
+ } else {
+ if (holdingContainer.isOptional()) {
+ jc = jc._else()._if(holdingContainer.getIsSet().eq(JExpr.lit(1)).cand(holdingContainer.getValue().eq(JExpr.lit(1))));
} else {
- jc._then().assign(output.getHolder(), thenExpr.getHolder());
+ jc = jc._else()._if(holdingContainer.getValue().eq(JExpr.lit(1)));
}
+ }
+ HoldingContainer thenExpr = c.expression.accept(this, generator);
+ if (thenExpr.isOptional()) {
+ JConditional newCond = jc._then()._if(thenExpr.getIsSet().ne(JExpr.lit(0)));
+ JBlock b = newCond._then();
+ b.assign(output.getHolder(), thenExpr.getHolder());
+ //b.assign(output.getIsSet(), thenExpr.getIsSet());
+ } else {
+ jc._then().assign(output.getHolder(), thenExpr.getHolder());
}
HoldingContainer elseExpr = ifExpr.elseExpression.accept(this, generator);
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java
index 18cd894a8..60bc78c94 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java
@@ -17,6 +17,9 @@
*/
package org.apache.drill.exec.expr;
+import java.lang.reflect.Array;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import com.google.common.base.Function;
@@ -70,10 +73,12 @@ import org.apache.drill.exec.expr.fn.DrillFuncHolder;
import org.apache.drill.exec.expr.fn.FunctionImplementationRegistry;
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorAccessible;
+import org.apache.drill.exec.resolver.DefaultFunctionResolver;
import org.apache.drill.exec.resolver.FunctionResolver;
import org.apache.drill.exec.resolver.FunctionResolverFactory;
import com.google.common.collect.Lists;
+import org.apache.drill.exec.resolver.TypeCastRules;
public class ExpressionTreeMaterializer {
@@ -142,7 +147,38 @@ public class ExpressionTreeMaterializer {
//replace with a new function call, since its argument could be changed.
return new BooleanOperator(op.getName(), args, op.getPosition());
}
-
+
+ private LogicalExpression addCastExpression(LogicalExpression fromExpr, MajorType toType, FunctionImplementationRegistry registry) {
+ String castFuncName = CastFunctions.getCastFunc(toType.getMinorType());
+ List<LogicalExpression> castArgs = Lists.newArrayList();
+ castArgs.add(fromExpr); //input_expr
+
+ if (!Types.isFixedWidthType(toType)) {
+
+ /* We are implicitly casting to VARCHAR so we don't have a max length,
+ * using an arbitrary value. We trim down the size of the stored bytes
+ * to the actual size so this size doesn't really matter.
+ */
+ castArgs.add(new ValueExpressions.LongExpression(65536, null));
+ }
+ else if (toType.getMinorType().name().startsWith("DECIMAL")) {
+ // Add the scale and precision to the arguments of the implicit cast
+ castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getPrecision(), null));
+ castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getScale(), null));
+ }
+
+ FunctionCall castCall = new FunctionCall(castFuncName, castArgs, ExpressionPosition.UNKNOWN);
+ FunctionResolver resolver = FunctionResolverFactory.getResolver(castCall);
+ DrillFuncHolder matchedCastFuncHolder = registry.findDrillFunction(resolver, castCall);
+
+ if (matchedCastFuncHolder == null) {
+ logFunctionResolutionError(errorCollector, castCall);
+ return NullExpression.INSTANCE;
+ }
+
+ return matchedCastFuncHolder.getExpr(castFuncName, castArgs, ExpressionPosition.UNKNOWN);
+
+ }
@Override
public LogicalExpression visitFunctionCall(FunctionCall call, FunctionImplementationRegistry registry) {
List<LogicalExpression> args = Lists.newArrayList();
@@ -186,34 +222,7 @@ public class ExpressionTreeMaterializer {
argsWithCast.add(currentArg);
} else {
//Case 3: insert cast if param type is different from arg type.
- String castFuncName = CastFunctions.getCastFunc(parmType.getMinorType());
- List<LogicalExpression> castArgs = Lists.newArrayList();
- castArgs.add(call.args.get(i)); //input_expr
-
- if (!Types.isFixedWidthType(parmType)) {
-
- /* We are implicitly casting to VARCHAR so we don't have a max length,
- * using an arbitrary value. We trim down the size of the stored bytes
- * to the actual size so this size doesn't really matter.
- */
- castArgs.add(new ValueExpressions.LongExpression(65536, null));
- }
- else if (parmType.getMinorType().name().startsWith("DECIMAL")) {
- // Add the scale and precision to the arguments of the implicit cast
- castArgs.add(new ValueExpressions.LongExpression(currentArg.getMajorType().getPrecision(), null));
- castArgs.add(new ValueExpressions.LongExpression(currentArg.getMajorType().getScale(), null));
- }
-
- FunctionCall castCall = new FunctionCall(castFuncName, castArgs, ExpressionPosition.UNKNOWN);
- DrillFuncHolder matchedCastFuncHolder = registry.findDrillFunction(resolver, castCall);
-
- if (matchedCastFuncHolder == null) {
- logFunctionResolutionError(errorCollector, castCall);
- return NullExpression.INSTANCE;
- }
-
- argsWithCast.add(matchedCastFuncHolder.getExpr(call.getName(), castArgs, ExpressionPosition.UNKNOWN));
-
+ argsWithCast.add(addCastExpression(call.args.get(i), parmType, registry));
}
}
@@ -255,31 +264,40 @@ public class ExpressionTreeMaterializer {
errorCollector.addGeneralError(call.getPosition(), sb.toString());
}
+
@Override
public LogicalExpression visitIfExpression(IfExpression ifExpr, FunctionImplementationRegistry registry) {
- List<IfExpression.IfCondition> conditions = Lists.newArrayList(ifExpr.conditions);
+ IfExpression.IfCondition conditions = ifExpr.ifCondition;
LogicalExpression newElseExpr = ifExpr.elseExpression.accept(this, registry);
- for (int i = 0; i < conditions.size(); ++i) {
- IfExpression.IfCondition condition = conditions.get(i);
+ LogicalExpression newCondition = conditions.condition.accept(this, registry);
+ LogicalExpression newExpr = conditions.expression.accept(this, registry);
+ conditions = new IfExpression.IfCondition(newCondition, newExpr);
- LogicalExpression newCondition = condition.condition.accept(this, registry);
- LogicalExpression newExpr = condition.expression.accept(this, registry);
- conditions.set(i, new IfExpression.IfCondition(newCondition, newExpr));
+ MinorType thenType = conditions.expression.getMajorType().getMinorType();
+ MinorType elseType = newElseExpr.getMajorType().getMinorType();
+
+ // Check if we need a cast
+ if (thenType != elseType && !(thenType == MinorType.NULL || elseType == MinorType.NULL)) {
+
+ MinorType leastRestrictive = TypeCastRules.getLeastRestrictiveType((Arrays.asList(thenType, elseType)));
+ if (leastRestrictive != thenType) {
+ // Implicitly cast the then expression
+ conditions = new IfExpression.IfCondition(newCondition,
+ addCastExpression(conditions.expression, newElseExpr.getMajorType(), registry));
+ } else if (leastRestrictive != elseType) {
+ // Implicitly cast the else expression
+ newElseExpr = addCastExpression(newElseExpr, conditions.expression.getMajorType(), registry);
+ } else {
+ assert false: "Incorrect least restrictive type computed, leastRestrictive: " +
+ leastRestrictive.toString() + " thenType: " + thenType.toString() + " elseType: " + elseType;
+ }
}
// Resolve NullExpression into TypedNullConstant by visiting all conditions
// We need to do this because we want to give the correct MajorType to the Null constant
- Iterable<LogicalExpression> logicalExprs = Iterables.transform(conditions,
- new Function<IfCondition, LogicalExpression>() {
- @Override
- public LogicalExpression apply(IfExpression.IfCondition input) {
- return input.expression;
- }
- }
- );
-
- List<LogicalExpression> allExpressions = Lists.newArrayList(logicalExprs);
+ List<LogicalExpression> allExpressions = Lists.newArrayList();
+ allExpressions.add(conditions.expression);
allExpressions.add(newElseExpr);
boolean containsNullExpr = Iterables.any(allExpressions, new Predicate<LogicalExpression>() {
@@ -301,12 +319,7 @@ public class ExpressionTreeMaterializer {
if(nonNullExpr.isPresent()) {
MajorType type = nonNullExpr.get().getMajorType();
- for (int i = 0; i < conditions.size(); ++i) {
- IfExpression.IfCondition condition = conditions.get(i);
- conditions.set(i,
- new IfExpression.IfCondition(condition.condition, rewriteNullExpression(condition.expression, type))
- );
- }
+ conditions = new IfExpression.IfCondition(conditions.condition, rewriteNullExpression(conditions.expression, type));
newElseExpr = rewriteNullExpression(newElseExpr, type);
}
@@ -314,16 +327,13 @@ public class ExpressionTreeMaterializer {
// If the type of the IF expression is nullable, apply a convertToNullable*Holder function for "THEN"/"ELSE"
// expressions whose type is not nullable.
- if (IfExpression.newBuilder().setElse(newElseExpr).addConditions(conditions).build().getMajorType().getMode()
+ if (IfExpression.newBuilder().setElse(newElseExpr).setIfCondition(conditions).build().getMajorType().getMode()
== DataMode.OPTIONAL) {
- for (int i = 0; i < conditions.size(); ++i) {
- IfExpression.IfCondition condition = conditions.get(i);
+ IfExpression.IfCondition condition = conditions;
if (condition.expression.getMajorType().getMode() != DataMode.OPTIONAL) {
- conditions.set(i, new IfExpression.IfCondition(condition.condition,
- getConvertToNullableExpr(ImmutableList.of(condition.expression),
- condition.expression.getMajorType().getMinorType(), registry)));
- }
- }
+ conditions = new IfExpression.IfCondition(condition.condition, getConvertToNullableExpr(ImmutableList.of(condition.expression),
+ condition.expression.getMajorType().getMinorType(), registry));
+ }
if (newElseExpr.getMajorType().getMode() != DataMode.OPTIONAL) {
newElseExpr = getConvertToNullableExpr(ImmutableList.of(newElseExpr),
@@ -331,7 +341,7 @@ public class ExpressionTreeMaterializer {
}
}
- return validateNewExpr(IfExpression.newBuilder().setElse(newElseExpr).addConditions(conditions).build());
+ return validateNewExpr(IfExpression.newBuilder().setElse(newElseExpr).setIfCondition(conditions).build());
}
private LogicalExpression getConvertToNullableExpr(List<LogicalExpression> args, MinorType minorType,
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java
index 75a5ebc42..633084f03 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java
@@ -157,7 +157,7 @@ public class DrillOptiq {
for (int i=1; i<caseArgs.size(); i=i+2) {
elseExpression = IfExpression.newBuilder()
.setElse(elseExpression)
- .addCondition(new IfCondition(caseArgs.get(i + 1), caseArgs.get(i))).build();
+ .setIfCondition(new IfCondition(caseArgs.get(i + 1), caseArgs.get(i))).build();
}
return elseExpression;
}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java
index bf202c866..817656784 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java
@@ -20,9 +20,11 @@ package org.apache.drill.exec.resolver;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.List;
import java.util.Map;
import java.util.Set;
+import org.apache.drill.common.exceptions.DrillRuntimeException;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.types.Types;
import org.apache.drill.common.types.TypeProtos.DataMode;
@@ -768,11 +770,46 @@ public class TypeCastRules {
rules.put(MinorType.VARBINARY, rule);
}
- public static boolean isCastable(MajorType from, MajorType to, NullHandling nullHandling) {
+ public static boolean isCastableWithNullHandling(MajorType from, MajorType to, NullHandling nullHandling) {
if (nullHandling == NullHandling.INTERNAL && from.getMode() != to.getMode()) return false;
+ return isCastable(from.getMinorType(), to.getMinorType());
+ }
+
+ private static boolean isCastable(MinorType from, MinorType to) {
+ return from.equals(MinorType.NULL) || //null could be casted to any other type.
+ (rules.get(to) == null ? false : rules.get(to).contains(from));
+ }
+
+ /*
+ * Function checks if casting is allowed from the 'from' -> 'to' minor type. If its allowed
+ * we also check if the precedence map allows such a cast and return true if both cases are satisfied
+ */
+ public static MinorType getLeastRestrictiveType(List<MinorType> types) {
+ assert types.size() >= 2;
+ MinorType result = types.get(0);
+ int resultPrec = ResolverTypePrecedence.precedenceMap.get(result);
+
+ for (int i = 1; i < types.size(); i++) {
+ MinorType next = types.get(i);
+ if (next == result) {
+ // both args are of the same type; continue
+ continue;
+ }
+
+ int nextPrec = ResolverTypePrecedence.precedenceMap.get(next);
+
+ if (isCastable(next, result) && resultPrec >= nextPrec) {
+ // result is the least restrictive between the two args; nothing to do continue
+ continue;
+ } else if(isCastable(result, next) && nextPrec >= resultPrec) {
+ result = next;
+ resultPrec = nextPrec;
+ } else {
+ throw new DrillRuntimeException("Case expression branches have different output types ");
+ }
+ }
- return from.getMinorType().equals(MinorType.NULL) || //null could be casted to any other type.
- (rules.get(to.getMinorType()) == null ? false : rules.get(to.getMinorType()).contains(from.getMinorType()));
+ return result;
}
private static final int DATAMODE_CAST_COST = 1;
@@ -817,7 +854,7 @@ public class TypeCastRules {
// return -1;
}
- if (!TypeCastRules.isCastable(argType, parmType, holder.getNullHandling())) {
+ if (!TypeCastRules.isCastableWithNullHandling(argType, parmType, holder.getNullHandling())) {
return -1;
}