diff options
author | Javier Setoain <javier.setoain@gmail.com> | 2021-10-12 14:26:01 +0100 |
---|---|---|
committer | Javier Setoain <javier.setoain@gmail.com> | 2021-12-15 09:31:37 +0000 |
commit | a4830d14edbb2a21eb35f3d79d1f64bd09db8b1c (patch) | |
tree | fe5673196655da8dad4cbb1e4210a80772ac008c /mlir/lib/Conversion/VectorToLLVM | |
parent | 7161aa06ef53d3fc0ce30be77e932f3d30c68466 (diff) |
[mlir][RFC] Add scalable dimensions to VectorType
With VectorType supporting scalable dimensions, we don't need many of
the operations currently present in ArmSVE, like mask generation and
basic arithmetic instructions. Therefore, this patch also gets
rid of those.
Having built-in scalable vector support also simplifies the lowering of
scalable vector dialects down to LLVMIR.
Scalable dimensions are indicated with the scalable dimensions
between square brackets:
vector<[4]xf32>
Is a scalable vector of 4 single precission floating point elements.
More generally, a VectorType can have a set of fixed-length dimensions
followed by a set of scalable dimensions:
vector<2x[4x4]xf32>
Is a vector with 2 scalable 4x4 vectors of single precission floating
point elements.
The scale of the scalable dimensions can be obtained with the Vector
operation:
%vs = vector.vscale
This change is being discussed in the discourse RFC:
https://llvm.discourse.group/t/rfc-add-built-in-support-for-scalable-vector-types/4484
Differential Revision: https://reviews.llvm.org/D111819
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9b4dce458b7a..062a54432cea 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -26,13 +26,21 @@ using namespace mlir::vector; // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); + unsigned numScalableDims = tp.getNumScalableDims(); + if (tp.getShape().size() == numScalableDims) + --numScalableDims; + return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), + numScalableDims); } // Helper to reduce vector type by *all* but one rank at back. static VectorType reducedVectorTypeBack(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().take_back(), tp.getElementType()); + unsigned numScalableDims = tp.getNumScalableDims(); + if (numScalableDims > 0) + --numScalableDims; + return VectorType::get(tp.getShape().take_back(), tp.getElementType(), + numScalableDims); } // Helper that picks the proper sequence for inserting. @@ -112,6 +120,10 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, namespace { +/// Trivial Vector to LLVM conversions +using VectorScaleOpConversion = + OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; + /// Conversion pattern for a vector.bitcast. class VectorBitCastOpConversion : public ConvertOpToLLVMPattern<vector::BitCastOp> { @@ -1064,7 +1076,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorExtractElementOpConversion, VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertElementOpConversion, VectorInsertOpConversion, VectorPrintOpConversion, - VectorTypeCastOpConversion, + VectorTypeCastOpConversion, VectorScaleOpConversion, VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, VectorLoadStoreConversion<vector::MaskedLoadOp, vector::MaskedLoadOpAdaptor>, |