aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorNicolas Vasilache <nicolas.vasilache@gmail.com>2021-10-22 09:39:07 +0000
committerNicolas Vasilache <nicolas.vasilache@gmail.com>2021-10-22 10:03:33 +0000
commiteda2ebd7807376829eb880c39623f364b438971f (patch)
tree06cd45c9b428dc3a6a80d2a426bea51fff370eed /mlir/lib/Conversion/VectorToLLVM
parentcac8808f154cef6446e507d55aba5721c3bd5352 (diff)
[mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.
Differential Revision: https://reviews.llvm.org/D112301
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp231
1 files changed, 3 insertions, 228 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a6f25332d133..77d2a4697717 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -52,17 +53,6 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
rewriter.getI64ArrayAttr(pos));
}
-// Helper that picks the proper sequence for inserting.
-static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
- Value into, int64_t offset) {
- auto vectorType = into.getType().cast<VectorType>();
- if (vectorType.getRank() > 1)
- return rewriter.create<InsertOp>(loc, from, into, offset);
- return rewriter.create<vector::InsertElementOp>(
- loc, vectorType, from, into,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
-}
-
// Helper that picks the proper sequence for extracting.
static Value extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
@@ -79,32 +69,6 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
rewriter.getI64ArrayAttr(pos));
}
-// Helper that picks the proper sequence for extracting.
-static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
- int64_t offset) {
- auto vectorType = vector.getType().cast<VectorType>();
- if (vectorType.getRank() > 1)
- return rewriter.create<ExtractOp>(loc, vector, offset);
- return rewriter.create<vector::ExtractElementOp>(
- loc, vectorType.getElementType(), vector,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
-}
-
-// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
-// TODO: Better support for attribute subtype forwarding + slicing.
-static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
- unsigned dropFront = 0,
- unsigned dropBack = 0) {
- assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
- auto range = arrayAttr.getAsRange<IntegerAttr>();
- SmallVector<int64_t, 4> res;
- res.reserve(arrayAttr.size() - dropFront - dropBack);
- for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
- it != eit; ++it)
- res.push_back((*it).getValue().getSExtValue());
- return res;
-}
-
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@@ -813,132 +777,6 @@ public:
}
};
-// When ranks are different, InsertStridedSlice needs to extract a properly
-// ranked vector from the destination vector into which to insert. This pattern
-// only takes care of this part and forwards the rest of the conversion to
-// another pattern that converts InsertStridedSlice for operands of the same
-// rank.
-//
-// RewritePattern for InsertStridedSliceOp where source and destination vectors
-// have different ranks. In this case:
-// 1. the proper subvector is extracted from the destination vector
-// 2. a new InsertStridedSlice op is created to insert the source in the
-// destination subvector
-// 3. the destination subvector is inserted back in the proper place
-// 4. the op is replaced by the result of step 3.
-// The new InsertStridedSlice from step 2. will be picked up by a
-// `VectorInsertStridedSliceOpSameRankRewritePattern`.
-class VectorInsertStridedSliceOpDifferentRankRewritePattern
- : public OpRewritePattern<InsertStridedSliceOp> {
-public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InsertStridedSliceOp op,
- PatternRewriter &rewriter) const override {
- auto srcType = op.getSourceVectorType();
- auto dstType = op.getDestVectorType();
-
- if (op.offsets().getValue().empty())
- return failure();
-
- auto loc = op.getLoc();
- int64_t rankDiff = dstType.getRank() - srcType.getRank();
- assert(rankDiff >= 0);
- if (rankDiff == 0)
- return failure();
-
- int64_t rankRest = dstType.getRank() - rankDiff;
- // Extract / insert the subvector of matching rank and InsertStridedSlice
- // on it.
- Value extracted =
- rewriter.create<ExtractOp>(loc, op.dest(),
- getI64SubArray(op.offsets(), /*dropFront=*/0,
- /*dropBack=*/rankRest));
- // A different pattern will kick in for InsertStridedSlice with matching
- // ranks.
- auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
- loc, op.source(), extracted,
- getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
- getI64SubArray(op.strides(), /*dropFront=*/0));
- rewriter.replaceOpWithNewOp<InsertOp>(
- op, stridedSliceInnerOp.getResult(), op.dest(),
- getI64SubArray(op.offsets(), /*dropFront=*/0,
- /*dropBack=*/rankRest));
- return success();
- }
-};
-
-// RewritePattern for InsertStridedSliceOp where source and destination vectors
-// have the same rank. In this case, we reduce
-// 1. the proper subvector is extracted from the destination vector
-// 2. a new InsertStridedSlice op is created to insert the source in the
-// destination subvector
-// 3. the destination subvector is inserted back in the proper place
-// 4. the op is replaced by the result of step 3.
-// The new InsertStridedSlice from step 2. will be picked up by a
-// `VectorInsertStridedSliceOpSameRankRewritePattern`.
-class VectorInsertStridedSliceOpSameRankRewritePattern
- : public OpRewritePattern<InsertStridedSliceOp> {
-public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
-
- void initialize() {
- // This pattern creates recursive InsertStridedSliceOp, but the recursion is
- // bounded as the rank is strictly decreasing.
- setHasBoundedRewriteRecursion();
- }
-
- LogicalResult matchAndRewrite(InsertStridedSliceOp op,
- PatternRewriter &rewriter) const override {
- auto srcType = op.getSourceVectorType();
- auto dstType = op.getDestVectorType();
-
- if (op.offsets().getValue().empty())
- return failure();
-
- int64_t rankDiff = dstType.getRank() - srcType.getRank();
- assert(rankDiff >= 0);
- if (rankDiff != 0)
- return failure();
-
- if (srcType == dstType) {
- rewriter.replaceOp(op, op.source());
- return success();
- }
-
- int64_t offset =
- op.offsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size = srcType.getShape().front();
- int64_t stride =
- op.strides().getValue().front().cast<IntegerAttr>().getInt();
-
- auto loc = op.getLoc();
- Value res = op.dest();
- // For each slice of the source vector along the most major dimension.
- for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
- off += stride, ++idx) {
- // 1. extract the proper subvector (or element) from source
- Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
- if (extractedSource.getType().isa<VectorType>()) {
- // 2. If we have a vector, extract the proper subvector from destination
- // Otherwise we are at the element level and no need to recurse.
- Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
- // 3. Reduce the problem to lowering a new InsertStridedSlice op with
- // smaller rank.
- extractedSource = rewriter.create<InsertStridedSliceOp>(
- loc, extractedSource, extractedDest,
- getI64SubArray(op.offsets(), /* dropFront=*/1),
- getI64SubArray(op.strides(), /* dropFront=*/1));
- }
- // 4. Insert the extractedSource into the res vector.
- res = insertOne(rewriter, loc, extractedSource, res, off);
- }
-
- rewriter.replaceOp(op, res);
- return success();
- }
-};
-
/// Returns the strides if the memory underlying `memRefType` has a contiguous
/// static layout.
static llvm::Optional<SmallVector<int64_t, 4>>
@@ -1189,67 +1027,6 @@ private:
}
};
-/// Progressive lowering of ExtractStridedSliceOp to either:
-/// 1. express single offset extract as a direct shuffle.
-/// 2. extract + lower rank strided_slice + insert for the n-D case.
-class VectorExtractStridedSliceOpConversion
- : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
-
- void initialize() {
- // This pattern creates recursive ExtractStridedSliceOp, but the recursion
- // is bounded as the rank is strictly decreasing.
- setHasBoundedRewriteRecursion();
- }
-
- LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
- PatternRewriter &rewriter) const override {
- auto dstType = op.getType();
-
- assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
-
- int64_t offset =
- op.offsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
- int64_t stride =
- op.strides().getValue().front().cast<IntegerAttr>().getInt();
-
- auto loc = op.getLoc();
- auto elemType = dstType.getElementType();
- assert(elemType.isSignlessIntOrIndexOrFloat());
-
- // Single offset can be more efficiently shuffled.
- if (op.offsets().getValue().size() == 1) {
- SmallVector<int64_t, 4> offsets;
- offsets.reserve(size);
- for (int64_t off = offset, e = offset + size * stride; off < e;
- off += stride)
- offsets.push_back(off);
- rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
- op.vector(),
- rewriter.getI64ArrayAttr(offsets));
- return success();
- }
-
- // Extract/insert on a lower ranked extract strided slice op.
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, elemType, rewriter.getZeroAttr(elemType));
- Value res = rewriter.create<SplatOp>(loc, dstType, zero);
- for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
- off += stride, ++idx) {
- Value one = extractOne(rewriter, loc, op.vector(), off);
- Value extracted = rewriter.create<ExtractStridedSliceOp>(
- loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
- getI64SubArray(op.sizes(), /* dropFront=*/1),
- getI64SubArray(op.strides(), /* dropFront=*/1));
- res = insertOne(rewriter, loc, extracted, res, idx);
- }
- rewriter.replaceOp(op, res);
- return success();
- }
-};
-
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1257,10 +1034,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions) {
MLIRContext *ctx = converter.getDialect()->getContext();
- patterns.add<VectorFMAOpNDRewritePattern,
- VectorInsertStridedSliceOpDifferentRankRewritePattern,
- VectorInsertStridedSliceOpSameRankRewritePattern,
- VectorExtractStridedSliceOpConversion>(ctx);
+ patterns.add<VectorFMAOpNDRewritePattern>(ctx);
+ populateVectorInsertExtractStridedSliceTransforms(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,