aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2021-04-07 21:11:55 +0900
committerMatthias Springer <springerm@google.com>2021-04-07 21:33:13 +0900
commit65a3f289397fd7d6cfcb4ddfdf324e37cf90cad7 (patch)
tree5919c9ff685dd7e6d392ff7d5b61f6c604271411 /mlir/lib/Conversion/VectorToLLVM
parentc0ef93bec85a8847b51d91d2a6470af903e1ec9a (diff)
[mlir] Add "mask" operand to vector.transfer_read/write.
Also factors out out-of-bounds mask generation from vector.transfer_read/write into a new MaterializeTransferMask pattern. Differential Revision: https://reviews.llvm.org/D100001
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp151
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp5
2 files changed, 23 insertions, 133 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 82e4bc2f4353..0c752c33ff16 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -104,66 +104,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
return res;
}
-static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
- Location loc, Type targetType, Value value) {
- if (targetType == value.getType())
- return value;
-
- bool targetIsIndex = targetType.isIndex();
- bool valueIsIndex = value.getType().isIndex();
- if (targetIsIndex ^ valueIsIndex)
- return rewriter.create<IndexCastOp>(loc, targetType, value);
-
- auto targetIntegerType = targetType.dyn_cast<IntegerType>();
- auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
- assert(targetIntegerType && valueIntegerType &&
- "unexpected cast between types other than integers and index");
- assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
-
- if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
- return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
- return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
-}
-
-// Helper that returns a vector comparison that constructs a mask:
-// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
-//
-// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
-// much more compact, IR for this operation, but LLVM eventually
-// generates more elaborate instructions for this intrinsic since it
-// is very conservative on the boundary conditions.
-static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
- Operation *op, bool enableIndexOptimizations,
- int64_t dim, Value b, Value *off = nullptr) {
- auto loc = op->getLoc();
- // If we can assume all indices fit in 32-bit, we perform the vector
- // comparison in 32-bit to get a higher degree of SIMD parallelism.
- // Otherwise we perform the vector comparison using 64-bit indices.
- Value indices;
- Type idxType;
- if (enableIndexOptimizations) {
- indices = rewriter.create<ConstantOp>(
- loc, rewriter.getI32VectorAttr(
- llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
- idxType = rewriter.getI32Type();
- } else {
- indices = rewriter.create<ConstantOp>(
- loc, rewriter.getI64VectorAttr(
- llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
- idxType = rewriter.getI64Type();
- }
- // Add in an offset if requested.
- if (off) {
- Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
- Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
- indices = rewriter.create<AddIOp>(loc, ov, indices);
- }
- // Construct the vector comparison.
- Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
- Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
- return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
-}
-
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@@ -250,7 +190,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
- auto adaptor = TransferWriteOpAdaptor(operands);
+ auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
align);
return success();
@@ -266,7 +206,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
- auto adaptor = TransferWriteOpAdaptor(operands);
+ auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
xferOp, adaptor.vector(), dataPtr, mask,
rewriter.getI32IntegerAttr(align));
@@ -275,12 +215,12 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
ArrayRef<Value> operands) {
- return TransferReadOpAdaptor(operands);
+ return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
}
static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
ArrayRef<Value> operands) {
- return TransferWriteOpAdaptor(operands);
+ return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
}
namespace {
@@ -618,33 +558,6 @@ private:
const bool reassociateFPReductions;
};
-/// Conversion pattern for a vector.create_mask (1-D only).
-class VectorCreateMaskOpConversion
- : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
-public:
- explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
- bool enableIndexOpt)
- : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
- enableIndexOptimizations(enableIndexOpt) {}
-
- LogicalResult
- matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto dstType = op.getType();
- int64_t rank = dstType.getRank();
- if (rank == 1) {
- rewriter.replaceOp(
- op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
- dstType.getDimSize(0), operands[0]));
- return success();
- }
- return failure();
- }
-
-private:
- const bool enableIndexOptimizations;
-};
-
class VectorShuffleOpConversion
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
@@ -1177,20 +1090,12 @@ public:
}
};
-/// Conversion pattern that converts a 1-D vector transfer read/write op in a
-/// sequence of:
-/// 1. Get the source/dst address as an LLVM vector pointer.
-/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-/// 4. Create a mask where offsetVector is compared against memref upper bound.
-/// 5. Rewrite op as a masked read or write.
+/// Conversion pattern that converts a 1-D vector transfer read/write op into a
+/// a masked or unmasked read/write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
- explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
- bool enableIndexOpt)
- : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
- enableIndexOptimizations(enableIndexOpt) {}
+ using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
@@ -1212,6 +1117,9 @@ public:
auto strides = computeContiguousStrides(memRefType);
if (!strides)
return failure();
+ // Out-of-bounds dims are handled by MaterializeTransferMask.
+ if (xferOp.hasOutOfBoundsDim())
+ return failure();
auto toLLVMTy = [&](Type t) {
return this->getTypeConverter()->convertType(t);
@@ -1241,40 +1149,24 @@ public:
#endif // ifndef NDEBUG
}
- // 1. Get the source/dst address as an LLVM vector pointer.
+ // Get the source/dst address as an LLVM vector pointer.
VectorType vtp = xferOp.getVectorType();
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
Value vectorDataPtr =
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
- if (xferOp.isDimInBounds(0))
+ // Rewrite as an unmasked masked read / write.
+ if (!xferOp.mask())
return replaceTransferOpWithLoadOrStore(rewriter,
*this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr);
- // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
- // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
- // 4. Let dim the memref dimension, compute the vector comparison mask
- // (in-bounds mask):
- // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
- //
- // TODO: when the leaf transfer rank is k > 1, we need the last `k`
- // dimensions here.
- unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
- unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
- Value off = xferOp.indices()[lastIndex];
- Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
- Value mask = buildVectorComparison(
- rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
-
- // 5. Rewrite as a masked read / write.
+ // Rewrite as a masked read / write.
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
- xferOp, operands, vectorDataPtr, mask);
+ xferOp, operands, vectorDataPtr,
+ xferOp.mask());
}
-
-private:
- const bool enableIndexOptimizations;
};
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
@@ -1484,17 +1376,13 @@ public:
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions, bool enableIndexOptimizations) {
+ bool reassociateFPReductions) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern,
VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpConversion>(ctx);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
- patterns.add<VectorCreateMaskOpConversion,
- VectorTransferConversion<TransferReadOp>,
- VectorTransferConversion<TransferWriteOp>>(
- converter, enableIndexOptimizations);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
@@ -1508,8 +1396,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
- VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
- converter);
+ VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+ VectorTransferConversion<TransferReadOp>,
+ VectorTransferConversion<TransferWriteOp>>(converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index abddcd73af1e..49ee670b2f06 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -71,9 +71,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
RewritePatternSet patterns(&getContext());
+ populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
- populateVectorToLLVMConversionPatterns(
- converter, patterns, reassociateFPReductions, enableIndexOptimizations);
+ populateVectorToLLVMConversionPatterns(converter, patterns,
+ reassociateFPReductions);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.