aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorJavier Setoain <javier.setoain@gmail.com>2022-01-26 15:01:39 +0000
committerJavier Setoain <javier.setoain@gmail.com>2022-03-25 10:48:59 +0000
commita75a46db89f3fe3f3cb7d683e2b6d0227f282e18 (patch)
tree72f12ab23a8169205a77d844644c64177e7ec45f /mlir/lib/Conversion/VectorToLLVM
parent718aec209c891487294d8a6199cf12c796c6e901 (diff)
[mlir][Vector] Enable create_mask for scalable vectors
The way vector.create_mask is currently lowered is vector-length-dependent, and therefore incompatible with scalable vector types. This patch adds an alternative lowering path for create_mask operations that return a scalable vector mask. Differential Revision: https://reviews.llvm.org/D118248
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp43
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp4
2 files changed, 42 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 697b7a8d8786..20e51008c52b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -900,6 +901,40 @@ public:
}
};
+/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
+/// Non-scalable versions of this operation are handled in Vector Transforms.
+class VectorCreateMaskOpRewritePattern
+ : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+ explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
+ bool enableIndexOpt)
+ : OpRewritePattern<vector::CreateMaskOp>(context),
+ indexOptimizations(enableIndexOpt) {}
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getType();
+ if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
+ return failure();
+ IntegerType idxType =
+ indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
+ auto loc = op->getLoc();
+ Value indices = rewriter.create<LLVM::StepVectorOp>(
+ loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
+ /*isScalable=*/true));
+ auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
+ op.getOperand(0));
+ Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+ Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+ indices, bounds);
+ rewriter.replaceOp(op, comp);
+ return success();
+ }
+
+private:
+ const bool indexOptimizations;
+};
+
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
public:
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
@@ -1157,13 +1192,15 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::populateVectorToLLVMConversionPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions) {
+void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ bool reassociateFPReductions,
+ bool indexOptimizations) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
populateVectorInsertExtractStridedSliceTransforms(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
+ patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 16d57efc5858..68edc23e8237 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -80,8 +80,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateVectorMaskMaterializationPatterns(patterns, indexOptimizations);
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
- populateVectorToLLVMConversionPatterns(converter, patterns,
- reassociateFPReductions);
+ populateVectorToLLVMConversionPatterns(
+ converter, patterns, reassociateFPReductions, indexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.