aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorJavier Setoain <javier.setoain@gmail.com>2021-10-12 14:26:01 +0100
committerJavier Setoain <javier.setoain@gmail.com>2021-12-15 09:31:37 +0000
commita4830d14edbb2a21eb35f3d79d1f64bd09db8b1c (patch)
treefe5673196655da8dad4cbb1e4210a80772ac008c /mlir/lib/Conversion/VectorToLLVM
parent7161aa06ef53d3fc0ce30be77e932f3d30c68466 (diff)
[mlir][RFC] Add scalable dimensions to VectorType
With VectorType supporting scalable dimensions, we don't need many of the operations currently present in ArmSVE, like mask generation and basic arithmetic instructions. Therefore, this patch also gets rid of those. Having built-in scalable vector support also simplifies the lowering of scalable vector dialects down to LLVMIR. Scalable dimensions are indicated with the scalable dimensions between square brackets: vector<[4]xf32> Is a scalable vector of 4 single precission floating point elements. More generally, a VectorType can have a set of fixed-length dimensions followed by a set of scalable dimensions: vector<2x[4x4]xf32> Is a vector with 2 scalable 4x4 vectors of single precission floating point elements. The scale of the scalable dimensions can be obtained with the Vector operation: %vs = vector.vscale This change is being discussed in the discourse RFC: https://llvm.discourse.group/t/rfc-add-built-in-support-for-scalable-vector-types/4484 Differential Revision: https://reviews.llvm.org/D111819
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp18
1 files changed, 15 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9b4dce458b7a..062a54432cea 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -26,13 +26,21 @@ using namespace mlir::vector;
// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
- return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
+ unsigned numScalableDims = tp.getNumScalableDims();
+ if (tp.getShape().size() == numScalableDims)
+ --numScalableDims;
+ return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
+ numScalableDims);
}
// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
- return VectorType::get(tp.getShape().take_back(), tp.getElementType());
+ unsigned numScalableDims = tp.getNumScalableDims();
+ if (numScalableDims > 0)
+ --numScalableDims;
+ return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
+ numScalableDims);
}
// Helper that picks the proper sequence for inserting.
@@ -112,6 +120,10 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
namespace {
+/// Trivial Vector to LLVM conversions
+using VectorScaleOpConversion =
+ OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
+
/// Conversion pattern for a vector.bitcast.
class VectorBitCastOpConversion
: public ConvertOpToLLVMPattern<vector::BitCastOp> {
@@ -1064,7 +1076,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
- VectorTypeCastOpConversion,
+ VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedLoadOp,
vector::MaskedLoadOpAdaptor>,