diff options
author | Nicolas Vasilache <nicolas.vasilache@gmail.com> | 2021-10-22 09:39:07 +0000 |
---|---|---|
committer | Nicolas Vasilache <nicolas.vasilache@gmail.com> | 2021-10-22 10:03:33 +0000 |
commit | eda2ebd7807376829eb880c39623f364b438971f (patch) | |
tree | 06cd45c9b428dc3a6a80d2a426bea51fff370eed /mlir/lib/Conversion/VectorToLLVM | |
parent | cac8808f154cef6446e507d55aba5721c3bd5352 (diff) |
[mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.
Differential Revision: https://reviews.llvm.org/D112301
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 231 |
1 files changed, 3 insertions, 228 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a6f25332d133..77d2a4697717 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorRewritePatterns.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/MathExtras.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" @@ -52,17 +53,6 @@ static Value insertOne(ConversionPatternRewriter &rewriter, rewriter.getI64ArrayAttr(pos)); } -// Helper that picks the proper sequence for inserting. -static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, - Value into, int64_t offset) { - auto vectorType = into.getType().cast<VectorType>(); - if (vectorType.getRank() > 1) - return rewriter.create<InsertOp>(loc, from, into, offset); - return rewriter.create<vector::InsertElementOp>( - loc, vectorType, from, into, - rewriter.create<arith::ConstantIndexOp>(loc, offset)); -} - // Helper that picks the proper sequence for extracting. static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, @@ -79,32 +69,6 @@ static Value extractOne(ConversionPatternRewriter &rewriter, rewriter.getI64ArrayAttr(pos)); } -// Helper that picks the proper sequence for extracting. -static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, - int64_t offset) { - auto vectorType = vector.getType().cast<VectorType>(); - if (vectorType.getRank() > 1) - return rewriter.create<ExtractOp>(loc, vector, offset); - return rewriter.create<vector::ExtractElementOp>( - loc, vectorType.getElementType(), vector, - rewriter.create<arith::ConstantIndexOp>(loc, offset)); -} - -// Helper that returns a subset of `arrayAttr` as a vector of int64_t. -// TODO: Better support for attribute subtype forwarding + slicing. -static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, - unsigned dropFront = 0, - unsigned dropBack = 0) { - assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); - auto range = arrayAttr.getAsRange<IntegerAttr>(); - SmallVector<int64_t, 4> res; - res.reserve(arrayAttr.size() - dropFront - dropBack); - for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; - it != eit; ++it) - res.push_back((*it).getValue().getSExtValue()); - return res; -} - // Helper that returns data layout alignment of a memref. LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align) { @@ -813,132 +777,6 @@ public: } }; -// When ranks are different, InsertStridedSlice needs to extract a properly -// ranked vector from the destination vector into which to insert. This pattern -// only takes care of this part and forwards the rest of the conversion to -// another pattern that converts InsertStridedSlice for operands of the same -// rank. -// -// RewritePattern for InsertStridedSliceOp where source and destination vectors -// have different ranks. In this case: -// 1. the proper subvector is extracted from the destination vector -// 2. a new InsertStridedSlice op is created to insert the source in the -// destination subvector -// 3. the destination subvector is inserted back in the proper place -// 4. the op is replaced by the result of step 3. -// The new InsertStridedSlice from step 2. will be picked up by a -// `VectorInsertStridedSliceOpSameRankRewritePattern`. -class VectorInsertStridedSliceOpDifferentRankRewritePattern - : public OpRewritePattern<InsertStridedSliceOp> { -public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertStridedSliceOp op, - PatternRewriter &rewriter) const override { - auto srcType = op.getSourceVectorType(); - auto dstType = op.getDestVectorType(); - - if (op.offsets().getValue().empty()) - return failure(); - - auto loc = op.getLoc(); - int64_t rankDiff = dstType.getRank() - srcType.getRank(); - assert(rankDiff >= 0); - if (rankDiff == 0) - return failure(); - - int64_t rankRest = dstType.getRank() - rankDiff; - // Extract / insert the subvector of matching rank and InsertStridedSlice - // on it. - Value extracted = - rewriter.create<ExtractOp>(loc, op.dest(), - getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropBack=*/rankRest)); - // A different pattern will kick in for InsertStridedSlice with matching - // ranks. - auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( - loc, op.source(), extracted, - getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), - getI64SubArray(op.strides(), /*dropFront=*/0)); - rewriter.replaceOpWithNewOp<InsertOp>( - op, stridedSliceInnerOp.getResult(), op.dest(), - getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropBack=*/rankRest)); - return success(); - } -}; - -// RewritePattern for InsertStridedSliceOp where source and destination vectors -// have the same rank. In this case, we reduce -// 1. the proper subvector is extracted from the destination vector -// 2. a new InsertStridedSlice op is created to insert the source in the -// destination subvector -// 3. the destination subvector is inserted back in the proper place -// 4. the op is replaced by the result of step 3. -// The new InsertStridedSlice from step 2. will be picked up by a -// `VectorInsertStridedSliceOpSameRankRewritePattern`. -class VectorInsertStridedSliceOpSameRankRewritePattern - : public OpRewritePattern<InsertStridedSliceOp> { -public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; - - void initialize() { - // This pattern creates recursive InsertStridedSliceOp, but the recursion is - // bounded as the rank is strictly decreasing. - setHasBoundedRewriteRecursion(); - } - - LogicalResult matchAndRewrite(InsertStridedSliceOp op, - PatternRewriter &rewriter) const override { - auto srcType = op.getSourceVectorType(); - auto dstType = op.getDestVectorType(); - - if (op.offsets().getValue().empty()) - return failure(); - - int64_t rankDiff = dstType.getRank() - srcType.getRank(); - assert(rankDiff >= 0); - if (rankDiff != 0) - return failure(); - - if (srcType == dstType) { - rewriter.replaceOp(op, op.source()); - return success(); - } - - int64_t offset = - op.offsets().getValue().front().cast<IntegerAttr>().getInt(); - int64_t size = srcType.getShape().front(); - int64_t stride = - op.strides().getValue().front().cast<IntegerAttr>().getInt(); - - auto loc = op.getLoc(); - Value res = op.dest(); - // For each slice of the source vector along the most major dimension. - for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; - off += stride, ++idx) { - // 1. extract the proper subvector (or element) from source - Value extractedSource = extractOne(rewriter, loc, op.source(), idx); - if (extractedSource.getType().isa<VectorType>()) { - // 2. If we have a vector, extract the proper subvector from destination - // Otherwise we are at the element level and no need to recurse. - Value extractedDest = extractOne(rewriter, loc, op.dest(), off); - // 3. Reduce the problem to lowering a new InsertStridedSlice op with - // smaller rank. - extractedSource = rewriter.create<InsertStridedSliceOp>( - loc, extractedSource, extractedDest, - getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); - } - // 4. Insert the extractedSource into the res vector. - res = insertOne(rewriter, loc, extractedSource, res, off); - } - - rewriter.replaceOp(op, res); - return success(); - } -}; - /// Returns the strides if the memory underlying `memRefType` has a contiguous /// static layout. static llvm::Optional<SmallVector<int64_t, 4>> @@ -1189,67 +1027,6 @@ private: } }; -/// Progressive lowering of ExtractStridedSliceOp to either: -/// 1. express single offset extract as a direct shuffle. -/// 2. extract + lower rank strided_slice + insert for the n-D case. -class VectorExtractStridedSliceOpConversion - : public OpRewritePattern<ExtractStridedSliceOp> { -public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; - - void initialize() { - // This pattern creates recursive ExtractStridedSliceOp, but the recursion - // is bounded as the rank is strictly decreasing. - setHasBoundedRewriteRecursion(); - } - - LogicalResult matchAndRewrite(ExtractStridedSliceOp op, - PatternRewriter &rewriter) const override { - auto dstType = op.getType(); - - assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); - - int64_t offset = - op.offsets().getValue().front().cast<IntegerAttr>().getInt(); - int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); - int64_t stride = - op.strides().getValue().front().cast<IntegerAttr>().getInt(); - - auto loc = op.getLoc(); - auto elemType = dstType.getElementType(); - assert(elemType.isSignlessIntOrIndexOrFloat()); - - // Single offset can be more efficiently shuffled. - if (op.offsets().getValue().size() == 1) { - SmallVector<int64_t, 4> offsets; - offsets.reserve(size); - for (int64_t off = offset, e = offset + size * stride; off < e; - off += stride) - offsets.push_back(off); - rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), - op.vector(), - rewriter.getI64ArrayAttr(offsets)); - return success(); - } - - // Extract/insert on a lower ranked extract strided slice op. - Value zero = rewriter.create<arith::ConstantOp>( - loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = rewriter.create<SplatOp>(loc, dstType, zero); - for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; - off += stride, ++idx) { - Value one = extractOne(rewriter, loc, op.vector(), off); - Value extracted = rewriter.create<ExtractStridedSliceOp>( - loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.sizes(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); - res = insertOne(rewriter, loc, extracted, res, idx); - } - rewriter.replaceOp(op, res); - return success(); - } -}; - } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. @@ -1257,10 +1034,8 @@ void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.add<VectorFMAOpNDRewritePattern, - VectorInsertStridedSliceOpDifferentRankRewritePattern, - VectorInsertStridedSliceOpSameRankRewritePattern, - VectorExtractStridedSliceOpConversion>(ctx); + patterns.add<VectorFMAOpNDRewritePattern>(ctx); + populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); patterns .add<VectorBitCastOpConversion, VectorShuffleOpConversion, |