aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2022-07-26 10:42:17 +0000
committerAlex Zinenko <zinenko@google.com>2022-07-27 08:52:08 +0000
commit08a1b07e7c19d14896f8be501c763ba8aff5b427 (patch)
tree96de1b6d7f5b9fe8ffa6bc873301351e3a88eeb0 /mlir/lib/Dialect/Linalg/Transforms
parent79ff02a1220518b0f98a1ed403f81c26376c76a9 (diff)
[mlir] Partially port splitting transform to TilingInterface
The structured op splitting transformation is conceptually similar to tiling in the sense that it decomposes the iteration space of the original op into several parts. Therefore, it is possible to implement it using the TilingInterface to operate on iteration spaces and their parts. However, the implementation also requires to pass updated input operands, which is not supported by the interface, so the implementation currently remains Linalg-specific. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D129564
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Split.cpp225
1 files changed, 101 insertions, 124 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 8849e7f964b3..d735c671ab49 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -8,147 +8,124 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace mlir::linalg;
-/// Extract the slices of `operands` supplied to the given operation `op` such
-/// that they are sufficient to execute the op for the subset of its iteration
-/// space defined by `splitIterationSpace`. The subset is a part of the original
-/// iteration space split at the given `dimension`. If `offset` is provided, it
-/// indicates the iterator value at which the dimension has been split and
-/// requires the "high" part starting at the given offset of the operands to be
-/// generated; otherwise, the "low" part with no offset is generated. Note that
-/// `operands` are not necessarily the actual operands of `op`.
-static SmallVector<Value>
-getOperandSlices(RewriterBase &b, Location loc, LinalgOp op,
- ValueRange splitIterationSpace, ValueRange operands,
- unsigned dimension, Value offset = nullptr) {
- SmallVector<Value> slices;
- slices.reserve(op.getNumInputsAndOutputs());
- for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
- auto type = opOperand->get().getType().dyn_cast<ShapedType>();
- AffineMap indexing = op.getTiedIndexingMap(opOperand);
-
- // If the type is not sliceable, or the slice is requested along the
- // dimension that is not used in indexing this type, just use the entire
- // operand.
- if (!type || dimension >= indexing.getNumDims() ||
- !indexing.isFunctionOfDim(dimension)) {
- slices.push_back(opOperand->get());
- continue;
- }
-
- SmallVector<OpFoldResult> sizes;
- sizes.reserve(indexing.getNumResults());
- for (AffineExpr dimIndexing : indexing.getResults()) {
- sizes.push_back(makeComposedFoldedAffineApply(
- b, loc, dimIndexing,
- getAsOpFoldResult(llvm::to_vector(splitIterationSpace))));
- }
- SmallVector<OpFoldResult> offsets(type.getRank(), b.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(type.getRank(), b.getIndexAttr(1));
-
- if (offset) {
- offsets[dimension] = offset;
- offsets = applyMapToValues(b, loc, indexing, offsets);
- }
-
- slices.push_back(createSlice(b, loc,
- operands[opOperand->getOperandNumber()],
- offsets, sizes, strides));
- }
-
- return slices;
-}
-
/// Creates a part of the given `op` split along the iteration space `dimension`
/// with the given `size` and an optional `offset` (default 0). Makes slices
/// of operands, using the input operands of the original op and the output
-/// operands provided as `resultOperands`. Expects `splitIterationSpace` to be
-/// a list of values representing the shape of the iteration space of the
-/// original op and updates it to be the iteration space of the curent part.
-/// Returns the split-out op as well as the output operand values updated with
-/// the partial results produced by this op through `results`.
-static LinalgOp
-createSplitPart(RewriterBase &b, Location loc, LinalgOp op,
- ValueRange resultOperands,
- llvm::MutableArrayRef<Value> splitIterationSpace,
- unsigned dimension, OpFoldResult size,
- SmallVectorImpl<Value> &results, Value offset = nullptr) {
- ImplicitLocOpBuilder implicit(op.getLoc(), b);
- splitIterationSpace[dimension] = materializeOpFoldResult(implicit, size);
- SmallVector<Value> operands = llvm::to_vector(
- llvm::map_range(op.getInputOperands(),
- [](OpOperand *opOperand) { return opOperand->get(); }));
- llvm::append_range(operands, resultOperands);
- operands = getOperandSlices(b, loc, op, splitIterationSpace, operands,
- dimension, offset);
- Operation *part =
- op.clone(b, loc, getTensorOutputTypes(op, operands), operands);
- results = insertSlicesBack(b, loc, op, operands, part->getResults());
- return cast<LinalgOp>(part);
+/// operands provided as `resultOperands`. Expects `offsets` and `sizes` to
+/// define the shape of the iteration space of the original op. Returns the
+/// split-out op as well as the output operand values updated with the partial
+/// results produced by this op through `results`.
+static TilingInterface
+createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ ValueRange resultOperands, unsigned dimension,
+ OpFoldResult size, OpFoldResult offset,
+ SmallVectorImpl<Value> &results) {
+ // Iteration space of the current part.
+ SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes);
+ SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets);
+ sizesCopy[dimension] = size;
+ offsetsCopy[dimension] = offset;
+
+ // Create the part as it it were a single tile.
+ SmallVector<Operation *> tiled =
+ op.getTiledImplementation(b, resultOperands, offsetsCopy, sizesCopy,
+ /*tileDestOperands=*/true);
+ assert(tiled.size() == 1 && "expected a single result from tiling");
+ auto part = cast<TilingInterface>(tiled.front());
+
+ // Insert the results back and populate the `results` list.
+ for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) {
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy,
+ resultOffsets, resultSizes)))
+ return nullptr;
+ SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
+ b.getIndexAttr(1));
+ Value inserted = b.create<tensor::InsertSliceOp>(
+ loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes,
+ resultStrides);
+ results.push_back(inserted);
+ }
+
+ return part;
}
-std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
- LinalgOp op, unsigned dimension,
- OpFoldResult splitPoint) {
+std::pair<TilingInterface, TilingInterface>
+linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
+ OpFoldResult splitPoint) {
+ // Compute the iteration space.
+ SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter);
+
// Bail out on dimension overflow.
- if (dimension >= op.getNumLoops())
- return std::make_pair(op, LinalgOp());
-
- // Compute the iteration space size as values.
- SmallVector<Value, 4> allShapes =
- op.createFlatListOfOperandDims(rewriter, op.getLoc());
- AffineMap shapesToLoops = op.getShapesToLoopsMap();
- SmallVector<Value, 4> iterationSpaceShapes =
- applyMapToValues(rewriter, op.getLoc(), shapesToLoops, allShapes);
-
- // Update the iteration space to have `splitPoint` as the size of `dimension`
- // and use it to slice operands and results for a new, smaller instance of the
- // `op`. Adjust the size if necessary to prevent overflows. Insert the partial
- // results back.
- OpFoldResult dimSize = getAsOpFoldResult(iterationSpaceShapes[dimension]);
+ if (dimension >= iterationSpace.size())
+ return std::make_pair(op, TilingInterface());
+
+ SmallVector<OpFoldResult> offsets =
+ getAsOpFoldResult(llvm::to_vector(llvm::map_range(
+ iterationSpace, [](const Range &range) { return range.offset; })));
+ SmallVector<OpFoldResult> sizes =
+ getAsOpFoldResult(llvm::to_vector(llvm::map_range(
+ iterationSpace, [](const Range &range) { return range.size; })));
+
+ // Adjust the split point so that it doesn't overflow the size.
+ AffineExpr d0, d1, d2;
+ bindDims(rewriter.getContext(), d0, d1, d2);
OpFoldResult minSplitPoint = makeComposedFoldedAffineMin(
- rewriter, op->getLoc(),
- AffineMap::getMultiDimIdentityMap(/*numDims=*/2, rewriter.getContext()),
- {splitPoint, dimSize});
- SmallVector<Value> splitIterationSpace =
- llvm::to_vector(iterationSpaceShapes);
- SmallVector<Value> originalResults = llvm::to_vector(
- llvm::map_range(op.getOutputOperands(),
- [](OpOperand *opOperand) { return opOperand->get(); }));
- SmallVector<Value> firstResults;
- LinalgOp first = createSplitPart(rewriter, op.getLoc(), op, originalResults,
- splitIterationSpace, dimension,
- minSplitPoint, firstResults);
-
- // Update the iteration space to cover the remaining part of the original
- // space, then create another instance of the `op` in that space. The size of
- // the remaining part may become zero, but is never negative because of the
- // adjustment above.
- AffineExpr d0 = rewriter.getAffineDimExpr(0);
- AffineExpr d1 = rewriter.getAffineDimExpr(1);
+ rewriter, op.getLoc(),
+ AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2}).front(),
+ {splitPoint, offsets[dimension], sizes[dimension]});
+
+ // Compute the size of the second part. Return early if the second part would
+ // have an empty iteration space.
OpFoldResult remainingSize = makeComposedFoldedAffineApply(
- rewriter, op.getLoc(), d0 - d1, {dimSize, minSplitPoint});
+ rewriter, op.getLoc(), d0 + d1 - d2,
+ {iterationSpace[dimension].offset, iterationSpace[dimension].size,
+ minSplitPoint});
+ if (auto attr = remainingSize.dyn_cast<Attribute>()) {
+ if (attr.cast<IntegerAttr>().getValue().isZero())
+ return {op, TilingInterface()};
+ }
+
+ // Create the first part.
+ SmallVector<Value> firstResults;
+ TilingInterface firstPart = createSplitPart(
+ rewriter, op.getLoc(), op, offsets, sizes,
+ op.getDestinationOperands(rewriter), dimension, minSplitPoint,
+ getAsOpFoldResult(iterationSpace[dimension].offset), firstResults);
+
+ // Need to pretend that the original op now takes as operands firstResults,
+ // otherwise tiling interface implementation will take the wrong value to
+ // produce data tiles.
+ rewriter.updateRootInPlace(op, [&]() {
+ unsigned numTotalOperands = op->getNumOperands();
+ unsigned numOutputOperands = firstResults.size();
+ op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
+ firstResults);
+ });
+
+ // Create the second part.
+ OpFoldResult totalOffset = makeComposedFoldedAffineApply(
+ rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint});
SmallVector<Value> secondResults;
- ImplicitLocOpBuilder implicit(op.getLoc(), rewriter);
- Value splitPointValue = materializeOpFoldResult(implicit, minSplitPoint);
- LinalgOp second = createSplitPart(
- rewriter, op.getLoc(), op, firstResults, splitIterationSpace, dimension,
- remainingSize, secondResults, splitPointValue);
-
- // Fixup the linalg.index results in the second part.
- SmallVector<Value> ivAdditions;
- ivAdditions.resize(splitIterationSpace.size());
- ivAdditions[dimension] = splitPointValue;
- linalg::offsetIndices(rewriter, cast<LinalgOp>(second), ivAdditions);
+ TilingInterface secondPart =
+ createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
+ dimension, remainingSize, totalOffset, secondResults);
// Replace the original op with the results of the two newly created ops.
rewriter.replaceOp(op, secondResults);
- return std::make_pair(first, second);
+ return {firstPart, secondPart};
}