diff options
author | Javier Setoain <javier.setoain@gmail.com> | 2022-01-26 15:01:39 +0000 |
---|---|---|
committer | Javier Setoain <javier.setoain@gmail.com> | 2022-03-25 10:48:59 +0000 |
commit | a75a46db89f3fe3f3cb7d683e2b6d0227f282e18 (patch) | |
tree | 72f12ab23a8169205a77d844644c64177e7ec45f /mlir/lib/Conversion/VectorToLLVM | |
parent | 718aec209c891487294d8a6199cf12c796c6e901 (diff) |
[mlir][Vector] Enable create_mask for scalable vectors
The way vector.create_mask is currently lowered is
vector-length-dependent, and therefore incompatible with scalable vector
types. This patch adds an alternative lowering path for create_mask
operations that return a scalable vector mask.
Differential Revision: https://reviews.llvm.org/D118248
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 43 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp | 4 |
2 files changed, 42 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 697b7a8d8786..20e51008c52b 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -900,6 +901,40 @@ public: } }; +/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). +/// Non-scalable versions of this operation are handled in Vector Transforms. +class VectorCreateMaskOpRewritePattern + : public OpRewritePattern<vector::CreateMaskOp> { +public: + explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, + bool enableIndexOpt) + : OpRewritePattern<vector::CreateMaskOp>(context), + indexOptimizations(enableIndexOpt) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getType(); + if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) + return failure(); + IntegerType idxType = + indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); + auto loc = op->getLoc(); + Value indices = rewriter.create<LLVM::StepVectorOp>( + loc, LLVM::getVectorType(idxType, dstType.getShape()[0], + /*isScalable=*/true)); + auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, + op.getOperand(0)); + Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); + Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, + indices, bounds); + rewriter.replaceOp(op, comp); + return success(); + } + +private: + const bool indexOptimizations; +}; + class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { public: using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; @@ -1157,13 +1192,15 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. -void mlir::populateVectorToLLVMConversionPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions) { +void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns, + bool reassociateFPReductions, + bool indexOptimizations) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add<VectorFMAOpNDRewritePattern>(ctx); populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); + patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations); patterns .add<VectorBitCastOpConversion, VectorShuffleOpConversion, VectorExtractElementOpConversion, VectorExtractOpConversion, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 16d57efc5858..68edc23e8237 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -80,8 +80,8 @@ void LowerVectorToLLVMPass::runOnOperation() { populateVectorMaskMaterializationPatterns(patterns, indexOptimizations); populateVectorTransferLoweringPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); - populateVectorToLLVMConversionPatterns(converter, patterns, - reassociateFPReductions); + populateVectorToLLVMConversionPatterns( + converter, patterns, reassociateFPReductions, indexOptimizations); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); // Architecture specific augmentations. |