diff options
author | River Riddle <riddleriver@gmail.com> | 2022-01-25 15:51:05 -0800 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2022-02-02 14:45:12 -0800 |
commit | 6a8ba3186ed561bd5ac6aa31ba483d81223bf198 (patch) | |
tree | 6543539be1a30a0025667bba86d4e8c66b4d6df7 /mlir/lib/Conversion/VectorToLLVM | |
parent | f7a6c341cb936991eb3ccac3be25b02fecf7a4b8 (diff) |
[mlir] Split std.splat into tensor.splat and vector.splat
This is part of the larger effort to split the standard dialect. This will also allow for pruning some
additional dependencies on Standard (done in a followup).
Differential Revision: https://reviews.llvm.org/D118202
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 99 |
1 files changed, 96 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 664947211995..80f50a3996e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -778,7 +778,7 @@ public: auto elemType = vType.getElementType(); Value zero = rewriter.create<arith::ConstantOp>( loc, elemType, rewriter.getZeroAttr(elemType)); - Value desc = rewriter.create<SplatOp>(loc, vType, zero); + Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); @@ -1062,6 +1062,99 @@ private: } }; +/// The Splat operation is lowered to an insertelement + a shufflevector +/// operation. Splat to only 0-d and 1-d vector result types are lowered. +struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { + using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = splatOp.getType().cast<VectorType>(); + if (resultType.getRank() > 1) + return failure(); + + // First insert it into an undef vector so we can shuffle it. + auto vectorType = typeConverter->convertType(splatOp.getType()); + Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); + auto zero = rewriter.create<LLVM::ConstantOp>( + splatOp.getLoc(), + typeConverter->convertType(rewriter.getIntegerType(32)), + rewriter.getZeroAttr(rewriter.getIntegerType(32))); + + // For 0-d vector, we simply do `insertelement`. + if (resultType.getRank() == 0) { + rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( + splatOp, vectorType, undef, adaptor.input(), zero); + return success(); + } + + // For 1-d vector, we additionally do a `vectorshuffle`. + auto v = rewriter.create<LLVM::InsertElementOp>( + splatOp.getLoc(), vectorType, undef, adaptor.input(), zero); + + int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); + SmallVector<int32_t, 4> zeroValues(width, 0); + + // Shuffle the value across the desired number of elements. + ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); + rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, + zeroAttrs); + return success(); + } +}; + +/// The Splat operation is lowered to an insertelement + a shufflevector +/// operation. Splat to only 2+-d vector result types are lowered by the +/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. +struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { + using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = splatOp.getType(); + if (resultType.getRank() <= 1) + return failure(); + + // First insert it into an undef vector so we can shuffle it. + auto loc = splatOp.getLoc(); + auto vectorTypeInfo = + LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); + auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; + auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; + if (!llvmNDVectorTy || !llvm1DVectorTy) + return failure(); + + // Construct returned value. + Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); + + // Construct a 1-D vector with the splatted value that we insert in all the + // places within the returned descriptor. + Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); + auto zero = rewriter.create<LLVM::ConstantOp>( + loc, typeConverter->convertType(rewriter.getIntegerType(32)), + rewriter.getZeroAttr(rewriter.getIntegerType(32))); + Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, + adaptor.input(), zero); + + // Shuffle the value across the desired number of elements. + int64_t width = resultType.getDimSize(resultType.getRank() - 1); + SmallVector<int32_t, 4> zeroValues(width, 0); + ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); + v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs); + + // Iterate of linear index, convert to coords space and insert splatted 1-D + // vector in each position. + nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { + desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v, + position); + }); + rewriter.replaceOp(splatOp, desc); + return success(); + } +}; + } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. @@ -1085,8 +1178,8 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorLoadStoreConversion<vector::MaskedStoreOp, vector::MaskedStoreOpAdaptor>, VectorGatherOpConversion, VectorScatterOpConversion, - VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( - converter); + VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, + VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } |