aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
authorAlexander Belyaev <pifon@google.com>2022-08-04 11:18:04 +0200
committerAlexander Belyaev <pifon@google.com>2022-08-04 11:23:58 +0200
commit56d94b3b902e21ff79b1ce9a6fb606a3f7c1c4db (patch)
tree95074dc12ea4308d94f4bc239ba939aa52bc40a3 /mlir/lib/Dialect
parent57a9bccec7dea036dbfa1a78f1ec5e73ecf7a33c (diff)
[mlir] Extract offsets-sizes-strides computation from `makeTiledShape(s)`.
This change separates computation of the actual parameters of the subset and the materialization of subview/extract_slice. That way the users can still use Linalg tiling logic even if they use different operations to materialize the subsets. Differential Revision: https://reviews.llvm.org/D131053
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp129
1 files changed, 86 insertions, 43 deletions
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 4f14164bf26c..0259f9a542f6 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -802,28 +802,61 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
}
+static Value materializeTiledShape(OpBuilder &builder, Location loc,
+ Value valueToTile,
+ const SliceParameters &sliceParams) {
+ auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
+ auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
+ .Case([&](MemRefType) {
+ return builder.create<memref::SubViewOp>(
+ loc, valueToTile, sliceParams.offsets,
+ sliceParams.sizes, sliceParams.strides);
+ })
+ .Case([&](RankedTensorType) {
+ return makeComposedExtractSliceOp(
+ builder, loc, valueToTile, sliceParams.offsets,
+ sliceParams.sizes, sliceParams.strides);
+ })
+ .Default([](ShapedType) -> Operation * {
+ llvm_unreachable("Unexpected shaped type");
+ });
+ return sliceOp->getResult(0);
+}
+
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck) {
+ SliceParameters sliceParams =
+ computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
+ ubs, subShapeSizes, omitPartialTileCheck);
+ return materializeTiledShape(builder, loc, valueToTile, sliceParams);
+}
+
+SliceParameters
+computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+ ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> subShapeSizes,
+ bool omitPartialTileCheck) {
auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
assert(shapedType && "only shaped types can be tiled");
ArrayRef<int64_t> shape = shapedType.getShape();
int64_t rank = shapedType.getRank();
// Construct a new subview / extract_slice for the tile.
- SmallVector<OpFoldResult, 4> offsets, sizes, strides;
- offsets.reserve(rank);
- sizes.reserve(rank);
- strides.reserve(rank);
+ SliceParameters sliceParams;
+ sliceParams.offsets.reserve(rank);
+ sliceParams.sizes.reserve(rank);
+ sliceParams.strides.reserve(rank);
for (unsigned r = 0; r < rank; ++r) {
- LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: for dim#" << r);
+ LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
if (!isTiled(map.getSubMap({r}), tileSizes)) {
- offsets.push_back(builder.getIndexAttr(0));
+ sliceParams.offsets.push_back(builder.getIndexAttr(0));
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
- sizes.push_back(dim);
- strides.push_back(builder.getIndexAttr(1));
+ sliceParams.sizes.push_back(dim);
+ sliceParams.strides.push_back(builder.getIndexAttr(1));
LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
continue;
}
@@ -832,26 +865,27 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
// Tiling creates a new slice at the proper index, the slice step is 1
// (i.e. the op does not subsample, stepping occurs in the loop).
auto m = map.getSubMap({r});
- LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
IRRewriter rewriter(builder);
OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
- offsets.push_back(offset);
+ sliceParams.offsets.push_back(offset);
OpFoldResult closedIntSize =
makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
// Resulting size needs to be made half open interval again.
AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
OpFoldResult size =
makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
- LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n");
LLVM_DEBUG(llvm::dbgs()
- << "makeTiledShape: new offset: " << offset << "\n");
- strides.push_back(builder.getIndexAttr(1));
+ << "computeSliceParameters: raw size: " << size << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "computeSliceParameters: new offset: " << offset << "\n");
+ sliceParams.strides.push_back(builder.getIndexAttr(1));
if (omitPartialTileCheck) {
// We statically know that the partial/boundary tile condition is
// unnecessary.
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
- sizes.push_back(size);
+ sliceParams.sizes.push_back(size);
continue;
}
@@ -903,22 +937,9 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
}
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
- sizes.push_back(size);
+ sliceParams.sizes.push_back(size);
}
-
- auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
- .Case([&](MemRefType) {
- return builder.create<memref::SubViewOp>(
- loc, valueToTile, offsets, sizes, strides);
- })
- .Case([&](RankedTensorType) {
- return makeComposedExtractSliceOp(
- builder, loc, valueToTile, offsets, sizes, strides);
- })
- .Default([](ShapedType) -> Operation * {
- llvm_unreachable("Unexpected shaped type");
- });
- return sliceOp->getResult(0);
+ return sliceParams;
}
SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
@@ -1003,12 +1024,12 @@ Value materializeOpFoldResult(OpBuilder &builder, Location loc,
return materializeOpFoldResult(b, opFoldResult);
}
-SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc,
- LinalgOp linalgOp, ValueRange valuesToTile,
- ArrayRef<OpFoldResult> ivs,
- ArrayRef<OpFoldResult> tileSizes,
- ArrayRef<OpFoldResult> sizeBounds,
- bool omitPartialTileCheck) {
+SmallVector<Optional<SliceParameters>>
+computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
+ ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
+ ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<OpFoldResult> sizeBounds,
+ bool omitPartialTileCheck) {
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
[](OpFoldResult v) { return !isZero(v); })) &&
@@ -1016,15 +1037,16 @@ SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc,
// Construct (potentially temporary) mins and maxes on which to apply maps
// that define tile subshapes.
- SmallVector<OpFoldResult> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
+ SmallVector<OpFoldResult> lbs =
+ computeTileOffsets(builder, loc, ivs, tileSizes);
SmallVector<OpFoldResult> subShapeSizes =
- computeTileSizes(b, loc, tileSizes, sizeBounds);
+ computeTileSizes(builder, loc, tileSizes, sizeBounds);
assert(static_cast<int64_t>(valuesToTile.size()) ==
linalgOp.getNumInputsAndOutputs() &&
"expected one value to tile for every operand");
- SmallVector<Value> tiledShapes;
- tiledShapes.reserve(valuesToTile.size());
+ SmallVector<Optional<SliceParameters>> allSliceParams;
+ allSliceParams.reserve(valuesToTile.size());
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
@@ -1035,18 +1057,39 @@ SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc,
// extract/insert slice pairs make the accessed iteration argument
// subdomains explicit.
if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
- tiledShapes.push_back(shapedOp);
+ allSliceParams.push_back(llvm::None);
LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
<< opOperand->get().getType() << "\n");
continue;
}
LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
- tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
- sizeBounds, subShapeSizes,
- omitPartialTileCheck));
+ allSliceParams.push_back(computeSliceParameters(
+ builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
+ omitPartialTileCheck));
}
+ return allSliceParams;
+}
+
+SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
+ LinalgOp linalgOp, ValueRange valuesToTile,
+ ArrayRef<OpFoldResult> ivs,
+ ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<OpFoldResult> sizeBounds,
+ bool omitPartialTileCheck) {
+ SmallVector<Optional<SliceParameters>> allSliceParameter =
+ computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
+ tileSizes, sizeBounds, omitPartialTileCheck);
+ SmallVector<Value> tiledShapes;
+ for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
+ Value valueToTile = std::get<0>(item);
+ Optional<SliceParameters> sliceParams = std::get<1>(item);
+ tiledShapes.push_back(
+ sliceParams.hasValue()
+ ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
+ : valueToTile);
+ }
return tiledShapes;
}