diff options
author | Alexander Belyaev <pifon@google.com> | 2022-08-04 11:18:04 +0200 |
---|---|---|
committer | Alexander Belyaev <pifon@google.com> | 2022-08-04 11:23:58 +0200 |
commit | 56d94b3b902e21ff79b1ce9a6fb606a3f7c1c4db (patch) | |
tree | 95074dc12ea4308d94f4bc239ba939aa52bc40a3 /mlir/lib/Dialect | |
parent | 57a9bccec7dea036dbfa1a78f1ec5e73ecf7a33c (diff) |
[mlir] Extract offsets-sizes-strides computation from `makeTiledShape(s)`.
This change separates computation of the actual parameters of the subset and
the materialization of subview/extract_slice. That way the users can still use
Linalg tiling logic even if they use different operations to materialize the
subsets.
Differential Revision: https://reviews.llvm.org/D131053
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 129 |
1 files changed, 86 insertions, 43 deletions
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 4f14164bf26c..0259f9a542f6 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -802,28 +802,61 @@ void GenerateLoopNest<scf::ParallelOp>::doit( assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); } +static Value materializeTiledShape(OpBuilder &builder, Location loc, + Value valueToTile, + const SliceParameters &sliceParams) { + auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); + auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) + .Case([&](MemRefType) { + return builder.create<memref::SubViewOp>( + loc, valueToTile, sliceParams.offsets, + sliceParams.sizes, sliceParams.strides); + }) + .Case([&](RankedTensorType) { + return makeComposedExtractSliceOp( + builder, loc, valueToTile, sliceParams.offsets, + sliceParams.sizes, sliceParams.strides); + }) + .Default([](ShapedType) -> Operation * { + llvm_unreachable("Unexpected shaped type"); + }); + return sliceOp->getResult(0); +} + Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes, bool omitPartialTileCheck) { + SliceParameters sliceParams = + computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, + ubs, subShapeSizes, omitPartialTileCheck); + return materializeTiledShape(builder, loc, valueToTile, sliceParams); +} + +SliceParameters +computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, + ArrayRef<OpFoldResult> tileSizes, AffineMap map, + ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, + ArrayRef<OpFoldResult> subShapeSizes, + bool omitPartialTileCheck) { auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); assert(shapedType && "only shaped types can be tiled"); ArrayRef<int64_t> shape = shapedType.getShape(); int64_t rank = shapedType.getRank(); // Construct a new subview / extract_slice for the tile. - SmallVector<OpFoldResult, 4> offsets, sizes, strides; - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); + SliceParameters sliceParams; + sliceParams.offsets.reserve(rank); + sliceParams.sizes.reserve(rank); + sliceParams.strides.reserve(rank); for (unsigned r = 0; r < rank; ++r) { - LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: for dim#" << r); + LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); if (!isTiled(map.getSubMap({r}), tileSizes)) { - offsets.push_back(builder.getIndexAttr(0)); + sliceParams.offsets.push_back(builder.getIndexAttr(0)); OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r); - sizes.push_back(dim); - strides.push_back(builder.getIndexAttr(1)); + sliceParams.sizes.push_back(dim); + sliceParams.strides.push_back(builder.getIndexAttr(1)); LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); continue; } @@ -832,26 +865,27 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, // Tiling creates a new slice at the proper index, the slice step is 1 // (i.e. the op does not subsample, stepping occurs in the loop). auto m = map.getSubMap({r}); - LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n"); + LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n"); IRRewriter rewriter(builder); OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs); - offsets.push_back(offset); + sliceParams.offsets.push_back(offset); OpFoldResult closedIntSize = makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes); // Resulting size needs to be made half open interval again. AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext()); OpFoldResult size = makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize); - LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n"); LLVM_DEBUG(llvm::dbgs() - << "makeTiledShape: new offset: " << offset << "\n"); - strides.push_back(builder.getIndexAttr(1)); + << "computeSliceParameters: raw size: " << size << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "computeSliceParameters: new offset: " << offset << "\n"); + sliceParams.strides.push_back(builder.getIndexAttr(1)); if (omitPartialTileCheck) { // We statically know that the partial/boundary tile condition is // unnecessary. LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); - sizes.push_back(size); + sliceParams.sizes.push_back(size); continue; } @@ -903,22 +937,9 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset}); } LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); - sizes.push_back(size); + sliceParams.sizes.push_back(size); } - - auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) - .Case([&](MemRefType) { - return builder.create<memref::SubViewOp>( - loc, valueToTile, offsets, sizes, strides); - }) - .Case([&](RankedTensorType) { - return makeComposedExtractSliceOp( - builder, loc, valueToTile, offsets, sizes, strides); - }) - .Default([](ShapedType) -> Operation * { - llvm_unreachable("Unexpected shaped type"); - }); - return sliceOp->getResult(0); + return sliceParams; } SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, @@ -1003,12 +1024,12 @@ Value materializeOpFoldResult(OpBuilder &builder, Location loc, return materializeOpFoldResult(b, opFoldResult); } -SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc, - LinalgOp linalgOp, ValueRange valuesToTile, - ArrayRef<OpFoldResult> ivs, - ArrayRef<OpFoldResult> tileSizes, - ArrayRef<OpFoldResult> sizeBounds, - bool omitPartialTileCheck) { +SmallVector<Optional<SliceParameters>> +computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, + ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, + ArrayRef<OpFoldResult> tileSizes, + ArrayRef<OpFoldResult> sizeBounds, + bool omitPartialTileCheck) { assert(ivs.size() == static_cast<size_t>(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), [](OpFoldResult v) { return !isZero(v); })) && @@ -1016,15 +1037,16 @@ SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc, // Construct (potentially temporary) mins and maxes on which to apply maps // that define tile subshapes. - SmallVector<OpFoldResult> lbs = computeTileOffsets(b, loc, ivs, tileSizes); + SmallVector<OpFoldResult> lbs = + computeTileOffsets(builder, loc, ivs, tileSizes); SmallVector<OpFoldResult> subShapeSizes = - computeTileSizes(b, loc, tileSizes, sizeBounds); + computeTileSizes(builder, loc, tileSizes, sizeBounds); assert(static_cast<int64_t>(valuesToTile.size()) == linalgOp.getNumInputsAndOutputs() && "expected one value to tile for every operand"); - SmallVector<Value> tiledShapes; - tiledShapes.reserve(valuesToTile.size()); + SmallVector<Optional<SliceParameters>> allSliceParams; + allSliceParams.reserve(valuesToTile.size()); for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { Value shapedOp = valuesToTile[opOperand->getOperandNumber()]; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); @@ -1035,18 +1057,39 @@ SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc, // extract/insert slice pairs make the accessed iteration argument // subdomains explicit. if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) { - tiledShapes.push_back(shapedOp); + allSliceParams.push_back(llvm::None); LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " << opOperand->get().getType() << "\n"); continue; } LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); - tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs, - sizeBounds, subShapeSizes, - omitPartialTileCheck)); + allSliceParams.push_back(computeSliceParameters( + builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, + omitPartialTileCheck)); } + return allSliceParams; +} + +SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, + LinalgOp linalgOp, ValueRange valuesToTile, + ArrayRef<OpFoldResult> ivs, + ArrayRef<OpFoldResult> tileSizes, + ArrayRef<OpFoldResult> sizeBounds, + bool omitPartialTileCheck) { + SmallVector<Optional<SliceParameters>> allSliceParameter = + computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, + tileSizes, sizeBounds, omitPartialTileCheck); + SmallVector<Value> tiledShapes; + for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { + Value valueToTile = std::get<0>(item); + Optional<SliceParameters> sliceParams = std::get<1>(item); + tiledShapes.push_back( + sliceParams.hasValue() + ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) + : valueToTile); + } return tiledShapes; } |