aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2022-01-25 15:51:05 -0800
committerRiver Riddle <riddleriver@gmail.com>2022-02-02 14:45:12 -0800
commit6a8ba3186ed561bd5ac6aa31ba483d81223bf198 (patch)
tree6543539be1a30a0025667bba86d4e8c66b4d6df7 /mlir/lib/Conversion/VectorToLLVM
parentf7a6c341cb936991eb3ccac3be25b02fecf7a4b8 (diff)
[mlir] Split std.splat into tensor.splat and vector.splat
This is part of the larger effort to split the standard dialect. This will also allow for pruning some additional dependencies on Standard (done in a followup). Differential Revision: https://reviews.llvm.org/D118202
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp99
1 files changed, 96 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 664947211995..80f50a3996e3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -778,7 +778,7 @@ public:
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
- Value desc = rewriter.create<SplatOp>(loc, vType, zero);
+ Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
@@ -1062,6 +1062,99 @@ private:
}
};
+/// The Splat operation is lowered to an insertelement + a shufflevector
+/// operation. Splat to only 0-d and 1-d vector result types are lowered.
+struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
+ using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = splatOp.getType().cast<VectorType>();
+ if (resultType.getRank() > 1)
+ return failure();
+
+ // First insert it into an undef vector so we can shuffle it.
+ auto vectorType = typeConverter->convertType(splatOp.getType());
+ Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
+ auto zero = rewriter.create<LLVM::ConstantOp>(
+ splatOp.getLoc(),
+ typeConverter->convertType(rewriter.getIntegerType(32)),
+ rewriter.getZeroAttr(rewriter.getIntegerType(32)));
+
+ // For 0-d vector, we simply do `insertelement`.
+ if (resultType.getRank() == 0) {
+ rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+ splatOp, vectorType, undef, adaptor.input(), zero);
+ return success();
+ }
+
+ // For 1-d vector, we additionally do a `vectorshuffle`.
+ auto v = rewriter.create<LLVM::InsertElementOp>(
+ splatOp.getLoc(), vectorType, undef, adaptor.input(), zero);
+
+ int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
+ SmallVector<int32_t, 4> zeroValues(width, 0);
+
+ // Shuffle the value across the desired number of elements.
+ ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
+ rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
+ zeroAttrs);
+ return success();
+ }
+};
+
+/// The Splat operation is lowered to an insertelement + a shufflevector
+/// operation. Splat to only 2+-d vector result types are lowered by the
+/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
+struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
+ using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = splatOp.getType();
+ if (resultType.getRank() <= 1)
+ return failure();
+
+ // First insert it into an undef vector so we can shuffle it.
+ auto loc = splatOp.getLoc();
+ auto vectorTypeInfo =
+ LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
+ auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
+ auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
+ if (!llvmNDVectorTy || !llvm1DVectorTy)
+ return failure();
+
+ // Construct returned value.
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
+
+ // Construct a 1-D vector with the splatted value that we insert in all the
+ // places within the returned descriptor.
+ Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
+ auto zero = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(rewriter.getIntegerType(32)),
+ rewriter.getZeroAttr(rewriter.getIntegerType(32)));
+ Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
+ adaptor.input(), zero);
+
+ // Shuffle the value across the desired number of elements.
+ int64_t width = resultType.getDimSize(resultType.getRank() - 1);
+ SmallVector<int32_t, 4> zeroValues(width, 0);
+ ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
+ v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
+
+ // Iterate of linear index, convert to coords space and insert splatted 1-D
+ // vector in each position.
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
+ position);
+ });
+ rewriter.replaceOp(splatOp, desc);
+ return success();
+ }
+};
+
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1085,8 +1178,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
- VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
- converter);
+ VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+ VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}