diff options
author | Mehant Baid <mehantr@gmail.com> | 2014-07-25 19:39:30 -0700 |
---|---|---|
committer | Mehant Baid <mehantr@gmail.com> | 2014-07-26 02:15:15 -0700 |
commit | 846e291d5f6966a9fdf2b590c597b79b3205bb68 (patch) | |
tree | 72da38bd2d2cf28bd757181e9a11163bba951430 /exec/java-exec/src/main | |
parent | b6410ff846217ebffda23528bb46619248b1675a (diff) |
DRILL-1180: Add casts to case expression to ensure all branches have same output type
Diffstat (limited to 'exec/java-exec/src/main')
4 files changed, 132 insertions, 86 deletions
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java index 8bfa14959..73ce45e38 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/EvaluationVisitor.java @@ -136,32 +136,31 @@ public class EvaluationVisitor { JConditional jc = null; JBlock conditionalBlock = new JBlock(false, false); - for (IfCondition c : ifExpr.conditions) { - HoldingContainer holdingContainer = c.condition.accept(this, generator); - if (jc == null) { - if (holdingContainer.isOptional()) { - jc = conditionalBlock._if(holdingContainer.getIsSet().eq(JExpr.lit(1)).cand(holdingContainer.getValue().eq(JExpr.lit(1)))); - } else { - jc = conditionalBlock._if(holdingContainer.getValue().eq(JExpr.lit(1))); - } + IfCondition c = ifExpr.ifCondition; + + HoldingContainer holdingContainer = c.condition.accept(this, generator); + if (jc == null) { + if (holdingContainer.isOptional()) { + jc = conditionalBlock._if(holdingContainer.getIsSet().eq(JExpr.lit(1)).cand(holdingContainer.getValue().eq(JExpr.lit(1)))); } else { - if (holdingContainer.isOptional()) { - jc = jc._else()._if(holdingContainer.getIsSet().eq(JExpr.lit(1)).cand(holdingContainer.getValue().eq(JExpr.lit(1)))); - } else { - jc = jc._else()._if(holdingContainer.getValue().eq(JExpr.lit(1))); - } + jc = conditionalBlock._if(holdingContainer.getValue().eq(JExpr.lit(1))); } - - HoldingContainer thenExpr = c.expression.accept(this, generator); - if (thenExpr.isOptional()) { - JConditional newCond = jc._then()._if(thenExpr.getIsSet().ne(JExpr.lit(0))); - JBlock b = newCond._then(); - b.assign(output.getHolder(), thenExpr.getHolder()); - //b.assign(output.getIsSet(), thenExpr.getIsSet()); + } else { + if (holdingContainer.isOptional()) { + jc = jc._else()._if(holdingContainer.getIsSet().eq(JExpr.lit(1)).cand(holdingContainer.getValue().eq(JExpr.lit(1)))); } else { - jc._then().assign(output.getHolder(), thenExpr.getHolder()); + jc = jc._else()._if(holdingContainer.getValue().eq(JExpr.lit(1))); } + } + HoldingContainer thenExpr = c.expression.accept(this, generator); + if (thenExpr.isOptional()) { + JConditional newCond = jc._then()._if(thenExpr.getIsSet().ne(JExpr.lit(0))); + JBlock b = newCond._then(); + b.assign(output.getHolder(), thenExpr.getHolder()); + //b.assign(output.getIsSet(), thenExpr.getIsSet()); + } else { + jc._then().assign(output.getHolder(), thenExpr.getHolder()); } HoldingContainer elseExpr = ifExpr.elseExpression.accept(this, generator); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java index 18cd894a8..60bc78c94 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/ExpressionTreeMaterializer.java @@ -17,6 +17,9 @@ */ package org.apache.drill.exec.expr; +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import com.google.common.base.Function; @@ -70,10 +73,12 @@ import org.apache.drill.exec.expr.fn.DrillFuncHolder; import org.apache.drill.exec.expr.fn.FunctionImplementationRegistry; import org.apache.drill.exec.record.TypedFieldId; import org.apache.drill.exec.record.VectorAccessible; +import org.apache.drill.exec.resolver.DefaultFunctionResolver; import org.apache.drill.exec.resolver.FunctionResolver; import org.apache.drill.exec.resolver.FunctionResolverFactory; import com.google.common.collect.Lists; +import org.apache.drill.exec.resolver.TypeCastRules; public class ExpressionTreeMaterializer { @@ -142,7 +147,38 @@ public class ExpressionTreeMaterializer { //replace with a new function call, since its argument could be changed. return new BooleanOperator(op.getName(), args, op.getPosition()); } - + + private LogicalExpression addCastExpression(LogicalExpression fromExpr, MajorType toType, FunctionImplementationRegistry registry) { + String castFuncName = CastFunctions.getCastFunc(toType.getMinorType()); + List<LogicalExpression> castArgs = Lists.newArrayList(); + castArgs.add(fromExpr); //input_expr + + if (!Types.isFixedWidthType(toType)) { + + /* We are implicitly casting to VARCHAR so we don't have a max length, + * using an arbitrary value. We trim down the size of the stored bytes + * to the actual size so this size doesn't really matter. + */ + castArgs.add(new ValueExpressions.LongExpression(65536, null)); + } + else if (toType.getMinorType().name().startsWith("DECIMAL")) { + // Add the scale and precision to the arguments of the implicit cast + castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getPrecision(), null)); + castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getScale(), null)); + } + + FunctionCall castCall = new FunctionCall(castFuncName, castArgs, ExpressionPosition.UNKNOWN); + FunctionResolver resolver = FunctionResolverFactory.getResolver(castCall); + DrillFuncHolder matchedCastFuncHolder = registry.findDrillFunction(resolver, castCall); + + if (matchedCastFuncHolder == null) { + logFunctionResolutionError(errorCollector, castCall); + return NullExpression.INSTANCE; + } + + return matchedCastFuncHolder.getExpr(castFuncName, castArgs, ExpressionPosition.UNKNOWN); + + } @Override public LogicalExpression visitFunctionCall(FunctionCall call, FunctionImplementationRegistry registry) { List<LogicalExpression> args = Lists.newArrayList(); @@ -186,34 +222,7 @@ public class ExpressionTreeMaterializer { argsWithCast.add(currentArg); } else { //Case 3: insert cast if param type is different from arg type. - String castFuncName = CastFunctions.getCastFunc(parmType.getMinorType()); - List<LogicalExpression> castArgs = Lists.newArrayList(); - castArgs.add(call.args.get(i)); //input_expr - - if (!Types.isFixedWidthType(parmType)) { - - /* We are implicitly casting to VARCHAR so we don't have a max length, - * using an arbitrary value. We trim down the size of the stored bytes - * to the actual size so this size doesn't really matter. - */ - castArgs.add(new ValueExpressions.LongExpression(65536, null)); - } - else if (parmType.getMinorType().name().startsWith("DECIMAL")) { - // Add the scale and precision to the arguments of the implicit cast - castArgs.add(new ValueExpressions.LongExpression(currentArg.getMajorType().getPrecision(), null)); - castArgs.add(new ValueExpressions.LongExpression(currentArg.getMajorType().getScale(), null)); - } - - FunctionCall castCall = new FunctionCall(castFuncName, castArgs, ExpressionPosition.UNKNOWN); - DrillFuncHolder matchedCastFuncHolder = registry.findDrillFunction(resolver, castCall); - - if (matchedCastFuncHolder == null) { - logFunctionResolutionError(errorCollector, castCall); - return NullExpression.INSTANCE; - } - - argsWithCast.add(matchedCastFuncHolder.getExpr(call.getName(), castArgs, ExpressionPosition.UNKNOWN)); - + argsWithCast.add(addCastExpression(call.args.get(i), parmType, registry)); } } @@ -255,31 +264,40 @@ public class ExpressionTreeMaterializer { errorCollector.addGeneralError(call.getPosition(), sb.toString()); } + @Override public LogicalExpression visitIfExpression(IfExpression ifExpr, FunctionImplementationRegistry registry) { - List<IfExpression.IfCondition> conditions = Lists.newArrayList(ifExpr.conditions); + IfExpression.IfCondition conditions = ifExpr.ifCondition; LogicalExpression newElseExpr = ifExpr.elseExpression.accept(this, registry); - for (int i = 0; i < conditions.size(); ++i) { - IfExpression.IfCondition condition = conditions.get(i); + LogicalExpression newCondition = conditions.condition.accept(this, registry); + LogicalExpression newExpr = conditions.expression.accept(this, registry); + conditions = new IfExpression.IfCondition(newCondition, newExpr); - LogicalExpression newCondition = condition.condition.accept(this, registry); - LogicalExpression newExpr = condition.expression.accept(this, registry); - conditions.set(i, new IfExpression.IfCondition(newCondition, newExpr)); + MinorType thenType = conditions.expression.getMajorType().getMinorType(); + MinorType elseType = newElseExpr.getMajorType().getMinorType(); + + // Check if we need a cast + if (thenType != elseType && !(thenType == MinorType.NULL || elseType == MinorType.NULL)) { + + MinorType leastRestrictive = TypeCastRules.getLeastRestrictiveType((Arrays.asList(thenType, elseType))); + if (leastRestrictive != thenType) { + // Implicitly cast the then expression + conditions = new IfExpression.IfCondition(newCondition, + addCastExpression(conditions.expression, newElseExpr.getMajorType(), registry)); + } else if (leastRestrictive != elseType) { + // Implicitly cast the else expression + newElseExpr = addCastExpression(newElseExpr, conditions.expression.getMajorType(), registry); + } else { + assert false: "Incorrect least restrictive type computed, leastRestrictive: " + + leastRestrictive.toString() + " thenType: " + thenType.toString() + " elseType: " + elseType; + } } // Resolve NullExpression into TypedNullConstant by visiting all conditions // We need to do this because we want to give the correct MajorType to the Null constant - Iterable<LogicalExpression> logicalExprs = Iterables.transform(conditions, - new Function<IfCondition, LogicalExpression>() { - @Override - public LogicalExpression apply(IfExpression.IfCondition input) { - return input.expression; - } - } - ); - - List<LogicalExpression> allExpressions = Lists.newArrayList(logicalExprs); + List<LogicalExpression> allExpressions = Lists.newArrayList(); + allExpressions.add(conditions.expression); allExpressions.add(newElseExpr); boolean containsNullExpr = Iterables.any(allExpressions, new Predicate<LogicalExpression>() { @@ -301,12 +319,7 @@ public class ExpressionTreeMaterializer { if(nonNullExpr.isPresent()) { MajorType type = nonNullExpr.get().getMajorType(); - for (int i = 0; i < conditions.size(); ++i) { - IfExpression.IfCondition condition = conditions.get(i); - conditions.set(i, - new IfExpression.IfCondition(condition.condition, rewriteNullExpression(condition.expression, type)) - ); - } + conditions = new IfExpression.IfCondition(conditions.condition, rewriteNullExpression(conditions.expression, type)); newElseExpr = rewriteNullExpression(newElseExpr, type); } @@ -314,16 +327,13 @@ public class ExpressionTreeMaterializer { // If the type of the IF expression is nullable, apply a convertToNullable*Holder function for "THEN"/"ELSE" // expressions whose type is not nullable. - if (IfExpression.newBuilder().setElse(newElseExpr).addConditions(conditions).build().getMajorType().getMode() + if (IfExpression.newBuilder().setElse(newElseExpr).setIfCondition(conditions).build().getMajorType().getMode() == DataMode.OPTIONAL) { - for (int i = 0; i < conditions.size(); ++i) { - IfExpression.IfCondition condition = conditions.get(i); + IfExpression.IfCondition condition = conditions; if (condition.expression.getMajorType().getMode() != DataMode.OPTIONAL) { - conditions.set(i, new IfExpression.IfCondition(condition.condition, - getConvertToNullableExpr(ImmutableList.of(condition.expression), - condition.expression.getMajorType().getMinorType(), registry))); - } - } + conditions = new IfExpression.IfCondition(condition.condition, getConvertToNullableExpr(ImmutableList.of(condition.expression), + condition.expression.getMajorType().getMinorType(), registry)); + } if (newElseExpr.getMajorType().getMode() != DataMode.OPTIONAL) { newElseExpr = getConvertToNullableExpr(ImmutableList.of(newElseExpr), @@ -331,7 +341,7 @@ public class ExpressionTreeMaterializer { } } - return validateNewExpr(IfExpression.newBuilder().setElse(newElseExpr).addConditions(conditions).build()); + return validateNewExpr(IfExpression.newBuilder().setElse(newElseExpr).setIfCondition(conditions).build()); } private LogicalExpression getConvertToNullableExpr(List<LogicalExpression> args, MinorType minorType, diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java index 75a5ebc42..633084f03 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java @@ -157,7 +157,7 @@ public class DrillOptiq { for (int i=1; i<caseArgs.size(); i=i+2) { elseExpression = IfExpression.newBuilder() .setElse(elseExpression) - .addCondition(new IfCondition(caseArgs.get(i + 1), caseArgs.get(i))).build(); + .setIfCondition(new IfCondition(caseArgs.get(i + 1), caseArgs.get(i))).build(); } return elseExpression; } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java index bf202c866..817656784 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java @@ -20,9 +20,11 @@ package org.apache.drill.exec.resolver; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; +import org.apache.drill.common.exceptions.DrillRuntimeException; import org.apache.drill.common.expression.FunctionCall; import org.apache.drill.common.types.Types; import org.apache.drill.common.types.TypeProtos.DataMode; @@ -768,11 +770,46 @@ public class TypeCastRules { rules.put(MinorType.VARBINARY, rule); } - public static boolean isCastable(MajorType from, MajorType to, NullHandling nullHandling) { + public static boolean isCastableWithNullHandling(MajorType from, MajorType to, NullHandling nullHandling) { if (nullHandling == NullHandling.INTERNAL && from.getMode() != to.getMode()) return false; + return isCastable(from.getMinorType(), to.getMinorType()); + } + + private static boolean isCastable(MinorType from, MinorType to) { + return from.equals(MinorType.NULL) || //null could be casted to any other type. + (rules.get(to) == null ? false : rules.get(to).contains(from)); + } + + /* + * Function checks if casting is allowed from the 'from' -> 'to' minor type. If its allowed + * we also check if the precedence map allows such a cast and return true if both cases are satisfied + */ + public static MinorType getLeastRestrictiveType(List<MinorType> types) { + assert types.size() >= 2; + MinorType result = types.get(0); + int resultPrec = ResolverTypePrecedence.precedenceMap.get(result); + + for (int i = 1; i < types.size(); i++) { + MinorType next = types.get(i); + if (next == result) { + // both args are of the same type; continue + continue; + } + + int nextPrec = ResolverTypePrecedence.precedenceMap.get(next); + + if (isCastable(next, result) && resultPrec >= nextPrec) { + // result is the least restrictive between the two args; nothing to do continue + continue; + } else if(isCastable(result, next) && nextPrec >= resultPrec) { + result = next; + resultPrec = nextPrec; + } else { + throw new DrillRuntimeException("Case expression branches have different output types "); + } + } - return from.getMinorType().equals(MinorType.NULL) || //null could be casted to any other type. - (rules.get(to.getMinorType()) == null ? false : rules.get(to.getMinorType()).contains(from.getMinorType())); + return result; } private static final int DATAMODE_CAST_COST = 1; @@ -817,7 +854,7 @@ public class TypeCastRules { // return -1; } - if (!TypeCastRules.isCastable(argType, parmType, holder.getNullHandling())) { + if (!TypeCastRules.isCastableWithNullHandling(argType, parmType, holder.getNullHandling())) { return -1; } |