aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorMarkus Böck <markus.boeck02@gmail.com>2022-07-29 01:00:22 +0200
committerMarkus Böck <markus.boeck02@gmail.com>2022-08-01 17:22:55 +0200
commitbd7eff1f2a7462ffbebc6beb8c7a3fecb1c39350 (patch)
tree9abce0ee8dc8f1d871af1e3063152c7d4a018c9e /mlir/lib
parentcb5d0b41baf2f137f377a8d03481d6a5574a31ec (diff)
[mlir][flang] Make use of the new `GEPArg` builder of GEP Op to simplify code
This is the follow up on https://reviews.llvm.org/D130730 which goes through upstream code and removes creating constant values in favour of using the constant indices in GEP directly. This leads to less and more readable code and more compact IR as well. Differential Revision: https://reviews.llvm.org/D130731
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp4
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp17
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp15
-rw-r--r--mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp45
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp9
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp22
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp5
7 files changed, 38 insertions, 79 deletions
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 159a726cd919..b923b8cec1e2 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -567,10 +567,8 @@ public:
// %Size = getelementptr %T* null, int 1
// %SizeI = ptrtoint %T* %Size to i64
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
- auto one = rewriter.create<LLVM::ConstantOp>(
- loc, i64, rewriter.getI64IntegerAttr(1));
auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
- one.getResult());
+ ArrayRef<LLVM::GEPArg>{1});
return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
};
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 85d1a5234b8f..e115d9cd71da 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -82,12 +82,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// Rewrite workgroup memory attributions to addresses of global buffers.
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
- auto i32Type = IntegerType::get(rewriter.getContext(), 32);
- Value zero = nullptr;
- if (!workgroupBuffers.empty())
- zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
- rewriter.getI32IntegerAttr(0));
for (const auto &en : llvm::enumerate(workgroupBuffers)) {
LLVM::GlobalOp global = en.value();
Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
@@ -95,7 +90,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
global.getType().cast<LLVM::LLVMArrayType>().getElementType();
Value memory = rewriter.create<LLVM::GEPOp>(
loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
- address, ArrayRef<Value>{zero, zero});
+ address, ArrayRef<LLVM::GEPArg>{0, 0});
// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
@@ -170,7 +165,6 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
- mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
@@ -226,10 +220,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
- Value zero = rewriter.create<LLVM::ConstantOp>(
- loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
+ loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringLen = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
@@ -289,7 +281,6 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
- mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
@@ -325,10 +316,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
- Value zero = rewriter.create<LLVM::ConstantOp>(
- loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
+ loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
// Construct arguments and function call
auto argsRange = adaptor.args();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 529efab55892..aaee8211ffd8 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -653,17 +653,14 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
arraySize, /*alignment=*/0);
- auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- builder.getI32IntegerAttr(0));
for (const auto &en : llvm::enumerate(arguments)) {
- auto index = builder.create<LLVM::ConstantOp>(
- loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
auto fieldPtr = builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
- ArrayRef<Value>{zero, index.getResult()});
+ ArrayRef<LLVM::GEPArg>{0, en.index()});
builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
- auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
- arrayPtr, index.getResult());
+ auto elementPtr =
+ builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType, arrayPtr,
+ ArrayRef<LLVM::GEPArg>{en.index()});
auto casted =
builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
@@ -811,8 +808,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
- ArrayRef<Value>{numElements});
+ Value gepPtr =
+ rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, numElements);
auto sizeBytes =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index c9b25f738a00..df21c4281035 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -138,7 +138,6 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
// Copy size values to stack-allocated memory.
- auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
auto one = createIndexAttrConstant(builder, loc, indexType, 1);
auto sizes = builder.create<LLVM::ExtractValueOp>(
loc, arrayTy, value,
@@ -149,7 +148,7 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
// Load an return size value of interest.
auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
- ValueRange({zero, pos}));
+ ArrayRef<LLVM::GEPArg>{0, pos});
return builder.create<LLVM::LoadOp>(loc, resultPtr);
}
@@ -402,10 +401,8 @@ Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
- Value one =
- createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
Value alignedGep = builder.create<LLVM::GEPOp>(
- loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
+ loc, elemPtrPtrType, elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
return builder.create<LLVM::LoadOp>(loc, alignedGep);
}
@@ -417,10 +414,8 @@ void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
- Value one =
- createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
Value alignedGep = builder.create<LLVM::GEPOp>(
- loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
+ loc, elemPtrPtrType, elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
}
@@ -431,10 +426,8 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
- Value two =
- createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
Value offsetGep = builder.create<LLVM::GEPOp>(
- loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
+ loc, elemPtrPtrType, elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
offsetGep = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
return builder.create<LLVM::LoadOp>(loc, offsetGep);
@@ -447,10 +440,8 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
- Value two =
- createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
Value offsetGep = builder.create<LLVM::GEPOp>(
- loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
+ loc, elemPtrPtrType, elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
offsetGep = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
@@ -467,21 +458,16 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
Value structPtr =
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
- Type int32Type = typeConverter.convertType(builder.getI32Type());
- Value zero =
- createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
- Value three = builder.create<LLVM::ConstantOp>(loc, int32Type,
- builder.getI32IntegerAttr(3));
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
- structPtr, ValueRange({zero, three}));
+ structPtr, ArrayRef<LLVM::GEPArg>{0, 3});
}
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value index) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
- Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
- ValueRange({index}));
+ Value sizeStoreGep =
+ builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, index);
return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
}
@@ -490,8 +476,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
Value sizeBasePtr, Value index,
Value size) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
- Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
- ValueRange({index}));
+ Value sizeStoreGep =
+ builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, index);
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
}
@@ -499,8 +485,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value rank) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
- return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
- ValueRange({rank}));
+ return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, rank);
}
Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
@@ -508,8 +493,8 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
Value strideBasePtr, Value index,
Value stride) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
- Value strideStoreGep = builder.create<LLVM::GEPOp>(
- loc, indexPtrTy, strideBasePtr, ValueRange({index}));
+ Value strideStoreGep =
+ builder.create<LLVM::GEPOp>(loc, indexPtrTy, strideBasePtr, index);
return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
}
@@ -518,7 +503,7 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
Value strideBasePtr, Value index,
Value stride) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
- Value strideStoreGep = builder.create<LLVM::GEPOp>(
- loc, indexPtrTy, strideBasePtr, ValueRange({index}));
+ Value strideStoreGep =
+ builder.create<LLVM::GEPOp>(loc, indexPtrTy, strideBasePtr, index);
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 7c99402cc62c..b6288f5c0717 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -163,8 +163,8 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
// Buffer size in bytes.
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
- ArrayRef<Value>{runningStride});
+ Value gepPtr =
+ rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, runningStride);
sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
}
@@ -178,9 +178,8 @@ Value ConvertToLLVMPattern::getSizeInBytes(
auto convertedPtrType =
LLVM::LLVMPointerType::get(typeConverter->convertType(type));
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
- auto gep = rewriter.create<LLVM::GEPOp>(
- loc, convertedPtrType, nullPtr,
- ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)});
+ auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, nullPtr,
+ ArrayRef<LLVM::GEPArg>{1});
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 884931770fde..18747df79e47 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -389,19 +389,15 @@ private:
// Get pointer to offset field of memref<element_type> descriptor.
Type indexPtrTy = LLVM::LLVMPointerType::get(
getTypeConverter()->getIndexType(), addressSpace);
- Value two = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(rewriter.getI32Type()),
- rewriter.getI32IntegerAttr(2));
Value offsetPtr = rewriter.create<LLVM::GEPOp>(
- loc, indexPtrTy, scalarMemRefDescPtr,
- ValueRange({createIndexConstant(rewriter, loc, 0), two}));
+ loc, indexPtrTy, scalarMemRefDescPtr, ArrayRef<LLVM::GEPArg>{0, 2});
// The size value that we have to extract can be obtained using GEPop with
// `dimOp.index() + 1` index argument.
Value idxPlusOne = rewriter.create<LLVM::AddOp>(
loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
- Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
- ValueRange({idxPlusOne}));
+ Value sizePtr =
+ rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, idxPlusOne);
return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
}
@@ -664,11 +660,9 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
Type elementType = typeConverter->convertType(type.getElementType());
Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
- SmallVector<Value> operands;
- operands.insert(operands.end(), type.getRank() + 1,
- createIndexConstant(rewriter, loc, 0));
- auto gep =
- rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
+ auto gep = rewriter.create<LLVM::GEPOp>(
+ loc, elementPtrType, addressOf,
+ SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
// We do not expect the memref obtained using `memref.get_global` to be
// ever deallocated. Set the allocated pointer to be known bad value to
@@ -1286,8 +1280,8 @@ private:
// Copy size from shape to descriptor.
Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
- Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
- loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
+ Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(loc, llvmIndexPtrType,
+ shapeOperandPtr, indexArg);
Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
targetSizesBase, indexArg, size);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 33fe8902b977..8ff803f0fd7d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2981,12 +2981,9 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
- Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, IntegerType::get(ctx, 64),
- builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
- ValueRange{cst0, cst0});
+ ArrayRef<GEPArg>{0, 0});
}
bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {