aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms
diff options
context:
space:
mode:
authorMahesh Ravishankar <ravishankarm@google.com>2022-07-13 23:22:47 +0000
committerMahesh Ravishankar <ravishankarm@google.com>2022-07-15 23:01:18 +0000
commit3139cc766c86b09426893a7349763c347639cbdc (patch)
tree898ca921cbd27b13fcb81e3e34a199cefbde7141 /mlir/lib/Dialect/Linalg/Transforms
parenteda2bcad020d882d487cac2951d2ad9b75f59123 (diff)
[mlir][Linalg] Add a pattern to decompose `linalg.generic` ops.
This patch adds a pattern to decompose a `linalg.generic` operations that - has only parallel iterator types - has more than 2 statements (including the yield) into multiple `linalg.generic` operation such that each operation has a single statement and a yield. The pattern added here just splits the matching `linalg.generic` into two `linalg.generic`s, one containing the first statement, and the other containing the remaining. The same pattern can be applied repeatedly on the second op to ultimately fully decompose the generic op. Differential Revision: https://reviews.llvm.org/D129704
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp391
2 files changed, 392 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 8015edeb59a9..5bc2740afbe0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Bufferize.cpp
CodegenStrategy.cpp
ConstantFold.cpp
+ DecomposeLinalgOps.cpp
Detensorize.cpp
DropUnitDims.cpp
ElementwiseOpFusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
new file mode 100644
index 000000000000..9b6218474ed0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -0,0 +1,391 @@
+//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// Pattern to decompose a GenericOp that has more than two statements
+/// into one GenericOp with the first statement (i.e. peeled operation), and
+/// a second GenericOp with the remaining statements (i.e. residual operations).
+
+/// - The result of the first GenericOp has the same shape as the iteration
+/// space of the GenericOp. The body of the op yields as many values as the
+/// original op plus all the results of the peeled operation.
+/// - The second GenericOp has as many operands as the original operation plus
+/// all the results of the first Generic Op. It has the same number of yields as
+/// the original op.
+/// - If the result of the peeled operation was yielded by the original
+/// GenericOp the uses of the corresponding results will be replaced with the
+/// result of the first GenericOp created.
+///
+/// Example
+///
+/// ```mlir
+/// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
+/// outs(%init0, %init1 : ...) {
+/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
+/// %0 = <s0> %b0, %b1 : ...
+/// %1 = <s1> %0, %b2 : ...
+/// linalg.yield %0, %1 : ...
+/// } -> (..., ...)
+/// return %result#0, %result#1
+/// ```
+///
+/// gets split into
+///
+/// ```mlir
+/// %init = linalg.init_tensor ...
+/// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
+/// outs(%init0, %init1, %init : ...)
+/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
+/// %0 = <s0> %b0, %b1 : ...
+/// linalg.yield %0, %..., %0 : ...
+/// } -> (..., ..., ...)
+/// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
+/// outs(%init0, %init1 : ...) {
+/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
+/// %1 = <s1> %b3, %b2 : ...
+/// linalg.yield %..., %1 : ...
+/// } -> (..., ...)
+/// return %op0#0, %op1#1
+/// ```
+///
+/// After canonicalization this is expected to be
+///
+/// ```mlir
+/// %init = linalg.init_tensor ...
+/// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
+/// outs(%init : ...)
+/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
+/// %0 = <s0> %b0, %b1 : ...
+/// linalg.yield %0 : ...
+/// } -> ...
+/// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
+/// outs(%init1 : ...) {
+/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
+/// %1 = <s1> %b1, %b0 : ...
+/// linalg.yield %..., %1 : ...
+/// } -> ...
+/// return %op0, %op1
+/// ```
+struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Helper method to create a generic op for the peeled scalar operation. The
+ /// created op has an empty region.
+ GenericOp createPeeledGenericOp(GenericOp genericOp,
+ PatternRewriter &rewriter) const;
+
+ /// Helper method to create a generic op for the residual scalar operation.
+ /// The created op has the same region as the original op.
+ GenericOp createResidualGenericOp(GenericOp genericOp,
+ GenericOp peeledGenericOp,
+ PatternRewriter &rewriter) const;
+};
+} // namespace
+
+/// Helper method to compute the range of a generic op.
+static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
+ GenericOp op) {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto allShapesSizes =
+ cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
+ AffineMap map = op.getShapesToLoopsMap();
+ return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes));
+}
+
+/// Helper method to permute the list of `values` based on the `map`.
+SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
+ AffineMap map) {
+ assert(map.isPermutation());
+ SmallVector<OpFoldResult> permutedValues(values.size());
+ for (auto position :
+ llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
+ return expr.cast<AffineDimExpr>().getPosition();
+ })))
+ permutedValues[position.value()] = values[position.index()];
+ return permutedValues;
+}
+
+/// Get zero value for an element type.
+static Value getZero(OpBuilder &b, Location loc, Type elementType) {
+ assert(elementType.isIntOrIndexOrFloat() &&
+ "expected scalar type while computing zero value");
+ if (elementType.isa<IntegerType>())
+ return b.create<arith::ConstantIntOp>(loc, 0, elementType);
+ if (elementType.isIndex())
+ return b.create<arith::ConstantIndexOp>(loc, 0);
+ // Assume float.
+ auto floatType = elementType.cast<FloatType>();
+ return b.create<arith::ConstantFloatOp>(
+ loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
+}
+
+GenericOp
+DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
+ PatternRewriter &rewriter) const {
+ Block *body = genericOp.getBody();
+ Operation *peeledScalarOperation = &(*body->begin());
+ SmallVector<AffineMap> peeledGenericOpIndexingMaps =
+ genericOp.getIndexingMaps();
+
+ /// Compute the loop ranges for operation. This is the shape of the result of
+ /// the generic op for the peeled operation.
+ Location loc = genericOp.getLoc();
+ SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
+ SmallVector<Value> newInitValues;
+ SmallVector<Type> newResultTypes;
+
+ /// The indexing map to use for the new results is obtained by
+ /// - Check if the result is yielded. If so use the same indexing map as the
+ /// corresponding output
+ /// - Identity indexing map if the result is not yielded.
+ Operation *yieldOp = body->getTerminator();
+ auto getResultIndexingMap = [&](OpResult scalarOpResult) -> AffineMap {
+ OpOperand *firstUseInYield = nullptr, *identityUseInYield = nullptr;
+ for (OpOperand &use : scalarOpResult.getUses()) {
+ if (use.getOwner() != yieldOp)
+ continue;
+ if (!firstUseInYield)
+ firstUseInYield = &use;
+ OpResult genericOpResult =
+ genericOp.getResult(use.getOperandNumber()).cast<OpResult>();
+ AffineMap indexingMap =
+ genericOp.getTiedIndexingMapForResult(genericOpResult);
+ if (indexingMap.isIdentity())
+ identityUseInYield = &use;
+ }
+ if (identityUseInYield || !firstUseInYield)
+ return rewriter.getMultiDimIdentityMap(domain.size());
+ OpResult genericOpResult =
+ genericOp.getResult(firstUseInYield->getOperandNumber())
+ .cast<OpResult>();
+ return genericOp.getTiedIndexingMapForResult(genericOpResult);
+ };
+
+ for (auto scalarResult : peeledScalarOperation->getResults()) {
+ AffineMap resultIndexingMap = getResultIndexingMap(scalarResult);
+ SmallVector<OpFoldResult> initSize =
+ permuteValues(domain, resultIndexingMap);
+ Value initTensor = rewriter.create<linalg::InitTensorOp>(
+ loc, initSize, scalarResult.getType());
+ newInitValues.push_back(initTensor);
+ newResultTypes.push_back(initTensor.getType());
+ peeledGenericOpIndexingMaps.push_back(resultIndexingMap);
+ }
+
+ /// Create the peeled generic op with an empty body.
+ SmallVector<Value> outsOperands = genericOp.getOutputOperands();
+ outsOperands.append(newInitValues.begin(), newInitValues.end());
+ SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
+ resultTypes.append(newResultTypes.begin(), newResultTypes.end());
+ auto indexingMapAttr =
+ rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
+ return rewriter.create<GenericOp>(
+ loc, resultTypes, genericOp.inputs(), outsOperands, indexingMapAttr,
+ genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
+ [](OpBuilder, Location, ValueRange) {});
+}
+
+GenericOp
+DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
+ GenericOp peeledGenericOp,
+ PatternRewriter &rewriter) const {
+ /// Append all results from the peeledGenericOps as `ins` operand for the
+ /// residual generic op.
+ SmallVector<Value> residualGenericOpOperands = llvm::to_vector(
+ llvm::map_range(genericOp.getInputOperands(),
+ [](OpOperand *operand) { return operand->get(); }));
+ unsigned origNumResults = genericOp.getNumResults();
+ unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
+ SmallVector<Value> extraIns;
+ for (auto resultNum :
+ llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
+ extraIns.push_back(peeledGenericOp->getResult(resultNum));
+ residualGenericOpOperands.append(extraIns);
+
+ /// Add indexing maps for the newly added operands. Use the same map
+ /// as those used for the new results of the peeledGenericOp.
+ auto indexingMaps = llvm::to_vector(
+ llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) {
+ return genericOp.getTiedIndexingMap(operand);
+ }));
+ for (auto resultNum :
+ llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
+ OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
+ indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result));
+ }
+ for (OpOperand *outOperand : genericOp.getOutputOperands())
+ indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand));
+
+ auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
+ return rewriter.create<GenericOp>(
+ genericOp->getLoc(), genericOp->getResultTypes(),
+ residualGenericOpOperands, genericOp.outputs(), indexingMapAttr,
+ genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
+ [](OpBuilder, Location, ValueRange) {});
+}
+
+LogicalResult
+DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const {
+ /// For now only match on operations where the iterator types are all parallel
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "unhandled decomposition of operation "
+ "with non-parallel iterator types");
+ }
+ // TODO: this could be generalized to handle `linalg.generic` with buffer
+ // operands too but requires allocation for intermediates. Punt on this for
+ // now.
+ if (!genericOp.hasTensorSemantics()) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "only operations with tensor semantics are handled");
+ }
+
+ // TODO: For now only decompose operations where the `outs` operands values
+ // are not accessed within the payload. This might be relaxed in future, but
+ // needs a bit more reasoning to ensure that it is safe.
+ if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
+ return genericOp.payloadUsesValueFromOperand(outOperand);
+ })) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "unhandled decomposition of generic op with use of out "
+ "operand value in payload");
+ }
+
+ if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
+ return !genericOp.getTiedIndexingMap(outOperand).isPermutation();
+ })) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "unhandled decomposition of generic op with out operand not "
+ "accessed using a permutation");
+ }
+
+ /// If the op has only a single statement (apart from the yield), do nothing.
+ Block *body = genericOp.getBody();
+ if (body->getOperations().size() <= 2) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "operation has less than 3 statements");
+ }
+
+ /// Check that the peeled statement has a scalar element type.
+ if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
+ [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
+ return rewriter.notifyMatchFailure(
+ &(*body->getOperations().begin()),
+ "expected return type to be only int, index or float");
+ }
+
+ GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
+ GenericOp residualGenericOp =
+ createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
+
+ /// Move the first statement of the original operation into the body of the
+ /// generic op for the peeled operation.
+ Block *peeledGenericOpBody = peeledGenericOp.getBody();
+ Block *residualGenericOpBody = residualGenericOp.getBody();
+ assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
+ "expected split generic ops to have empty region");
+ peeledGenericOpBody->getOperations().splice(
+ peeledGenericOpBody->begin(), body->getOperations(), body->begin());
+ residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
+ body->getOperations());
+
+ Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
+ auto yieldOp = residualGenericOpBody->getTerminator();
+ {
+ // Yield all the result of the peeled scalar operation.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToEnd(peeledGenericOpBody);
+ SmallVector<Value> yieldedVals;
+ for (auto origYield : yieldOp->getOperands()) {
+ if (origYield.getDefiningOp() == peeledScalarOperation) {
+ yieldedVals.push_back(origYield);
+ } else {
+ yieldedVals.push_back(
+ getZero(rewriter, genericOp.getLoc(), origYield.getType()));
+ }
+ }
+ yieldedVals.append(llvm::to_vector(
+ llvm::map_range(peeledScalarOperation->getResults(),
+ [](OpResult opr) -> Value { return opr; })));
+ rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
+ }
+
+ /// In the split operations, replace block arguments uses that refer to
+ /// original operation to the block arguments of the newly created operation.
+ unsigned origNumInputs = genericOp.getNumInputs();
+ for (auto inputBlockArg :
+ llvm::enumerate(genericOp.getBody()->getArguments())) {
+ Value residualOpReplacementArg =
+ residualGenericOpBody->getArgument(inputBlockArg.index());
+ inputBlockArg.value().replaceUsesWithIf(
+ residualOpReplacementArg, [&](OpOperand &use) {
+ return use.getOwner()->getBlock() == residualGenericOpBody;
+ });
+
+ Value peeledOpReplacementArg =
+ peeledGenericOpBody->getArgument(inputBlockArg.index());
+ inputBlockArg.value().replaceUsesWithIf(
+ peeledOpReplacementArg, [&](OpOperand &use) {
+ return use.getOwner()->getBlock() == peeledGenericOpBody;
+ });
+ }
+
+ /// Before fixing up the residual operation, track what values are yielded. If
+ /// any of those are from the peeled scalar operation, the uses of the
+ /// corresponding result have to be remapped to result of the generic op for
+ /// the peeled operation.
+ SmallVector<Value> replacements;
+ for (auto yieldValue : llvm::enumerate(yieldOp->getOperands())) {
+ OpResult opr = yieldValue.value().dyn_cast<OpResult>();
+ if (!opr || opr.getOwner() != peeledScalarOperation)
+ replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
+ else
+ replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
+ }
+
+ /// Update all uses of the peeled scalar operation results in the residual op
+ /// to the newly added arguments.
+ {
+ SmallVector<Value> scalarReplacements;
+ unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
+ scalarReplacements.reserve(peeledScalarOpNumResults);
+ for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
+ scalarReplacements.push_back(
+ residualGenericOpBody->getArgument(num + origNumInputs));
+ bool allUsesReplaced = false;
+ rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
+ residualGenericOpBody, &allUsesReplaced);
+ assert(!allUsesReplaced &&
+ "peeled scalar operation is erased when it wasnt expected to be");
+ }
+
+ // Replace the original operation
+ rewriter.replaceOp(genericOp, replacements);
+ return success();
+}
+
+void mlir::linalg::populateDecomposeLinalgOpsPattern(
+ RewritePatternSet &patterns) {
+ patterns.insert<DecomposeLinalgOp>(patterns.getContext());
+}