diff options
author | Volodymyr Vysotskyi <vvovyk@gmail.com> | 2019-03-09 23:54:06 +0200 |
---|---|---|
committer | Sorabh Hamirwasia <sorabh@apache.org> | 2019-03-14 15:36:11 -0700 |
commit | ad0418ff1e7b8fd3dd1a5ba4b0818bc3e6d3f041 (patch) | |
tree | fade31ff3d2e913a588dc357f2fdaf296d8137ed | |
parent | a99db5fe53084a18a83e589ae529ffe5edbcc0a8 (diff) |
DRILL-6524: Prevent incorrect scalar replacement for the case of assigning references inside if block
7 files changed, 391 insertions, 44 deletions
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/MethodAnalyzer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/MethodAnalyzer.java index 78cf16fc1..ae900cf71 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/MethodAnalyzer.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/MethodAnalyzer.java @@ -17,16 +17,26 @@ */ package org.apache.drill.exec.compile.bytecode; +import org.objectweb.asm.tree.AbstractInsnNode; +import org.objectweb.asm.tree.InsnList; +import org.objectweb.asm.tree.LabelNode; +import org.objectweb.asm.tree.MethodNode; import org.objectweb.asm.tree.analysis.Analyzer; +import org.objectweb.asm.tree.analysis.AnalyzerException; import org.objectweb.asm.tree.analysis.Frame; import org.objectweb.asm.tree.analysis.Interpreter; import org.objectweb.asm.tree.analysis.Value; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.HashSet; +import java.util.Set; + /** * Analyzer that allows us to inject additional functionality into ASMs basic analysis. * * <p>We need to be able to keep track of local variables that are assigned to each other - * so that we can infer their replacability (for scalar replacement). In order to do that, + * so that we can infer their replaceability (for scalar replacement). In order to do that, * we need to know when local variables are assigned (with the old value being overwritten) * so that we can associate them with the new value, and hence determine whether they can * also be replaced, or not. @@ -36,22 +46,71 @@ import org.objectweb.asm.tree.analysis.Value; * as factories that will provide our own derivative of Frame<> which we use to detect */ public class MethodAnalyzer<V extends Value> extends Analyzer <V> { + + // list of method instructions which is analyzed + private InsnList insnList; + + public MethodAnalyzer(Interpreter<V> interpreter) { + super(interpreter); + } + + @Override + protected Frame<V> newFrame(int maxLocals, int maxStack) { + return new AssignmentTrackingFrame<>(maxLocals, maxStack); + } + + @Override + protected Frame<V> newFrame(Frame<? extends V> src) { + return new AssignmentTrackingFrame<>(src); + } + + @Override + protected void newControlFlowEdge(int insnIndex, int successorIndex) { + AssignmentTrackingFrame<V> oldFrame = (AssignmentTrackingFrame<V>) getFrames()[insnIndex]; + AbstractInsnNode insn = insnList.get(insnIndex); + if (insn.getType() == AbstractInsnNode.LABEL) { + // checks whether current label corresponds to the end of conditional block to restore previous + // local variables set + if (insn.equals(oldFrame.labelsStack.peekFirst())) { + oldFrame.localVariablesSet.pop(); + oldFrame.labelsStack.pop(); + } + } + } + + @Override + public Frame<V>[] analyze(String owner, MethodNode method) throws AnalyzerException { + insnList = method.instructions; + return super.analyze(owner, method); + } + /** * Custom Frame<> that captures setLocal() calls in order to associate values - * that are assigned to the same local variable slot. + * that are assigned to the same local variable slot. Also it controls stack to determine whether + * object was assigned to the value declared outside of conditional block. * * <p>Since this is almost a pass-through, the constructors' arguments match * those from Frame<>. */ private static class AssignmentTrackingFrame<V extends Value> extends Frame<V> { + + // represents stack of variable sets declared inside current code block + private final Deque<Set<Integer>> localVariablesSet; + + // stack of LabelNode instances which correspond to the end of conditional block + private final Deque<LabelNode> labelsStack; + /** * Constructor. * * @param nLocals the number of locals the frame should have * @param nStack the maximum size of the stack the frame should have */ - public AssignmentTrackingFrame(final int nLocals, final int nStack) { + public AssignmentTrackingFrame(int nLocals, int nStack) { super(nLocals, nStack); + localVariablesSet = new ArrayDeque<>(); + localVariablesSet.push(new HashSet<>()); + labelsStack = new ArrayDeque<>(); } /** @@ -59,47 +118,70 @@ public class MethodAnalyzer<V extends Value> extends Analyzer <V> { * * @param src the frame being copied */ - public AssignmentTrackingFrame(final Frame<? extends V> src) { + @SuppressWarnings("unchecked") + public AssignmentTrackingFrame(Frame<? extends V> src) { super(src); + AssignmentTrackingFrame trackingFrame = (AssignmentTrackingFrame) src; + localVariablesSet = new ArrayDeque<>(); + for (Set<Integer> integers : (Deque<Set<Integer>>) trackingFrame.localVariablesSet) { + localVariablesSet.addFirst(new HashSet<>(integers)); + } + labelsStack = new ArrayDeque<>(trackingFrame.labelsStack); } @Override - public void setLocal(final int i, final V value) { + public void setLocal(int i, V value) { /* * If we're replacing one ReplacingBasicValue with another, we need to - * associate them together so that they will have the same replacability + * associate them together so that they will have the same replaceability * attributes. We also track the local slot the new value will be stored in. */ if (value instanceof ReplacingBasicValue) { - final ReplacingBasicValue replacingValue = (ReplacingBasicValue) value; + ReplacingBasicValue replacingValue = (ReplacingBasicValue) value; replacingValue.setFrameSlot(i); - final V localValue = getLocal(i); - if ((localValue != null) && (localValue instanceof ReplacingBasicValue)) { - final ReplacingBasicValue localReplacingValue = (ReplacingBasicValue) localValue; + V localValue = getLocal(i); + Set<Integer> currentLocalVars = localVariablesSet.element(); + if (localValue instanceof ReplacingBasicValue) { + if (!currentLocalVars.contains(i)) { + // value is assigned to object declared outside of conditional block + replacingValue.setAssignedInConditionalBlock(); + } + ReplacingBasicValue localReplacingValue = (ReplacingBasicValue) localValue; localReplacingValue.associate(replacingValue); + } else { + currentLocalVars.add(i); } } super.setLocal(i, value); } - } - /** - * Constructor. - * - * @param interpreter the interpreter to use - */ - public MethodAnalyzer(final Interpreter<V> interpreter) { - super(interpreter); - } - - @Override - protected Frame<V> newFrame(final int maxLocals, final int maxStack) { - return new AssignmentTrackingFrame<V>(maxLocals, maxStack); - } - - @Override - protected Frame<V> newFrame(final Frame<? extends V> src) { - return new AssignmentTrackingFrame<V>(src); + @Override + public void initJumpTarget(int opcode, LabelNode target) { + if (target != null) { + switch (opcode) { + case IFEQ: + case IFNE: + case IFLT: + case IFGE: + case IFGT: + case IFLE: + case IF_ICMPEQ: + case IF_ICMPNE: + case IF_ICMPLT: + case IF_ICMPGE: + case IF_ICMPGT: + case IF_ICMPLE: + case IF_ACMPEQ: + case IF_ACMPNE: + case IFNONNULL: + // for the case when conditional block is handled, creates new variables set + // to store local variables declared inside current conditional block and + // stores its target LabelNode to restore previous variables set after conditional block is ended + localVariablesSet.push(new HashSet<>()); + labelsStack.push(target); + } + } + } } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ReplacingBasicValue.java b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ReplacingBasicValue.java index c05f5cbbe..42c8685ef 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ReplacingBasicValue.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ReplacingBasicValue.java @@ -49,6 +49,7 @@ public class ReplacingBasicValue extends BasicValue { boolean isFunctionReturn = false; boolean isFunctionArgument = false; boolean isAssignedToMember = false; + boolean isAssignedInConditionalBlock = false; boolean isThis = false; /** @@ -67,6 +68,9 @@ public class ReplacingBasicValue extends BasicValue { if (other.isAssignedToMember) { isAssignedToMember = true; } + if (other.isAssignedInConditionalBlock) { + isAssignedInConditionalBlock = true; + } if (other.isThis) { isThis = true; } @@ -78,7 +82,7 @@ public class ReplacingBasicValue extends BasicValue { * @return whether or not the value is replaceable */ public boolean isReplaceable() { - return !(isFunctionReturn || isFunctionArgument || isAssignedToMember || isThis); + return !(isFunctionReturn || isFunctionArgument || isAssignedToMember || isAssignedInConditionalBlock || isThis); } /** @@ -115,6 +119,15 @@ public class ReplacingBasicValue extends BasicValue { needSpace = true; } + if (isAssignedInConditionalBlock) { + if (needSpace) { + sb.append(' '); + } + + sb.append("conditional"); + needSpace = true; + } + if (isThis) { if (needSpace) { sb.append(' '); @@ -397,6 +410,13 @@ public class ReplacingBasicValue extends BasicValue { } /** + * Mark this value as being assigned to a variable inside of conditional block. + */ + public void setAssignedInConditionalBlock() { + flagSet.isAssignedInConditionalBlock = true; + } + + /** * Indicates whether or not this value is assigned to a class member variable. * * @return whether or not this value is assigned to a class member variable @@ -406,6 +426,15 @@ public class ReplacingBasicValue extends BasicValue { } /** + * Indicates whether or not this value is assigned to a variable inside of conditional block. + * + * @return whether or not this value is assigned to a variable inside of conditional block + */ + public boolean isAssignedInConditionalBlock() { + return flagSet.isAssignedInConditionalBlock; + } + + /** * Return the ValueHolder identity for this value. * * @return the ValueHolderIden for this value diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ScalarReplacementNode.java b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ScalarReplacementNode.java index 3963e9b52..bd836e02e 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ScalarReplacementNode.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ScalarReplacementNode.java @@ -56,7 +56,7 @@ public class ScalarReplacementNode extends MethodNode { final LinkedList<ReplacingBasicValue> valueList = new LinkedList<>(); final MethodAnalyzer<BasicValue> analyzer = - new MethodAnalyzer<BasicValue>(new ReplacingInterpreter(className, valueList)); + new MethodAnalyzer<>(new ReplacingInterpreter(className, valueList)); Frame<BasicValue>[] frames; try { frames = analyzer.analyze(className, this); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ValueHolderReplacementVisitor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ValueHolderReplacementVisitor.java index 9094b33a1..1eb1e7948 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ValueHolderReplacementVisitor.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/compile/bytecode/ValueHolderReplacementVisitor.java @@ -50,20 +50,20 @@ public class ValueHolderReplacementVisitor extends ClassVisitor { @Override public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) { MethodVisitor innerVisitor = super.visitMethod(access, name, desc, signature, exceptions); -// innerVisitor = new Debugger(access, name, desc, signature, exceptions, innerVisitor); +// innerVisitor = new Debugger(access, name, desc, signature, exceptions, innerVisitor); if (verifyBytecode) { innerVisitor = new CheckMethodVisitorFsm(api, innerVisitor); } return new ScalarReplacementNode(className, access, name, desc, signature, - exceptions,innerVisitor, verifyBytecode); + exceptions, innerVisitor, verifyBytecode); } private static class Debugger extends MethodNode { MethodVisitor inner; public Debugger(int access, String name, String desc, String signature, String[] exceptions, MethodVisitor inner) { - super(access, name, desc, signature, exceptions); + super(CompilationConfig.ASM_API_VERSION, access, name, desc, signature, exceptions); this.inner = inner; } diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestProjectWithFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/TestProjectWithFunctions.java index 58cbad8cc..4f91a9103 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/TestProjectWithFunctions.java +++ b/exec/java-exec/src/test/java/org/apache/drill/TestProjectWithFunctions.java @@ -18,6 +18,9 @@ package org.apache.drill; import org.apache.drill.categories.PlannerTest; +import org.apache.drill.exec.ExecConstants; +import org.apache.drill.exec.compile.ClassCompilerSelector; +import org.apache.drill.exec.compile.ClassTransformer; import org.apache.drill.test.ClusterFixture; import org.apache.drill.test.ClusterFixtureBuilder; import org.apache.drill.test.ClusterTest; @@ -25,8 +28,9 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.junit.experimental.categories.Category; - import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; /** * Test the optimizer plan in terms of projecting different functions e.g. cast @@ -49,6 +53,121 @@ public class TestProjectWithFunctions extends ClusterTest { public void testCastFunctions() throws Exception { String sql = "select t1.f from dfs.`view/emp_6212.view.drill` as t inner join dfs.`view/emp_6212.view.drill` as t1 " + "on cast(t.f as int) = cast(t1.f as int) and cast(t.f as int) = 10 and cast(t1.f as int) = 10"; - client.queryBuilder().sql(sql).run(); + queryBuilder().sql(sql).run(); + } + + @Test // DRILL-6524 + public void testCaseWithColumnsInClause() throws Exception { + String sql = + "select\n" + + "case when a = 3 then a else b end as c,\n" + + "case when a = 1 then a else b end as d\n" + + "from (values(1, 2)) t(a, b)"; + + try { + client.alterSession(ExecConstants.SCALAR_REPLACEMENT_OPTION, ClassTransformer.ScalarReplacementOption.ON.name()); + + List<String> compilers = Arrays.asList(ClassCompilerSelector.CompilerPolicy.JANINO.name(), + ClassCompilerSelector.CompilerPolicy.JDK.name()); + + for (String compilerName : compilers) { + client.alterSession(ClassCompilerSelector.JAVA_COMPILER_OPTION, compilerName); + + testBuilder() + .sqlQuery(sql) + .unOrdered() + .baselineColumns("c", "d") + .baselineValues(2L, 1L) + .go(); + } + } finally { + client.resetSession(ExecConstants.SCALAR_REPLACEMENT_OPTION); + client.resetSession(ClassCompilerSelector.JAVA_COMPILER_OPTION); + } + } + + @Test // DRILL-6722 + public void testCaseWithColumnExprsInClause() throws Exception { + String sqlCreate = + "create table dfs.tmp.test as \n" + + "select 1 as a, 2 as b\n" + + "union all\n" + + "select 3 as a, 2 as b\n" + + "union all\n" + + "select 1 as a, 4 as b\n" + + "union all\n" + + "select 2 as a, 2 as b"; + try { + run(sqlCreate); + String sql = + "select\n" + + "case when s.a > s.b then s.a else s.b end as b, \n" + + "abs(s.a - s.b) as d\n" + + "from dfs.tmp.test s"; + + client.alterSession(ExecConstants.SCALAR_REPLACEMENT_OPTION, ClassTransformer.ScalarReplacementOption.ON.name()); + + List<String> compilers = Arrays.asList(ClassCompilerSelector.CompilerPolicy.JANINO.name(), + ClassCompilerSelector.CompilerPolicy.JDK.name()); + + for (String compilerName : compilers) { + client.alterSession(ClassCompilerSelector.JAVA_COMPILER_OPTION, compilerName); + + testBuilder() + .sqlQuery(sql) + .unOrdered() + .baselineColumns("b", "d") + .baselineValues(2, 1) + .baselineValues(3, 1) + .baselineValues(4, 3) + .baselineValues(2, 0) + .go(); + } + } finally { + run("drop table if exists dfs.tmp.test"); + client.resetSession(ExecConstants.SCALAR_REPLACEMENT_OPTION); + client.resetSession(ClassCompilerSelector.JAVA_COMPILER_OPTION); + } + } + + @Test // DRILL-5581 + public void testCaseWithColumnExprsOnView() throws Exception { + String sqlCreate = + "CREATE VIEW dfs.tmp.`vw_order_sample_csv` as\n" + + "SELECT\n" + + "a AS `ND`,\n" + + "CAST(b AS BIGINT) AS `col1`,\n" + + "CAST(c AS BIGINT) AS `col2`\n" + + "FROM (values('202634342',20000101,20160301)) as t(a, b, c)"; + try { + run(sqlCreate); + String sql = + "select\n" + + "case when col1 > col2 then col1 else col2 end as temp_col,\n" + + "case when col1 = 20000101 and (20170302 - col2) > 10000 then 'D'\n" + + "when col2 = 20000101 then 'P' when col1 - col2 > 10000 then '0'\n" + + "else 'A' end as status\n" + + "from dfs.tmp.`vw_order_sample_csv`"; + + client.alterSession(ExecConstants.SCALAR_REPLACEMENT_OPTION, ClassTransformer.ScalarReplacementOption.ON.name()); + + List<String> compilers = Arrays.asList(ClassCompilerSelector.CompilerPolicy.JANINO.name(), + ClassCompilerSelector.CompilerPolicy.JDK.name()); + + for (String compilerName : compilers) { + client.alterSession(ClassCompilerSelector.JAVA_COMPILER_OPTION, compilerName); + + testBuilder() + .sqlQuery(sql) + .unOrdered() + .baselineColumns("temp_col", "status") + .baselineValues(20160301L, "D") + .go(); + } + } finally { + run("drop view if exists dfs.tmp.`vw_order_sample_csv`"); + client.resetSession(ExecConstants.SCALAR_REPLACEMENT_OPTION); + client.resetSession(ClassCompilerSelector.JAVA_COMPILER_OPTION); + } } } diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java b/exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java index 97dfd6b06..bcaee1914 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java @@ -18,7 +18,12 @@ package org.apache.drill.exec.compile; import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import org.apache.drill.common.util.DrillFileUtils; +import org.apache.drill.exec.ExecConstants; +import org.apache.drill.exec.compile.bytecode.ValueHolderReplacementVisitor; import org.apache.drill.test.BaseTestQuery; import org.apache.drill.exec.compile.ClassTransformer.ClassSet; import org.apache.drill.exec.compile.sig.GeneratorMapping; @@ -32,6 +37,9 @@ import org.codehaus.commons.compiler.CompileException; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.tree.ClassNode; public class TestClassTransformation extends BaseTestQuery { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(TestClassTransformation.class); @@ -41,7 +49,7 @@ public class TestClassTransformation extends BaseTestQuery { private static SessionOptionManager sessionOptions; @BeforeClass - public static void beforeTestClassTransformation() throws Exception { + public static void beforeTestClassTransformation() { // Tests here require the byte-code merge technique and are meaningless // if the plain-old Java technique is selected. Force the plain-Java // technique to be off if it happened to be set on in the default @@ -50,7 +58,7 @@ public class TestClassTransformation extends BaseTestQuery { final UserSession userSession = UserSession.Builder.newBuilder() .withOptionManager(getDrillbitContext().getOptionManager()) .build(); - sessionOptions = (SessionOptionManager) userSession.getOptions(); + sessionOptions = userSession.getOptions(); } @Test @@ -79,11 +87,10 @@ public class TestClassTransformation extends BaseTestQuery { ClassSet classSet = new ClassSet(null, cg.getDefinition().getTemplateClassName(), cg.getMaterializedClassName()); String sourceCode = cg.generateAndGet(); sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_OPTION, ClassCompilerSelector.CompilerPolicy.JDK.name()); - sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_DEBUG_OPTION, false); - @SuppressWarnings("resource") + QueryClassLoader loader = new QueryClassLoader(config, sessionOptions); - final byte[][] codeWithoutDebug = loader.getClassByteCode(classSet.generated, sourceCode); + byte[][] codeWithoutDebug = loader.getClassByteCode(classSet.generated, sourceCode); loader.close(); int sizeWithoutDebug = 0; for (byte[] bs : codeWithoutDebug) { @@ -92,7 +99,7 @@ public class TestClassTransformation extends BaseTestQuery { sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_DEBUG_OPTION, true); loader = new QueryClassLoader(config, sessionOptions); - final byte[][] codeWithDebug = loader.getClassByteCode(classSet.generated, sourceCode); + byte[][] codeWithDebug = loader.getClassByteCode(classSet.generated, sourceCode); loader.close(); int sizeWithDebug = 0; for (byte[] bs : codeWithDebug) { @@ -103,9 +110,77 @@ public class TestClassTransformation extends BaseTestQuery { logger.debug("Optimized code is {}% smaller than debug code.", (int)((sizeWithDebug - sizeWithoutDebug)/(double)sizeWithDebug*100)); } + @Test // DRILL-6524 + public void testScalarReplacementInCondition() throws Exception { + ClassTransformer.ClassNames classNames = new ClassTransformer.ClassNames("org.apache.drill.CompileClassWithIfs"); + String entireClass = DrillFileUtils.getResourceAsString(DrillFileUtils.SEPARATOR + classNames.slash + ".java"); + + sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_DEBUG_OPTION, false); + + List<String> compilers = Arrays.asList(ClassCompilerSelector.CompilerPolicy.JANINO.name(), + ClassCompilerSelector.CompilerPolicy.JDK.name()); + for (String compilerName : compilers) { + sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_OPTION, compilerName); + + QueryClassLoader queryClassLoader = new QueryClassLoader(config, sessionOptions); + + byte[][] implementationClasses = queryClassLoader.getClassByteCode(classNames, entireClass); + + ClassNode originalClass = AsmUtil.classFromBytes(implementationClasses[0], ClassReader.EXPAND_FRAMES); + + ClassNode transformedClass = new ClassNode(); + DrillCheckClassAdapter mergeGenerator = new DrillCheckClassAdapter(CompilationConfig.ASM_API_VERSION, + new CheckClassVisitorFsm(CompilationConfig.ASM_API_VERSION, transformedClass), true); + originalClass.accept(new ValueHolderReplacementVisitor(mergeGenerator, true)); + + if (!AsmUtil.isClassOk(logger, classNames.dot, transformedClass)) { + throw new IllegalStateException(String.format("Problem found after transforming %s", classNames.dot)); + } + ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES); + transformedClass.accept(writer); + byte[] outputClass = writer.toByteArray(); + + queryClassLoader.injectByteCode(classNames.dot, outputClass); + Class<?> transformedClazz = queryClassLoader.findClass(classNames.dot); + transformedClazz.getMethod("doSomething").invoke(null); + } + } + + @Test // DRILL-5683 + public void testCaseWithColumnExprsOnView() throws Exception { + String sqlCreate = + "create table dfs.tmp.t1 as\n" + + "select r_regionkey, r_name, case when mod(r_regionkey, 3) > 0 then mod(r_regionkey, 3) else null end as flag\n" + + "from cp.`tpch/region.parquet`"; + try { + test(sqlCreate); + String sql = "select * from dfs.tmp.t1 where NOT (flag IS NOT NULL)"; + + setSessionOption(ExecConstants.SCALAR_REPLACEMENT_OPTION, ClassTransformer.ScalarReplacementOption.ON.name()); + + List<String> compilers = Arrays.asList(ClassCompilerSelector.CompilerPolicy.JANINO.name(), + ClassCompilerSelector.CompilerPolicy.JDK.name()); + + for (String compilerName : compilers) { + setSessionOption(ClassCompilerSelector.JAVA_COMPILER_OPTION, compilerName); + + testBuilder() + .sqlQuery(sql) + .unOrdered() + .baselineColumns("r_regionkey", "r_name", "flag") + .baselineValues(0, "AFRICA", null) + .baselineValues(3, "EUROPE", null) + .go(); + } + } finally { + test("drop table if exists dfs.tmp.t1"); + resetSessionOption(ExecConstants.SCALAR_REPLACEMENT_OPTION); + resetSessionOption(ClassCompilerSelector.JAVA_COMPILER_OPTION); + } + } + /** * Do a test of a three level class to ensure that nested code generators works correctly. - * @throws Exception */ private void compilationInnerClass(boolean asPoj) throws Exception{ CodeGenerator<ExampleInner> cg = newCodeGenerator(ExampleInner.class, ExampleTemplateWithInner.class); @@ -114,13 +189,13 @@ public class TestClassTransformation extends BaseTestQuery { CodeCompiler.CodeGenCompiler cc = new CodeCompiler.CodeGenCompiler(config, sessionOptions); @SuppressWarnings("unchecked") Class<? extends ExampleInner> c = (Class<? extends ExampleInner>) cc.generateAndCompile(cg); - ExampleInner t = (ExampleInner) c.newInstance(); + ExampleInner t = c.newInstance(); t.doOutside(); t.doInsideOutside(); } private <T, X extends T> CodeGenerator<T> newCodeGenerator(Class<T> iface, Class<X> impl) { - final TemplateClassDefinition<T> template = new TemplateClassDefinition<T>(iface, impl); + final TemplateClassDefinition<T> template = new TemplateClassDefinition<>(iface, impl); CodeGenerator<T> cg = CodeGenerator.get(template, getDrillbitContext().getOptionManager()); cg.plainJavaCapable(true); diff --git a/exec/java-exec/src/test/resources/org/apache/drill/CompileClassWithIfs.java b/exec/java-exec/src/test/resources/org/apache/drill/CompileClassWithIfs.java new file mode 100644 index 000000000..ed206b338 --- /dev/null +++ b/exec/java-exec/src/test/resources/org/apache/drill/CompileClassWithIfs.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill; + +import org.apache.drill.exec.expr.holders.NullableBigIntHolder; + +public class CompileClassWithIfs { + + public static void doSomething() { + int a = 2; + NullableBigIntHolder out0 = new NullableBigIntHolder(); + out0.isSet = 1; + NullableBigIntHolder out4 = new NullableBigIntHolder(); + out4.isSet = 0; + if (a == 0) { + out0 = out4; + } else { + } + + if (out4.isSet == 0) { + out0.isSet = 1; + } else { + out0.isSet = 0; + assert false : "Incorrect class transformation. This code should never be executed."; + } + } +} |