diff options
author | Matthias Springer <springerm@google.com> | 2021-04-07 21:11:55 +0900 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2021-04-07 21:33:13 +0900 |
commit | 65a3f289397fd7d6cfcb4ddfdf324e37cf90cad7 (patch) | |
tree | 5919c9ff685dd7e6d392ff7d5b61f6c604271411 /mlir/lib/Conversion/VectorToLLVM | |
parent | c0ef93bec85a8847b51d91d2a6470af903e1ec9a (diff) |
[mlir] Add "mask" operand to vector.transfer_read/write.
Also factors out out-of-bounds mask generation from vector.transfer_read/write into a new MaterializeTransferMask pattern.
Differential Revision: https://reviews.llvm.org/D100001
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 151 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp | 5 |
2 files changed, 23 insertions, 133 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 82e4bc2f4353..0c752c33ff16 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -104,66 +104,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, return res; } -static Value createCastToIndexLike(ConversionPatternRewriter &rewriter, - Location loc, Type targetType, Value value) { - if (targetType == value.getType()) - return value; - - bool targetIsIndex = targetType.isIndex(); - bool valueIsIndex = value.getType().isIndex(); - if (targetIsIndex ^ valueIsIndex) - return rewriter.create<IndexCastOp>(loc, targetType, value); - - auto targetIntegerType = targetType.dyn_cast<IntegerType>(); - auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); - assert(targetIntegerType && valueIntegerType && - "unexpected cast between types other than integers and index"); - assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); - - if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) - return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value); - return rewriter.create<TruncateIOp>(loc, targetIntegerType, value); -} - -// Helper that returns a vector comparison that constructs a mask: -// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] -// -// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, -// much more compact, IR for this operation, but LLVM eventually -// generates more elaborate instructions for this intrinsic since it -// is very conservative on the boundary conditions. -static Value buildVectorComparison(ConversionPatternRewriter &rewriter, - Operation *op, bool enableIndexOptimizations, - int64_t dim, Value b, Value *off = nullptr) { - auto loc = op->getLoc(); - // If we can assume all indices fit in 32-bit, we perform the vector - // comparison in 32-bit to get a higher degree of SIMD parallelism. - // Otherwise we perform the vector comparison using 64-bit indices. - Value indices; - Type idxType; - if (enableIndexOptimizations) { - indices = rewriter.create<ConstantOp>( - loc, rewriter.getI32VectorAttr( - llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); - idxType = rewriter.getI32Type(); - } else { - indices = rewriter.create<ConstantOp>( - loc, rewriter.getI64VectorAttr( - llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); - idxType = rewriter.getI64Type(); - } - // Add in an offset if requested. - if (off) { - Value o = createCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); - indices = rewriter.create<AddIOp>(loc, ov, indices); - } - // Construct the vector comparison. - Value bound = createCastToIndexLike(rewriter, loc, idxType, b); - Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); - return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); -} - // Helper that returns data layout alignment of a memref. LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align) { @@ -250,7 +190,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, if (failed(getMemRefAlignment( typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) return failure(); - auto adaptor = TransferWriteOpAdaptor(operands); + auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, align); return success(); @@ -266,7 +206,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) return failure(); - auto adaptor = TransferWriteOpAdaptor(operands); + auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( xferOp, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); @@ -275,12 +215,12 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) { - return TransferReadOpAdaptor(operands); + return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); } static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) { - return TransferWriteOpAdaptor(operands); + return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); } namespace { @@ -618,33 +558,6 @@ private: const bool reassociateFPReductions; }; -/// Conversion pattern for a vector.create_mask (1-D only). -class VectorCreateMaskOpConversion - : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { -public: - explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, - bool enableIndexOpt) - : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), - enableIndexOptimizations(enableIndexOpt) {} - - LogicalResult - matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const override { - auto dstType = op.getType(); - int64_t rank = dstType.getRank(); - if (rank == 1) { - rewriter.replaceOp( - op, buildVectorComparison(rewriter, op, enableIndexOptimizations, - dstType.getDimSize(0), operands[0])); - return success(); - } - return failure(); - } - -private: - const bool enableIndexOptimizations; -}; - class VectorShuffleOpConversion : public ConvertOpToLLVMPattern<vector::ShuffleOp> { public: @@ -1177,20 +1090,12 @@ public: } }; -/// Conversion pattern that converts a 1-D vector transfer read/write op in a -/// sequence of: -/// 1. Get the source/dst address as an LLVM vector pointer. -/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. -/// 4. Create a mask where offsetVector is compared against memref upper bound. -/// 5. Rewrite op as a masked read or write. +/// Conversion pattern that converts a 1-D vector transfer read/write op into a +/// a masked or unmasked read/write. template <typename ConcreteOp> class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { public: - explicit VectorTransferConversion(LLVMTypeConverter &typeConv, - bool enableIndexOpt) - : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), - enableIndexOptimizations(enableIndexOpt) {} + using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, @@ -1212,6 +1117,9 @@ public: auto strides = computeContiguousStrides(memRefType); if (!strides) return failure(); + // Out-of-bounds dims are handled by MaterializeTransferMask. + if (xferOp.hasOutOfBoundsDim()) + return failure(); auto toLLVMTy = [&](Type t) { return this->getTypeConverter()->convertType(t); @@ -1241,40 +1149,24 @@ public: #endif // ifndef NDEBUG } - // 1. Get the source/dst address as an LLVM vector pointer. + // Get the source/dst address as an LLVM vector pointer. VectorType vtp = xferOp.getVectorType(); Value dataPtr = this->getStridedElementPtr( loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); Value vectorDataPtr = castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); - if (xferOp.isDimInBounds(0)) + // Rewrite as an unmasked masked read / write. + if (!xferOp.mask()) return replaceTransferOpWithLoadOrStore(rewriter, *this->getTypeConverter(), loc, xferOp, operands, vectorDataPtr); - // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. - // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. - // 4. Let dim the memref dimension, compute the vector comparison mask - // (in-bounds mask): - // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] - // - // TODO: when the leaf transfer rank is k > 1, we need the last `k` - // dimensions here. - unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue(); - unsigned lastIndex = llvm::size(xferOp.indices()) - 1; - Value off = xferOp.indices()[lastIndex]; - Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex); - Value mask = buildVectorComparison( - rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); - - // 5. Rewrite as a masked read / write. + // Rewrite as a masked read / write. return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, - xferOp, operands, vectorDataPtr, mask); + xferOp, operands, vectorDataPtr, + xferOp.mask()); } - -private: - const bool enableIndexOptimizations; }; class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { @@ -1484,17 +1376,13 @@ public: /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions, bool enableIndexOptimizations) { + bool reassociateFPReductions) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add<VectorFMAOpNDRewritePattern, VectorInsertStridedSliceOpDifferentRankRewritePattern, VectorInsertStridedSliceOpSameRankRewritePattern, VectorExtractStridedSliceOpConversion>(ctx); patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); - patterns.add<VectorCreateMaskOpConversion, - VectorTransferConversion<TransferReadOp>, - VectorTransferConversion<TransferWriteOp>>( - converter, enableIndexOptimizations); patterns .add<VectorBitCastOpConversion, VectorShuffleOpConversion, VectorExtractElementOpConversion, VectorExtractOpConversion, @@ -1508,8 +1396,9 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorLoadStoreConversion<vector::MaskedStoreOp, vector::MaskedStoreOpAdaptor>, VectorGatherOpConversion, VectorScatterOpConversion, - VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( - converter); + VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, + VectorTransferConversion<TransferReadOp>, + VectorTransferConversion<TransferWriteOp>>(converter); } void mlir::populateVectorToLLVMMatrixConversionPatterns( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index abddcd73af1e..49ee670b2f06 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -71,9 +71,10 @@ void LowerVectorToLLVMPass::runOnOperation() { // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext()); RewritePatternSet patterns(&getContext()); + populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); - populateVectorToLLVMConversionPatterns( - converter, patterns, reassociateFPReductions, enableIndexOptimizations); + populateVectorToLLVMConversionPatterns(converter, patterns, + reassociateFPReductions); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); // Architecture specific augmentations. |