diff options
author | River Riddle <riddleriver@gmail.com> | 2020-04-07 13:55:34 -0700 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2020-04-07 14:08:52 -0700 |
commit | 722f909f7aa1d5ab21f68eb8ce1baf109cc5bb13 (patch) | |
tree | d6262e46b68a7d8cd6cf4f02c520837ac4bcffd8 | |
parent | 2481f26ac3f228cc085d4d68ee72dadc07afa48f (diff) |
[mlir][Pass][NFC] Replace usages of ModulePass with OperationPass<ModuleOp>
ModulePass doesn't provide any special utilities and thus doesn't give enough benefit to warrant a special pass class. This revision replaces all usages with the more general OperationPass.
Differential Revision: https://reviews.llvm.org/D77339
26 files changed, 124 insertions, 133 deletions
diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md index 0444d2a7690a..e1dfc0039f8e 100644 --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -105,7 +105,7 @@ We want to completely lower to LLVM, so we use a `FullConversion`. This ensures that only legal operations will remain after the conversion. ```c++ - mlir::ModuleOp module = getModule(); + mlir::ModuleOp module = getOperation(); if (mlir::failed(mlir::applyFullConversion(module, target, patterns, &typeConverter))) signalPassFailure(); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index f6dcba229276..99465d3201e5 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -153,12 +153,13 @@ private: //===----------------------------------------------------------------------===// namespace { -struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> { - void runOnModule() final; +struct ToyToLLVMLoweringPass + : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> { + void runOnOperation() final; }; } // end anonymous namespace -void ToyToLLVMLoweringPass::runOnModule() { +void ToyToLLVMLoweringPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. @@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() { // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. - auto module = getModule(); + auto module = getOperation(); if (failed(applyFullConversion(module, target, patterns, &typeConverter))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index f6dcba229276..99465d3201e5 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -153,12 +153,13 @@ private: //===----------------------------------------------------------------------===// namespace { -struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> { - void runOnModule() final; +struct ToyToLLVMLoweringPass + : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> { + void runOnOperation() final; }; } // end anonymous namespace -void ToyToLLVMLoweringPass::runOnModule() { +void ToyToLLVMLoweringPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. @@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() { // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. - auto module = getModule(); + auto module = getOperation(); if (failed(applyFullConversion(module, target, patterns, &typeConverter))) signalPassFailure(); } diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 80c4ddfeae62..c1eec4f4706a 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -341,24 +341,9 @@ template <typename T> struct FunctionPass : public OperationPass<T, FuncOp> { runOnFunction(); } - /// Return the current module being transformed. + /// Return the current function being transformed. FuncOp getFunction() { return this->getOperation(); } }; - -/// A model for providing module pass specific utilities. -/// -/// Derived module passes are expected to provide the following: -/// - A 'void runOnModule()' method. -template <typename T> struct ModulePass : public OperationPass<T, ModuleOp> { - /// The polymorphic API that runs the pass over the currently held module. - virtual void runOnModule() = 0; - - /// The polymorphic API that runs the pass over the currently held operation. - void runOnOperation() final { runOnModule(); } - - /// Return the current module being transformed. - ModuleOp getModule() { return this->getOperation(); } -}; } // end namespace mlir #endif // MLIR_PASS_PASS_H diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp index 91f3cc933a02..08b187fc835e 100644 --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -163,16 +163,17 @@ void mlir::populateAVX512ToLLVMConversionPatterns( } namespace { -struct ConvertAVX512ToLLVMPass : public ModulePass<ConvertAVX512ToLLVMPass> { +struct ConvertAVX512ToLLVMPass + : public OperationPass<ConvertAVX512ToLLVMPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_ConvertAVX512ToLLVM #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void ConvertAVX512ToLLVMPass::runOnModule() { +void ConvertAVX512ToLLVMPass::runOnOperation() { // Convert to the LLVM IR dialect. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); @@ -186,8 +187,8 @@ void ConvertAVX512ToLLVMPass::runOnModule() { target.addIllegalDialect<avx512::AVX512Dialect>(); target.addDynamicallyLegalOp<FuncOp>( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed( - applyPartialConversion(getModule(), target, patterns, &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns, + &converter))) { signalPassFailure(); } } diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index 38c092a2eaf0..71fe129d3875 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -61,7 +61,7 @@ namespace { /// /// Intermediate data structures are allocated on the stack. class GpuLaunchFuncToCudaCallsPass - : public ModulePass<GpuLaunchFuncToCudaCallsPass> { + : public OperationPass<GpuLaunchFuncToCudaCallsPass, ModuleOp> { private: /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuLaunchFuncToCudaCalls @@ -126,20 +126,19 @@ private: public: // Run the dialect converter on the module. - void runOnModule() override { + void runOnOperation() override { // Cache the LLVMDialect for the current module. llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); // Cache the used LLVM types. initializeCachedTypes(); - getModule().walk([this](mlir::gpu::LaunchFuncOp op) { - translateGpuLaunchCalls(op); - }); + getOperation().walk( + [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); }); // GPU kernel modules are no longer necessary since we have a global // constant with the CUBIN data. for (auto m : - llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>())) + llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>())) m.erase(); } @@ -160,7 +159,7 @@ private: // The types in comments give the actual types expected/returned but the API // uses void pointers. This is fine as they have the same linkage in C. void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { - ModuleOp module = getModule(); + ModuleOp module = getOperation(); OpBuilder builder(module.getBody()->getTerminator()); if (!module.lookupSymbol(cuModuleLoadName)) { builder.create<LLVM::LLVMFuncOp>( @@ -391,7 +390,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.getI32IntegerAttr(0)); // Create an LLVM global with CUBIN extracted from the kernel annotation and // obtain a pointer to the first byte in it. - auto kernelModule = getModule().lookupSymbol<gpu::GPUModuleOp>( + auto kernelModule = getOperation().lookupSymbol<gpu::GPUModuleOp>( launchOp.getKernelModuleName()); assert(kernelModule && "expected a kernel module"); @@ -412,7 +411,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( // in the called helper function. auto cuModule = allocatePointer(builder, loc); auto cuModuleLoad = - getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName); + getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName); builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()}, builder.getSymbolRefAttr(cuModuleLoad), ArrayRef<Value>{cuModule, data}); @@ -423,20 +422,20 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder); auto cuFunction = allocatePointer(builder, loc); auto cuModuleGetFunction = - getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName); + getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName); builder.create<LLVM::CallOp>( loc, ArrayRef<Type>{getCUResultType()}, builder.getSymbolRefAttr(cuModuleGetFunction), ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName}); // Grab the global stream needed for execution. auto cuGetStreamHelper = - getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName); + getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName); auto cuStream = builder.create<LLVM::CallOp>( loc, ArrayRef<Type>{getPointerType()}, builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{}); // Invoke the function with required arguments. auto cuLaunchKernel = - getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName); + getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName); auto cuFunctionRef = builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction); auto paramsArray = setupParamsArray(launchOp, builder); @@ -458,7 +457,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( nullpointer /* extra */}); // Sync on the stream to make it synchronous. auto cuStreamSync = - getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName); + getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName); builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()}, builder.getSymbolRefAttr(cuStreamSync), ArrayRef<Value>(cuStream.getResult(0))); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp index 1102ef182c5f..edee5025ded9 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -33,18 +33,18 @@ namespace { /// replace it). /// /// 2) Lower the body of the spirv::ModuleOp. -struct GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> { +struct GPUToSPIRVPass : public OperationPass<GPUToSPIRVPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuToSPIRV #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void GPUToSPIRVPass::runOnModule() { +void GPUToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); - ModuleOp module = getModule(); + ModuleOp module = getOperation(); SmallVector<Operation *, 1> kernelModules; OpBuilder builder(context); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp index 823860ba2589..cbcfd741d9f8 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -38,13 +38,13 @@ namespace { /// function and attaching binary data and entry point name as an attributes to /// created vulkan launch call op. class ConvertGpuLaunchFuncToVulkanLaunchFunc - : public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> { + : public OperationPass<ConvertGpuLaunchFuncToVulkanLaunchFunc, ModuleOp> { public: /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuLaunchFuncToVulkanLaunchFunc #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; private: /// Creates a SPIR-V binary shader from the given `module` using @@ -68,14 +68,13 @@ private: /// operand is unsupported by Vulkan runtime. LogicalResult declareVulkanLaunchFunc(Location loc, gpu::LaunchFuncOp launchOp); - }; } // anonymous namespace -void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() { +void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() { bool done = false; - getModule().walk([this, &done](gpu::LaunchFuncOp op) { + getOperation().walk([this, &done](gpu::LaunchFuncOp op) { if (done) { op.emitError("should only contain one 'gpu::LaunchFuncOp' op"); return signalPassFailure(); @@ -86,17 +85,17 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() { // Erase `gpu::GPUModuleOp` and `spirv::Module` operations. for (auto gpuModule : - llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>())) + llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>())) gpuModule.erase(); for (auto spirvModule : - llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>())) + llvm::make_early_inc_range(getOperation().getOps<spirv::ModuleOp>())) spirvModule.erase(); } LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc( Location loc, gpu::LaunchFuncOp launchOp) { - OpBuilder builder(getModule().getBody()->getTerminator()); + OpBuilder builder(getOperation().getBody()->getTerminator()); // TODO: Workgroup size is written into the kernel. So to properly modelling // vulkan launch, we cannot have the local workgroup size configuration here. SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()}; @@ -138,7 +137,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader( void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc( gpu::LaunchFuncOp launchOp) { - ModuleOp module = getModule(); + ModuleOp module = getOperation(); OpBuilder builder(launchOp); Location loc = launchOp.getLoc(); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp index ebc8ded483ff..2daa13085bcb 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -58,7 +58,7 @@ namespace { /// * deinitVulkan -- deinitializes vulkan runtime /// class VulkanLaunchFuncToVulkanCallsPass - : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { + : public OperationPass<VulkanLaunchFuncToVulkanCallsPass, ModuleOp> { private: /// Include the generated pass utilities. #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls @@ -150,7 +150,7 @@ private: LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); public: - void runOnModule() override; + void runOnOperation() override; private: LLVM::LLVMDialect *llvmDialect; @@ -169,18 +169,18 @@ private: } // anonymous namespace -void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { +void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { initializeCachedTypes(); // Collect SPIR-V attributes such as `spirv_blob` and // `spirv_entry_point_name`. - getModule().walk([this](LLVM::CallOp op) { + getOperation().walk([this](LLVM::CallOp op) { if (isVulkanLaunchCallOp(op)) collectSPIRVAttributes(op); }); // Convert vulkan launch call op into a sequence of Vulkan runtime calls. - getModule().walk([this](LLVM::CallOp op) { + getOperation().walk([this](LLVM::CallOp op) { if (isCInterfaceVulkanLaunchCallOp(op)) translateVulkanLaunchCall(op); }); @@ -278,7 +278,7 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, } void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { - ModuleOp module = getModule(); + ModuleOp module = getOperation(); OpBuilder builder(module.getBody()->getTerminator()); if (!module.lookupSymbol(kSetEntryPoint)) { diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 07c8111941e4..99f106e29de7 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -561,17 +561,18 @@ void mlir::populateLinalgToLLVMConversionPatterns( } namespace { -struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> { +struct ConvertLinalgToLLVMPass + : public OperationPass<ConvertLinalgToLLVMPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_ConvertLinalgToLLVM #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void ConvertLinalgToLLVMPass::runOnModule() { - auto module = getModule(); +void ConvertLinalgToLLVMPass::runOnOperation() { + auto module = getOperation(); // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp index 0962746c486a..4b66063b88eb 100644 --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -16,18 +16,18 @@ using namespace mlir; namespace { /// A pass converting MLIR Linalg ops into SPIR-V ops. -class LinalgToSPIRVPass : public ModulePass<LinalgToSPIRVPass> { +class LinalgToSPIRVPass : public OperationPass<LinalgToSPIRVPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_ConvertLinalgToSPIRV #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void LinalgToSPIRVPass::runOnModule() { +void LinalgToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); - ModuleOp module = getModule(); + ModuleOp module = getOperation(); auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr<ConversionTarget> target = diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 1e127a0a884e..ef5dabf2ff88 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2847,7 +2847,7 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands, namespace { /// A pass converting MLIR operations into the LLVM IR dialect. -struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> { +struct LLVMLoweringPass : public OperationPass<LLVMLoweringPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_ConvertStandardToLLVM #include "mlir/Conversion/Passes.h.inc" @@ -2863,16 +2863,16 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> { LLVMLoweringPass(const LLVMLoweringPass &pass) {} /// Run the dialect converter on the module. - void runOnModule() override { + void runOnOperation() override { if (useBarePtrCallConv && emitCWrappers) { - getModule().emitError() + getOperation().emitError() << "incompatible conversion options: bare-pointer calling convention " "and C wrapper emission"; signalPassFailure(); return; } - ModuleOp m = getModule(); + ModuleOp m = getOperation(); LLVMTypeConverterCustomization customs; customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index ab7dd8546995..86c8cd17433c 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -22,18 +22,18 @@ using namespace mlir; namespace { /// A pass converting MLIR Standard operations into the SPIR-V dialect. class ConvertStandardToSPIRVPass - : public ModulePass<ConvertStandardToSPIRVPass> { + : public OperationPass<ConvertStandardToSPIRVPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_ConvertStandardToSPIRV #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void ConvertStandardToSPIRVPass::runOnModule() { +void ConvertStandardToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); - ModuleOp module = getModule(); + ModuleOp module = getOperation(); auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr<ConversionTarget> target = diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index d5a4f86d2ca9..b2a1c443f518 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1118,23 +1118,24 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns( } namespace { -struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { +struct LowerVectorToLLVMPass + : public OperationPass<LowerVectorToLLVMPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_ConvertVectorToLLVM #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void LowerVectorToLLVMPass::runOnModule() { +void LowerVectorToLLVMPass::runOnOperation() { // Perform progressive lowering of operations on slices and // all contraction operations. Also applies folding and DCE. { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getModule(), patterns); + applyPatternsGreedily(getOperation(), patterns); } // Convert to the LLVM IR dialect. @@ -1148,8 +1149,8 @@ void LowerVectorToLLVMPass::runOnModule() { LLVMConversionTarget target(getContext()); target.addDynamicallyLegalOp<FuncOp>( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed( - applyPartialConversion(getModule(), target, patterns, &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns, + &converter))) { signalPassFailure(); } } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 2eadf87f038a..daf9169d242c 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -214,16 +214,17 @@ namespace { /// The gpu.modules are intended to be compiled to a cubin blob independently in /// a separate pass. The external functions can then be annotated with the /// symbol of the cubin accessor function. -class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> { +class GpuKernelOutliningPass + : public OperationPass<GpuKernelOutliningPass, ModuleOp> { public: /// Include the generated pass utilities. #define GEN_PASS_GpuKernelOutlining #include "mlir/Dialect/GPU/Passes.h.inc" - void runOnModule() override { - SymbolTable symbolTable(getModule()); + void runOnOperation() override { + SymbolTable symbolTable(getOperation()); bool modified = false; - for (auto func : getModule().getOps<FuncOp>()) { + for (auto func : getOperation().getOps<FuncOp>()) { // Insert just after the function. Block::iterator insertPt(func.getOperation()->getNextNode()); auto funcWalkResult = func.walk([&](gpu::LaunchOp op) { @@ -255,8 +256,8 @@ public: // If any new module was inserted in this module, annotate this module as // a container module. if (modified) - getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(), - UnitAttr::get(&getContext())); + getOperation().setAttr(gpu::GPUDialect::getContainerModuleAttrName(), + UnitAttr::get(&getContext())); } private: @@ -267,7 +268,7 @@ private: // a SymbolTable by the caller. SymbolTable needs to be refactored to // prevent manual building of Ops with symbols in code using SymbolTables // and then this needs to use the OpBuilder. - auto context = getModule().getContext(); + auto context = getOperation().getContext(); Builder builder(context); OperationState state(kernelFunc.getLoc(), gpu::GPUModuleOp::getOperationName()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index 79ed81956f08..e4622741536e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -80,14 +80,14 @@ static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns, namespace { class DecorateSPIRVCompositeTypeLayoutPass - : public ModulePass<DecorateSPIRVCompositeTypeLayoutPass> { + : public OperationPass<DecorateSPIRVCompositeTypeLayoutPass, ModuleOp> { private: - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() { - auto module = getModule(); +void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() { + auto module = getOperation(); OwningRewritePatternList patterns; populateSPIRVLayoutInfoPatterns(patterns, module.getContext()); ConversionTarget target(*(module.getContext())); diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp index b7832f580dd4..2b519d697020 100644 --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -18,7 +18,7 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> { +struct PrintOpStatsPass : public OperationPass<PrintOpStatsPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_PrintOpStats #include "mlir/Transforms/Passes.h.inc" @@ -26,7 +26,7 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> { explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {} // Prints the resultant operation statistics post iterating over the module. - void runOnModule() override; + void runOnOperation() override; // Print summary of op stats. void printSummary(); @@ -37,11 +37,11 @@ private: }; } // namespace -void PrintOpStatsPass::runOnModule() { +void PrintOpStatsPass::runOnOperation() { opCount.clear(); // Compute the operation statistics for each function in the module. - for (auto &op : getModule()) + for (auto &op : getOperation()) op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); printSummary(); } diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index fcaff9a0b069..c5d921db059e 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -100,7 +100,7 @@ namespace { // PrintOpPass is simple pass to write graph per function. // Note: this is a module pass only to avoid interleaving on the same ostream // due to multi-threading over functions. -struct PrintOpPass : public ModulePass<PrintOpPass> { +struct PrintOpPass : public OperationPass<PrintOpPass, ModuleOp> { /// Include the generated pass utilities. #define GEN_PASS_PrintOpGraph #include "mlir/Transforms/Passes.h.inc" @@ -140,7 +140,7 @@ struct PrintOpPass : public ModulePass<PrintOpPass> { } } - void runOnModule() override { processModule(getModule()); } + void runOnOperation() override { processModule(getOperation()); } private: raw_ostream &os; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index e6cc52d29722..6ccfa04a8194 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -398,13 +398,13 @@ struct TestTypeConverter : public TypeConverter { }; struct TestLegalizePatternDriver - : public ModulePass<TestLegalizePatternDriver> { + : public OperationPass<TestLegalizePatternDriver, ModuleOp> { /// The mode of conversion to use with the driver. enum class ConversionMode { Analysis, Full, Partial }; TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} - void runOnModule() override { + void runOnOperation() override { TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); @@ -450,7 +450,8 @@ struct TestLegalizePatternDriver // Handle a partial conversion. if (mode == ConversionMode::Partial) { - (void)applyPartialConversion(getModule(), target, patterns, &converter); + (void)applyPartialConversion(getOperation(), target, patterns, + &converter); return; } @@ -461,7 +462,7 @@ struct TestLegalizePatternDriver return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); }); - (void)applyFullConversion(getModule(), target, patterns, &converter); + (void)applyFullConversion(getOperation(), target, patterns, &converter); return; } @@ -470,7 +471,7 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet<Operation *> legalizedOps; - if (failed(applyAnalysisConversion(getModule(), target, patterns, + if (failed(applyAnalysisConversion(getOperation(), target, patterns, legalizedOps, &converter))) return signalPassFailure(); diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp index 0e885c555e38..c1b90397ec44 100644 --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -13,9 +13,9 @@ using namespace mlir; namespace { /// This is a test pass for verifying FuncOp's eraseArgument method. -struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> { - void runOnModule() override { - auto module = getModule(); +struct TestFuncEraseArg : public OperationPass<TestFuncEraseArg, ModuleOp> { + void runOnOperation() override { + auto module = getOperation(); for (FuncOp func : module.getOps<FuncOp>()) { SmallVector<unsigned, 4> indicesToErase; @@ -36,9 +36,9 @@ struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> { }; /// This is a test pass for verifying FuncOp's setType method. -struct TestFuncSetType : public ModulePass<TestFuncSetType> { - void runOnModule() override { - auto module = getModule(); +struct TestFuncSetType : public OperationPass<TestFuncSetType, ModuleOp> { + void runOnOperation() override { + auto module = getOperation(); SymbolTable symbolTable(module); for (FuncOp func : module.getOps<FuncOp>()) { diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp index 9f52c42e4953..a99348537e25 100644 --- a/mlir/test/lib/IR/TestSideEffects.cpp +++ b/mlir/test/lib/IR/TestSideEffects.cpp @@ -12,9 +12,9 @@ using namespace mlir; namespace { -struct SideEffectsPass : public ModulePass<SideEffectsPass> { - void runOnModule() override { - auto module = getModule(); +struct SideEffectsPass : public OperationPass<SideEffectsPass, ModuleOp> { + void runOnOperation() override { + auto module = getOperation(); // Walk operations detecting side effects. SmallVector<MemoryEffects::EffectInstance, 8> effects; diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp index 6082cdcbe72b..c39615ef1352 100644 --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -15,7 +15,7 @@ using namespace mlir; namespace { /// This is a symbol test pass that tests the symbol uselist functionality /// provided by the symbol table along with erasing from the symbol table. -struct SymbolUsesPass : public ModulePass<SymbolUsesPass> { +struct SymbolUsesPass : public OperationPass<SymbolUsesPass, ModuleOp> { WalkResult operateOnSymbol(Operation *symbol, ModuleOp module, SmallVectorImpl<FuncOp> &deadFunctions) { // Test computing uses on a non symboltable op. @@ -59,8 +59,8 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> { return WalkResult::advance(); } - void runOnModule() override { - auto module = getModule(); + void runOnOperation() override { + auto module = getOperation(); // Walk nested symbols. SmallVector<FuncOp, 4> deadFunctions; @@ -86,9 +86,10 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> { /// This is a symbol test pass that tests the symbol use replacement /// functionality provided by the symbol table. -struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> { - void runOnModule() override { - auto module = getModule(); +struct SymbolReplacementPass + : public OperationPass<SymbolReplacementPass, ModuleOp> { + void runOnOperation() override { + auto module = getOperation(); // Walk nested functions and modules. module.getBodyRegion().walk([&](Operation *nestedOp) { diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index 95bef9b878e2..be8a7479200c 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -13,8 +13,8 @@ using namespace mlir; namespace { -struct TestModulePass : public ModulePass<TestModulePass> { - void runOnModule() final {} +struct TestModulePass : public OperationPass<TestModulePass, ModuleOp> { + void runOnOperation() final {} }; struct TestFunctionPass : public FunctionPass<TestFunctionPass> { void runOnFunction() final {} diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp index 508f70887350..6455dab70f45 100644 --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -18,11 +18,11 @@ using namespace mlir; namespace { struct TestAllReduceLoweringPass - : public ModulePass<TestAllReduceLoweringPass> { - void runOnModule() override { + : public OperationPass<TestAllReduceLoweringPass, ModuleOp> { + void runOnOperation() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); - applyPatternsGreedily(getModule(), patterns); + applyPatternsGreedily(getOperation(), patterns); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestCallGraph.cpp b/mlir/test/lib/Transforms/TestCallGraph.cpp index 89c25da9e8ed..a181d645f2af 100644 --- a/mlir/test/lib/Transforms/TestCallGraph.cpp +++ b/mlir/test/lib/Transforms/TestCallGraph.cpp @@ -17,9 +17,9 @@ using namespace mlir; namespace { -struct TestCallGraphPass : public ModulePass<TestCallGraphPass> { - void runOnModule() { - llvm::errs() << "Testing : " << getModule().getAttr("test.name") << "\n"; +struct TestCallGraphPass : public OperationPass<TestCallGraphPass, ModuleOp> { + void runOnOperation() override { + llvm::errs() << "Testing : " << getOperation().getAttr("test.name") << "\n"; getAnalysis<CallGraph>().print(llvm::errs()); } }; diff --git a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp index baae5297306d..47152c459805 100644 --- a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp +++ b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp @@ -17,7 +17,7 @@ namespace { /// It also takes all operations that are not function operations or /// terminators and clones them with opaque locations which store the initial /// locations. -struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> { +struct TestOpaqueLoc : public OperationPass<TestOpaqueLoc, ModuleOp> { /// A simple structure which is used for testing as an underlying location in /// OpaqueLoc. @@ -29,11 +29,11 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> { int id; }; - void runOnModule() override { + void runOnOperation() override { std::vector<std::unique_ptr<MyLocation>> myLocs; int last_it = 0; - getModule().walk([&](Operation *op) { + getOperation().walk([&](Operation *op) { myLocs.push_back(std::make_unique<MyLocation>(last_it++)); Location loc = op->getLoc(); @@ -74,7 +74,7 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> { os.flush(); }); - getModule().walk([&](Operation *op) { op->emitOpError(); }); + getOperation().walk([&](Operation *op) { op->emitOpError(); }); } }; |