diff options
author | Thomas Raoux <thomasraoux@google.com> | 2022-07-12 21:38:10 +0000 |
---|---|---|
committer | Thomas Raoux <thomasraoux@google.com> | 2022-07-12 22:03:39 +0000 |
commit | 8fe076ffe09028bef761b0d0ebdd5842c595ca87 (patch) | |
tree | 0d711dc2a3fb069ef9c2cb96527c12179cc1a23f /mlir/lib/Conversion/VectorToLLVM | |
parent | 7f3000fa8b321f7fae169a615734de74a737b5d4 (diff) |
[mlir][VectorToLLVM] Fix bug in lowering of vector.reduce fmax/fmin
The lowering of fmax/fmin reduce was ignoring the optional accumulator.
Differential Revision: https://reviews.llvm.org/D129597
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 43 |
1 files changed, 36 insertions, 7 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index fa4920486aad..bf9967f0001d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/MathExtras.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" @@ -393,6 +394,27 @@ static Value createIntegerReductionComparisonOpLowering( return result; } +/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum +/// with vector types. +static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, + Value rhs, bool isMin) { + auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); + Type i1Type = builder.getI1Type(); + if (auto vecType = lhs.getType().dyn_cast<VectorType>()) + i1Type = VectorType::get(vecType.getShape(), i1Type); + Value cmp = builder.create<LLVM::FCmpOp>( + loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, + lhs, rhs); + Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); + Value isNan = builder.create<LLVM::FCmpOp>( + loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); + Value nan = builder.create<LLVM::ConstantOp>( + loc, lhs.getType(), + builder.getFloatAttr(floatType, + APFloat::getQNaN(floatType.getFloatSemantics()))); + return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); +} + /// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertOpToLLVMPattern<vector::ReductionOp> { @@ -497,18 +519,25 @@ public: rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( reductionOp, llvmType, acc, operand, rewriter.getBoolAttr(reassociateFPReductions)); - } else if (kind == vector::CombiningKind::MINF) + } else if (kind == vector::CombiningKind::MINF) { // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle // NaNs/-0.0/+0.0 in the same way. - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::MAXF) + Value result = + rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand); + if (acc) + result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true); + rewriter.replaceOp(reductionOp, result); + } else if (kind == vector::CombiningKind::MAXF) { // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle // NaNs/-0.0/+0.0 in the same way. - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp, - llvmType, operand); - else + Value result = + rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand); + if (acc) + result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false); + rewriter.replaceOp(reductionOp, result); + } else return failure(); + return success(); } |