//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert scf.for, scf.if and loop.terminator // ops into standard CFG ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" using namespace mlir; using namespace mlir::scf; namespace { struct SCFToControlFlowPass : public SCFToControlFlowBase { void runOnOperation() override; }; // Create a CFG subgraph for the loop around its body blocks (if the body // contained other loops, they have been already lowered to a flow of blocks). // Maintain the invariants that a CFG subgraph created for any loop has a single // entry and a single exit, and that the entry/exit blocks are respectively // first/last blocks in the parent region. The original loop operation is // replaced by the initialization operations that set up the initial value of // the loop induction variable (%iv) and computes the loop bounds that are loop- // invariant for affine loops. The operations following the original scf.for // are split out into a separate continuation (exit) block. A condition block is // created before the continuation block. It checks the exit condition of the // loop and branches either to the continuation block, or to the first block of // the body. The condition block takes as arguments the values of the induction // variable followed by loop-carried values. Since it dominates both the body // blocks and the continuation block, loop-carried values are visible in all of // those blocks. Induction variable modification is appended to the last block // of the body (which is the exit block from the body subgraph thanks to the // invariant we maintain) along with a branch that loops back to the condition // block. Loop-carried values are the loop terminator operands, which are // forwarded to the branch. // // +---------------------------------+ // | | // | | // | | // | cf.br cond(%iv, %init...) | // +---------------------------------+ // | // -------| | // | v v // | +--------------------------------+ // | | cond(%iv, %init...): | // | | | // | | cf.cond_br %r, body, end | // | +--------------------------------+ // | | | // | | -------------| // | v | // | +--------------------------------+ | // | | body-first: | | // | | <%init visible by dominance> | | // | | | | // | +--------------------------------+ | // | | | // | ... | // | | | // | +--------------------------------+ | // | | body-last: | | // | | | | // | | | | // | | %new_iv = | | // | | cf.br cond(%new_iv, %yields) | | // | +--------------------------------+ | // | | | // |----------- |-------------------- // v // +--------------------------------+ // | end: | // | | // | <%init visible by dominance> | // +--------------------------------+ // struct ForLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const override; }; // Create a CFG subgraph for the scf.if operation (including its "then" and // optional "else" operation blocks). We maintain the invariants that the // subgraph has a single entry and a single exit point, and that the entry/exit // blocks are respectively the first/last block of the enclosing region. The // operations following the scf.if are split into a continuation (subgraph // exit) block. The condition is lowered to a chain of blocks that implement the // short-circuit scheme. The "scf.if" operation is replaced with a conditional // branch to either the first block of the "then" region, or to the first block // of the "else" region. In these blocks, "scf.yield" is unconditional branches // to the post-dominating block. When the "scf.if" does not return values, the // post-dominating block is the same as the continuation block. When it returns // values, the post-dominating block is a new block with arguments that // correspond to the values returned by the "scf.if" that unconditionally // branches to the continuation block. This allows block arguments to dominate // any uses of the hitherto "scf.if" results that they replaced. (Inserting a // new block allows us to avoid modifying the argument list of an existing // block, which is illegal in a conversion pattern). When the "else" region is // empty, which is only allowed for "scf.if"s that don't return values, the // condition branches directly to the continuation block. // // CFG for a scf.if with else and without results. // // +--------------------------------+ // | | // | cf.cond_br %cond, %then, %else | // +--------------------------------+ // | | // | --------------| // v | // +--------------------------------+ | // | then: | | // | | | // | cf.br continue | | // +--------------------------------+ | // | | // |---------- |------------- // | V // | +--------------------------------+ // | | else: | // | | | // | | cf.br continue | // | +--------------------------------+ // | | // ------| | // v v // +--------------------------------+ // | continue: | // | | // +--------------------------------+ // // CFG for a scf.if with results. // // +--------------------------------+ // | | // | cf.cond_br %cond, %then, %else | // +--------------------------------+ // | | // | --------------| // v | // +--------------------------------+ | // | then: | | // | | | // | cf.br dom(%args...) | | // +--------------------------------+ | // | | // |---------- |------------- // | V // | +--------------------------------+ // | | else: | // | | | // | | cf.br dom(%args...) | // | +--------------------------------+ // | | // ------| | // v v // +--------------------------------+ // | dom(%args...): | // | cf.br continue | // +--------------------------------+ // | // v // +--------------------------------+ // | continue: | // | | // +--------------------------------+ // struct IfLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const override; }; struct ExecuteRegionLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override; }; struct ParallelLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, PatternRewriter &rewriter) const override; }; /// Create a CFG subgraph for this loop construct. The regions of the loop need /// not be a single block anymore (for example, if other SCF constructs that /// they contain have been already converted to CFG), but need to be single-exit /// from the last block of each region. The operations following the original /// WhileOp are split into a new continuation block. Both regions of the WhileOp /// are inlined, and their terminators are rewritten to organize the control /// flow implementing the loop as follows. /// /// +---------------------------------+ /// | | /// | cf.br ^before(%operands...) | /// +---------------------------------+ /// | /// -------| | /// | v v /// | +--------------------------------+ /// | | ^before(%bargs...): | /// | | %vals... = | /// | +--------------------------------+ /// | | /// | ... /// | | /// | +--------------------------------+ /// | | ^before-last: /// | | %cond = | /// | | cf.cond_br %cond, | /// | | ^after(%vals...), ^cont | /// | +--------------------------------+ /// | | | /// | | -------------| /// | v | /// | +--------------------------------+ | /// | | ^after(%aargs...): | | /// | | | | /// | +--------------------------------+ | /// | | | /// | ... | /// | | | /// | +--------------------------------+ | /// | | ^after-last: | | /// | | %yields... = | | /// | | cf.br ^before(%yields...) | | /// | +--------------------------------+ | /// | | | /// |----------- |-------------------- /// v /// +--------------------------------+ /// | ^cont: | /// | | /// | <%vals from 'before' region | /// | visible by dominance> | /// +--------------------------------+ /// /// Values are communicated between ex-regions (the groups of blocks that used /// to form a region before inlining) through block arguments of their /// entry blocks, which are visible in all other dominated blocks. Similarly, /// the results of the WhileOp are defined in the 'before' region, which is /// required to have a single existing block, and are therefore accessible in /// the continuation block due to dominance. struct WhileLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WhileOp whileOp, PatternRewriter &rewriter) const override; }; /// Optimized version of the above for the case of the "after" region merely /// forwarding its arguments back to the "before" region (i.e., a "do-while" /// loop). This avoid inlining the "after" region completely and branches back /// to the "before" entry instead. struct DoWhileLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WhileOp whileOp, PatternRewriter &rewriter) const override; }; } // namespace LogicalResult ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { Location loc = forOp.getLoc(); // Start by splitting the block containing the 'scf.for' into two parts. // The part before will get the init code, the part after will be the end // point. auto *initBlock = rewriter.getInsertionBlock(); auto initPosition = rewriter.getInsertionPoint(); auto *endBlock = rewriter.splitBlock(initBlock, initPosition); // Use the first block of the loop body as the condition block since it is the // block that has the induction variable and loop-carried values as arguments. // Split out all operations from the first block into a new block. Move all // body blocks from the loop body region to the region containing the loop. auto *conditionBlock = &forOp.getRegion().front(); auto *firstBodyBlock = rewriter.splitBlock(conditionBlock, conditionBlock->begin()); auto *lastBodyBlock = &forOp.getRegion().back(); rewriter.inlineRegionBefore(forOp.getRegion(), endBlock); auto iv = conditionBlock->getArgument(0); // Append the induction variable stepping logic to the last body block and // branch back to the condition block. Loop-carried values are taken from // operands of the loop terminator. Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.getStep(); auto stepped = rewriter.create(loc, iv, step).getResult(); if (!stepped) return failure(); SmallVector loopCarried; loopCarried.push_back(stepped); loopCarried.append(terminator->operand_begin(), terminator->operand_end()); rewriter.create(loc, conditionBlock, loopCarried); rewriter.eraseOp(terminator); // Compute loop bounds before branching to the condition. rewriter.setInsertionPointToEnd(initBlock); Value lowerBound = forOp.getLowerBound(); Value upperBound = forOp.getUpperBound(); if (!lowerBound || !upperBound) return failure(); // The initial values of loop-carried values is obtained from the operands // of the loop operation. SmallVector destOperands; destOperands.push_back(lowerBound); auto iterOperands = forOp.getIterOperands(); destOperands.append(iterOperands.begin(), iterOperands.end()); rewriter.create(loc, conditionBlock, destOperands); // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); auto comparison = rewriter.create( loc, arith::CmpIPredicate::slt, iv, upperBound); rewriter.create(loc, comparison, firstBodyBlock, ArrayRef(), endBlock, ArrayRef()); // The result of the loop operation is the values of the condition block // arguments except the induction variable on the last iteration. rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); return success(); } LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { auto loc = ifOp.getLoc(); // Start by splitting the block containing the 'scf.if' into two parts. // The part before will contain the condition, the part after will be the // continuation point. auto *condBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); Block *continueBlock; if (ifOp.getNumResults() == 0) { continueBlock = remainingOpsBlock; } else { continueBlock = rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), SmallVector(ifOp.getNumResults(), loc)); rewriter.create(loc, remainingOpsBlock); } // Move blocks from the "then" region to the region containing 'scf.if', // place it before the continuation block, and branch to it. auto &thenRegion = ifOp.getThenRegion(); auto *thenBlock = &thenRegion.front(); Operation *thenTerminator = thenRegion.back().getTerminator(); ValueRange thenTerminatorOperands = thenTerminator->getOperands(); rewriter.setInsertionPointToEnd(&thenRegion.back()); rewriter.create(loc, continueBlock, thenTerminatorOperands); rewriter.eraseOp(thenTerminator); rewriter.inlineRegionBefore(thenRegion, continueBlock); // Move blocks from the "else" region (if present) to the region containing // 'scf.if', place it before the continuation block and branch to it. It // will be placed after the "then" regions. auto *elseBlock = continueBlock; auto &elseRegion = ifOp.getElseRegion(); if (!elseRegion.empty()) { elseBlock = &elseRegion.front(); Operation *elseTerminator = elseRegion.back().getTerminator(); ValueRange elseTerminatorOperands = elseTerminator->getOperands(); rewriter.setInsertionPointToEnd(&elseRegion.back()); rewriter.create(loc, continueBlock, elseTerminatorOperands); rewriter.eraseOp(elseTerminator); rewriter.inlineRegionBefore(elseRegion, continueBlock); } rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, ifOp.getCondition(), thenBlock, /*trueArgs=*/ArrayRef(), elseBlock, /*falseArgs=*/ArrayRef()); // Ok, we're done! rewriter.replaceOp(ifOp, continueBlock->getArguments()); return success(); } LogicalResult ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const { auto loc = op.getLoc(); auto *condBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); auto ®ion = op.getRegion(); rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, ®ion.front()); for (Block &block : region) { if (auto terminator = dyn_cast(block.getTerminator())) { ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(&block); rewriter.create(loc, remainingOpsBlock, terminatorOperands); rewriter.eraseOp(terminator); } } rewriter.inlineRegionBefore(region, remainingOpsBlock); SmallVector vals; SmallVector argLocs(op.getNumResults(), op->getLoc()); for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes(), argLocs)) vals.push_back(arg); rewriter.replaceOp(op, vals); return success(); } LogicalResult ParallelLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { Location loc = parallelOp.getLoc(); // For a parallel loop, we essentially need to create an n-dimensional loop // nest. We do this by translating to scf.for ops and have those lowered in // a further rewrite. If a parallel loop contains reductions (and thus returns // values), forward the initial values for the reductions down the loop // hierarchy and bubble up the results by modifying the "yield" terminator. SmallVector iterArgs = llvm::to_vector<4>(parallelOp.getInitVals()); SmallVector ivs; ivs.reserve(parallelOp.getNumLoops()); bool first = true; SmallVector loopResults(iterArgs); for (auto loopOperands : llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep())) { Value iv, lower, upper, step; std::tie(iv, lower, upper, step) = loopOperands; ForOp forOp = rewriter.create(loc, lower, upper, step, iterArgs); ivs.push_back(forOp.getInductionVar()); auto iterRange = forOp.getRegionIterArgs(); iterArgs.assign(iterRange.begin(), iterRange.end()); if (first) { // Store the results of the outermost loop that will be used to replace // the results of the parallel loop when it is fully rewritten. loopResults.assign(forOp.result_begin(), forOp.result_end()); first = false; } else if (!forOp.getResults().empty()) { // A loop is constructed with an empty "yield" terminator if there are // no results. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); rewriter.create(loc, forOp.getResults()); } rewriter.setInsertionPointToStart(forOp.getBody()); } // First, merge reduction blocks into the main region. SmallVector yieldOperands; yieldOperands.reserve(parallelOp.getNumResults()); for (auto &op : *parallelOp.getBody()) { auto reduce = dyn_cast(op); if (!reduce) continue; Block &reduceBlock = reduce.getReductionOperator().front(); Value arg = iterArgs[yieldOperands.size()]; yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0)); rewriter.eraseOp(reduceBlock.getTerminator()); rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()}); rewriter.eraseOp(reduce); } // Then merge the loop body without the terminator. rewriter.eraseOp(parallelOp.getBody()->getTerminator()); Block *newBody = rewriter.getInsertionBlock(); if (newBody->empty()) rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs); else rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(), ivs); // Finally, create the terminator if required (for loops with no results, it // has been already created in loop construction). if (!yieldOperands.empty()) { rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); rewriter.create(loc, yieldOperands); } rewriter.replaceOp(parallelOp, loopResults); return success(); } LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, PatternRewriter &rewriter) const { OpBuilder::InsertionGuard guard(rewriter); Location loc = whileOp.getLoc(); // Split the current block before the WhileOp to create the inlining point. Block *currentBlock = rewriter.getInsertionBlock(); Block *continuation = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); // Inline both regions. Block *after = &whileOp.getAfter().front(); Block *afterLast = &whileOp.getAfter().back(); Block *before = &whileOp.getBefore().front(); Block *beforeLast = &whileOp.getBefore().back(); rewriter.inlineRegionBefore(whileOp.getAfter(), continuation); rewriter.inlineRegionBefore(whileOp.getBefore(), after); // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, before, whileOp.getInits()); // Replace terminators with branches. Assuming bodies are SESE, which holds // given only the patterns from this file, we only need to look at the last // block. This should be reconsidered if we allow break/continue in SCF. rewriter.setInsertionPointToEnd(beforeLast); auto condOp = cast(beforeLast->getTerminator()); rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), after, condOp.getArgs(), continuation, ValueRange()); rewriter.setInsertionPointToEnd(afterLast); auto yieldOp = cast(afterLast->getTerminator()); rewriter.replaceOpWithNewOp(yieldOp, before, yieldOp.getResults()); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. rewriter.replaceOp(whileOp, condOp.getArgs()); return success(); } LogicalResult DoWhileLowering::matchAndRewrite(WhileOp whileOp, PatternRewriter &rewriter) const { if (!llvm::hasSingleElement(whileOp.getAfter())) return rewriter.notifyMatchFailure(whileOp, "do-while simplification applicable to " "single-block 'after' region only"); Block &afterBlock = whileOp.getAfter().front(); if (!llvm::hasSingleElement(afterBlock)) return rewriter.notifyMatchFailure(whileOp, "do-while simplification applicable " "only if 'after' region has no payload"); auto yield = dyn_cast(&afterBlock.front()); if (!yield || yield.getResults() != afterBlock.getArguments()) return rewriter.notifyMatchFailure(whileOp, "do-while simplification applicable " "only to forwarding 'after' regions"); // Split the current block before the WhileOp to create the inlining point. OpBuilder::InsertionGuard guard(rewriter); Block *currentBlock = rewriter.getInsertionBlock(); Block *continuation = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); // Only the "before" region should be inlined. Block *before = &whileOp.getBefore().front(); Block *beforeLast = &whileOp.getBefore().back(); rewriter.inlineRegionBefore(whileOp.getBefore(), continuation); // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(whileOp.getLoc(), before, whileOp.getInits()); // Loop around the "before" region based on condition. rewriter.setInsertionPointToEnd(beforeLast); auto condOp = cast(beforeLast->getTerminator()); rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), before, condOp.getArgs(), continuation, ValueRange()); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. rewriter.replaceOp(whileOp, condOp.getArgs()); return success(); } void mlir::populateSCFToControlFlowConversionPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), /*benefit=*/2); } void SCFToControlFlowPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateSCFToControlFlowConversionPatterns(patterns); // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr mlir::createConvertSCFToCFPass() { return std::make_unique(); }