aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorMogball <jeffniu22@gmail.com>2021-10-12 23:14:57 +0000
committerMogball <jeffniu22@gmail.com>2021-10-13 03:07:03 +0000
commita54f4eae0e1d0ef5adccdcf9f6c2b518dc1101aa (patch)
treef4478fb873d88f382b2c3dbdadf254e68faa7244 /mlir/lib/Conversion/VectorToLLVM
parent666accf283311c5110ae4e2e5e4c4b99078eed15 (diff)
[MLIR] Replace std ops with arith dialect ops
Precursor: https://reviews.llvm.org/D110200 Removed redundant ops from the standard dialect that were moved to the `arith` or `math` dialects. Renamed all instances of operations in the codebase and in tests. Reviewed By: rriddle, jpienaar Differential Revision: https://reviews.llvm.org/D110797
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp17
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp3
3 files changed, 13 insertions, 8 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 9a7c5aabc9a7..9f3fc6e826e1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRVectorToLLVM
Core
LINK_LIBS PUBLIC
+ MLIRArithmetic
MLIRArmNeon
MLIRArmSVE
MLIRArmSVETransforms
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f5ba71726af7..765b58d8c3d2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -59,7 +60,7 @@ static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
- rewriter.create<ConstantIndexOp>(loc, offset));
+ rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
// Helper that picks the proper sequence for extracting.
@@ -86,7 +87,7 @@ static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
- rewriter.create<ConstantIndexOp>(loc, offset));
+ rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
@@ -797,8 +798,8 @@ public:
auto loc = op.getLoc();
auto elemType = vType.getElementType();
- Value zero = rewriter.create<ConstantOp>(loc, elemType,
- rewriter.getZeroAttr(elemType));
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, elemType, rewriter.getZeroAttr(elemType));
Value desc = rewriter.create<SplatOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
@@ -1146,11 +1147,11 @@ private:
if (rank == 0) {
switch (conversion) {
case PrintConversion::ZeroExt64:
- value = rewriter.create<ZeroExtendIOp>(
+ value = rewriter.create<arith::ExtUIOp>(
loc, value, IntegerType::get(rewriter.getContext(), 64));
break;
case PrintConversion::SignExt64:
- value = rewriter.create<SignExtendIOp>(
+ value = rewriter.create<arith::ExtSIOp>(
loc, value, IntegerType::get(rewriter.getContext(), 64));
break;
case PrintConversion::None:
@@ -1233,8 +1234,8 @@ public:
}
// Extract/insert on a lower ranked extract strided slice op.
- Value zero = rewriter.create<ConstantOp>(loc, elemType,
- rewriter.getZeroAttr(elemType));
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, elemType, rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index d920bb7b0f9a..583ba4a13eb0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
@@ -42,6 +43,7 @@ struct LowerVectorToLLVMPass
// Override explicitly to allow conditional dialect dependence.
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
+ registry.insert<arith::ArithmeticDialect>();
registry.insert<memref::MemRefDialect>();
if (enableArmNeon)
registry.insert<arm_neon::ArmNeonDialect>();
@@ -84,6 +86,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
+ target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();