diff options
author | River Riddle <riddleriver@gmail.com> | 2020-04-07 13:56:16 -0700 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2020-04-07 14:08:52 -0700 |
commit | 80aca1eaf778a58458833591e82b74647b5b7280 (patch) | |
tree | f38add6f74980acce56446dd2db7d238f3d0efa4 | |
parent | 722f909f7aa1d5ab21f68eb8ce1baf109cc5bb13 (diff) |
[mlir][Pass] Remove the use of CRTP from the Pass classes
This revision removes all of the CRTP from the pass hierarchy in preparation for using the tablegen backend instead. This creates a much cleaner interface in the C++ code, and naturally fits with the rest of the infrastructure. A new utility class, PassWrapper, is added to replicate the existing behavior for passes not suitable for using the tablegen backend.
Differential Revision: https://reviews.llvm.org/D77350
114 files changed, 405 insertions, 342 deletions
diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md index c9a5d8188f9d..dcdf8d83840d 100644 --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -318,7 +318,8 @@ Implementing such a pass is done by creating a class inheriting from `mlir::FunctionPass` and overriding the `runOnFunction()` method. ```c++ -class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { +class ShapeInferencePass + : public mlir::PassWrapper<ShapeInferencePass, FunctionPass> { void runOnFunction() override { FuncOp function = getFunction(); ... diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 296bec094624..f9f32c22f8d4 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -44,7 +44,8 @@ namespace { /// d) infer the shape of its output from the argument types. /// 3) If the worklist is empty, the algorithm succeeded. /// -class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { +class ShapeInferencePass + : public mlir::PassWrapper<ShapeInferencePass, FunctionPass> { public: void runOnFunction() override { auto f = getFunction(); diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index 0614f3ac043b..0988f5fe0c41 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -260,7 +260,8 @@ struct TransposeOpLowering : public ConversionPattern { /// computationally intensive (like matmul for example...) while keeping the /// rest of the code in the Toy dialect. namespace { -struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> { +struct ToyToAffineLoweringPass + : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> { void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 296bec094624..f9f32c22f8d4 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -44,7 +44,8 @@ namespace { /// d) infer the shape of its output from the argument types. /// 3) If the worklist is empty, the algorithm succeeded. /// -class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { +class ShapeInferencePass + : public mlir::PassWrapper<ShapeInferencePass, FunctionPass> { public: void runOnFunction() override { auto f = getFunction(); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index 4292d14ec3ed..56629d7ae217 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -259,7 +259,8 @@ struct TransposeOpLowering : public ConversionPattern { /// computationally intensive (like matmul for example...) while keeping the /// rest of the code in the Toy dialect. namespace { -struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> { +struct ToyToAffineLoweringPass + : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> { void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 99465d3201e5..9c36f11c52c4 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -154,7 +154,7 @@ private: namespace { struct ToyToLLVMLoweringPass - : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> { + : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> { void runOnOperation() final; }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp index 296bec094624..f9f32c22f8d4 100644 --- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp @@ -44,7 +44,8 @@ namespace { /// d) infer the shape of its output from the argument types. /// 3) If the worklist is empty, the algorithm succeeded. /// -class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { +class ShapeInferencePass + : public mlir::PassWrapper<ShapeInferencePass, FunctionPass> { public: void runOnFunction() override { auto f = getFunction(); diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index 0614f3ac043b..0988f5fe0c41 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -260,7 +260,8 @@ struct TransposeOpLowering : public ConversionPattern { /// computationally intensive (like matmul for example...) while keeping the /// rest of the code in the Toy dialect. namespace { -struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> { +struct ToyToAffineLoweringPass + : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> { void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 99465d3201e5..9c36f11c52c4 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -154,7 +154,7 @@ private: namespace { struct ToyToLLVMLoweringPass - : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> { + : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> { void runOnOperation() final; }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp index 296bec094624..f9f32c22f8d4 100644 --- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp @@ -44,7 +44,8 @@ namespace { /// d) infer the shape of its output from the argument types. /// 3) If the worklist is empty, the algorithm succeeded. /// -class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { +class ShapeInferencePass + : public mlir::PassWrapper<ShapeInferencePass, FunctionPass> { public: void runOnFunction() override { auto f = getFunction(); diff --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h index bd65970d5bf7..fdd203a6f6ef 100644 --- a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h +++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h @@ -14,7 +14,7 @@ namespace mlir { class LLVMTypeConverter; class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; class OwningRewritePatternList; /// Collect a set of patterns to convert from the AVX512 dialect to LLVM. @@ -22,7 +22,7 @@ void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Create a pass to convert AVX512 operations to the LLVMIR dialect. -std::unique_ptr<OpPassBase<ModuleOp>> createConvertAVX512ToLLVMPass(); +std::unique_ptr<OperationPass<ModuleOp>> createConvertAVX512ToLLVMPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h index 35794bfab6dd..6e21483c6728 100644 --- a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h +++ b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h @@ -19,7 +19,7 @@ namespace mlir { class Location; class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; namespace gpu { class GPUModuleOp; @@ -42,7 +42,7 @@ using CubinGenerator = /// attached as a string attribute named 'nvvm.cubin' to the kernel function. /// After the transformation, the body of the kernel function is removed (i.e., /// it is turned into a declaration). -std::unique_ptr<OpPassBase<gpu::GPUModuleOp>> +std::unique_ptr<OperationPass<gpu::GPUModuleOp>> createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator); /// Creates a pass to convert a gpu.launch_func operation into a sequence of @@ -51,7 +51,7 @@ createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator); /// This pass does not generate code to call CUDA directly but instead uses a /// small wrapper library that exports a stable and conveniently typed ABI /// on top of CUDA. -std::unique_ptr<OpPassBase<ModuleOp>> +std::unique_ptr<OperationPass<ModuleOp>> createConvertGpuLaunchFuncToCudaCallsPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 4a6698cfb50c..5dbfce9bd00f 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -14,7 +14,7 @@ namespace mlir { class LLVMTypeConverter; class OwningRewritePatternList; -template <typename OpT> class OpPassBase; +template <typename OpT> class OperationPass; namespace gpu { class GPUModuleOp; @@ -25,7 +25,8 @@ void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Creates a pass that lowers GPU dialect operations to NVVM counterparts. -std::unique_ptr<OpPassBase<gpu::GPUModuleOp>> createLowerGpuOpsToNVVMOpsPass(); +std::unique_ptr<OperationPass<gpu::GPUModuleOp>> +createLowerGpuOpsToNVVMOpsPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h index 64fa40c02622..341526fc9964 100644 --- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -15,10 +15,11 @@ namespace mlir { namespace gpu { class GPUModuleOp; } // namespace gpu -template <typename OpT> class OpPassBase; +template <typename OpT> class OperationPass; /// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. -std::unique_ptr<OpPassBase<gpu::GPUModuleOp>> createLowerGpuOpsToROCDLOpsPass(); +std::unique_ptr<OperationPass<gpu::GPUModuleOp>> +createLowerGpuOpsToROCDLOpsPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h index cf3246a55114..f22db8477d84 100644 --- a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h @@ -20,11 +20,11 @@ namespace mlir { class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; /// Pass to convert GPU Ops to SPIR-V ops. For a gpu.func to be converted, it /// should have a spv.entry_point_abi attribute. -std::unique_ptr<OpPassBase<ModuleOp>> createConvertGPUToSPIRVPass(); +std::unique_ptr<OperationPass<ModuleOp>> createConvertGPUToSPIRVPass(); } // namespace mlir #endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h index 9a02860bfc1a..92e3f80b2be3 100644 --- a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h +++ b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h @@ -21,12 +21,12 @@ namespace mlir { class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; -std::unique_ptr<OpPassBase<ModuleOp>> +std::unique_ptr<OperationPass<ModuleOp>> createConvertVulkanLaunchFuncToVulkanCallsPass(); -std::unique_ptr<OpPassBase<mlir::ModuleOp>> +std::unique_ptr<OperationPass<mlir::ModuleOp>> createConvertGpuLaunchFuncToVulkanLaunchFuncPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h index 4124f3f0e3b2..b7c9d0016d65 100644 --- a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h +++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h @@ -14,7 +14,7 @@ namespace mlir { class MLIRContext; class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; /// Populate the given list with patterns that convert from Linalg to LLVM. void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, @@ -22,7 +22,7 @@ void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, MLIRContext *ctx); /// Create a pass to convert Linalg operations to the LLVMIR dialect. -std::unique_ptr<OpPassBase<ModuleOp>> createConvertLinalgToLLVMPass(); +std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToLLVMPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h index 4ed2bddac575..9b10b4b705d5 100644 --- a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h @@ -18,7 +18,7 @@ namespace mlir { /// Creates and returns a pass to convert Linalg ops to SPIR-V ops. -std::unique_ptr<OpPassBase<ModuleOp>> createLinalgToSPIRVPass(); +std::unique_ptr<OperationPass<ModuleOp>> createLinalgToSPIRVPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h index 049e8538d746..4488b071ea43 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -14,7 +14,7 @@ namespace mlir { class FuncOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; class Pass; /// Create a pass that converts loop nests into GPU kernels. It considers @@ -25,9 +25,9 @@ class Pass; /// parallelization is performed, it is under the responsibility of the caller /// to strip-mine the loops and to perform the dependence analysis before /// calling the conversion. -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims); -std::unique_ptr<OpPassBase<FuncOp>> createSimpleLoopsToGPUPass(); +std::unique_ptr<OperationPass<FuncOp>> createSimpleLoopsToGPUPass(); /// Create a pass that converts every loop operation within the body of the /// FuncOp into a GPU launch. The number of workgroups and workgroup size for @@ -35,10 +35,10 @@ std::unique_ptr<OpPassBase<FuncOp>> createSimpleLoopsToGPUPass(); /// method. For testing, the values are set as constants obtained from a command /// line flag. See convertLoopToGPULaunch for a description of the required /// semantics of the converted loop operation. -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups, ArrayRef<int64_t> workGroupSize); -std::unique_ptr<OpPassBase<FuncOp>> createLoopToGPUPass(); +std::unique_ptr<OperationPass<FuncOp>> createLoopToGPUPass(); /// Creates a pass that converts loop.parallel operations into a gpu.launch /// operation. The mapping of loop dimensions to launch dimensions is derived diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index c4aab9f867f6..72f852e57187 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -14,7 +14,7 @@ namespace mlir { class LLVMTypeConverter; class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; class OwningRewritePatternList; /// Collect a set of patterns to convert memory-related operations from the @@ -61,7 +61,7 @@ struct LowerToLLVMOptions { /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// stdlib malloc/free is used for allocating memrefs allocated with std.alloc, /// while LLVM's alloca is used for those allocated with std.alloca. -std::unique_ptr<OpPassBase<ModuleOp>> createLowerToLLVMPass( +std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass( const LowerToLLVMOptions &options = { /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false, /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout}); diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h index e5436d4203f0..29e1d635a00f 100644 --- a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h @@ -18,7 +18,7 @@ namespace mlir { /// Pass to convert StandardOps to SPIR-V ops. -std::unique_ptr<OpPassBase<ModuleOp>> createConvertStandardToSPIRVPass(); +std::unique_ptr<OperationPass<ModuleOp>> createConvertStandardToSPIRVPass(); /// Pass to legalize ops that are not directly lowered to SPIR-V. std::unique_ptr<Pass> createLegalizeStdOpsForSPIRVLoweringPass(); diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h index a92906d0c2c3..0ef6df5c34b3 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -13,7 +13,7 @@ namespace mlir { class LLVMTypeConverter; class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix /// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics @@ -26,7 +26,7 @@ void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Create a pass to convert vector operations to the LLVMIR dialect. -std::unique_ptr<OpPassBase<ModuleOp>> createConvertVectorToLLVMPass(); +std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index 75ff4a33649d..5d3b997b9b58 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -24,32 +24,33 @@ class AffineForOp; class FuncOp; class ModuleOp; class Pass; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; /// Creates a simplification pass for affine structures (maps and sets). In /// addition, this pass also normalizes memrefs to have the trivial (identity) /// layout map. -std::unique_ptr<OpPassBase<FuncOp>> createSimplifyAffineStructuresPass(); +std::unique_ptr<OperationPass<FuncOp>> createSimplifyAffineStructuresPass(); /// Creates a loop invariant code motion pass that hoists loop invariant /// operations out of affine loops. -std::unique_ptr<OpPassBase<FuncOp>> createAffineLoopInvariantCodeMotionPass(); +std::unique_ptr<OperationPass<FuncOp>> +createAffineLoopInvariantCodeMotionPass(); /// Performs packing (or explicit copying) of accessed memref regions into /// buffers in the specified faster memory space through either pointwise copies /// or DMA operations. -std::unique_ptr<OpPassBase<FuncOp>> createAffineDataCopyGenerationPass( +std::unique_ptr<OperationPass<FuncOp>> createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = std::numeric_limits<uint64_t>::max()); /// Overload relying on pass options for initialization. -std::unique_ptr<OpPassBase<FuncOp>> createAffineDataCopyGenerationPass(); +std::unique_ptr<OperationPass<FuncOp>> createAffineDataCopyGenerationPass(); /// Creates a pass to perform tiling on loop nests. -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createLoopTilingPass(uint64_t cacheSizeBytes); /// Overload relying on pass options for initialization. -std::unique_ptr<OpPassBase<FuncOp>> createLoopTilingPass(); +std::unique_ptr<OperationPass<FuncOp>> createLoopTilingPass(); /// Creates a loop unrolling pass with the provided parameters. /// 'getUnrollFactor' is a function callback for clients to supply a function @@ -57,22 +58,22 @@ std::unique_ptr<OpPassBase<FuncOp>> createLoopTilingPass(); /// factors supplied through other means. If -1 is passed as the unrollFactor /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). -std::unique_ptr<OpPassBase<FuncOp>> createLoopUnrollPass( +std::unique_ptr<OperationPass<FuncOp>> createLoopUnrollPass( int unrollFactor = -1, int unrollFull = -1, const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr); /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command /// line if provided. -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createLoopUnrollAndJamPass(int unrollJamFactor = -1); /// Creates a pass to vectorize loops, operations and data types using a /// target-independent, n-D super-vector abstraction. -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize); /// Overload relying on pass options for initialization. -std::unique_ptr<OpPassBase<FuncOp>> createSuperVectorizePass(); +std::unique_ptr<OperationPass<FuncOp>> createSuperVectorizePass(); } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h index 51536c4dc411..bc349061f39f 100644 --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -19,10 +19,10 @@ namespace mlir { class MLIRContext; class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; class OwningRewritePatternList; -std::unique_ptr<OpPassBase<ModuleOp>> createGpuKernelOutliningPass(); +std::unique_ptr<OperationPass<ModuleOp>> createGpuKernelOutliningPass(); /// Collect a set of patterns to rewrite ops within the GPU dialect. void populateGpuRewritePatterns(MLIRContext *context, diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 9f52e360c7fb..d8886acc5992 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -19,34 +19,34 @@ namespace mlir { class FuncOp; class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; class Pass; -std::unique_ptr<OpPassBase<FuncOp>> createLinalgFusionPass(); +std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass(); std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass(); -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {}); -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes = {}); -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass(bool dynamicBuffers); -std::unique_ptr<OpPassBase<FuncOp>> createLinalgPromotionPass(); +std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass(); /// Create a pass to convert Linalg operations to loop.for loops and /// std.load/std.store accesses. -std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToLoopsPass(); +std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToLoopsPass(); /// Create a pass to convert Linalg operations to loop.parallel loops and /// std.load/std.store accesses. -std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToParallelLoopsPass(); +std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToParallelLoopsPass(); /// Create a pass to convert Linalg operations to affine.for loops and /// affine_load/affine_store accesses. /// Placeholder for now, this is NYI. -std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToAffineLoopsPass(); +std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass(); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Quant/Passes.h b/mlir/include/mlir/Dialect/Quant/Passes.h index 34c7c1f011a7..b938c9a86b72 100644 --- a/mlir/include/mlir/Dialect/Quant/Passes.h +++ b/mlir/include/mlir/Dialect/Quant/Passes.h @@ -20,20 +20,20 @@ namespace mlir { class FuncOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; namespace quant { /// Creates a pass that converts quantization simulation operations (i.e. /// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. -std::unique_ptr<OpPassBase<FuncOp>> createConvertSimulatedQuantPass(); +std::unique_ptr<OperationPass<FuncOp>> createConvertSimulatedQuantPass(); /// Creates a pass that converts constants followed by a qbarrier to a /// constant whose value is quantized. This is typically one of the last /// passes done when lowering to express actual quantized arithmetic in a /// low level representation. Because it modifies the constant, it is /// destructive and cannot be undone. -std::unique_ptr<OpPassBase<FuncOp>> createConvertConstPass(); +std::unique_ptr<OperationPass<FuncOp>> createConvertConstPass(); } // namespace quant } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h index fc13460b797b..afc60805f75e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -23,7 +23,7 @@ class ModuleOp; /// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage /// classes with layout information. /// Right now this pass only supports Vulkan layout rules. -std::unique_ptr<OpPassBase<mlir::ModuleOp>> +std::unique_ptr<OperationPass<mlir::ModuleOp>> createDecorateSPIRVCompositeTypeLayoutPass(); /// Creates an operation pass that deduces and attaches the minimal version/ @@ -34,7 +34,7 @@ createDecorateSPIRVCompositeTypeLayoutPass(); /// to know which one to pick. `spv.target_env` gives the hard limit as for /// what the target environment can support; this pass deduces what are /// actually needed for a specific spv.module op. -std::unique_ptr<OpPassBase<spirv::ModuleOp>> +std::unique_ptr<OperationPass<spirv::ModuleOp>> createUpdateVersionCapabilityExtensionPass(); /// Creates an operation pass that lowers the ABI attributes specified during @@ -44,7 +44,7 @@ createUpdateVersionCapabilityExtensionPass(); /// argument. /// 2. Inserts the EntryPointOp and the ExecutionModeOp for entry point /// functions using the specification in the `spv.entry_point_abi` attribute. -std::unique_ptr<OpPassBase<spirv::ModuleOp>> createLowerABIAttributesPass(); +std::unique_ptr<OperationPass<spirv::ModuleOp>> createLowerABIAttributesPass(); } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index c1eec4f4706a..6d6226f83335 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -139,161 +139,119 @@ protected: virtual void runOnOperation() = 0; /// A clone method to create a copy of this pass. - virtual std::unique_ptr<Pass> clone() const = 0; + std::unique_ptr<Pass> clone() const { + auto newInst = clonePass(); + newInst->copyOptionValuesFrom(this); + return newInst; + } /// Return the current operation being transformed. Operation *getOperation() { return getPassState().irAndPassFailed.getPointer(); } - /// Returns the current analysis manager. - AnalysisManager getAnalysisManager() { - return getPassState().analysisManager; - } - - /// Copy the option values from 'other', which is another instance of this - /// pass. - void copyOptionValuesFrom(const Pass *other); - -private: - /// Forwarding function to execute this pass on the given operation. - LLVM_NODISCARD - LogicalResult run(Operation *op, AnalysisManager am); - - /// Out of line virtual method to ensure vtables and metadata are emitted to a - /// single .o file. - virtual void anchor(); - - /// Represents a unique identifier for the pass. - const PassID *passID; - - /// The name of the operation that this pass operates on, or None if this is a - /// generic OperationPass. - Optional<StringRef> opName; - - /// The current execution state for the pass. - Optional<detail::PassExecutionState> passState; - - /// The set of statistics held by this pass. - std::vector<Statistic *> statistics; - - /// The pass options registered to this pass instance. - detail::PassOptions passOptions; - - /// Allow access to 'clone' and 'run'. - friend class OpPassManager; - - /// Allow access to 'passOptions'. - friend class PassInfo; -}; - -//===----------------------------------------------------------------------===// -// Pass Model Definitions -//===----------------------------------------------------------------------===// -namespace detail { -/// The opaque CRTP model of a pass. This class provides utilities for derived -/// pass execution and handles all of the necessary polymorphic API. -template <typename PassT, typename BasePassT> -class PassModel : public BasePassT { -public: - /// Support isa/dyn_cast functionality for the derived pass class. - static bool classof(const Pass *pass) { - return pass->getPassID() == PassID::getID<PassT>(); - } - -protected: - explicit PassModel(Optional<StringRef> opName = llvm::None) - : BasePassT(PassID::getID<PassT>(), opName) {} - /// Signal that some invariant was broken when running. The IR is allowed to /// be in an invalid state. - void signalPassFailure() { - this->getPassState().irAndPassFailed.setInt(true); - } + void signalPassFailure() { getPassState().irAndPassFailed.setInt(true); } /// Query an analysis for the current ir unit. template <typename AnalysisT> AnalysisT &getAnalysis() { - return this->getAnalysisManager().template getAnalysis<AnalysisT>(); + return getAnalysisManager().getAnalysis<AnalysisT>(); } /// Query a cached instance of an analysis for the current ir unit if one /// exists. template <typename AnalysisT> Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() { - return this->getAnalysisManager().template getCachedAnalysis<AnalysisT>(); + return getAnalysisManager().getCachedAnalysis<AnalysisT>(); } /// Mark all analyses as preserved. void markAllAnalysesPreserved() { - this->getPassState().preservedAnalyses.preserveAll(); + getPassState().preservedAnalyses.preserveAll(); } /// Mark the provided analyses as preserved. template <typename... AnalysesT> void markAnalysesPreserved() { - this->getPassState().preservedAnalyses.template preserve<AnalysesT...>(); + getPassState().preservedAnalyses.preserve<AnalysesT...>(); } void markAnalysesPreserved(const AnalysisID *id) { - this->getPassState().preservedAnalyses.preserve(id); - } - - /// Returns the derived pass name. - StringRef getName() override { - StringRef name = llvm::getTypeName<PassT>(); - if (!name.consume_front("mlir::")) - name.consume_front("(anonymous namespace)::"); - return name; - } - - /// A clone method to create a copy of this pass. - std::unique_ptr<Pass> clone() const override { - auto newInst = std::make_unique<PassT>(*static_cast<const PassT *>(this)); - newInst->copyOptionValuesFrom(this); - return newInst; + getPassState().preservedAnalyses.preserve(id); } /// Returns the analysis for the parent operation if it exists. template <typename AnalysisT> Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis(Operation *parent) { - return this->getAnalysisManager() - .template getCachedParentAnalysis<AnalysisT>(parent); + return getAnalysisManager().getCachedParentAnalysis<AnalysisT>(parent); } template <typename AnalysisT> Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() { - return this->getAnalysisManager() - .template getCachedParentAnalysis<AnalysisT>( - this->getOperation()->getParentOp()); + return getAnalysisManager().getCachedParentAnalysis<AnalysisT>( + getOperation()->getParentOp()); } /// Returns the analysis for the given child operation if it exists. template <typename AnalysisT> Optional<std::reference_wrapper<AnalysisT>> getCachedChildAnalysis(Operation *child) { - return this->getAnalysisManager() - .template getCachedChildAnalysis<AnalysisT>(child); + return getAnalysisManager().getCachedChildAnalysis<AnalysisT>(child); } /// Returns the analysis for the given child operation, or creates it if it /// doesn't exist. template <typename AnalysisT> AnalysisT &getChildAnalysis(Operation *child) { - return this->getAnalysisManager().template getChildAnalysis<AnalysisT>( - child); + return getAnalysisManager().getChildAnalysis<AnalysisT>(child); } -}; -} // end namespace detail -/// Utility base class for OpPass below to denote an opaque pass operating on a -/// specific operation type. -template <typename OpT> class OpPassBase : public Pass { -public: - using Pass::Pass; - - /// Support isa/dyn_cast functionality. - static bool classof(const Pass *pass) { - return pass->getOpName() == OpT::getOperationName(); + /// Returns the current analysis manager. + AnalysisManager getAnalysisManager() { + return getPassState().analysisManager; } + + /// Create a copy of this pass, ignoring statistics and options. + virtual std::unique_ptr<Pass> clonePass() const = 0; + + /// Copy the option values from 'other', which is another instance of this + /// pass. + void copyOptionValuesFrom(const Pass *other); + +private: + /// Forwarding function to execute this pass on the given operation. + LLVM_NODISCARD + LogicalResult run(Operation *op, AnalysisManager am); + + /// Out of line virtual method to ensure vtables and metadata are emitted to a + /// single .o file. + virtual void anchor(); + + /// Represents a unique identifier for the pass. + const PassID *passID; + + /// The name of the operation that this pass operates on, or None if this is a + /// generic OperationPass. + Optional<StringRef> opName; + + /// The current execution state for the pass. + Optional<detail::PassExecutionState> passState; + + /// The set of statistics held by this pass. + std::vector<Statistic *> statistics; + + /// The pass options registered to this pass instance. + detail::PassOptions passOptions; + + /// Allow access to 'clone' and 'run'. + friend class OpPassManager; + + /// Allow access to 'passOptions'. + friend class PassInfo; }; +//===----------------------------------------------------------------------===// +// Pass Model Definitions +//===----------------------------------------------------------------------===// + /// Pass to transform an operation of a specific type. /// /// Operation passes must not: @@ -304,11 +262,16 @@ public: /// /// Derived function passes are expected to provide the following: /// - A 'void runOnOperation()' method. -template <typename PassT, typename OpT = void> -class OperationPass : public detail::PassModel<PassT, OpPassBase<OpT>> { +/// - A 'StringRef getName() const' method. +/// - A 'std::unique_ptr<Pass> clonePass() const' method. +template <typename OpT = void> class OperationPass : public Pass { protected: - OperationPass() - : detail::PassModel<PassT, OpPassBase<OpT>>(OpT::getOperationName()) {} + OperationPass(const PassID *passID) : Pass(passID, OpT::getOperationName()) {} + + /// Support isa/dyn_cast functionality. + static bool classof(const Pass *pass) { + return pass->getOpName() == OpT::getOperationName(); + } /// Return the current operation being transformed. OpT getOperation() { return cast<OpT>(Pass::getOperation()); } @@ -324,14 +287,23 @@ protected: /// /// Derived function passes are expected to provide the following: /// - A 'void runOnOperation()' method. -template <typename PassT> -struct OperationPass<PassT, void> : public detail::PassModel<PassT, Pass> {}; +/// - A 'StringRef getName() const' method. +/// - A 'std::unique_ptr<Pass> clonePass() const' method. +template <> class OperationPass<void> : public Pass { +protected: + OperationPass(const PassID *passID) : Pass(passID) {} +}; /// A model for providing function pass specific utilities. /// /// Derived function passes are expected to provide the following: /// - A 'void runOnFunction()' method. -template <typename T> struct FunctionPass : public OperationPass<T, FuncOp> { +/// - A 'StringRef getName() const' method. +/// - A 'std::unique_ptr<Pass> clonePass() const' method. +class FunctionPass : public OperationPass<FuncOp> { +public: + using OperationPass<FuncOp>::OperationPass; + /// The polymorphic API that runs the pass over the currently held function. virtual void runOnFunction() = 0; @@ -344,6 +316,35 @@ template <typename T> struct FunctionPass : public OperationPass<T, FuncOp> { /// Return the current function being transformed. FuncOp getFunction() { return this->getOperation(); } }; + +/// This class provides a CRTP wrapper around a base pass class to define +/// several necessary utility methods. This should only be used for passes that +/// are not suitably represented using the declarative pass specification(i.e. +/// tablegen backend). +template <typename PassT, typename BaseT> class PassWrapper : public BaseT { +public: + /// Support isa/dyn_cast functionality for the derived pass class. + static bool classof(const Pass *pass) { + return pass->getPassID() == PassID::getID<PassT>(); + } + +protected: + PassWrapper() : BaseT(PassID::getID<PassT>()) {} + + /// Returns the derived pass name. + StringRef getName() override { + StringRef name = llvm::getTypeName<PassT>(); + if (!name.consume_front("mlir::")) + name.consume_front("(anonymous namespace)::"); + return name; + } + + /// A clone method to create a copy of this pass. + std::unique_ptr<Pass> clonePass() const override { + return std::make_unique<PassT>(*static_cast<const PassT *>(this)); + } +}; + } // end namespace mlir #endif // MLIR_PASS_PASS_H diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index d9aee6e1f2c2..6f5601f3006b 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -24,7 +24,7 @@ class AffineForOp; class FuncOp; class ModuleOp; class Pass; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; /// Creates an instance of the Canonicalizer pass. std::unique_ptr<Pass> createCanonicalizerPass(); @@ -35,7 +35,7 @@ std::unique_ptr<Pass> createCSEPass(); /// Creates a loop fusion pass which fuses loops. Buffers of size less than or /// equal to `localBufSizeThreshold` are promoted to memory space /// `fastMemorySpace'. -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> createLoopFusionPass(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, bool maximalFusion = false); @@ -46,16 +46,16 @@ std::unique_ptr<Pass> createLoopInvariantCodeMotionPass(); /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -std::unique_ptr<OpPassBase<FuncOp>> createPipelineDataTransferPass(); +std::unique_ptr<OperationPass<FuncOp>> createPipelineDataTransferPass(); /// Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp) /// to equivalent lower-level constructs (flow of basic blocks and arithmetic /// primitives). -std::unique_ptr<OpPassBase<FuncOp>> createLowerAffinePass(); +std::unique_ptr<OperationPass<FuncOp>> createLowerAffinePass(); /// Creates a pass that transforms perfectly nested loops with independent /// bounds into a single loop. -std::unique_ptr<OpPassBase<FuncOp>> createLoopCoalescingPass(); +std::unique_ptr<OperationPass<FuncOp>> createLoopCoalescingPass(); /// Creates a pass that transforms a single ParallelLoop over N induction /// variables into another ParallelLoop over less than N induction variables. @@ -63,14 +63,14 @@ std::unique_ptr<Pass> createParallelLoopCollapsingPass(); /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -std::unique_ptr<OpPassBase<FuncOp>> createMemRefDataFlowOptPass(); +std::unique_ptr<OperationPass<FuncOp>> createMemRefDataFlowOptPass(); /// Creates a pass to strip debug information from a function. std::unique_ptr<Pass> createStripDebugInfoPass(); /// Creates a pass which prints the list of ops and the number of occurrences in /// the module. -std::unique_ptr<OpPassBase<ModuleOp>> createPrintOpStatsPass(); +std::unique_ptr<OperationPass<ModuleOp>> createPrintOpStatsPass(); /// Creates a pass which inlines calls and callable operations as defined by /// the CallGraph. diff --git a/mlir/include/mlir/Transforms/ViewOpGraph.h b/mlir/include/mlir/Transforms/ViewOpGraph.h index 7b4a7a4c4ecd..61f40358fec2 100644 --- a/mlir/include/mlir/Transforms/ViewOpGraph.h +++ b/mlir/include/mlir/Transforms/ViewOpGraph.h @@ -20,7 +20,7 @@ namespace mlir { class Block; class ModuleOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; /// Displays the graph in a window. This is for use from the debugger and /// depends on Graphviz to generate the graph. @@ -32,7 +32,7 @@ raw_ostream &writeGraph(raw_ostream &os, Block &block, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print op graphs. -std::unique_ptr<OpPassBase<ModuleOp>> +std::unique_ptr<OperationPass<ModuleOp>> createPrintOpGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false, const Twine &title = ""); diff --git a/mlir/include/mlir/Transforms/ViewRegionGraph.h b/mlir/include/mlir/Transforms/ViewRegionGraph.h index e8233f022ec1..950f4c349bbf 100644 --- a/mlir/include/mlir/Transforms/ViewRegionGraph.h +++ b/mlir/include/mlir/Transforms/ViewRegionGraph.h @@ -19,7 +19,7 @@ namespace mlir { class FuncOp; -template <typename T> class OpPassBase; +template <typename T> class OperationPass; class Region; /// Displays the CFG in a window. This is for use from the debugger and @@ -32,7 +32,7 @@ raw_ostream &writeGraph(raw_ostream &os, Region ®ion, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print CFG graphs. -std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>> +std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createPrintCFGGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false, const Twine &title = ""); diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp index 08b187fc835e..6117100fec98 100644 --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -164,7 +164,7 @@ void mlir::populateAVX512ToLLVMConversionPatterns( namespace { struct ConvertAVX512ToLLVMPass - : public OperationPass<ConvertAVX512ToLLVMPass, ModuleOp> { + : public PassWrapper<ConvertAVX512ToLLVMPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertAVX512ToLLVM #include "mlir/Conversion/Passes.h.inc" @@ -193,6 +193,6 @@ void ConvertAVX512ToLLVMPass::runOnOperation() { } } -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertAVX512ToLLVMPass() { +std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAVX512ToLLVMPass() { return std::make_unique<ConvertAVX512ToLLVMPass>(); } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 56928da2c633..c45444948e64 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -577,7 +577,7 @@ void mlir::populateAffineToStdConversionPatterns( } namespace { -class LowerAffinePass : public FunctionPass<LowerAffinePass> { +class LowerAffinePass : public PassWrapper<LowerAffinePass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_ConvertAffineToStandard #include "mlir/Conversion/Passes.h.inc" @@ -595,6 +595,6 @@ class LowerAffinePass : public FunctionPass<LowerAffinePass> { /// Lowers If and For operations within a function into their lower level CFG /// equivalent blocks. -std::unique_ptr<OpPassBase<FuncOp>> mlir::createLowerAffinePass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createLowerAffinePass() { return std::make_unique<LowerAffinePass>(); } diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 1640978b3a18..38820f174d98 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -47,7 +47,8 @@ static constexpr const char *kCubinAnnotation = "nvvm.cubin"; /// GPU binary code, which is then attached as an attribute to the function. The /// function body is erased. class GpuKernelToCubinPass - : public OperationPass<GpuKernelToCubinPass, gpu::GPUModuleOp> { + : public PassWrapper<GpuKernelToCubinPass, + OperationPass<gpu::GPUModuleOp>> { public: GpuKernelToCubinPass(CubinGenerator cubinGenerator) : cubinGenerator(cubinGenerator) {} @@ -143,7 +144,7 @@ StringAttr GpuKernelToCubinPass::translateGPUModuleToCubinAnnotation( return StringAttr::get({cubin->data(), cubin->size()}, loc->getContext()); } -std::unique_ptr<OpPassBase<gpu::GPUModuleOp>> +std::unique_ptr<OperationPass<gpu::GPUModuleOp>> mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) { return std::make_unique<GpuKernelToCubinPass>(cubinGenerator); } diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index 71fe129d3875..9a71ef56e309 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -61,7 +61,8 @@ namespace { /// /// Intermediate data structures are allocated on the stack. class GpuLaunchFuncToCudaCallsPass - : public OperationPass<GpuLaunchFuncToCudaCallsPass, ModuleOp> { + : public PassWrapper<GpuLaunchFuncToCudaCallsPass, + OperationPass<ModuleOp>> { private: /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuLaunchFuncToCudaCalls @@ -464,7 +465,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( launchOp.erase(); } -std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> +std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> mlir::createConvertGpuLaunchFuncToCudaCallsPass() { return std::make_unique<GpuLaunchFuncToCudaCallsPass>(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index d8297c76ba9a..f4161a77c6c1 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -246,7 +246,8 @@ struct GPUReturnOpLowering : public ConvertToLLVMPattern { /// This pass only handles device code and is not meant to be run on GPU host /// code. class LowerGpuOpsToNVVMOpsPass - : public OperationPass<LowerGpuOpsToNVVMOpsPass, gpu::GPUModuleOp> { + : public PassWrapper<LowerGpuOpsToNVVMOpsPass, + OperationPass<gpu::GPUModuleOp>> { public: /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuOpsToNVVMOps @@ -324,7 +325,7 @@ void mlir::populateGpuToNVVMConversionPatterns( "__nv_tanh"); } -std::unique_ptr<OpPassBase<gpu::GPUModuleOp>> +std::unique_ptr<OperationPass<gpu::GPUModuleOp>> mlir::createLowerGpuOpsToNVVMOpsPass() { return std::make_unique<LowerGpuOpsToNVVMOpsPass>(); } diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index b89a1704e9f5..36e1b85fd39c 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -32,7 +32,8 @@ namespace { // This pass only handles device code and is not meant to be run on GPU host // code. class LowerGpuOpsToROCDLOpsPass - : public OperationPass<LowerGpuOpsToROCDLOpsPass, gpu::GPUModuleOp> { + : public PassWrapper<LowerGpuOpsToROCDLOpsPass, + OperationPass<gpu::GPUModuleOp>> { public: /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuOpsToROCDLOps @@ -83,8 +84,7 @@ public: } // anonymous namespace -std::unique_ptr<OpPassBase<gpu::GPUModuleOp>> +std::unique_ptr<OperationPass<gpu::GPUModuleOp>> mlir::createLowerGpuOpsToROCDLOpsPass() { return std::make_unique<LowerGpuOpsToROCDLOpsPass>(); } - diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp index edee5025ded9..173c6d0f5826 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -33,7 +33,8 @@ namespace { /// replace it). /// /// 2) Lower the body of the spirv::ModuleOp. -struct GPUToSPIRVPass : public OperationPass<GPUToSPIRVPass, ModuleOp> { +struct GPUToSPIRVPass + : public PassWrapper<GPUToSPIRVPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuToSPIRV #include "mlir/Conversion/Passes.h.inc" @@ -71,6 +72,6 @@ void GPUToSPIRVPass::runOnOperation() { } } -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertGPUToSPIRVPass() { +std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertGPUToSPIRVPass() { return std::make_unique<GPUToSPIRVPass>(); } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp index cbcfd741d9f8..bf2c15f68ae5 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -38,7 +38,8 @@ namespace { /// function and attaching binary data and entry point name as an attributes to /// created vulkan launch call op. class ConvertGpuLaunchFuncToVulkanLaunchFunc - : public OperationPass<ConvertGpuLaunchFuncToVulkanLaunchFunc, ModuleOp> { + : public PassWrapper<ConvertGpuLaunchFuncToVulkanLaunchFunc, + OperationPass<ModuleOp>> { public: /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuLaunchFuncToVulkanLaunchFunc @@ -168,7 +169,7 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc( launchOp.erase(); } -std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> +std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass() { return std::make_unique<ConvertGpuLaunchFuncToVulkanLaunchFunc>(); } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp index 2daa13085bcb..03d924d74f56 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -58,7 +58,8 @@ namespace { /// * deinitVulkan -- deinitializes vulkan runtime /// class VulkanLaunchFuncToVulkanCallsPass - : public OperationPass<VulkanLaunchFuncToVulkanCallsPass, ModuleOp> { + : public PassWrapper<VulkanLaunchFuncToVulkanCallsPass, + OperationPass<ModuleOp>> { private: /// Include the generated pass utilities. #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls @@ -436,7 +437,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( cInterfaceVulkanLaunchCallOp.erase(); } -std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> +std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 99f106e29de7..05aab300e622 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -562,7 +562,7 @@ void mlir::populateLinalgToLLVMConversionPatterns( namespace { struct ConvertLinalgToLLVMPass - : public OperationPass<ConvertLinalgToLLVMPass, ModuleOp> { + : public PassWrapper<ConvertLinalgToLLVMPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertLinalgToLLVM #include "mlir/Conversion/Passes.h.inc" @@ -593,6 +593,6 @@ void ConvertLinalgToLLVMPass::runOnOperation() { signalPassFailure(); } -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { +std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { return std::make_unique<ConvertLinalgToLLVMPass>(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp index 4b66063b88eb..acb87b72c3c6 100644 --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -16,7 +16,8 @@ using namespace mlir; namespace { /// A pass converting MLIR Linalg ops into SPIR-V ops. -class LinalgToSPIRVPass : public OperationPass<LinalgToSPIRVPass, ModuleOp> { +class LinalgToSPIRVPass + : public PassWrapper<LinalgToSPIRVPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertLinalgToSPIRV #include "mlir/Conversion/Passes.h.inc" @@ -47,6 +48,6 @@ void LinalgToSPIRVPass::runOnOperation() { return signalPassFailure(); } -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLinalgToSPIRVPass() { +std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgToSPIRVPass() { return std::make_unique<LinalgToSPIRVPass>(); } diff --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp index 7d20561ab6cf..9929b8e816f6 100644 --- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp @@ -30,7 +30,8 @@ using namespace mlir::loop; namespace { -struct LoopToStandardPass : public OperationPass<LoopToStandardPass> { +struct LoopToStandardPass + : public PassWrapper<LoopToStandardPass, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertLoopToStandard #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 98a855c105c3..679d0b339691 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -28,7 +28,7 @@ namespace { // A pass that traverses top-level loops in the function and converts them to // GPU launch operations. Nested launches are not allowed, so this does not // walk the function recursively to avoid considering nested loops. -struct ForLoopMapper : public FunctionPass<ForLoopMapper> { +struct ForLoopMapper : public PassWrapper<ForLoopMapper, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_ConvertSimpleLoopsToGPU #include "mlir/Conversion/Passes.h.inc" @@ -62,7 +62,7 @@ struct ForLoopMapper : public FunctionPass<ForLoopMapper> { // nested loops as the size of `numWorkGroups`. Within these any loop nest has // to be perfectly nested upto depth equal to size of `workGroupSize`. struct ImperfectlyNestedForLoopMapper - : public FunctionPass<ImperfectlyNestedForLoopMapper> { + : public PassWrapper<ImperfectlyNestedForLoopMapper, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_ConvertLoopsToGPU #include "mlir/Conversion/Passes.h.inc" @@ -104,7 +104,8 @@ struct ImperfectlyNestedForLoopMapper } }; -struct ParallelLoopToGpuPass : public OperationPass<ParallelLoopToGpuPass> { +struct ParallelLoopToGpuPass + : public PassWrapper<ParallelLoopToGpuPass, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertParallelLoopToGpu #include "mlir/Conversion/Passes.h.inc" @@ -125,22 +126,22 @@ struct ParallelLoopToGpuPass : public OperationPass<ParallelLoopToGpuPass> { } // namespace -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims) { return std::make_unique<ForLoopMapper>(numBlockDims, numThreadDims); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::createSimpleLoopsToGPUPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createSimpleLoopsToGPUPass() { return std::make_unique<ForLoopMapper>(); } -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups, ArrayRef<int64_t> workGroupSize) { return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups, workGroupSize); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopToGPUPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopToGPUPass() { return std::make_unique<ImperfectlyNestedForLoopMapper>(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index ef5dabf2ff88..d23883f6d624 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2847,7 +2847,8 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands, namespace { /// A pass converting MLIR operations into the LLVM IR dialect. -struct LLVMLoweringPass : public OperationPass<LLVMLoweringPass, ModuleOp> { +struct LLVMLoweringPass + : public PassWrapper<LLVMLoweringPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertStandardToLLVM #include "mlir/Conversion/Passes.h.inc" @@ -2901,7 +2902,7 @@ mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) this->addIllegalOp<TanhOp>(); } -std::unique_ptr<OpPassBase<ModuleOp>> +std::unique_ptr<OperationPass<ModuleOp>> mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { return std::make_unique<LLVMLoweringPass>( options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index 86c8cd17433c..b0ce99fa837f 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -22,7 +22,7 @@ using namespace mlir; namespace { /// A pass converting MLIR Standard operations into the SPIR-V dialect. class ConvertStandardToSPIRVPass - : public OperationPass<ConvertStandardToSPIRVPass, ModuleOp> { + : public PassWrapper<ConvertStandardToSPIRVPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertStandardToSPIRV #include "mlir/Conversion/Passes.h.inc" @@ -49,6 +49,7 @@ void ConvertStandardToSPIRVPass::runOnOperation() { } } -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertStandardToSPIRVPass() { +std::unique_ptr<OperationPass<ModuleOp>> +mlir::createConvertStandardToSPIRVPass() { return std::make_unique<ConvertStandardToSPIRVPass>(); } diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 381087fbdf76..9dbb76174201 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -160,7 +160,8 @@ void mlir::populateStdLegalizationPatternsForSPIRVLowering( //===----------------------------------------------------------------------===// namespace { -struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> { +struct SPIRVLegalization final + : public PassWrapper<SPIRVLegalization, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_LegalizeStandardForSPIRV #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index b2a1c443f518..03cbb67bc5d7 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1119,7 +1119,7 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns( namespace { struct LowerVectorToLLVMPass - : public OperationPass<LowerVectorToLLVMPass, ModuleOp> { + : public PassWrapper<LowerVectorToLLVMPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_ConvertVectorToLLVM #include "mlir/Conversion/Passes.h.inc" @@ -1155,6 +1155,6 @@ void LowerVectorToLLVMPass::runOnOperation() { } } -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertVectorToLLVMPass() { +std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() { return std::make_unique<LowerVectorToLLVMPass>(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index 3ba61bcc022d..1f7d670ce473 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -75,7 +75,7 @@ namespace { // TODO(bondhugula): We currently can't generate copies correctly when stores // are strided. Check for strided stores. struct AffineDataCopyGeneration - : public FunctionPass<AffineDataCopyGeneration> { + : public PassWrapper<AffineDataCopyGeneration, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_AffineDataCopyGeneration #include "mlir/Dialect/Affine/Passes.h.inc" @@ -134,14 +134,15 @@ struct AffineDataCopyGeneration /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. /// TODO(bondhugula): extend this to store op's. -std::unique_ptr<OpPassBase<FuncOp>> mlir::createAffineDataCopyGenerationPass( +std::unique_ptr<OperationPass<FuncOp>> mlir::createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace, int minDmaTransferSize, uint64_t fastMemCapacityBytes) { return std::make_unique<AffineDataCopyGeneration>( slowMemorySpace, fastMemorySpace, tagMemorySpace, minDmaTransferSize, fastMemCapacityBytes); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::createAffineDataCopyGenerationPass() { +std::unique_ptr<OperationPass<FuncOp>> +mlir::createAffineDataCopyGenerationPass() { return std::make_unique<AffineDataCopyGeneration>(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp index 12fdb37bf841..066a53d14e23 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -41,7 +41,8 @@ namespace { /// TODO(asabne) : Check for the presence of side effects before hoisting. /// TODO: This code should be removed once the new LICM pass can handle its /// uses. -struct LoopInvariantCodeMotion : public FunctionPass<LoopInvariantCodeMotion> { +struct LoopInvariantCodeMotion + : public PassWrapper<LoopInvariantCodeMotion, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_AffineLoopInvariantCodeMotion #include "mlir/Dialect/Affine/Passes.h.inc" @@ -232,7 +233,7 @@ void LoopInvariantCodeMotion::runOnFunction() { }); } -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createAffineLoopInvariantCodeMotionPass() { return std::make_unique<LoopInvariantCodeMotion>(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp index 2f5eea7606a9..1cfb31045c66 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -58,7 +58,7 @@ static llvm::cl::list<unsigned> clTileSizes( namespace { /// A pass to perform loop tiling on all suitable loop nests of a Function. -struct LoopTiling : public FunctionPass<LoopTiling> { +struct LoopTiling : public PassWrapper<LoopTiling, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_AffineLoopTiling #include "mlir/Dialect/Affine/Passes.h.inc" @@ -85,11 +85,11 @@ struct LoopTiling : public FunctionPass<LoopTiling> { /// Creates a pass to perform loop tiling on all suitable loop nests of a /// Function. -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { return std::make_unique<LoopTiling>(cacheSizeBytes); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopTilingPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopTilingPass() { return std::make_unique<LoopTiling>(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp index 7f33630d3b8a..d9a6b1cd2690 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -58,7 +58,7 @@ namespace { /// full unroll threshold was specified, in which case, fully unrolls all loops /// with trip count less than the specified threshold. The latter is for testing /// purposes, especially for testing outer loop unrolling. -struct LoopUnroll : public FunctionPass<LoopUnroll> { +struct LoopUnroll : public PassWrapper<LoopUnroll, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_AffineUnroll #include "mlir/Dialect/Affine/Passes.h.inc" @@ -166,7 +166,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopUnrollPass( +std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, const std::function<unsigned(AffineForOp)> &getUnrollFactor) { return std::make_unique<LoopUnroll>( diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp index 5f419a83d6cf..1a2567796795 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp @@ -60,7 +60,7 @@ static llvm::cl::opt<unsigned> namespace { /// Loop unroll jam pass. Currently, this just unroll jams the first /// outer loop in a Function. -struct LoopUnrollAndJam : public FunctionPass<LoopUnrollAndJam> { +struct LoopUnrollAndJam : public PassWrapper<LoopUnrollAndJam, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_AffineLoopUnrollAndJam #include "mlir/Dialect/Affine/Passes.h.inc" @@ -76,7 +76,7 @@ struct LoopUnrollAndJam : public FunctionPass<LoopUnrollAndJam> { }; } // end anonymous namespace -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { return std::make_unique<LoopUnrollAndJam>( unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor)); diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index fc58b19656fe..af11d4e8d114 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -27,7 +27,7 @@ namespace { /// all memrefs with non-trivial layout maps are converted to ones with trivial /// identity layout ones. struct SimplifyAffineStructures - : public FunctionPass<SimplifyAffineStructures> { + : public PassWrapper<SimplifyAffineStructures, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_SimplifyAffineStructures #include "mlir/Dialect/Affine/Passes.h.inc" @@ -73,7 +73,8 @@ struct SimplifyAffineStructures } // end anonymous namespace -std::unique_ptr<OpPassBase<FuncOp>> mlir::createSimplifyAffineStructuresPass() { +std::unique_ptr<OperationPass<FuncOp>> +mlir::createSimplifyAffineStructuresPass() { return std::make_unique<SimplifyAffineStructures>(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 395945982225..06de9a2c9da5 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -573,7 +573,7 @@ namespace { /// Base state for the vectorize pass. /// Command line arguments are preempted by non-empty pass arguments. -struct Vectorize : public FunctionPass<Vectorize> { +struct Vectorize : public PassWrapper<Vectorize, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_AffineVectorize #include "mlir/Dialect/Affine/Passes.h.inc" @@ -1252,10 +1252,10 @@ void Vectorize::runOnFunction() { LLVM_DEBUG(dbgs() << "\n"); } -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize) { return std::make_unique<Vectorize>(virtualVectorSize); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::createSuperVectorizePass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createSuperVectorizePass() { return std::make_unique<Vectorize>(); } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index daf9169d242c..70ace0f1f45f 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -215,7 +215,7 @@ namespace { /// a separate pass. The external functions can then be annotated with the /// symbol of the cubin accessor function. class GpuKernelOutliningPass - : public OperationPass<GpuKernelOutliningPass, ModuleOp> { + : public PassWrapper<GpuKernelOutliningPass, OperationPass<ModuleOp>> { public: /// Include the generated pass utilities. #define GEN_PASS_GpuKernelOutlining @@ -301,6 +301,6 @@ private: } // namespace -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createGpuKernelOutliningPass() { +std::unique_ptr<OperationPass<ModuleOp>> mlir::createGpuKernelOutliningPass() { return std::make_unique<GpuKernelOutliningPass>(); } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp index 663f0c11432b..88f06d7ed18f 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp @@ -57,7 +57,8 @@ void mlir::LLVM::ensureDistinctSuccessors(Operation *op) { } namespace { -struct LegalizeForExportPass : public OperationPass<LegalizeForExportPass> { +struct LegalizeForExportPass + : public PassWrapper<LegalizeForExportPass, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_LLVMLegalizeForExport #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index d3c56851a385..07b4d9788a80 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -567,7 +567,8 @@ struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> { }; /// Pass that fuses generic ops on tensors. Used only for testing. -struct FusionOfTensorOpsPass : public OperationPass<FusionOfTensorOpsPass> { +struct FusionOfTensorOpsPass + : public PassWrapper<FusionOfTensorOpsPass, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_LinalgFusionOfTensorOps #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -580,7 +581,7 @@ struct FusionOfTensorOpsPass : public OperationPass<FusionOfTensorOpsPass> { }; }; -struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { +struct LinalgFusionPass : public PassWrapper<LinalgFusionPass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LinalgFusion #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -589,7 +590,7 @@ struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { }; } // namespace -std::unique_ptr<OpPassBase<FuncOp>> mlir::createLinalgFusionPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { return std::make_unique<LinalgFusionPass>(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index 4e26080eec53..8a4df6414833 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -693,7 +693,8 @@ static void lowerLinalgToLoopsImpl(Operation *op, MLIRContext *context) { } namespace { -struct LowerToAffineLoops : public FunctionPass<LowerToAffineLoops> { +struct LowerToAffineLoops + : public PassWrapper<LowerToAffineLoops, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LinalgLowerToAffineLoops #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -703,7 +704,7 @@ struct LowerToAffineLoops : public FunctionPass<LowerToAffineLoops> { &getContext()); } }; -struct LowerToLoops : public FunctionPass<LowerToLoops> { +struct LowerToLoops : public PassWrapper<LowerToLoops, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LinalgLowerToLoops #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -713,7 +714,8 @@ struct LowerToLoops : public FunctionPass<LowerToLoops> { &getContext()); } }; -struct LowerToParallelLoops : public FunctionPass<LowerToParallelLoops> { +struct LowerToParallelLoops + : public PassWrapper<LowerToParallelLoops, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LinalgLowerToParallelLoops #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -725,16 +727,16 @@ struct LowerToParallelLoops : public FunctionPass<LowerToParallelLoops> { }; } // namespace -std::unique_ptr<OpPassBase<FuncOp>> mlir::createConvertLinalgToLoopsPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { return std::make_unique<LowerToLoops>(); } -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToParallelLoopsPass() { return std::make_unique<LowerToParallelLoops>(); } -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToAffineLoopsPass() { return std::make_unique<LowerToAffineLoops>(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 6603507a1cdf..6eea97f954ff 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -230,7 +230,8 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { } namespace { -struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> { +struct LinalgPromotionPass + : public PassWrapper<LinalgPromotionPass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LinalgPromotion #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -247,10 +248,10 @@ struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> { }; } // namespace -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgPromotionPass(bool dynamicBuffers) { return std::make_unique<LinalgPromotionPass>(dynamicBuffers); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::createLinalgPromotionPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgPromotionPass() { return std::make_unique<LinalgPromotionPass>(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 919a7d53f479..1e528aa9a201 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -507,7 +507,7 @@ static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) { } namespace { -struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> { +struct LinalgTilingPass : public PassWrapper<LinalgTilingPass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LinalgTiling #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -524,7 +524,7 @@ struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> { }; struct LinalgTilingToParallelLoopsPass - : public FunctionPass<LinalgTilingToParallelLoopsPass> { + : public PassWrapper<LinalgTilingToParallelLoopsPass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LinalgTilingToParallelLoops #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -542,12 +542,12 @@ struct LinalgTilingToParallelLoopsPass } // namespace -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) { return std::make_unique<LinalgTilingPass>(tileSizes); } -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) { return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes); } diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp index b59f1fc3c8c9..b031f81076cd 100644 --- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp @@ -160,7 +160,8 @@ void mlir::loop::naivelyFuseParallelOps(Region ®ion) { } namespace { -struct ParallelLoopFusion : public OperationPass<ParallelLoopFusion> { +struct ParallelLoopFusion + : public PassWrapper<ParallelLoopFusion, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_LoopParallelLoopFusion #include "mlir/Dialect/LoopOps/Passes.h.inc" diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopSpecialization.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopSpecialization.cpp index 63afa5059509..98776abdb06c 100644 --- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopSpecialization.cpp +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopSpecialization.cpp @@ -60,7 +60,7 @@ static void specializeLoopForUnrolling(ParallelOp op) { namespace { struct ParallelLoopSpecialization - : public FunctionPass<ParallelLoopSpecialization> { + : public PassWrapper<ParallelLoopSpecialization, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LoopParallelLoopSpecialization #include "mlir/Dialect/LoopOps/Passes.h.inc" diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp index 8b63d0090291..c10872ece9f9 100644 --- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp @@ -101,7 +101,8 @@ static bool getInnermostNestedLoops(Block *block, } namespace { -struct ParallelLoopTiling : public FunctionPass<ParallelLoopTiling> { +struct ParallelLoopTiling + : public PassWrapper<ParallelLoopTiling, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LoopParallelLoopTiling #include "mlir/Dialect/LoopOps/Passes.h.inc" diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp index 3017346a9acc..d892f67a23d7 100644 --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -21,7 +21,7 @@ using namespace mlir; using namespace mlir::quant; namespace { -struct ConvertConstPass : public FunctionPass<ConvertConstPass> { +struct ConvertConstPass : public PassWrapper<ConvertConstPass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_QuantConvertConst #include "mlir/Dialect/Quant/Passes.h.inc" @@ -105,6 +105,6 @@ void ConvertConstPass::runOnFunction() { applyPatternsGreedily(func, patterns); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::quant::createConvertConstPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() { return std::make_unique<ConvertConstPass>(); } diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp index b76cee6a412c..079c3ff96ad1 100644 --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp @@ -20,7 +20,7 @@ using namespace mlir::quant; namespace { struct ConvertSimulatedQuantPass - : public FunctionPass<ConvertSimulatedQuantPass> { + : public PassWrapper<ConvertSimulatedQuantPass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_QuantConvertSimulatedQuant #include "mlir/Dialect/Quant/Passes.h.inc" @@ -140,7 +140,7 @@ void ConvertSimulatedQuantPass::runOnFunction() { signalPassFailure(); } -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertSimulatedQuantPass() { return std::make_unique<ConvertSimulatedQuantPass>(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index e4622741536e..a0b2c168985f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -80,7 +80,8 @@ static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns, namespace { class DecorateSPIRVCompositeTypeLayoutPass - : public OperationPass<DecorateSPIRVCompositeTypeLayoutPass, ModuleOp> { + : public PassWrapper<DecorateSPIRVCompositeTypeLayoutPass, + OperationPass<ModuleOp>> { private: void runOnOperation() override; }; @@ -113,7 +114,7 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() { } } -std::unique_ptr<OpPassBase<ModuleOp>> +std::unique_ptr<OperationPass<ModuleOp>> mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass() { return std::make_unique<DecorateSPIRVCompositeTypeLayoutPass>(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 9cb2bfe1e1fc..d666f9697374 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -148,7 +148,8 @@ public: /// Pass to implement the ABI information specified as attributes. class LowerABIAttributesPass final - : public OperationPass<LowerABIAttributesPass, spirv::ModuleOp> { + : public PassWrapper<LowerABIAttributesPass, + OperationPass<spirv::ModuleOp>> { private: void runOnOperation() override; }; @@ -260,7 +261,7 @@ void LowerABIAttributesPass::runOnOperation() { } } -std::unique_ptr<OpPassBase<spirv::ModuleOp>> +std::unique_ptr<OperationPass<spirv::ModuleOp>> mlir::spirv::createLowerABIAttributesPass() { return std::make_unique<LowerABIAttributesPass>(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index ebb5b6eda83a..415535b6da97 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -27,7 +27,7 @@ namespace { /// Pass to deduce minimal version/extension/capability requirements for a /// spirv::ModuleOp. class UpdateVCEPass final - : public OperationPass<UpdateVCEPass, spirv::ModuleOp> { + : public PassWrapper<UpdateVCEPass, OperationPass<spirv::ModuleOp>> { private: void runOnOperation() override; }; @@ -173,7 +173,7 @@ void UpdateVCEPass::runOnOperation() { module.setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple); } -std::unique_ptr<OpPassBase<spirv::ModuleOp>> +std::unique_ptr<OperationPass<spirv::ModuleOp>> mlir::spirv::createUpdateVersionCapabilityExtensionPass() { return std::make_unique<UpdateVCEPass>(); } diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index bcd71ce4e2f6..59d9a7a0576f 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -19,7 +19,7 @@ namespace detail { //===----------------------------------------------------------------------===// /// Pass to verify an operation and signal failure if necessary. -class VerifierPass : public OperationPass<VerifierPass> { +class VerifierPass : public PassWrapper<VerifierPass, OperationPass<>> { void runOnOperation() override; }; @@ -49,8 +49,9 @@ protected: /// An adaptor pass used to run operation passes over nested operations /// synchronously on a single thread. -class OpToOpPassAdaptor : public OperationPass<OpToOpPassAdaptor>, - public OpToOpPassAdaptorBase { +class OpToOpPassAdaptor + : public PassWrapper<OpToOpPassAdaptor, OperationPass<>>, + public OpToOpPassAdaptorBase { public: OpToOpPassAdaptor(OpPassManager &&mgr); @@ -61,7 +62,7 @@ public: /// An adaptor pass used to run operation passes over nested operations /// asynchronously across multiple threads. class OpToOpPassAdaptorParallel - : public OperationPass<OpToOpPassAdaptorParallel>, + : public PassWrapper<OpToOpPassAdaptorParallel, OperationPass<>>, public OpToOpPassAdaptorBase { public: OpToOpPassAdaptorParallel(OpPassManager &&mgr); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 67890f63dcd9..919c957f3d9d 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -73,7 +73,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> { namespace { /// Simple common sub-expression elimination. -struct CSE : public OperationPass<CSE> { +struct CSE : public PassWrapper<CSE, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_CSE #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 964fc7f66500..3f3d30296785 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -19,7 +19,7 @@ using namespace mlir; namespace { /// Canonicalize operations in nested regions. -struct Canonicalizer : public OperationPass<Canonicalizer> { +struct Canonicalizer : public PassWrapper<Canonicalizer, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_Canonicalizer #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index 04ae30faf2aa..60382ea64f76 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -589,7 +589,7 @@ static void inlineSCC(Inliner &inliner, CGUseList &useList, //===----------------------------------------------------------------------===// namespace { -struct InlinerPass : public OperationPass<InlinerPass> { +struct InlinerPass : public PassWrapper<InlinerPass, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_Inliner #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp index 4f54e13e76fd..e9858bc142e8 100644 --- a/mlir/lib/Transforms/LocationSnapshot.cpp +++ b/mlir/lib/Transforms/LocationSnapshot.cpp @@ -123,7 +123,8 @@ LogicalResult mlir::generateLocationsFromIR(StringRef fileName, StringRef tag, } namespace { -struct LocationSnapshotPass : public OperationPass<LocationSnapshotPass> { +struct LocationSnapshotPass + : public PassWrapper<LocationSnapshotPass, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_LocationSnapshot #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp index 322b3b92c52c..57d8e2a26d67 100644 --- a/mlir/lib/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Transforms/LoopCoalescing.cpp @@ -19,7 +19,8 @@ using namespace mlir; namespace { -struct LoopCoalescingPass : public FunctionPass<LoopCoalescingPass> { +struct LoopCoalescingPass + : public PassWrapper<LoopCoalescingPass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_LoopCoalescing #include "mlir/Transforms/Passes.h.inc" @@ -89,6 +90,6 @@ struct LoopCoalescingPass : public FunctionPass<LoopCoalescingPass> { } // namespace -std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopCoalescingPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopCoalescingPass() { return std::make_unique<LoopCoalescingPass>(); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 55b45c0595a6..f802ba526b25 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -77,7 +77,7 @@ namespace { // TODO(andydavis) Extend this pass to check for fusion preventing dependences, // and add support for more general loop fusion algorithms. -struct LoopFusion : public FunctionPass<LoopFusion> { +struct LoopFusion : public PassWrapper<LoopFusion, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_AffineLoopFusion #include "mlir/Transforms/Passes.h.inc" @@ -104,7 +104,7 @@ struct LoopFusion : public FunctionPass<LoopFusion> { } // end anonymous namespace -std::unique_ptr<OpPassBase<FuncOp>> +std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopFusionPass(unsigned fastMemorySpace, uint64_t localBufSizeThreshold, bool maximalFusion) { return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 7407676d5877..e7e48ac40714 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -27,7 +27,8 @@ using namespace mlir; namespace { /// Loop invariant code motion (LICM) pass. -struct LoopInvariantCodeMotion : public OperationPass<LoopInvariantCodeMotion> { +struct LoopInvariantCodeMotion + : public PassWrapper<LoopInvariantCodeMotion, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_LoopInvariantCodeMotion #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index e251edaf38cb..5b03de923991 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -60,7 +60,7 @@ namespace { // currently only eliminates the stores only if no other loads/uses (other // than dealloc) remain. // -struct MemRefDataFlowOpt : public FunctionPass<MemRefDataFlowOpt> { +struct MemRefDataFlowOpt : public PassWrapper<MemRefDataFlowOpt, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_MemRefDataFlowOpt #include "mlir/Transforms/Passes.h.inc" @@ -82,7 +82,7 @@ struct MemRefDataFlowOpt : public FunctionPass<MemRefDataFlowOpt> { /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefDataFlowOptPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createMemRefDataFlowOptPass() { return std::make_unique<MemRefDataFlowOpt>(); } diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp index 2b519d697020..667a0b4f4f57 100644 --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -18,7 +18,8 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public OperationPass<PrintOpStatsPass, ModuleOp> { +struct PrintOpStatsPass + : public PassWrapper<PrintOpStatsPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_PrintOpStats #include "mlir/Transforms/Passes.h.inc" @@ -85,6 +86,6 @@ void PrintOpStatsPass::printSummary() { } } -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createPrintOpStatsPass() { +std::unique_ptr<OperationPass<ModuleOp>> mlir::createPrintOpStatsPass() { return std::make_unique<PrintOpStatsPass>(); } diff --git a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp index 29aded6465c7..4380fe30d089 100644 --- a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp +++ b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp @@ -20,7 +20,8 @@ using namespace mlir; namespace { -struct ParallelLoopCollapsing : public OperationPass<ParallelLoopCollapsing> { +struct ParallelLoopCollapsing + : public PassWrapper<ParallelLoopCollapsing, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_ParallelLoopCollapsing #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 4c1030336540..8eeea89d73f6 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -28,7 +28,8 @@ using namespace mlir; namespace { -struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> { +struct PipelineDataTransfer + : public PassWrapper<PipelineDataTransfer, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_AffinePipelineDataTransfer #include "mlir/Transforms/Passes.h.inc" @@ -43,7 +44,7 @@ struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> { /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -std::unique_ptr<OpPassBase<FuncOp>> mlir::createPipelineDataTransferPass() { +std::unique_ptr<OperationPass<FuncOp>> mlir::createPipelineDataTransferPass() { return std::make_unique<PipelineDataTransfer>(); } diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index d3420cfc35a1..e5ba14402515 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -14,7 +14,7 @@ using namespace mlir; namespace { -struct StripDebugInfo : public OperationPass<StripDebugInfo> { +struct StripDebugInfo : public PassWrapper<StripDebugInfo, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_StripDebugInfo #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp index 7513e34af388..251a956be75d 100644 --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -17,7 +17,7 @@ using namespace mlir; namespace { -struct SymbolDCE : public OperationPass<SymbolDCE> { +struct SymbolDCE : public PassWrapper<SymbolDCE, OperationPass<>> { /// Include the generated pass utilities. #define GEN_PASS_SymbolDCE #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index c5d921db059e..8ac61fc4b815 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -100,7 +100,7 @@ namespace { // PrintOpPass is simple pass to write graph per function. // Note: this is a module pass only to avoid interleaving on the same ostream // due to multi-threading over functions. -struct PrintOpPass : public OperationPass<PrintOpPass, ModuleOp> { +struct PrintOpPass : public PassWrapper<PrintOpPass, OperationPass<ModuleOp>> { /// Include the generated pass utilities. #define GEN_PASS_PrintOpGraph #include "mlir/Transforms/Passes.h.inc" @@ -160,7 +160,7 @@ raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames, return llvm::WriteGraph(os, &block, shortNames, title); } -std::unique_ptr<OpPassBase<ModuleOp>> +std::unique_ptr<OperationPass<ModuleOp>> mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames, const Twine &title) { return std::make_unique<PrintOpPass>(os, shortNames, title); diff --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp index cf9ff6d8077e..4f31a79cd9d3 100644 --- a/mlir/lib/Transforms/ViewRegionGraph.cpp +++ b/mlir/lib/Transforms/ViewRegionGraph.cpp @@ -60,7 +60,7 @@ void mlir::Region::viewGraph(const Twine ®ionName) { void mlir::Region::viewGraph() { viewGraph("region"); } namespace { -struct PrintCFGPass : public FunctionPass<PrintCFGPass> { +struct PrintCFGPass : public PassWrapper<PrintCFGPass, FunctionPass> { /// Include the generated pass utilities. #define GEN_PASS_PrintCFG #include "mlir/Transforms/Passes.h.inc" @@ -79,7 +79,7 @@ private: }; } // namespace -std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>> +std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> mlir::createPrintCFGGraphPass(raw_ostream &os, bool shortNames, const Twine &title) { return std::make_unique<PrintCFGPass>(os, shortNames, title); diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp index 9e3b3434104a..7c0052dd9e68 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -26,7 +26,8 @@ static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options"); namespace { -struct TestAffineDataCopy : public FunctionPass<TestAffineDataCopy> { +struct TestAffineDataCopy + : public PassWrapper<TestAffineDataCopy, FunctionPass> { TestAffineDataCopy() = default; TestAffineDataCopy(const TestAffineDataCopy &pass){}; diff --git a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp index 87b2e620f7d6..a34901685159 100644 --- a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp +++ b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp @@ -25,7 +25,8 @@ static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options"); namespace { /// This pass applies the permutation on the first maximal perfect nest. -struct TestLoopPermutation : public FunctionPass<TestLoopPermutation> { +struct TestLoopPermutation + : public PassWrapper<TestLoopPermutation, FunctionPass> { TestLoopPermutation() = default; TestLoopPermutation(const TestLoopPermutation &pass){}; diff --git a/mlir/test/lib/Dialect/Affine/TestParallelismDetection.cpp b/mlir/test/lib/Dialect/Affine/TestParallelismDetection.cpp index 1140dab92dbb..b19e26031693 100644 --- a/mlir/test/lib/Dialect/Affine/TestParallelismDetection.cpp +++ b/mlir/test/lib/Dialect/Affine/TestParallelismDetection.cpp @@ -20,7 +20,7 @@ using namespace mlir; namespace { struct TestParallelismDetection - : public FunctionPass<TestParallelismDetection> { + : public PassWrapper<TestParallelismDetection, FunctionPass> { void runOnFunction() override; }; diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp index 01382530fa39..ca738fde6103 100644 --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -72,7 +72,8 @@ static llvm::cl::opt<bool> clTestNormalizeMaps( llvm::cl::cat(clOptionsCategory)); namespace { -struct VectorizerTestPass : public FunctionPass<VectorizerTestPass> { +struct VectorizerTestPass + : public PassWrapper<VectorizerTestPass, FunctionPass> { static constexpr auto kTestAffineMapOpName = "test_affine_map"; static constexpr auto kTestAffineMapAttrName = "affine_map"; diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp index ad77e7d05f42..8c6ca60dabaa 100644 --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -20,7 +20,8 @@ using namespace mlir; namespace { /// A pass for testing SPIR-V op availability. -struct PrintOpAvailability : public FunctionPass<PrintOpAvailability> { +struct PrintOpAvailability + : public PassWrapper<PrintOpAvailability, FunctionPass> { void runOnFunction() override; }; } // end anonymous namespace @@ -88,7 +89,8 @@ void registerPrintOpAvailabilityPass() { namespace { /// A pass for testing SPIR-V op availability. -struct ConvertToTargetEnv : public FunctionPass<ConvertToTargetEnv> { +struct ConvertToTargetEnv + : public PassWrapper<ConvertToTargetEnv, FunctionPass> { void runOnFunction() override; }; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 6ccfa04a8194..39b3fc1e5f4b 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -38,7 +38,7 @@ namespace { //===----------------------------------------------------------------------===// namespace { -struct TestPatternDriver : public FunctionPass<TestPatternDriver> { +struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { void runOnFunction() override { mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); @@ -96,7 +96,8 @@ static void reifyReturnShape(Operation *op) { << it.value().getDefiningOp(); } -struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> { +struct TestReturnTypeDriver + : public PassWrapper<TestReturnTypeDriver, FunctionPass> { void runOnFunction() override { if (getFunction().getName() == "testCreateFunctions") { std::vector<Operation *> ops; @@ -398,7 +399,7 @@ struct TestTypeConverter : public TypeConverter { }; struct TestLegalizePatternDriver - : public OperationPass<TestLegalizePatternDriver, ModuleOp> { + : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { /// The mode of conversion to use with the driver. enum class ConversionMode { Analysis, Full, Partial }; @@ -534,7 +535,8 @@ struct OneVResOneVOperandOp1Converter } }; -struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> { +struct TestRemappedValue + : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { void runOnFunction() override { mlir::OwningRewritePatternList patterns; patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp index c1b90397ec44..637864e049fd 100644 --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -13,7 +13,8 @@ using namespace mlir; namespace { /// This is a test pass for verifying FuncOp's eraseArgument method. -struct TestFuncEraseArg : public OperationPass<TestFuncEraseArg, ModuleOp> { +struct TestFuncEraseArg + : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> { void runOnOperation() override { auto module = getOperation(); @@ -36,7 +37,8 @@ struct TestFuncEraseArg : public OperationPass<TestFuncEraseArg, ModuleOp> { }; /// This is a test pass for verifying FuncOp's setType method. -struct TestFuncSetType : public OperationPass<TestFuncSetType, ModuleOp> { +struct TestFuncSetType + : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> { void runOnOperation() override { auto module = getOperation(); SymbolTable symbolTable(module); diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp index a5db0684c0b5..8af91506f639 100644 --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -15,7 +15,7 @@ using namespace mlir; namespace { /// This is a test pass for verifying matchers. -struct TestMatchers : public FunctionPass<TestMatchers> { +struct TestMatchers : public PassWrapper<TestMatchers, FunctionPass> { void runOnFunction() override; }; } // end anonymous namespace diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp index a99348537e25..1ac9e82a7f7e 100644 --- a/mlir/test/lib/IR/TestSideEffects.cpp +++ b/mlir/test/lib/IR/TestSideEffects.cpp @@ -12,7 +12,8 @@ using namespace mlir; namespace { -struct SideEffectsPass : public OperationPass<SideEffectsPass, ModuleOp> { +struct SideEffectsPass + : public PassWrapper<SideEffectsPass, OperationPass<ModuleOp>> { void runOnOperation() override { auto module = getOperation(); diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp index c39615ef1352..13188485ec41 100644 --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -15,7 +15,8 @@ using namespace mlir; namespace { /// This is a symbol test pass that tests the symbol uselist functionality /// provided by the symbol table along with erasing from the symbol table. -struct SymbolUsesPass : public OperationPass<SymbolUsesPass, ModuleOp> { +struct SymbolUsesPass + : public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> { WalkResult operateOnSymbol(Operation *symbol, ModuleOp module, SmallVectorImpl<FuncOp> &deadFunctions) { // Test computing uses on a non symboltable op. @@ -87,7 +88,7 @@ struct SymbolUsesPass : public OperationPass<SymbolUsesPass, ModuleOp> { /// This is a symbol test pass that tests the symbol use replacement /// functionality provided by the symbol table. struct SymbolReplacementPass - : public OperationPass<SymbolReplacementPass, ModuleOp> { + : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> { void runOnOperation() override { auto module = getOperation(); diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index be8a7479200c..ffac1b18be46 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -13,13 +13,14 @@ using namespace mlir; namespace { -struct TestModulePass : public OperationPass<TestModulePass, ModuleOp> { +struct TestModulePass + : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> { void runOnOperation() final {} }; -struct TestFunctionPass : public FunctionPass<TestFunctionPass> { +struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> { void runOnFunction() final {} }; -class TestOptionsPass : public FunctionPass<TestOptionsPass> { +class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> { public: struct Options : public PassPipelineOptions<Options> { ListOption<int> listOption{*this, "list", @@ -53,12 +54,14 @@ public: /// A test pass that always aborts to enable testing the crash recovery /// mechanism of the pass manager. -class TestCrashRecoveryPass : public OperationPass<TestCrashRecoveryPass> { +class TestCrashRecoveryPass + : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> { void runOnOperation() final { abort(); } }; /// A test pass that contains a statistic. -struct TestStatisticPass : public OperationPass<TestStatisticPass> { +struct TestStatisticPass + : public PassWrapper<TestStatisticPass, OperationPass<>> { TestStatisticPass() = default; TestStatisticPass(const TestStatisticPass &) {} diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp index 6455dab70f45..9a2bcc292379 100644 --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -18,7 +18,7 @@ using namespace mlir; namespace { struct TestAllReduceLoweringPass - : public OperationPass<TestAllReduceLoweringPass, ModuleOp> { + : public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> { void runOnOperation() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); diff --git a/mlir/test/lib/Transforms/TestCallGraph.cpp b/mlir/test/lib/Transforms/TestCallGraph.cpp index a181d645f2af..bd651a5dfe2c 100644 --- a/mlir/test/lib/Transforms/TestCallGraph.cpp +++ b/mlir/test/lib/Transforms/TestCallGraph.cpp @@ -17,7 +17,8 @@ using namespace mlir; namespace { -struct TestCallGraphPass : public OperationPass<TestCallGraphPass, ModuleOp> { +struct TestCallGraphPass + : public PassWrapper<TestCallGraphPass, OperationPass<ModuleOp>> { void runOnOperation() override { llvm::errs() << "Testing : " << getOperation().getAttr("test.name") << "\n"; getAnalysis<CallGraph>().print(llvm::errs()); diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp index cc6ece7f7c46..089d450aaf37 100644 --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -19,7 +19,7 @@ using namespace mlir; namespace { /// Simple constant folding pass. -struct TestConstantFold : public FunctionPass<TestConstantFold> { +struct TestConstantFold : public PassWrapper<TestConstantFold, FunctionPass> { // All constants in the function post folding. SmallVector<Operation *, 8> existingConstants; diff --git a/mlir/test/lib/Transforms/TestDominance.cpp b/mlir/test/lib/Transforms/TestDominance.cpp index 784bb1f40564..97674c400f81 100644 --- a/mlir/test/lib/Transforms/TestDominance.cpp +++ b/mlir/test/lib/Transforms/TestDominance.cpp @@ -64,7 +64,7 @@ private: DenseMap<Block *, size_t> blockIds; }; -struct TestDominancePass : public FunctionPass<TestDominancePass> { +struct TestDominancePass : public PassWrapper<TestDominancePass, FunctionPass> { void runOnFunction() override { llvm::errs() << "Testing : " << getFunction().getName() << "\n"; diff --git a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp index 72304244d8fc..08862dd06140 100644 --- a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp +++ b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp @@ -24,7 +24,8 @@ namespace { /// does not check whether the promotion is legal (e.g., amount of memory used) /// or beneficial (e.g., makes previously uncoalesced loads coalesced). class TestGpuMemoryPromotionPass - : public OperationPass<TestGpuMemoryPromotionPass, gpu::GPUFuncOp> { + : public PassWrapper<TestGpuMemoryPromotionPass, + OperationPass<gpu::GPUFuncOp>> { void runOnOperation() override { gpu::GPUFuncOp op = getOperation(); for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { diff --git a/mlir/test/lib/Transforms/TestGpuParallelLoopMapping.cpp b/mlir/test/lib/Transforms/TestGpuParallelLoopMapping.cpp index d7bbdbb94cba..d877001bedd5 100644 --- a/mlir/test/lib/Transforms/TestGpuParallelLoopMapping.cpp +++ b/mlir/test/lib/Transforms/TestGpuParallelLoopMapping.cpp @@ -20,7 +20,8 @@ namespace { /// Simple pass for testing the mapping of parallel loops to hardware ids using /// a greedy mapping strategy. class TestGpuGreedyParallelLoopMappingPass - : public OperationPass<TestGpuGreedyParallelLoopMappingPass, FuncOp> { + : public PassWrapper<TestGpuGreedyParallelLoopMappingPass, + OperationPass<FuncOp>> { void runOnOperation() override { Operation *op = getOperation(); for (Region ®ion : op->getRegions()) diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp index d6f7776ba5d2..21e0b76cf5cb 100644 --- a/mlir/test/lib/Transforms/TestInlining.cpp +++ b/mlir/test/lib/Transforms/TestInlining.cpp @@ -23,7 +23,7 @@ using namespace mlir; namespace { -struct Inliner : public FunctionPass<Inliner> { +struct Inliner : public PassWrapper<Inliner, FunctionPass> { void runOnFunction() override { auto function = getFunction(); diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index 645ce887135d..85300f981f1e 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -27,7 +27,8 @@ namespace { } // end namespace mlir namespace { -struct TestLinalgTransforms : public FunctionPass<TestLinalgTransforms> { +struct TestLinalgTransforms + : public PassWrapper<TestLinalgTransforms, FunctionPass> { void runOnFunction() override; }; } // end anonymous namespace diff --git a/mlir/test/lib/Transforms/TestLiveness.cpp b/mlir/test/lib/Transforms/TestLiveness.cpp index 9b98ebd49a9c..e51ee1b5f154 100644 --- a/mlir/test/lib/Transforms/TestLiveness.cpp +++ b/mlir/test/lib/Transforms/TestLiveness.cpp @@ -19,7 +19,7 @@ using namespace mlir; namespace { -struct TestLivenessPass : public FunctionPass<TestLivenessPass> { +struct TestLivenessPass : public PassWrapper<TestLivenessPass, FunctionPass> { void runOnFunction() override { llvm::errs() << "Testing : " << getFunction().getName() << "\n"; getAnalysis<Liveness>().print(llvm::errs()); diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp index 4d63e412aab6..cf33ec307718 100644 --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -48,7 +48,7 @@ static llvm::cl::opt<bool> clTestLoopFusionTransformation( namespace { -struct TestLoopFusion : public FunctionPass<TestLoopFusion> { +struct TestLoopFusion : public PassWrapper<TestLoopFusion, FunctionPass> { void runOnFunction() override; }; diff --git a/mlir/test/lib/Transforms/TestLoopMapping.cpp b/mlir/test/lib/Transforms/TestLoopMapping.cpp index ee96d630ae0e..184234807130 100644 --- a/mlir/test/lib/Transforms/TestLoopMapping.cpp +++ b/mlir/test/lib/Transforms/TestLoopMapping.cpp @@ -22,7 +22,8 @@ using namespace mlir; namespace { -class TestLoopMappingPass : public FunctionPass<TestLoopMappingPass> { +class TestLoopMappingPass + : public PassWrapper<TestLoopMappingPass, FunctionPass> { public: explicit TestLoopMappingPass() {} diff --git a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp index f82ea3aed284..61d34bcfcd37 100644 --- a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp +++ b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp @@ -23,7 +23,7 @@ namespace { // Extracts fixed-range loops for top-level loop nests with ranges defined in // the pass constructor. Assumes loops are permutable. class SimpleParametricLoopTilingPass - : public FunctionPass<SimpleParametricLoopTilingPass> { + : public PassWrapper<SimpleParametricLoopTilingPass, FunctionPass> { public: SimpleParametricLoopTilingPass() = default; SimpleParametricLoopTilingPass(const SimpleParametricLoopTilingPass &) {} diff --git a/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp b/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp index ef566de1391e..51339be67e68 100644 --- a/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp +++ b/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp @@ -28,7 +28,8 @@ using namespace mlir; namespace { /// Checks for out of bound memef access subscripts.. -struct TestMemRefBoundCheck : public FunctionPass<TestMemRefBoundCheck> { +struct TestMemRefBoundCheck + : public PassWrapper<TestMemRefBoundCheck, FunctionPass> { void runOnFunction() override; }; diff --git a/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp b/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp index 2803c1d9dccc..34db53b6ce1e 100644 --- a/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp +++ b/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp @@ -28,7 +28,7 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. struct TestMemRefDependenceCheck - : public FunctionPass<TestMemRefDependenceCheck> { + : public PassWrapper<TestMemRefDependenceCheck, FunctionPass> { SmallVector<Operation *, 4> loadsAndStores; void runOnFunction() override; }; diff --git a/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp index 9e60b8da99dc..f2715bfb8e93 100644 --- a/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp @@ -14,14 +14,13 @@ using namespace mlir; namespace { -/// Simple constant folding pass. struct TestMemRefStrideCalculation - : public FunctionPass<struct TestMemRefStrideCalculation> { + : public PassWrapper<TestMemRefStrideCalculation, FunctionPass> { void runOnFunction() override; }; } // end anonymous namespace -// Traverse AllocOp and compute strides of each MemRefType independently. +/// Traverse AllocOp and compute strides of each MemRefType independently. void TestMemRefStrideCalculation::runOnFunction() { llvm::outs() << "Testing: " << getFunction().getName() << "\n"; getFunction().walk([&](AllocOp allocOp) { diff --git a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp index 47152c459805..55b181cc9697 100644 --- a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp +++ b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp @@ -17,7 +17,8 @@ namespace { /// It also takes all operations that are not function operations or /// terminators and clones them with opaque locations which store the initial /// locations. -struct TestOpaqueLoc : public OperationPass<TestOpaqueLoc, ModuleOp> { +struct TestOpaqueLoc + : public PassWrapper<TestOpaqueLoc, OperationPass<ModuleOp>> { /// A simple structure which is used for testing as an underlying location in /// OpaqueLoc. diff --git a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp index 88d08ce63119..b1c02bdd0adf 100644 --- a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp +++ b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp @@ -18,7 +18,7 @@ using namespace mlir; namespace { struct TestVectorToLoopsPass - : public FunctionPass<TestVectorToLoopsPass> { + : public PassWrapper<TestVectorToLoopsPass, FunctionPass> { void runOnFunction() override { OwningRewritePatternList patterns; auto *context = &getContext(); diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 8f2f64e5f60a..808fcd21d331 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -21,7 +21,7 @@ namespace { #include "TestVectorTransformPatterns.h.inc" struct TestVectorToVectorConversion - : public FunctionPass<TestVectorToVectorConversion> { + : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { void runOnFunction() override { OwningRewritePatternList patterns; auto *context = &getContext(); @@ -33,7 +33,7 @@ struct TestVectorToVectorConversion }; struct TestVectorSlicesConversion - : public FunctionPass<TestVectorSlicesConversion> { + : public PassWrapper<TestVectorSlicesConversion, FunctionPass> { void runOnFunction() override { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); @@ -42,7 +42,7 @@ struct TestVectorSlicesConversion }; struct TestVectorContractionConversion - : public FunctionPass<TestVectorContractionConversion> { + : public PassWrapper<TestVectorContractionConversion, FunctionPass> { TestVectorContractionConversion() = default; TestVectorContractionConversion(const TestVectorContractionConversion &pass) { } |