aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorMahesh Ravishankar <ravishankarm@google.com>2022-06-23 21:06:45 +0000
committerMahesh Ravishankar <ravishankarm@google.com>2022-06-28 05:26:39 +0000
commitfa596c6921159af50e69cc3be189d951521a9eb9 (patch)
treeb3b24b76f7c405478ac6cb097aa4dd9844282317 /mlir/lib/Conversion/VectorToLLVM
parentb941857b40edd7f3f3a9ec2ec85a26db24739774 (diff)
[mlir][Vector] Fix reordering of floating point adds during lower of `vector.contract`.
Adding the accumulator value after the `vector.contract` changes the precision of the operation. This makes sure the accumulator is carried through to `vector.reduce` (and down to LLVM). Differential Revision: https://reviews.llvm.org/D128674
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp119
1 files changed, 90 insertions, 29 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a164c7d167dc..fa4920486aad 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -362,6 +362,37 @@ public:
}
};
+/// Helper method to lower a `vector.reduction` op that performs an arithmetic
+/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
+/// and `ScalarOp` is the scalar operation used to add the accumulation value if
+/// non-null.
+template <class VectorOp, class ScalarOp>
+static Value createIntegerReductionArithmeticOpLowering(
+ ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+ Value vectorOperand, Value accumulator) {
+ Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+ if (accumulator)
+ result = rewriter.create<ScalarOp>(loc, accumulator, result);
+ return result;
+}
+
+/// Helper method to lower a `vector.reduction` operation that performs
+/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
+/// intrinsic to use and `predicate` is the predicate to use to compare+combine
+/// the accumulator value if non-null.
+template <class VectorOp>
+static Value createIntegerReductionComparisonOpLowering(
+ ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+ Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
+ Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+ if (accumulator) {
+ Value cmp =
+ rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
+ result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
+ }
+ return result;
+}
+
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
@@ -377,38 +408,68 @@ public:
auto kind = reductionOp.getKind();
Type eltType = reductionOp.getDest().getType();
Type llvmType = typeConverter->convertType(eltType);
- Value operand = adaptor.getOperands()[0];
+ Value operand = adaptor.getVector();
+ Value acc = adaptor.getAcc();
+ Location loc = reductionOp.getLoc();
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
- if (kind == vector::CombiningKind::ADD)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::MUL)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::MINUI)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
- reductionOp, llvmType, operand);
- else if (kind == vector::CombiningKind::MINSI)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
- reductionOp, llvmType, operand);
- else if (kind == vector::CombiningKind::MAXUI)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
- reductionOp, llvmType, operand);
- else if (kind == vector::CombiningKind::MAXSI)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
- reductionOp, llvmType, operand);
- else if (kind == vector::CombiningKind::AND)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::OR)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::XOR)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
- llvmType, operand);
- else
+ Value result;
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
+ LLVM::AddOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ case vector::CombiningKind::MUL:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
+ LLVM::MulOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ case vector::CombiningKind::MINUI:
+ result = createIntegerReductionComparisonOpLowering<
+ LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
+ LLVM::ICmpPredicate::ule);
+ break;
+ case vector::CombiningKind::MINSI:
+ result = createIntegerReductionComparisonOpLowering<
+ LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
+ LLVM::ICmpPredicate::sle);
+ break;
+ case vector::CombiningKind::MAXUI:
+ result = createIntegerReductionComparisonOpLowering<
+ LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
+ LLVM::ICmpPredicate::uge);
+ break;
+ case vector::CombiningKind::MAXSI:
+ result = createIntegerReductionComparisonOpLowering<
+ LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
+ LLVM::ICmpPredicate::sge);
+ break;
+ case vector::CombiningKind::AND:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
+ LLVM::AndOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ case vector::CombiningKind::OR:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
+ LLVM::OrOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ case vector::CombiningKind::XOR:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
+ LLVM::XOrOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ default:
return failure();
+ }
+ rewriter.replaceOp(reductionOp, result);
+
return success();
}