aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp')
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp15
1 files changed, 6 insertions, 9 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 529efab55892..aaee8211ffd8 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -653,17 +653,14 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
arraySize, /*alignment=*/0);
- auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- builder.getI32IntegerAttr(0));
for (const auto &en : llvm::enumerate(arguments)) {
- auto index = builder.create<LLVM::ConstantOp>(
- loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
auto fieldPtr = builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
- ArrayRef<Value>{zero, index.getResult()});
+ ArrayRef<LLVM::GEPArg>{0, en.index()});
builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
- auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
- arrayPtr, index.getResult());
+ auto elementPtr =
+ builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType, arrayPtr,
+ ArrayRef<LLVM::GEPArg>{en.index()});
auto casted =
builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
@@ -811,8 +808,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
- ArrayRef<Value>{numElements});
+ Value gepPtr =
+ rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, numElements);
auto sizeBytes =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);