aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorThomas Raoux <thomasraoux@google.com>2022-07-12 21:38:10 +0000
committerThomas Raoux <thomasraoux@google.com>2022-07-12 22:03:39 +0000
commit8fe076ffe09028bef761b0d0ebdd5842c595ca87 (patch)
tree0d711dc2a3fb069ef9c2cb96527c12179cc1a23f /mlir/lib/Conversion/VectorToLLVM
parent7f3000fa8b321f7fae169a615734de74a737b5d4 (diff)
[mlir][VectorToLLVM] Fix bug in lowering of vector.reduce fmax/fmin
The lowering of fmax/fmin reduce was ignoring the optional accumulator. Differential Revision: https://reviews.llvm.org/D129597
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp43
1 files changed, 36 insertions, 7 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fa4920486aad..bf9967f0001d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -393,6 +394,27 @@ static Value createIntegerReductionComparisonOpLowering(
return result;
}
+/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
+/// with vector types.
+static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
+ Value rhs, bool isMin) {
+ auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
+ Type i1Type = builder.getI1Type();
+ if (auto vecType = lhs.getType().dyn_cast<VectorType>())
+ i1Type = VectorType::get(vecType.getShape(), i1Type);
+ Value cmp = builder.create<LLVM::FCmpOp>(
+ loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
+ lhs, rhs);
+ Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
+ Value isNan = builder.create<LLVM::FCmpOp>(
+ loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
+ Value nan = builder.create<LLVM::ConstantOp>(
+ loc, lhs.getType(),
+ builder.getFloatAttr(floatType,
+ APFloat::getQNaN(floatType.getFloatSemantics())));
+ return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
+}
+
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
@@ -497,18 +519,25 @@ public:
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
- } else if (kind == vector::CombiningKind::MINF)
+ } else if (kind == vector::CombiningKind::MINF) {
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
// NaNs/-0.0/+0.0 in the same way.
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::MAXF)
+ Value result =
+ rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
+ if (acc)
+ result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
+ rewriter.replaceOp(reductionOp, result);
+ } else if (kind == vector::CombiningKind::MAXF) {
// FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
// NaNs/-0.0/+0.0 in the same way.
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
- llvmType, operand);
- else
+ Value result =
+ rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
+ if (acc)
+ result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
+ rewriter.replaceOp(reductionOp, result);
+ } else
return failure();
+
return success();
}