aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2020-04-07 13:55:34 -0700
committerRiver Riddle <riddleriver@gmail.com>2020-04-07 14:08:52 -0700
commit722f909f7aa1d5ab21f68eb8ce1baf109cc5bb13 (patch)
treed6262e46b68a7d8cd6cf4f02c520837ac4bcffd8
parent2481f26ac3f228cc085d4d68ee72dadc07afa48f (diff)
[mlir][Pass][NFC] Replace usages of ModulePass with OperationPass<ModuleOp>
ModulePass doesn't provide any special utilities and thus doesn't give enough benefit to warrant a special pass class. This revision replaces all usages with the more general OperationPass. Differential Revision: https://reviews.llvm.org/D77339
-rw-r--r--mlir/docs/Tutorials/Toy/Ch-6.md2
-rw-r--r--mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp9
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp9
-rw-r--r--mlir/include/mlir/Pass/Pass.h17
-rw-r--r--mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp11
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp25
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp8
-rw-r--r--mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp17
-rw-r--r--mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp12
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp8
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp8
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp8
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp13
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp15
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp8
-rw-r--r--mlir/lib/Transforms/OpStats.cpp8
-rw-r--r--mlir/lib/Transforms/ViewOpGraph.cpp4
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp11
-rw-r--r--mlir/test/lib/IR/TestFunc.cpp12
-rw-r--r--mlir/test/lib/IR/TestSideEffects.cpp6
-rw-r--r--mlir/test/lib/IR/TestSymbolUses.cpp13
-rw-r--r--mlir/test/lib/Pass/TestPassManager.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestAllReduceLowering.cpp6
-rw-r--r--mlir/test/lib/Transforms/TestCallGraph.cpp6
-rw-r--r--mlir/test/lib/Transforms/TestOpaqueLoc.cpp8
26 files changed, 124 insertions, 133 deletions
diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 0444d2a7690a..e1dfc0039f8e 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -105,7 +105,7 @@ We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
that only legal operations will remain after the conversion.
```c++
- mlir::ModuleOp module = getModule();
+ mlir::ModuleOp module = getOperation();
if (mlir::failed(mlir::applyFullConversion(module, target, patterns,
&typeConverter)))
signalPassFailure();
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index f6dcba229276..99465d3201e5 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -153,12 +153,13 @@ private:
//===----------------------------------------------------------------------===//
namespace {
-struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
- void runOnModule() final;
+struct ToyToLLVMLoweringPass
+ : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
+ void runOnOperation() final;
};
} // end anonymous namespace
-void ToyToLLVMLoweringPass::runOnModule() {
+void ToyToLLVMLoweringPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering. For this lowering, we are only targeting
// the LLVM dialect.
@@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
- auto module = getModule();
+ auto module = getOperation();
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
signalPassFailure();
}
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index f6dcba229276..99465d3201e5 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -153,12 +153,13 @@ private:
//===----------------------------------------------------------------------===//
namespace {
-struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
- void runOnModule() final;
+struct ToyToLLVMLoweringPass
+ : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
+ void runOnOperation() final;
};
} // end anonymous namespace
-void ToyToLLVMLoweringPass::runOnModule() {
+void ToyToLLVMLoweringPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering. For this lowering, we are only targeting
// the LLVM dialect.
@@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
- auto module = getModule();
+ auto module = getOperation();
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
signalPassFailure();
}
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 80c4ddfeae62..c1eec4f4706a 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -341,24 +341,9 @@ template <typename T> struct FunctionPass : public OperationPass<T, FuncOp> {
runOnFunction();
}
- /// Return the current module being transformed.
+ /// Return the current function being transformed.
FuncOp getFunction() { return this->getOperation(); }
};
-
-/// A model for providing module pass specific utilities.
-///
-/// Derived module passes are expected to provide the following:
-/// - A 'void runOnModule()' method.
-template <typename T> struct ModulePass : public OperationPass<T, ModuleOp> {
- /// The polymorphic API that runs the pass over the currently held module.
- virtual void runOnModule() = 0;
-
- /// The polymorphic API that runs the pass over the currently held operation.
- void runOnOperation() final { runOnModule(); }
-
- /// Return the current module being transformed.
- ModuleOp getModule() { return this->getOperation(); }
-};
} // end namespace mlir
#endif // MLIR_PASS_PASS_H
diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index 91f3cc933a02..08b187fc835e 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -163,16 +163,17 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
}
namespace {
-struct ConvertAVX512ToLLVMPass : public ModulePass<ConvertAVX512ToLLVMPass> {
+struct ConvertAVX512ToLLVMPass
+ : public OperationPass<ConvertAVX512ToLLVMPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertAVX512ToLLVM
#include "mlir/Conversion/Passes.h.inc"
- void runOnModule() override;
+ void runOnOperation() override;
};
} // namespace
-void ConvertAVX512ToLLVMPass::runOnModule() {
+void ConvertAVX512ToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
@@ -186,8 +187,8 @@ void ConvertAVX512ToLLVMPass::runOnModule() {
target.addIllegalDialect<avx512::AVX512Dialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
- if (failed(
- applyPartialConversion(getModule(), target, patterns, &converter))) {
+ if (failed(applyPartialConversion(getOperation(), target, patterns,
+ &converter))) {
signalPassFailure();
}
}
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index 38c092a2eaf0..71fe129d3875 100644
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -61,7 +61,7 @@ namespace {
///
/// Intermediate data structures are allocated on the stack.
class GpuLaunchFuncToCudaCallsPass
- : public ModulePass<GpuLaunchFuncToCudaCallsPass> {
+ : public OperationPass<GpuLaunchFuncToCudaCallsPass, ModuleOp> {
private:
/// Include the generated pass utilities.
#define GEN_PASS_ConvertGpuLaunchFuncToCudaCalls
@@ -126,20 +126,19 @@ private:
public:
// Run the dialect converter on the module.
- void runOnModule() override {
+ void runOnOperation() override {
// Cache the LLVMDialect for the current module.
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
// Cache the used LLVM types.
initializeCachedTypes();
- getModule().walk([this](mlir::gpu::LaunchFuncOp op) {
- translateGpuLaunchCalls(op);
- });
+ getOperation().walk(
+ [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
// GPU kernel modules are no longer necessary since we have a global
// constant with the CUBIN data.
for (auto m :
- llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
+ llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
m.erase();
}
@@ -160,7 +159,7 @@ private:
// The types in comments give the actual types expected/returned but the API
// uses void pointers. This is fine as they have the same linkage in C.
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
- ModuleOp module = getModule();
+ ModuleOp module = getOperation();
OpBuilder builder(module.getBody()->getTerminator());
if (!module.lookupSymbol(cuModuleLoadName)) {
builder.create<LLVM::LLVMFuncOp>(
@@ -391,7 +390,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
builder.getI32IntegerAttr(0));
// Create an LLVM global with CUBIN extracted from the kernel annotation and
// obtain a pointer to the first byte in it.
- auto kernelModule = getModule().lookupSymbol<gpu::GPUModuleOp>(
+ auto kernelModule = getOperation().lookupSymbol<gpu::GPUModuleOp>(
launchOp.getKernelModuleName());
assert(kernelModule && "expected a kernel module");
@@ -412,7 +411,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// in the called helper function.
auto cuModule = allocatePointer(builder, loc);
auto cuModuleLoad =
- getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
+ getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleLoad),
ArrayRef<Value>{cuModule, data});
@@ -423,20 +422,20 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
auto cuFunction = allocatePointer(builder, loc);
auto cuModuleGetFunction =
- getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
+ getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleGetFunction),
ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName});
// Grab the global stream needed for execution.
auto cuGetStreamHelper =
- getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
+ getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{});
// Invoke the function with required arguments.
auto cuLaunchKernel =
- getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
+ getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
auto cuFunctionRef =
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
auto paramsArray = setupParamsArray(launchOp, builder);
@@ -458,7 +457,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
nullpointer /* extra */});
// Sync on the stream to make it synchronous.
auto cuStreamSync =
- getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
+ getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuStreamSync),
ArrayRef<Value>(cuStream.getResult(0)));
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index 1102ef182c5f..edee5025ded9 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -33,18 +33,18 @@ namespace {
/// replace it).
///
/// 2) Lower the body of the spirv::ModuleOp.
-struct GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+struct GPUToSPIRVPass : public OperationPass<GPUToSPIRVPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertGpuToSPIRV
#include "mlir/Conversion/Passes.h.inc"
- void runOnModule() override;
+ void runOnOperation() override;
};
} // namespace
-void GPUToSPIRVPass::runOnModule() {
+void GPUToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getModule();
+ ModuleOp module = getOperation();
SmallVector<Operation *, 1> kernelModules;
OpBuilder builder(context);
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 823860ba2589..cbcfd741d9f8 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -38,13 +38,13 @@ namespace {
/// function and attaching binary data and entry point name as an attributes to
/// created vulkan launch call op.
class ConvertGpuLaunchFuncToVulkanLaunchFunc
- : public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
+ : public OperationPass<ConvertGpuLaunchFuncToVulkanLaunchFunc, ModuleOp> {
public:
/// Include the generated pass utilities.
#define GEN_PASS_ConvertGpuLaunchFuncToVulkanLaunchFunc
#include "mlir/Conversion/Passes.h.inc"
- void runOnModule() override;
+ void runOnOperation() override;
private:
/// Creates a SPIR-V binary shader from the given `module` using
@@ -68,14 +68,13 @@ private:
/// operand is unsupported by Vulkan runtime.
LogicalResult declareVulkanLaunchFunc(Location loc,
gpu::LaunchFuncOp launchOp);
-
};
} // anonymous namespace
-void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
+void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
bool done = false;
- getModule().walk([this, &done](gpu::LaunchFuncOp op) {
+ getOperation().walk([this, &done](gpu::LaunchFuncOp op) {
if (done) {
op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
return signalPassFailure();
@@ -86,17 +85,17 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
// Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
for (auto gpuModule :
- llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
+ llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
gpuModule.erase();
for (auto spirvModule :
- llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
+ llvm::make_early_inc_range(getOperation().getOps<spirv::ModuleOp>()))
spirvModule.erase();
}
LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
Location loc, gpu::LaunchFuncOp launchOp) {
- OpBuilder builder(getModule().getBody()->getTerminator());
+ OpBuilder builder(getOperation().getBody()->getTerminator());
// TODO: Workgroup size is written into the kernel. So to properly modelling
// vulkan launch, we cannot have the local workgroup size configuration here.
SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()};
@@ -138,7 +137,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
gpu::LaunchFuncOp launchOp) {
- ModuleOp module = getModule();
+ ModuleOp module = getOperation();
OpBuilder builder(launchOp);
Location loc = launchOp.getLoc();
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index ebc8ded483ff..2daa13085bcb 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -58,7 +58,7 @@ namespace {
/// * deinitVulkan -- deinitializes vulkan runtime
///
class VulkanLaunchFuncToVulkanCallsPass
- : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
+ : public OperationPass<VulkanLaunchFuncToVulkanCallsPass, ModuleOp> {
private:
/// Include the generated pass utilities.
#define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls
@@ -150,7 +150,7 @@ private:
LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
public:
- void runOnModule() override;
+ void runOnOperation() override;
private:
LLVM::LLVMDialect *llvmDialect;
@@ -169,18 +169,18 @@ private:
} // anonymous namespace
-void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
+void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
initializeCachedTypes();
// Collect SPIR-V attributes such as `spirv_blob` and
// `spirv_entry_point_name`.
- getModule().walk([this](LLVM::CallOp op) {
+ getOperation().walk([this](LLVM::CallOp op) {
if (isVulkanLaunchCallOp(op))
collectSPIRVAttributes(op);
});
// Convert vulkan launch call op into a sequence of Vulkan runtime calls.
- getModule().walk([this](LLVM::CallOp op) {
+ getOperation().walk([this](LLVM::CallOp op) {
if (isCInterfaceVulkanLaunchCallOp(op))
translateVulkanLaunchCall(op);
});
@@ -278,7 +278,7 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
}
void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
- ModuleOp module = getModule();
+ ModuleOp module = getOperation();
OpBuilder builder(module.getBody()->getTerminator());
if (!module.lookupSymbol(kSetEntryPoint)) {
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 07c8111941e4..99f106e29de7 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -561,17 +561,18 @@ void mlir::populateLinalgToLLVMConversionPatterns(
}
namespace {
-struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
+struct ConvertLinalgToLLVMPass
+ : public OperationPass<ConvertLinalgToLLVMPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertLinalgToLLVM
#include "mlir/Conversion/Passes.h.inc"
- void runOnModule() override;
+ void runOnOperation() override;
};
} // namespace
-void ConvertLinalgToLLVMPass::runOnModule() {
- auto module = getModule();
+void ConvertLinalgToLLVMPass::runOnOperation() {
+ auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
index 0962746c486a..4b66063b88eb 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -16,18 +16,18 @@ using namespace mlir;
namespace {
/// A pass converting MLIR Linalg ops into SPIR-V ops.
-class LinalgToSPIRVPass : public ModulePass<LinalgToSPIRVPass> {
+class LinalgToSPIRVPass : public OperationPass<LinalgToSPIRVPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertLinalgToSPIRV
#include "mlir/Conversion/Passes.h.inc"
- void runOnModule() override;
+ void runOnOperation() override;
};
} // namespace
-void LinalgToSPIRVPass::runOnModule() {
+void LinalgToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getModule();
+ ModuleOp module = getOperation();
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 1e127a0a884e..ef5dabf2ff88 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2847,7 +2847,7 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
-struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
+struct LLVMLoweringPass : public OperationPass<LLVMLoweringPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertStandardToLLVM
#include "mlir/Conversion/Passes.h.inc"
@@ -2863,16 +2863,16 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
LLVMLoweringPass(const LLVMLoweringPass &pass) {}
/// Run the dialect converter on the module.
- void runOnModule() override {
+ void runOnOperation() override {
if (useBarePtrCallConv && emitCWrappers) {
- getModule().emitError()
+ getOperation().emitError()
<< "incompatible conversion options: bare-pointer calling convention "
"and C wrapper emission";
signalPassFailure();
return;
}
- ModuleOp m = getModule();
+ ModuleOp m = getOperation();
LLVMTypeConverterCustomization customs;
customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
index ab7dd8546995..86c8cd17433c 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
@@ -22,18 +22,18 @@ using namespace mlir;
namespace {
/// A pass converting MLIR Standard operations into the SPIR-V dialect.
class ConvertStandardToSPIRVPass
- : public ModulePass<ConvertStandardToSPIRVPass> {
+ : public OperationPass<ConvertStandardToSPIRVPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertStandardToSPIRV
#include "mlir/Conversion/Passes.h.inc"
- void runOnModule() override;
+ void runOnOperation() override;
};
} // namespace
-void ConvertStandardToSPIRVPass::runOnModule() {
+void ConvertStandardToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getModule();
+ ModuleOp module = getOperation();
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d5a4f86d2ca9..b2a1c443f518 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1118,23 +1118,24 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
}
namespace {
-struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
+struct LowerVectorToLLVMPass
+ : public OperationPass<LowerVectorToLLVMPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertVectorToLLVM
#include "mlir/Conversion/Passes.h.inc"
- void runOnModule() override;
+ void runOnOperation() override;
};
} // namespace
-void LowerVectorToLLVMPass::runOnModule() {
+void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
OwningRewritePatternList patterns;
populateVectorSlicesLoweringPatterns(patterns, &getContext());
populateVectorContractLoweringPatterns(patterns, &getContext());
- applyPatternsGreedily(getModule(), patterns);
+ applyPatternsGreedily(getOperation(), patterns);
}
// Convert to the LLVM IR dialect.
@@ -1148,8 +1149,8 @@ void LowerVectorToLLVMPass::runOnModule() {
LLVMConversionTarget target(getContext());
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
- if (failed(
- applyPartialConversion(getModule(), target, patterns, &converter))) {
+ if (failed(applyPartialConversion(getOperation(), target, patterns,
+ &converter))) {
signalPassFailure();
}
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 2eadf87f038a..daf9169d242c 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -214,16 +214,17 @@ namespace {
/// The gpu.modules are intended to be compiled to a cubin blob independently in
/// a separate pass. The external functions can then be annotated with the
/// symbol of the cubin accessor function.
-class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
+class GpuKernelOutliningPass
+ : public OperationPass<GpuKernelOutliningPass, ModuleOp> {
public:
/// Include the generated pass utilities.
#define GEN_PASS_GpuKernelOutlining
#include "mlir/Dialect/GPU/Passes.h.inc"
- void runOnModule() override {
- SymbolTable symbolTable(getModule());
+ void runOnOperation() override {
+ SymbolTable symbolTable(getOperation());
bool modified = false;
- for (auto func : getModule().getOps<FuncOp>()) {
+ for (auto func : getOperation().getOps<FuncOp>()) {
// Insert just after the function.
Block::iterator insertPt(func.getOperation()->getNextNode());
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
@@ -255,8 +256,8 @@ public:
// If any new module was inserted in this module, annotate this module as
// a container module.
if (modified)
- getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
- UnitAttr::get(&getContext()));
+ getOperation().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
+ UnitAttr::get(&getContext()));
}
private:
@@ -267,7 +268,7 @@ private:
// a SymbolTable by the caller. SymbolTable needs to be refactored to
// prevent manual building of Ops with symbols in code using SymbolTables
// and then this needs to use the OpBuilder.
- auto context = getModule().getContext();
+ auto context = getOperation().getContext();
Builder builder(context);
OperationState state(kernelFunc.getLoc(),
gpu::GPUModuleOp::getOperationName());
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
index 79ed81956f08..e4622741536e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
@@ -80,14 +80,14 @@ static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
namespace {
class DecorateSPIRVCompositeTypeLayoutPass
- : public ModulePass<DecorateSPIRVCompositeTypeLayoutPass> {
+ : public OperationPass<DecorateSPIRVCompositeTypeLayoutPass, ModuleOp> {
private:
- void runOnModule() override;
+ void runOnOperation() override;
};
} // namespace
-void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() {
- auto module = getModule();
+void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
+ auto module = getOperation();
OwningRewritePatternList patterns;
populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
ConversionTarget target(*(module.getContext()));
diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp
index b7832f580dd4..2b519d697020 100644
--- a/mlir/lib/Transforms/OpStats.cpp
+++ b/mlir/lib/Transforms/OpStats.cpp
@@ -18,7 +18,7 @@
using namespace mlir;
namespace {
-struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
+struct PrintOpStatsPass : public OperationPass<PrintOpStatsPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_PrintOpStats
#include "mlir/Transforms/Passes.h.inc"
@@ -26,7 +26,7 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
// Prints the resultant operation statistics post iterating over the module.
- void runOnModule() override;
+ void runOnOperation() override;
// Print summary of op stats.
void printSummary();
@@ -37,11 +37,11 @@ private:
};
} // namespace
-void PrintOpStatsPass::runOnModule() {
+void PrintOpStatsPass::runOnOperation() {
opCount.clear();
// Compute the operation statistics for each function in the module.
- for (auto &op : getModule())
+ for (auto &op : getOperation())
op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
printSummary();
}
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index fcaff9a0b069..c5d921db059e 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -100,7 +100,7 @@ namespace {
// PrintOpPass is simple pass to write graph per function.
// Note: this is a module pass only to avoid interleaving on the same ostream
// due to multi-threading over functions.
-struct PrintOpPass : public ModulePass<PrintOpPass> {
+struct PrintOpPass : public OperationPass<PrintOpPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_PrintOpGraph
#include "mlir/Transforms/Passes.h.inc"
@@ -140,7 +140,7 @@ struct PrintOpPass : public ModulePass<PrintOpPass> {
}
}
- void runOnModule() override { processModule(getModule()); }
+ void runOnOperation() override { processModule(getOperation()); }
private:
raw_ostream &os;
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e6cc52d29722..6ccfa04a8194 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -398,13 +398,13 @@ struct TestTypeConverter : public TypeConverter {
};
struct TestLegalizePatternDriver
- : public ModulePass<TestLegalizePatternDriver> {
+ : public OperationPass<TestLegalizePatternDriver, ModuleOp> {
/// The mode of conversion to use with the driver.
enum class ConversionMode { Analysis, Full, Partial };
TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
- void runOnModule() override {
+ void runOnOperation() override {
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
@@ -450,7 +450,8 @@ struct TestLegalizePatternDriver
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
- (void)applyPartialConversion(getModule(), target, patterns, &converter);
+ (void)applyPartialConversion(getOperation(), target, patterns,
+ &converter);
return;
}
@@ -461,7 +462,7 @@ struct TestLegalizePatternDriver
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
});
- (void)applyFullConversion(getModule(), target, patterns, &converter);
+ (void)applyFullConversion(getOperation(), target, patterns, &converter);
return;
}
@@ -470,7 +471,7 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
- if (failed(applyAnalysisConversion(getModule(), target, patterns,
+ if (failed(applyAnalysisConversion(getOperation(), target, patterns,
legalizedOps, &converter)))
return signalPassFailure();
diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index 0e885c555e38..c1b90397ec44 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -13,9 +13,9 @@ using namespace mlir;
namespace {
/// This is a test pass for verifying FuncOp's eraseArgument method.
-struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
- void runOnModule() override {
- auto module = getModule();
+struct TestFuncEraseArg : public OperationPass<TestFuncEraseArg, ModuleOp> {
+ void runOnOperation() override {
+ auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
SmallVector<unsigned, 4> indicesToErase;
@@ -36,9 +36,9 @@ struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
};
/// This is a test pass for verifying FuncOp's setType method.
-struct TestFuncSetType : public ModulePass<TestFuncSetType> {
- void runOnModule() override {
- auto module = getModule();
+struct TestFuncSetType : public OperationPass<TestFuncSetType, ModuleOp> {
+ void runOnOperation() override {
+ auto module = getOperation();
SymbolTable symbolTable(module);
for (FuncOp func : module.getOps<FuncOp>()) {
diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp
index 9f52c42e4953..a99348537e25 100644
--- a/mlir/test/lib/IR/TestSideEffects.cpp
+++ b/mlir/test/lib/IR/TestSideEffects.cpp
@@ -12,9 +12,9 @@
using namespace mlir;
namespace {
-struct SideEffectsPass : public ModulePass<SideEffectsPass> {
- void runOnModule() override {
- auto module = getModule();
+struct SideEffectsPass : public OperationPass<SideEffectsPass, ModuleOp> {
+ void runOnOperation() override {
+ auto module = getOperation();
// Walk operations detecting side effects.
SmallVector<MemoryEffects::EffectInstance, 8> effects;
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 6082cdcbe72b..c39615ef1352 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -15,7 +15,7 @@ using namespace mlir;
namespace {
/// This is a symbol test pass that tests the symbol uselist functionality
/// provided by the symbol table along with erasing from the symbol table.
-struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
+struct SymbolUsesPass : public OperationPass<SymbolUsesPass, ModuleOp> {
WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
SmallVectorImpl<FuncOp> &deadFunctions) {
// Test computing uses on a non symboltable op.
@@ -59,8 +59,8 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
return WalkResult::advance();
}
- void runOnModule() override {
- auto module = getModule();
+ void runOnOperation() override {
+ auto module = getOperation();
// Walk nested symbols.
SmallVector<FuncOp, 4> deadFunctions;
@@ -86,9 +86,10 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
/// This is a symbol test pass that tests the symbol use replacement
/// functionality provided by the symbol table.
-struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> {
- void runOnModule() override {
- auto module = getModule();
+struct SymbolReplacementPass
+ : public OperationPass<SymbolReplacementPass, ModuleOp> {
+ void runOnOperation() override {
+ auto module = getOperation();
// Walk nested functions and modules.
module.getBodyRegion().walk([&](Operation *nestedOp) {
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 95bef9b878e2..be8a7479200c 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -13,8 +13,8 @@
using namespace mlir;
namespace {
-struct TestModulePass : public ModulePass<TestModulePass> {
- void runOnModule() final {}
+struct TestModulePass : public OperationPass<TestModulePass, ModuleOp> {
+ void runOnOperation() final {}
};
struct TestFunctionPass : public FunctionPass<TestFunctionPass> {
void runOnFunction() final {}
diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
index 508f70887350..6455dab70f45 100644
--- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
+++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
@@ -18,11 +18,11 @@ using namespace mlir;
namespace {
struct TestAllReduceLoweringPass
- : public ModulePass<TestAllReduceLoweringPass> {
- void runOnModule() override {
+ : public OperationPass<TestAllReduceLoweringPass, ModuleOp> {
+ void runOnOperation() override {
OwningRewritePatternList patterns;
populateGpuRewritePatterns(&getContext(), patterns);
- applyPatternsGreedily(getModule(), patterns);
+ applyPatternsGreedily(getOperation(), patterns);
}
};
} // namespace
diff --git a/mlir/test/lib/Transforms/TestCallGraph.cpp b/mlir/test/lib/Transforms/TestCallGraph.cpp
index 89c25da9e8ed..a181d645f2af 100644
--- a/mlir/test/lib/Transforms/TestCallGraph.cpp
+++ b/mlir/test/lib/Transforms/TestCallGraph.cpp
@@ -17,9 +17,9 @@
using namespace mlir;
namespace {
-struct TestCallGraphPass : public ModulePass<TestCallGraphPass> {
- void runOnModule() {
- llvm::errs() << "Testing : " << getModule().getAttr("test.name") << "\n";
+struct TestCallGraphPass : public OperationPass<TestCallGraphPass, ModuleOp> {
+ void runOnOperation() override {
+ llvm::errs() << "Testing : " << getOperation().getAttr("test.name") << "\n";
getAnalysis<CallGraph>().print(llvm::errs());
}
};
diff --git a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
index baae5297306d..47152c459805 100644
--- a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
+++ b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
@@ -17,7 +17,7 @@ namespace {
/// It also takes all operations that are not function operations or
/// terminators and clones them with opaque locations which store the initial
/// locations.
-struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
+struct TestOpaqueLoc : public OperationPass<TestOpaqueLoc, ModuleOp> {
/// A simple structure which is used for testing as an underlying location in
/// OpaqueLoc.
@@ -29,11 +29,11 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
int id;
};
- void runOnModule() override {
+ void runOnOperation() override {
std::vector<std::unique_ptr<MyLocation>> myLocs;
int last_it = 0;
- getModule().walk([&](Operation *op) {
+ getOperation().walk([&](Operation *op) {
myLocs.push_back(std::make_unique<MyLocation>(last_it++));
Location loc = op->getLoc();
@@ -74,7 +74,7 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
os.flush();
});
- getModule().walk([&](Operation *op) { op->emitOpError(); });
+ getOperation().walk([&](Operation *op) { op->emitOpError(); });
}
};