From 297ba167ded073a47dd9ea7e408aa95acdfcedf1 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Thu, 21 Jul 2022 09:26:46 -0600 Subject: [mlir][linalg] Add tile_size option to `structured.tile_to_foreach_thread_op` This change modifies `structured.tile_to_foreach_thread_op` so that it accepts either `tile_sizes` or `num_threads` parameters. If `tile_sizes` are specified, then the number of threads required is derived the tile sizes rather than the other way around. In both cases, more aggressive folding of loop parameters is enabled during the transformation, allowing for the potential elimination of `affine.min` and `affine.max` operations in the static shape case when calculating the final adjusted tile size. Differential Revision: https://reviews.llvm.org/D130139 --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 9 +- .../Linalg/TransformOps/LinalgTransformOps.cpp | 20 +- mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt | 2 + mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 231 +++++++++++++-------- 4 files changed, 173 insertions(+), 89 deletions(-) (limited to 'mlir/lib/Dialect') diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index d2dc66d61078..345e27742fbb 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -721,8 +721,13 @@ static void materializeConstants(OpBuilder &b, Location loc, actualValues.push_back(value); continue; } - constants.push_back(dialect->materializeConstant(b, ofr.get(), - b.getIndexType(), loc)); + // Since we are directly specifying `index` as the result type, we need to + // ensure the provided attribute is also an index type. Otherwise, the + // AffineDialect materializer will create invalid `arith.constant` + // operations if the provided Attribute is any other kind of integer. + constants.push_back(dialect->materializeConstant( + b, b.getIndexAttr(ofr.get().cast().getInt()), + b.getIndexType(), loc)); actualValues.push_back(constants.back()->getResult(0)); } } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index adab8da3d518..070b1fc4eb82 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -909,12 +909,20 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); auto maybeThreadDimMappingAttr = getThreadDimMapping(); - FailureOr tilingResult = - linalg::tileToForeachThreadOp( - rewriter, target, getAsOpFoldResult(getNumThreads()), - maybeThreadDimMappingAttr - ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) - : ArrayRef{}); + auto dimMapping = + llvm::to_vector(maybeThreadDimMappingAttr + ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) + : ArrayRef{}); + + FailureOr tilingResult = failure(); + if (Optional numThreads = getNumThreads()) + tilingResult = linalg::tileToForeachThreadOp( + rewriter, target, getAsOpFoldResult(*numThreads), dimMapping); + + if (Optional tileSizes = getTileSizes()) + tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( + rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping); + if (failed(tilingResult)) return emitDefaultSilenceableFailure(target); rewriter.replaceOp(target, tilingResult->tileOp->getResults()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index a8112dbe50b8..6d97dfc6d84f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -41,9 +41,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRAnalysis MLIRArithmeticDialect MLIRArithmeticTransforms + MLIRArithmeticUtils MLIRBufferizationDialect MLIRBufferizationTransforms MLIRComplexDialect + MLIRDialectUtils MLIRFuncDialect MLIRFuncToLLVM MLIRFuncTransforms diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 0571ff5432af..1dfaf69efa72 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -13,6 +13,7 @@ #include #include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -182,23 +183,43 @@ createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc, } /// Build an `affine_max` of all the `vals`. -static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) { +static OpFoldResult buildMax(OpBuilder &b, Location loc, + ArrayRef vals) { + SmallVector args = getValueOrCreateConstantIndexOp(b, loc, vals); return b.createOrFold( loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), - vals); + args); } -/// Build an `affine_min` of all the `vals`. -static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) { - return b.createOrFold( - loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), - vals); +/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less +/// than `iterationSize`. +static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, + OpFoldResult numThreads, + OpFoldResult iterationSize) { + Optional tileSizeConst = getConstantIntValue(tileSize); + Optional numThreadsConst = getConstantIntValue(numThreads); + Optional iterSizeConst = getConstantIntValue(iterationSize); + if (!tileSizeConst || !numThreadsConst || !iterSizeConst) + return false; + return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; } -FailureOr -linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op, - ArrayRef numThreads, - ArrayRef threadDimMapping) { +/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The +/// tiling is specified by the number of tiles/threads `numThreads` and the +/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is +/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i], +/// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an +/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate +/// that the dimension is not tiled, and can be thought of as tiling by the full +/// size of data. +/// It is the user's responsibility to ensure that `numThreads` is a valid +/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the +/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will +/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. +static FailureOr tileToForeachThreadOpImpl( + RewriterBase &b, TilingInterface op, ArrayRef numThreads, + Optional> nominalTileSizes, + ArrayRef threadDimMapping, bool omitTileOffsetBoundsCheck) { Location loc = op->getLoc(); OpBuilder::InsertionGuard g(b); SmallVector loopRanges = op.getIterationDomain(b); @@ -224,80 +245,128 @@ linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op, Value zero = b.create(loc, 0); Operation *tiledOp = nullptr; + + // Create the ForeachThreadOp. We don't use the lambda body-builder + // version because we require the use of RewriterBase in the body, so we + // manually move the insertion point to the body below. scf::ForeachThreadOp foreachThreadOp = b.create( - loc, materializedNonZeroNumThreads, threadDimMapping, - [&](OpBuilder &b, Location loc, ValueRange threadIds) { - int64_t nLoops = loopRanges.size(); - SmallVector tiledOffsets, tiledSizes; - tiledOffsets.reserve(nLoops); - tiledSizes.reserve(nLoops); - for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; - ++loopIdx) { - bool overflow = loopIdx >= numThreads.size(); - bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); - // Degenerate case: take the whole domain. - if (overflow || isZero) { - tiledOffsets.push_back(loopRanges[loopIdx].offset); - tiledSizes.push_back(loopRanges[loopIdx].size); - continue; - } - - // Tiled case: compute the offset and size. - AffineExpr i, j, M, N, O; - bindDims(b.getContext(), i, j); - bindSymbols(b.getContext(), M, N, O); - Value size = loopRanges[loopIdx].size; - Value offset = loopRanges[loopIdx].offset; - Value threadId = threadIds[threadIdIdx]; - // TODO: more aggressive foldings. - // Symbolic fixed max size per thread. - // TODO: floor + 0/1 depending on case for better load-balancing. - Value maxSizePerThread = b.createOrFold( - loc, M.ceilDiv(N), - ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]}); - // Dynamic offset shifted by threadId * maxSizePerThread. - Value offsetPerThread = b.createOrFold( - loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread}); - // Dynamic upper-bound depending on the threadId. - Value sizeMinusOffsetPerThread = b.createOrFold( - loc, -i + M, ValueRange{offsetPerThread, size}); - Value tileSizePerThread = buildMin( - b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread}); - tiledOffsets.push_back(offsetPerThread); - // TODO: if tileSizePerThread <= 0 early exit. - tiledSizes.push_back( - buildMax(b, loc, ValueRange{zero, tileSizePerThread})); - ++threadIdIdx; - } - - SmallVector tiledOps = - op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes, - /*tileDestOperands=*/true); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); - - auto tilingInterfaceOp = dyn_cast(tiledOp); - assert(tilingInterfaceOp && - "Tiled op does not implement TilingInterface"); - - auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); - - // Create terminator with parallel subset insert operations. - auto performConcurrentlyOp = b.create(loc); - OpBuilder::InsertionGuard g(b); - b.setInsertionPointToStart(performConcurrentlyOp.getBody()); - for (auto it : - llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), - destOperands)) { - createMatchingParallelSubsetInsertOp( - b, loc, - cast(std::get<0>(it).getDefiningOp()), - std::get<1>(it), std::get<2>(it)); - } - }); + loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads), + threadDimMapping); + + // Fill out the ForeachThreadOp body. + b.setInsertionPointToStart(foreachThreadOp.getBody(0)); + ValueRange threadIds = foreachThreadOp.getThreadIndices(); + int64_t nLoops = loopRanges.size(); + SmallVector tiledOffsets, tiledSizes; + tiledOffsets.reserve(nLoops); + tiledSizes.reserve(nLoops); + for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { + bool overflow = loopIdx >= numThreads.size(); + bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); + // Degenerate case: take the whole domain. + if (overflow || isZero) { + tiledOffsets.push_back(loopRanges[loopIdx].offset); + tiledSizes.push_back(loopRanges[loopIdx].size); + continue; + } + + // Tiled case: compute the offset and size. + AffineExpr i, j, M, N, O; + bindDims(b.getContext(), i, j); + bindSymbols(b.getContext(), M, N, O); + Value size = loopRanges[loopIdx].size; + Value offset = loopRanges[loopIdx].offset; + Value threadId = threadIds[threadIdIdx]; + // Symbolic fixed max size per thread. + // TODO: floor + 0/1 depending on case for better load-balancing. + OpFoldResult tileSizePerThread = + nominalTileSizes.hasValue() + ? (*nominalTileSizes)[loopIdx] + : makeComposedFoldedAffineApply( + b, loc, M.ceilDiv(N), + ArrayRef{size, nonZeroNumThreads[threadIdIdx]}); + + // Dynamic offset shifted by threadId * maxSizePerThread. + OpFoldResult offsetPerThread = makeComposedFoldedAffineApply( + b, loc, i + j * M, {offset, threadId, tileSizePerThread}); + // Dynamic upper-bound depending on the threadId. + OpFoldResult residualTileSize = makeComposedFoldedAffineApply( + b, loc, i + j * M - N, + {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); + if (!isConstantIntValue(residualTileSize, 0)) { + OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( + b, loc, -i + M, {offsetPerThread, size}); + tileSizePerThread = makeComposedFoldedAffineMin( + b, loc, AffineMap::getMultiDimIdentityMap(2, b.getContext()), + ArrayRef{sizeMinusOffsetPerThread, tileSizePerThread}); + } + + tiledOffsets.push_back(offsetPerThread); + // TODO: if tileSizePerThread <= 0 early exit. + if (!omitTileOffsetBoundsCheck && + !canOmitTileOffsetInBoundsCheck(tileSizePerThread, + nonZeroNumThreads[threadIdIdx], size)) + tileSizePerThread = buildMax(b, loc, {zero, tileSizePerThread}); + + tiledSizes.push_back(tileSizePerThread); + ++threadIdIdx; + } + + SmallVector tiledOps = + op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes, + /*tileDestOperands=*/true); + assert(tiledOps.size() == 1 && "expected a single produced tiled op"); + tiledOp = tiledOps.front(); + + auto tilingInterfaceOp = dyn_cast(tiledOp); + assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); + + auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); + + // Create terminator with parallel subset insert operations. + b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); + for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), + destOperands)) { + createMatchingParallelSubsetInsertOp( + b, loc, cast(std::get<0>(it).getDefiningOp()), + std::get<1>(it), std::get<2>(it)); + } return ForeachThreadTilingResult{foreachThreadOp, tiledOp}; } +FailureOr +linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op, + ArrayRef numThreads, + ArrayRef threadDimMapping) { + return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None, + threadDimMapping, + /*omitTileOffsetBoundsCheck=*/false); +} + +FailureOr +linalg::tileToForeachThreadOpUsingTileSizes( + RewriterBase &b, TilingInterface op, ArrayRef tileSizes, + ArrayRef threadDimMapping) { + SmallVector loopRanges = op.getIterationDomain(b); + unsigned nLoops = loopRanges.size(); + SmallVector numThreads; + numThreads.reserve(nLoops); + AffineExpr s0, s1; + bindSymbols(b.getContext(), s0, s1); + AffineExpr divExpr = s0.ceilDiv(s1); + for (const auto &it : llvm::zip(tileSizes, loopRanges)) { + OpFoldResult numTiles = std::get<0>(it); + if (!isConstantIntValue(numTiles, 0)) + numTiles = makeComposedFoldedAffineApply( + b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)}); + numThreads.push_back(numTiles); + } + return tileToForeachThreadOpImpl(b, op, numThreads, + /*nominalTileSizes=*/tileSizes, + threadDimMapping, + /*omitTileOffsetBoundsCheck=*/true); +} + // Insert a tile `source` into the destination tensor `dest`. The position at // which the tile is inserted (as well as size of tile) is taken from a given // ExtractSliceOp `sliceOp`. -- cgit v1.2.3