diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2022-08-05 13:08:08 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2022-08-05 13:50:31 -0700 |
commit | 51bc82d147f8205dc516a50c66a3938249116f41 (patch) | |
tree | 6498ead6a6ffa5a967455b82d5eb2c7032d73467 | |
parent | 5c16eeb7ee13ab0b5eb52571998b9494475db301 (diff) |
[mlir] Implement SymbolUserOpInterface in LLVM::CallOp
Avoid expensive calls to `SymbolTable::lookupNearestSymbolFrom` in verifier
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D131285
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 5 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 4 | ||||
-rw-r--r-- | mlir/test/Dialect/LLVMIR/invalid.mlir | 13 |
3 files changed, 18 insertions, 4 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index e65827195340..7dcd48f0c5e8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -647,7 +647,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> { def LLVM_CallOp : LLVM_Op<"call", [DeclareOpInterfaceMethods<FastmathFlagsInterface>, - DeclareOpInterfaceMethods<CallOpInterface>]> { + DeclareOpInterfaceMethods<CallOpInterface>, + DeclareOpInterfaceMethods<SymbolUserOpInterface>]> { let summary = "Call to an LLVM function."; let description = [{ @@ -701,8 +702,8 @@ def LLVM_CallOp : LLVM_Op<"call", StringAttr::get($_builder.getContext(), callee), operands); }]>]; let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; } + def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); let results = (outs LLVM_Type:$res); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3d9ec1798e50..4cb6a5658c51 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1160,7 +1160,7 @@ Operation::operand_range CallOp::getArgOperands() { return getOperands().drop_front(getCallee().has_value() ? 0 : 1); } -LogicalResult CallOp::verify() { +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { if (getNumResults() > 1) return emitOpError("must have 0 or 1 result"); @@ -1184,7 +1184,7 @@ LogicalResult CallOp::verify() { fnType = ptrType.getElementType(); } else { Operation *callee = - SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); + symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr()); if (!callee) return emitOpError() << "'" << calleeName.getValue() diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 459d6b188753..4ac9724b0838 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -191,6 +191,7 @@ func.func @store_malformed_elem_type(%foo: !llvm.ptr, %bar: f32) { func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) { // expected-error@+1 {{expected function type}} llvm.call %callee(%arg) : !llvm.func<i8 (i8)> + llvm.return } // ----- @@ -198,6 +199,7 @@ func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) { func.func @invalid_call() { // expected-error@+1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}} "llvm.call"() : () -> () + llvm.return } // ----- @@ -205,6 +207,7 @@ func.func @invalid_call() { func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) { // expected-error@+1 {{expected function type}} llvm.call %callee(%arg) : !llvm.func<i8 (i8)> + llvm.return } // ----- @@ -212,6 +215,7 @@ func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) { func.func @call_unknown_symbol() { // expected-error@+1 {{'llvm.call' op 'missing_callee' does not reference a symbol in the current scope}} llvm.call @missing_callee() : () -> () + llvm.return } // ----- @@ -221,6 +225,7 @@ func.func private @standard_func_callee() func.func @call_non_llvm() { // expected-error@+1 {{'llvm.call' op 'standard_func_callee' does not reference a valid LLVM function}} llvm.call @standard_func_callee() : () -> () + llvm.return } // ----- @@ -228,6 +233,7 @@ func.func @call_non_llvm() { func.func @call_non_llvm_indirect(%arg0 : tensor<*xi32>) { // expected-error@+1 {{'llvm.call' op operand #0 must be LLVM dialect-compatible type}} "llvm.call"(%arg0) : (tensor<*xi32>) -> () + llvm.return } // ----- @@ -237,6 +243,7 @@ llvm.func @callee_func(i8) -> () func.func @callee_arg_mismatch(%arg0 : i32) { // expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}} llvm.call @callee_func(%arg0) : (i32) -> () + llvm.return } // ----- @@ -244,6 +251,7 @@ func.func @callee_arg_mismatch(%arg0 : i32) { func.func @indirect_callee_arg_mismatch(%arg0 : i32, %callee : !llvm.ptr<func<void(i8)>>) { // expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}} "llvm.call"(%callee, %arg0) : (!llvm.ptr<func<void(i8)>>, i32) -> () + llvm.return } // ----- @@ -253,6 +261,7 @@ llvm.func @callee_func() -> (i8) func.func @callee_return_mismatch() { // expected-error@+1 {{'llvm.call' op result type mismatch: 'i32' != 'i8'}} %res = llvm.call @callee_func() : () -> (i32) + llvm.return } // ----- @@ -260,6 +269,7 @@ func.func @callee_return_mismatch() { func.func @indirect_callee_return_mismatch(%callee : !llvm.ptr<func<i8()>>) { // expected-error@+1 {{'llvm.call' op result type mismatch: 'i32' != 'i8'}} "llvm.call"(%callee) : (!llvm.ptr<func<i8()>>) -> (i32) + llvm.return } // ----- @@ -267,6 +277,7 @@ func.func @indirect_callee_return_mismatch(%callee : !llvm.ptr<func<i8()>>) { func.func @call_too_many_results(%callee : () -> (i32,i32)) { // expected-error@+1 {{expected function with 0 or 1 result}} llvm.call %callee() : () -> (i32, i32) + llvm.return } // ----- @@ -274,6 +285,7 @@ func.func @call_too_many_results(%callee : () -> (i32,i32)) { func.func @call_non_llvm_result(%callee : () -> (tensor<*xi32>)) { // expected-error@+1 {{expected result to have LLVM type}} llvm.call %callee() : () -> (tensor<*xi32>) + llvm.return } // ----- @@ -281,6 +293,7 @@ func.func @call_non_llvm_result(%callee : () -> (tensor<*xi32>)) { func.func @call_non_llvm_input(%callee : (tensor<*xi32>) -> (), %arg : tensor<*xi32>) { // expected-error@+1 {{expected LLVM types as inputs}} llvm.call %callee(%arg) : (tensor<*xi32>) -> () + llvm.return } // ----- |