diff options
Diffstat (limited to 'exec/java-exec/src/test/java/org/apache/drill/exec/planner/rm/TestMemoryCalculator.java')
-rw-r--r-- | exec/java-exec/src/test/java/org/apache/drill/exec/planner/rm/TestMemoryCalculator.java | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/planner/rm/TestMemoryCalculator.java b/exec/java-exec/src/test/java/org/apache/drill/exec/planner/rm/TestMemoryCalculator.java new file mode 100644 index 000000000..4893a36fd --- /dev/null +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/planner/rm/TestMemoryCalculator.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.rm; + + +import org.apache.drill.PlanTestBase; +import org.apache.drill.exec.ExecConstants; +import org.apache.drill.exec.ops.QueryContext; +import org.apache.drill.exec.planner.PhysicalPlanReader; +import org.apache.drill.exec.planner.cost.NodeResource; +import org.apache.drill.exec.planner.fragment.Fragment; +import org.apache.drill.exec.planner.fragment.PlanningSet; +import org.apache.drill.exec.planner.fragment.QueueQueryParallelizer; +import org.apache.drill.exec.planner.fragment.SimpleParallelizer; +import org.apache.drill.exec.planner.fragment.Wrapper; +import org.apache.drill.exec.pop.PopUnitTestBase; +import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint; +import org.apache.drill.exec.proto.UserBitShared; +import org.apache.drill.exec.proto.UserProtos; +import org.apache.drill.exec.rpc.user.UserSession; +import org.apache.drill.exec.server.DrillbitContext; +import org.apache.drill.shaded.guava.com.google.common.collect.Iterables; +import org.apache.drill.test.ClientFixture; +import org.apache.drill.test.ClusterFixture; +import org.apache.drill.test.ClusterFixtureBuilder; +import org.junit.AfterClass; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestMemoryCalculator extends PlanTestBase { + + private static final long DEFAULT_SLICE_TARGET = 100000L; + private static final long DEFAULT_BATCH_SIZE = 16*1024*1024; + + private static final UserSession session = UserSession.Builder.newBuilder() + .withCredentials(UserBitShared.UserCredentials.newBuilder() + .setUserName("foo") + .build()) + .withUserProperties(UserProtos.UserProperties.getDefaultInstance()) + .withOptionManager(bits[0].getContext().getOptionManager()) + .build(); + + private static final DrillbitEndpoint N1_EP1 = newDrillbitEndpoint("node1", 30010); + private static final DrillbitEndpoint N1_EP2 = newDrillbitEndpoint("node2", 30011); + private static final DrillbitEndpoint N1_EP3 = newDrillbitEndpoint("node3", 30012); + private static final DrillbitEndpoint N1_EP4 = newDrillbitEndpoint("node4", 30013); + + private static final DrillbitEndpoint[] nodeList = {N1_EP1, N1_EP2, N1_EP3, N1_EP4}; + + private static final DrillbitEndpoint newDrillbitEndpoint(String address, int port) { + return DrillbitEndpoint.newBuilder().setAddress(address).setControlPort(port).build(); + } + private static final DrillbitContext drillbitContext = getDrillbitContext(); + private static final QueryContext queryContext = new QueryContext(session, drillbitContext, + UserBitShared.QueryId.getDefaultInstance()); + + @AfterClass + public static void close() throws Exception { + queryContext.close(); + } + + private final Wrapper mockWrapper(Wrapper rootFragment, + Map<DrillbitEndpoint, NodeResource> resourceMap, + List<DrillbitEndpoint> endpoints, + Map<Fragment, Wrapper> originalToMockWrapper ) { + final Wrapper mockWrapper = mock(Wrapper.class); + originalToMockWrapper.put(rootFragment.getNode(), mockWrapper); + List<Wrapper> mockdependencies = new ArrayList<>(); + + for (Wrapper dependency : rootFragment.getFragmentDependencies()) { + mockdependencies.add(mockWrapper(dependency, resourceMap, endpoints, originalToMockWrapper)); + } + + when(mockWrapper.getNode()).thenReturn(rootFragment.getNode()); + when(mockWrapper.getAssignedEndpoints()).thenReturn(endpoints); + when(mockWrapper.getResourceMap()).thenReturn(resourceMap); + when(mockWrapper.getWidth()).thenReturn(endpoints.size()); + when(mockWrapper.getFragmentDependencies()).thenReturn(mockdependencies); + when(mockWrapper.isEndpointsAssignmentDone()).thenReturn(true); + return mockWrapper; + } + + private final PlanningSet mockPlanningSet(PlanningSet planningSet, + Map<DrillbitEndpoint, NodeResource> resourceMap, + List<DrillbitEndpoint> endpoints) { + Map<Fragment, Wrapper> wrapperToMockWrapper = new HashMap<>(); + Wrapper rootFragment = mockWrapper( planningSet.getRootWrapper(), resourceMap, + endpoints, wrapperToMockWrapper); + PlanningSet mockPlanningSet = mock(PlanningSet.class); + when(mockPlanningSet.getRootWrapper()).thenReturn(rootFragment); + when(mockPlanningSet.get(any(Fragment.class))).thenAnswer(invocation -> { + return wrapperToMockWrapper.get(invocation.getArgument(0)); + }); + return mockPlanningSet; + } + + private String getPlanForQuery(String query) throws Exception { + return getPlanForQuery(query, DEFAULT_BATCH_SIZE); + } + + private String getPlanForQuery(String query, long outputBatchSize) throws Exception { + return getPlanForQuery(query, outputBatchSize, DEFAULT_SLICE_TARGET); + } + + private String getPlanForQuery(String query, long outputBatchSize, + long slice_target) throws Exception { + ClusterFixtureBuilder builder = ClusterFixture.builder(dirTestWatcher) + .setOptionDefault(ExecConstants.OUTPUT_BATCH_SIZE, outputBatchSize) + .setOptionDefault(ExecConstants.SLICE_TARGET, slice_target); + String plan; + + try (ClusterFixture cluster = builder.build(); + ClientFixture client = cluster.clientFixture()) { + plan = client.queryBuilder() + .sql(query) + .explainJson(); + } + return plan; + } + + private List<DrillbitEndpoint> getEndpoints(int totalMinorFragments, + Set<DrillbitEndpoint> notIn) { + List<DrillbitEndpoint> endpoints = new ArrayList<>(); + Iterator drillbits = Iterables.cycle(nodeList).iterator(); + + while(totalMinorFragments-- > 0) { + DrillbitEndpoint dbit = (DrillbitEndpoint) drillbits.next(); + if (!notIn.contains(dbit)) { + endpoints.add(dbit); + } + } + return endpoints; + } + + private Set<Wrapper> createSet(Wrapper... wrappers) { + Set<Wrapper> setOfWrappers = new HashSet<>(); + for (Wrapper wrapper : wrappers) { + setOfWrappers.add(wrapper); + } + return setOfWrappers; + } + + private Fragment getRootFragmentFromPlan(DrillbitContext context, + String plan) throws Exception { + final PhysicalPlanReader planReader = context.getPlanReader(); + return PopUnitTestBase.getRootFragmentFromPlanString(planReader, plan); + } + + private PlanningSet preparePlanningSet(List<DrillbitEndpoint> activeEndpoints, long slice_target, + Map<DrillbitEndpoint, NodeResource> resources, String sql, + SimpleParallelizer parallelizer) throws Exception { + Fragment rootFragment = getRootFragmentFromPlan(drillbitContext, getPlanForQuery(sql, 10, slice_target)); + return mockPlanningSet(parallelizer.prepareFragmentTree(rootFragment), resources, activeEndpoints); + } + + @Test + public void TestSingleMajorFragmentWithProjectAndScan() throws Exception { + List<DrillbitEndpoint> activeEndpoints = getEndpoints(2, new HashSet<>()); + Map<DrillbitEndpoint, NodeResource> resources = activeEndpoints.stream() + .collect(Collectors.toMap(x -> x, + x -> NodeResource.create())); + String sql = "SELECT * from cp.`tpch/nation.parquet`"; + + SimpleParallelizer parallelizer = new QueueQueryParallelizer(false, queryContext); + PlanningSet planningSet = preparePlanningSet(activeEndpoints, DEFAULT_SLICE_TARGET, resources, sql, parallelizer); + parallelizer.adjustMemory(planningSet, createSet(planningSet.getRootWrapper()), activeEndpoints); + assertTrue("memory requirement is different", Iterables.all(resources.entrySet(), (e) -> e.getValue().getMemory() == 30)); + } + + + @Test + public void TestSingleMajorFragmentWithGroupByProjectAndScan() throws Exception { + List<DrillbitEndpoint> activeEndpoints = getEndpoints(2, new HashSet<>()); + Map<DrillbitEndpoint, NodeResource> resources = activeEndpoints.stream() + .collect(Collectors.toMap(x -> x, + x -> NodeResource.create())); + String sql = "SELECT dept_id, count(*) from cp.`tpch/lineitem.parquet` group by dept_id"; + + SimpleParallelizer parallelizer = new QueueQueryParallelizer(false, queryContext); + PlanningSet planningSet = preparePlanningSet(activeEndpoints, DEFAULT_SLICE_TARGET, resources, sql, parallelizer); + parallelizer.adjustMemory(planningSet, createSet(planningSet.getRootWrapper()), activeEndpoints); + assertTrue("memory requirement is different", Iterables.all(resources.entrySet(), (e) -> e.getValue().getMemory() == 529570)); + } + + + @Test + public void TestTwoMajorFragmentWithSortyProjectAndScan() throws Exception { + List<DrillbitEndpoint> activeEndpoints = getEndpoints(2, new HashSet<>()); + Map<DrillbitEndpoint, NodeResource> resources = activeEndpoints.stream() + .collect(Collectors.toMap(x -> x, + x -> NodeResource.create())); + String sql = "SELECT * from cp.`tpch/lineitem.parquet` order by dept_id"; + + SimpleParallelizer parallelizer = new QueueQueryParallelizer(false, queryContext); + PlanningSet planningSet = preparePlanningSet(activeEndpoints, 2, resources, sql, parallelizer); + parallelizer.adjustMemory(planningSet, createSet(planningSet.getRootWrapper()), activeEndpoints); + assertTrue("memory requirement is different", Iterables.all(resources.entrySet(), (e) -> e.getValue().getMemory() == 481490)); + } +} |