aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms
diff options
context:
space:
mode:
authorChristopher Bate <cbate@nvidia.com>2022-07-21 09:26:46 -0600
committerChristopher Bate <cbate@nvidia.com>2022-07-21 10:32:01 -0600
commit297ba167ded073a47dd9ea7e408aa95acdfcedf1 (patch)
tree9de4da963d098b2ced7988d3d3a4c80285903483 /mlir/lib/Dialect/Linalg/Transforms
parent9e16fb72dd7456502eecdc60b2b8b2ebec362c18 (diff)
[mlir][linalg] Add tile_size option to `structured.tile_to_foreach_thread_op`
This change modifies `structured.tile_to_foreach_thread_op` so that it accepts either `tile_sizes` or `num_threads` parameters. If `tile_sizes` are specified, then the number of threads required is derived the tile sizes rather than the other way around. In both cases, more aggressive folding of loop parameters is enabled during the transformation, allowing for the potential elimination of `affine.min` and `affine.max` operations in the static shape case when calculating the final adjusted tile size. Differential Revision: https://reviews.llvm.org/D130139
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp231
2 files changed, 152 insertions, 81 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a8112dbe50b8..6d97dfc6d84f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -41,9 +41,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRAnalysis
MLIRArithmeticDialect
MLIRArithmeticTransforms
+ MLIRArithmeticUtils
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRComplexDialect
+ MLIRDialectUtils
MLIRFuncDialect
MLIRFuncToLLVM
MLIRFuncTransforms
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 0571ff5432af..1dfaf69efa72 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -13,6 +13,7 @@
#include <utility>
#include "PassDetail.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -182,23 +183,43 @@ createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc,
}
/// Build an `affine_max` of all the `vals`.
-static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) {
+static OpFoldResult buildMax(OpBuilder &b, Location loc,
+ ArrayRef<OpFoldResult> vals) {
+ SmallVector<Value> args = getValueOrCreateConstantIndexOp(b, loc, vals);
return b.createOrFold<AffineMaxOp>(
loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
- vals);
+ args);
}
-/// Build an `affine_min` of all the `vals`.
-static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) {
- return b.createOrFold<AffineMinOp>(
- loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
- vals);
+/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
+/// than `iterationSize`.
+static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
+ OpFoldResult numThreads,
+ OpFoldResult iterationSize) {
+ Optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
+ Optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
+ Optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
+ if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
+ return false;
+ return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
}
-FailureOr<ForeachThreadTilingResult>
-linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
- ArrayRef<OpFoldResult> numThreads,
- ArrayRef<int64_t> threadDimMapping) {
+/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The
+/// tiling is specified by the number of tiles/threads `numThreads` and the
+/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
+/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
+/// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an
+/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate
+/// that the dimension is not tiled, and can be thought of as tiling by the full
+/// size of data.
+/// It is the user's responsibility to ensure that `numThreads` is a valid
+/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
+/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
+/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
+static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
+ RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
+ Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+ ArrayRef<int64_t> threadDimMapping, bool omitTileOffsetBoundsCheck) {
Location loc = op->getLoc();
OpBuilder::InsertionGuard g(b);
SmallVector<Range> loopRanges = op.getIterationDomain(b);
@@ -224,80 +245,128 @@ linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Operation *tiledOp = nullptr;
+
+ // Create the ForeachThreadOp. We don't use the lambda body-builder
+ // version because we require the use of RewriterBase in the body, so we
+ // manually move the insertion point to the body below.
scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
- loc, materializedNonZeroNumThreads, threadDimMapping,
- [&](OpBuilder &b, Location loc, ValueRange threadIds) {
- int64_t nLoops = loopRanges.size();
- SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
- tiledOffsets.reserve(nLoops);
- tiledSizes.reserve(nLoops);
- for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops;
- ++loopIdx) {
- bool overflow = loopIdx >= numThreads.size();
- bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
- // Degenerate case: take the whole domain.
- if (overflow || isZero) {
- tiledOffsets.push_back(loopRanges[loopIdx].offset);
- tiledSizes.push_back(loopRanges[loopIdx].size);
- continue;
- }
-
- // Tiled case: compute the offset and size.
- AffineExpr i, j, M, N, O;
- bindDims(b.getContext(), i, j);
- bindSymbols(b.getContext(), M, N, O);
- Value size = loopRanges[loopIdx].size;
- Value offset = loopRanges[loopIdx].offset;
- Value threadId = threadIds[threadIdIdx];
- // TODO: more aggressive foldings.
- // Symbolic fixed max size per thread.
- // TODO: floor + 0/1 depending on case for better load-balancing.
- Value maxSizePerThread = b.createOrFold<AffineApplyOp>(
- loc, M.ceilDiv(N),
- ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]});
- // Dynamic offset shifted by threadId * maxSizePerThread.
- Value offsetPerThread = b.createOrFold<AffineApplyOp>(
- loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread});
- // Dynamic upper-bound depending on the threadId.
- Value sizeMinusOffsetPerThread = b.createOrFold<AffineApplyOp>(
- loc, -i + M, ValueRange{offsetPerThread, size});
- Value tileSizePerThread = buildMin(
- b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread});
- tiledOffsets.push_back(offsetPerThread);
- // TODO: if tileSizePerThread <= 0 early exit.
- tiledSizes.push_back(
- buildMax(b, loc, ValueRange{zero, tileSizePerThread}));
- ++threadIdIdx;
- }
-
- SmallVector<Operation *> tiledOps =
- op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
- /*tileDestOperands=*/true);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
-
- auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
- assert(tilingInterfaceOp &&
- "Tiled op does not implement TilingInterface");
-
- auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
-
- // Create terminator with parallel subset insert operations.
- auto performConcurrentlyOp = b.create<scf::PerformConcurrentlyOp>(loc);
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPointToStart(performConcurrentlyOp.getBody());
- for (auto it :
- llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
- destOperands)) {
- createMatchingParallelSubsetInsertOp(
- b, loc,
- cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
- std::get<1>(it), std::get<2>(it));
- }
- });
+ loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads),
+ threadDimMapping);
+
+ // Fill out the ForeachThreadOp body.
+ b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ ValueRange threadIds = foreachThreadOp.getThreadIndices();
+ int64_t nLoops = loopRanges.size();
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(nLoops);
+ tiledSizes.reserve(nLoops);
+ for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
+ bool overflow = loopIdx >= numThreads.size();
+ bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+ // Degenerate case: take the whole domain.
+ if (overflow || isZero) {
+ tiledOffsets.push_back(loopRanges[loopIdx].offset);
+ tiledSizes.push_back(loopRanges[loopIdx].size);
+ continue;
+ }
+
+ // Tiled case: compute the offset and size.
+ AffineExpr i, j, M, N, O;
+ bindDims(b.getContext(), i, j);
+ bindSymbols(b.getContext(), M, N, O);
+ Value size = loopRanges[loopIdx].size;
+ Value offset = loopRanges[loopIdx].offset;
+ Value threadId = threadIds[threadIdIdx];
+ // Symbolic fixed max size per thread.
+ // TODO: floor + 0/1 depending on case for better load-balancing.
+ OpFoldResult tileSizePerThread =
+ nominalTileSizes.hasValue()
+ ? (*nominalTileSizes)[loopIdx]
+ : makeComposedFoldedAffineApply(
+ b, loc, M.ceilDiv(N),
+ ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
+
+ // Dynamic offset shifted by threadId * maxSizePerThread.
+ OpFoldResult offsetPerThread = makeComposedFoldedAffineApply(
+ b, loc, i + j * M, {offset, threadId, tileSizePerThread});
+ // Dynamic upper-bound depending on the threadId.
+ OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
+ b, loc, i + j * M - N,
+ {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
+ if (!isConstantIntValue(residualTileSize, 0)) {
+ OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
+ b, loc, -i + M, {offsetPerThread, size});
+ tileSizePerThread = makeComposedFoldedAffineMin(
+ b, loc, AffineMap::getMultiDimIdentityMap(2, b.getContext()),
+ ArrayRef<OpFoldResult>{sizeMinusOffsetPerThread, tileSizePerThread});
+ }
+
+ tiledOffsets.push_back(offsetPerThread);
+ // TODO: if tileSizePerThread <= 0 early exit.
+ if (!omitTileOffsetBoundsCheck &&
+ !canOmitTileOffsetInBoundsCheck(tileSizePerThread,
+ nonZeroNumThreads[threadIdIdx], size))
+ tileSizePerThread = buildMax(b, loc, {zero, tileSizePerThread});
+
+ tiledSizes.push_back(tileSizePerThread);
+ ++threadIdIdx;
+ }
+
+ SmallVector<Operation *> tiledOps =
+ op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
+ /*tileDestOperands=*/true);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
+ assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
+
+ auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
+
+ // Create terminator with parallel subset insert operations.
+ b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody());
+ for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
+ destOperands)) {
+ createMatchingParallelSubsetInsertOp(
+ b, loc, cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
+ std::get<1>(it), std::get<2>(it));
+ }
return ForeachThreadTilingResult{foreachThreadOp, tiledOp};
}
+FailureOr<ForeachThreadTilingResult>
+linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op,
+ ArrayRef<OpFoldResult> numThreads,
+ ArrayRef<int64_t> threadDimMapping) {
+ return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None,
+ threadDimMapping,
+ /*omitTileOffsetBoundsCheck=*/false);
+}
+
+FailureOr<ForeachThreadTilingResult>
+linalg::tileToForeachThreadOpUsingTileSizes(
+ RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<int64_t> threadDimMapping) {
+ SmallVector<Range> loopRanges = op.getIterationDomain(b);
+ unsigned nLoops = loopRanges.size();
+ SmallVector<OpFoldResult> numThreads;
+ numThreads.reserve(nLoops);
+ AffineExpr s0, s1;
+ bindSymbols(b.getContext(), s0, s1);
+ AffineExpr divExpr = s0.ceilDiv(s1);
+ for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
+ OpFoldResult numTiles = std::get<0>(it);
+ if (!isConstantIntValue(numTiles, 0))
+ numTiles = makeComposedFoldedAffineApply(
+ b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
+ numThreads.push_back(numTiles);
+ }
+ return tileToForeachThreadOpImpl(b, op, numThreads,
+ /*nominalTileSizes=*/tileSizes,
+ threadDimMapping,
+ /*omitTileOffsetBoundsCheck=*/true);
+}
+
// Insert a tile `source` into the destination tensor `dest`. The position at
// which the tile is inserted (as well as size of tile) is taken from a given
// ExtractSliceOp `sliceOp`.