diff options
author | Mahesh Ravishankar <ravishankarm@google.com> | 2022-06-23 21:06:45 +0000 |
---|---|---|
committer | Mahesh Ravishankar <ravishankarm@google.com> | 2022-06-28 05:26:39 +0000 |
commit | fa596c6921159af50e69cc3be189d951521a9eb9 (patch) | |
tree | b3b24b76f7c405478ac6cb097aa4dd9844282317 /mlir/lib/Conversion/VectorToLLVM | |
parent | b941857b40edd7f3f3a9ec2ec85a26db24739774 (diff) |
[mlir][Vector] Fix reordering of floating point adds during lower of `vector.contract`.
Adding the accumulator value after the `vector.contract` changes the
precision of the operation. This makes sure the accumulator is carried
through to `vector.reduce` (and down to LLVM).
Differential Revision: https://reviews.llvm.org/D128674
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 119 |
1 files changed, 90 insertions, 29 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a164c7d167dc..fa4920486aad 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -362,6 +362,37 @@ public: } }; +/// Helper method to lower a `vector.reduction` op that performs an arithmetic +/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use +/// and `ScalarOp` is the scalar operation used to add the accumulation value if +/// non-null. +template <class VectorOp, class ScalarOp> +static Value createIntegerReductionArithmeticOpLowering( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator) { + Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); + if (accumulator) + result = rewriter.create<ScalarOp>(loc, accumulator, result); + return result; +} + +/// Helper method to lower a `vector.reduction` operation that performs +/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector +/// intrinsic to use and `predicate` is the predicate to use to compare+combine +/// the accumulator value if non-null. +template <class VectorOp> +static Value createIntegerReductionComparisonOpLowering( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { + Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); + if (accumulator) { + Value cmp = + rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); + result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); + } + return result; +} + /// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertOpToLLVMPattern<vector::ReductionOp> { @@ -377,38 +408,68 @@ public: auto kind = reductionOp.getKind(); Type eltType = reductionOp.getDest().getType(); Type llvmType = typeConverter->convertType(eltType); - Value operand = adaptor.getOperands()[0]; + Value operand = adaptor.getVector(); + Value acc = adaptor.getAcc(); + Location loc = reductionOp.getLoc(); if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. - if (kind == vector::CombiningKind::ADD) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::MUL) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::MINUI) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( - reductionOp, llvmType, operand); - else if (kind == vector::CombiningKind::MINSI) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( - reductionOp, llvmType, operand); - else if (kind == vector::CombiningKind::MAXUI) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( - reductionOp, llvmType, operand); - else if (kind == vector::CombiningKind::MAXSI) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( - reductionOp, llvmType, operand); - else if (kind == vector::CombiningKind::AND) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::OR) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::XOR) - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp, - llvmType, operand); - else + Value result; + switch (kind) { + case vector::CombiningKind::ADD: + result = + createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, + LLVM::AddOp>( + rewriter, loc, llvmType, operand, acc); + break; + case vector::CombiningKind::MUL: + result = + createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, + LLVM::MulOp>( + rewriter, loc, llvmType, operand, acc); + break; + case vector::CombiningKind::MINUI: + result = createIntegerReductionComparisonOpLowering< + LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, + LLVM::ICmpPredicate::ule); + break; + case vector::CombiningKind::MINSI: + result = createIntegerReductionComparisonOpLowering< + LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, + LLVM::ICmpPredicate::sle); + break; + case vector::CombiningKind::MAXUI: + result = createIntegerReductionComparisonOpLowering< + LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, + LLVM::ICmpPredicate::uge); + break; + case vector::CombiningKind::MAXSI: + result = createIntegerReductionComparisonOpLowering< + LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, + LLVM::ICmpPredicate::sge); + break; + case vector::CombiningKind::AND: + result = + createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, + LLVM::AndOp>( + rewriter, loc, llvmType, operand, acc); + break; + case vector::CombiningKind::OR: + result = + createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, + LLVM::OrOp>( + rewriter, loc, llvmType, operand, acc); + break; + case vector::CombiningKind::XOR: + result = + createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, + LLVM::XOrOp>( + rewriter, loc, llvmType, operand, acc); + break; + default: return failure(); + } + rewriter.replaceOp(reductionOp, result); + return success(); } |