diff options
Diffstat (limited to 'exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java')
-rw-r--r-- | exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java | 93 |
1 files changed, 84 insertions, 9 deletions
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java b/exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java index 97dfd6b06..bcaee1914 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/compile/TestClassTransformation.java @@ -18,7 +18,12 @@ package org.apache.drill.exec.compile; import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import org.apache.drill.common.util.DrillFileUtils; +import org.apache.drill.exec.ExecConstants; +import org.apache.drill.exec.compile.bytecode.ValueHolderReplacementVisitor; import org.apache.drill.test.BaseTestQuery; import org.apache.drill.exec.compile.ClassTransformer.ClassSet; import org.apache.drill.exec.compile.sig.GeneratorMapping; @@ -32,6 +37,9 @@ import org.codehaus.commons.compiler.CompileException; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.tree.ClassNode; public class TestClassTransformation extends BaseTestQuery { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(TestClassTransformation.class); @@ -41,7 +49,7 @@ public class TestClassTransformation extends BaseTestQuery { private static SessionOptionManager sessionOptions; @BeforeClass - public static void beforeTestClassTransformation() throws Exception { + public static void beforeTestClassTransformation() { // Tests here require the byte-code merge technique and are meaningless // if the plain-old Java technique is selected. Force the plain-Java // technique to be off if it happened to be set on in the default @@ -50,7 +58,7 @@ public class TestClassTransformation extends BaseTestQuery { final UserSession userSession = UserSession.Builder.newBuilder() .withOptionManager(getDrillbitContext().getOptionManager()) .build(); - sessionOptions = (SessionOptionManager) userSession.getOptions(); + sessionOptions = userSession.getOptions(); } @Test @@ -79,11 +87,10 @@ public class TestClassTransformation extends BaseTestQuery { ClassSet classSet = new ClassSet(null, cg.getDefinition().getTemplateClassName(), cg.getMaterializedClassName()); String sourceCode = cg.generateAndGet(); sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_OPTION, ClassCompilerSelector.CompilerPolicy.JDK.name()); - sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_DEBUG_OPTION, false); - @SuppressWarnings("resource") + QueryClassLoader loader = new QueryClassLoader(config, sessionOptions); - final byte[][] codeWithoutDebug = loader.getClassByteCode(classSet.generated, sourceCode); + byte[][] codeWithoutDebug = loader.getClassByteCode(classSet.generated, sourceCode); loader.close(); int sizeWithoutDebug = 0; for (byte[] bs : codeWithoutDebug) { @@ -92,7 +99,7 @@ public class TestClassTransformation extends BaseTestQuery { sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_DEBUG_OPTION, true); loader = new QueryClassLoader(config, sessionOptions); - final byte[][] codeWithDebug = loader.getClassByteCode(classSet.generated, sourceCode); + byte[][] codeWithDebug = loader.getClassByteCode(classSet.generated, sourceCode); loader.close(); int sizeWithDebug = 0; for (byte[] bs : codeWithDebug) { @@ -103,9 +110,77 @@ public class TestClassTransformation extends BaseTestQuery { logger.debug("Optimized code is {}% smaller than debug code.", (int)((sizeWithDebug - sizeWithoutDebug)/(double)sizeWithDebug*100)); } + @Test // DRILL-6524 + public void testScalarReplacementInCondition() throws Exception { + ClassTransformer.ClassNames classNames = new ClassTransformer.ClassNames("org.apache.drill.CompileClassWithIfs"); + String entireClass = DrillFileUtils.getResourceAsString(DrillFileUtils.SEPARATOR + classNames.slash + ".java"); + + sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_DEBUG_OPTION, false); + + List<String> compilers = Arrays.asList(ClassCompilerSelector.CompilerPolicy.JANINO.name(), + ClassCompilerSelector.CompilerPolicy.JDK.name()); + for (String compilerName : compilers) { + sessionOptions.setLocalOption(ClassCompilerSelector.JAVA_COMPILER_OPTION, compilerName); + + QueryClassLoader queryClassLoader = new QueryClassLoader(config, sessionOptions); + + byte[][] implementationClasses = queryClassLoader.getClassByteCode(classNames, entireClass); + + ClassNode originalClass = AsmUtil.classFromBytes(implementationClasses[0], ClassReader.EXPAND_FRAMES); + + ClassNode transformedClass = new ClassNode(); + DrillCheckClassAdapter mergeGenerator = new DrillCheckClassAdapter(CompilationConfig.ASM_API_VERSION, + new CheckClassVisitorFsm(CompilationConfig.ASM_API_VERSION, transformedClass), true); + originalClass.accept(new ValueHolderReplacementVisitor(mergeGenerator, true)); + + if (!AsmUtil.isClassOk(logger, classNames.dot, transformedClass)) { + throw new IllegalStateException(String.format("Problem found after transforming %s", classNames.dot)); + } + ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES); + transformedClass.accept(writer); + byte[] outputClass = writer.toByteArray(); + + queryClassLoader.injectByteCode(classNames.dot, outputClass); + Class<?> transformedClazz = queryClassLoader.findClass(classNames.dot); + transformedClazz.getMethod("doSomething").invoke(null); + } + } + + @Test // DRILL-5683 + public void testCaseWithColumnExprsOnView() throws Exception { + String sqlCreate = + "create table dfs.tmp.t1 as\n" + + "select r_regionkey, r_name, case when mod(r_regionkey, 3) > 0 then mod(r_regionkey, 3) else null end as flag\n" + + "from cp.`tpch/region.parquet`"; + try { + test(sqlCreate); + String sql = "select * from dfs.tmp.t1 where NOT (flag IS NOT NULL)"; + + setSessionOption(ExecConstants.SCALAR_REPLACEMENT_OPTION, ClassTransformer.ScalarReplacementOption.ON.name()); + + List<String> compilers = Arrays.asList(ClassCompilerSelector.CompilerPolicy.JANINO.name(), + ClassCompilerSelector.CompilerPolicy.JDK.name()); + + for (String compilerName : compilers) { + setSessionOption(ClassCompilerSelector.JAVA_COMPILER_OPTION, compilerName); + + testBuilder() + .sqlQuery(sql) + .unOrdered() + .baselineColumns("r_regionkey", "r_name", "flag") + .baselineValues(0, "AFRICA", null) + .baselineValues(3, "EUROPE", null) + .go(); + } + } finally { + test("drop table if exists dfs.tmp.t1"); + resetSessionOption(ExecConstants.SCALAR_REPLACEMENT_OPTION); + resetSessionOption(ClassCompilerSelector.JAVA_COMPILER_OPTION); + } + } + /** * Do a test of a three level class to ensure that nested code generators works correctly. - * @throws Exception */ private void compilationInnerClass(boolean asPoj) throws Exception{ CodeGenerator<ExampleInner> cg = newCodeGenerator(ExampleInner.class, ExampleTemplateWithInner.class); @@ -114,13 +189,13 @@ public class TestClassTransformation extends BaseTestQuery { CodeCompiler.CodeGenCompiler cc = new CodeCompiler.CodeGenCompiler(config, sessionOptions); @SuppressWarnings("unchecked") Class<? extends ExampleInner> c = (Class<? extends ExampleInner>) cc.generateAndCompile(cg); - ExampleInner t = (ExampleInner) c.newInstance(); + ExampleInner t = c.newInstance(); t.doOutside(); t.doInsideOutside(); } private <T, X extends T> CodeGenerator<T> newCodeGenerator(Class<T> iface, Class<X> impl) { - final TemplateClassDefinition<T> template = new TemplateClassDefinition<T>(iface, impl); + final TemplateClassDefinition<T> template = new TemplateClassDefinition<>(iface, impl); CodeGenerator<T> cg = CodeGenerator.get(template, getDrillbitContext().getOptionManager()); cg.plainJavaCapable(true); |