summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Green <david.green@arm.com>2018-10-02 09:48:34 +0000
committerDavid Green <david.green@arm.com>2018-10-02 09:48:34 +0000
commit6cb442849ec031d9495b631ea13df410d83f7c6e (patch)
treec19bfd8a9a3240f441f86c3118e01d98ebd6b19a
parent8be7cc320c308901bbaac46920187328397b58f7 (diff)
[InstCombine] Fold ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A
This is an attempt to get out of a local-minimum that instcombine currently gets stuck in. We essentially combine two optimisations at once, ~a - ~b = b-a and min(~a, ~b) = ~max(a, b), only doing the transform if the result is at least neutral. This involves using IsFreeToInvert, which has been expanded a little to include selects that can be easily inverted. This is trying to fix PR35875, using the ideas from Sanjay. It is a large improvement to one of our rgb to cmy kernels. Differential Revision: https://reviews.llvm.org/D52177
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp33
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h9
-rw-r--r--llvm/test/Transforms/InstCombine/sub-minmax.ll58
3 files changed, 68 insertions, 32 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 910ec835c5a..a2567d1039c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1662,6 +1662,39 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
}
}
+ {
+ // ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A
+ // ~A - Min/Max(O, ~A) -> Max/Min(A, ~O) - A
+ // Min/Max(~A, O) - ~A -> A - Max/Min(A, ~O)
+ // Min/Max(O, ~A) - ~A -> A - Max/Min(A, ~O)
+ // So long as O here is freely invertible, this will be neutral or a win.
+ Value *LHS, *RHS, *A;
+ Value *NotA = Op0, *MinMax = Op1;
+ SelectPatternFlavor SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor;
+ if (!SelectPatternResult::isMinOrMax(SPF)) {
+ NotA = Op1;
+ MinMax = Op0;
+ SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor;
+ }
+ if (SelectPatternResult::isMinOrMax(SPF) &&
+ match(NotA, m_Not(m_Value(A))) && (NotA == LHS || NotA == RHS)) {
+ if (NotA == LHS)
+ std::swap(LHS, RHS);
+ // LHS is now O above and expected to have at least 2 uses (the min/max)
+ // NotA is epected to have 2 uses from the min/max and 1 from the sub.
+ if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) &&
+ !NotA->hasNUsesOrMore(4)) {
+ // Note: We don't generate the inverse max/min, just create the not of
+ // it and let other folds do the rest.
+ Value *Not = Builder.CreateNot(MinMax);
+ if (NotA == Op0)
+ return BinaryOperator::CreateSub(Not, A);
+ else
+ return BinaryOperator::CreateSub(A, Not);
+ }
+ }
+ }
+
// Optimize pointer differences into the same array into a size. Consider:
// &A[10] - &A[0]: we should compile this to "10".
Value *LHSOp, *RHSOp;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 19114a6ab37..951fc22a913 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -20,7 +20,6 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetFolder.h"
-#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
@@ -33,6 +32,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Casting.h"
@@ -41,11 +41,14 @@
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
#include <cstdint>
#define DEBUG_TYPE "instcombine"
+using namespace llvm::PatternMatch;
+
namespace llvm {
class APInt;
@@ -175,6 +178,10 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) {
if (isa<Constant>(BO->getOperand(0)) || isa<Constant>(BO->getOperand(1)))
return WillInvertAllUses;
+ // Selects with invertible operands are freely invertible
+ if (match(V, m_Select(m_Value(), m_Not(m_Value()), m_Not(m_Value()))))
+ return WillInvertAllUses;
+
return false;
}
diff --git a/llvm/test/Transforms/InstCombine/sub-minmax.ll b/llvm/test/Transforms/InstCombine/sub-minmax.ll
index 43a4cf18358..ccc3483ce95 100644
--- a/llvm/test/Transforms/InstCombine/sub-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/sub-minmax.ll
@@ -125,10 +125,10 @@ define i32 @na_minus_max_bi_na(i32 %A, i32 %Bi) {
define i32 @max_na_bi_minux_na_use(i32 %A, i32 %Bi) {
; CHECK-LABEL: @max_na_bi_minux_na_use(
-; CHECK-NEXT: [[NOT:%.*]] = xor i32 [[A:%.*]], -1
-; CHECK-NEXT: [[L0:%.*]] = icmp ult i32 [[NOT]], 31
-; CHECK-NEXT: [[L1:%.*]] = select i1 [[L0]], i32 [[NOT]], i32 31
-; CHECK-NEXT: [[X:%.*]] = sub i32 [[L1]], [[NOT]]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[A:%.*]], -32
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i32 [[A]], i32 -32
+; CHECK-NEXT: [[L1:%.*]] = xor i32 [[TMP2]], -1
+; CHECK-NEXT: [[X:%.*]] = sub i32 [[A]], [[TMP2]]
; CHECK-NEXT: call void @use32(i32 [[L1]])
; CHECK-NEXT: ret i32 [[X]]
;
@@ -142,10 +142,10 @@ define i32 @max_na_bi_minux_na_use(i32 %A, i32 %Bi) {
define i32 @na_minus_max_na_bi_use(i32 %A, i32 %Bi) {
; CHECK-LABEL: @na_minus_max_na_bi_use(
-; CHECK-NEXT: [[NOT:%.*]] = xor i32 [[A:%.*]], -1
-; CHECK-NEXT: [[L0:%.*]] = icmp ult i32 [[NOT]], 31
-; CHECK-NEXT: [[L1:%.*]] = select i1 [[L0]], i32 [[NOT]], i32 31
-; CHECK-NEXT: [[X:%.*]] = sub i32 [[NOT]], [[L1]]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[A:%.*]], -32
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i32 [[A]], i32 -32
+; CHECK-NEXT: [[L1:%.*]] = xor i32 [[TMP2]], -1
+; CHECK-NEXT: [[X:%.*]] = sub i32 [[TMP2]], [[A]]
; CHECK-NEXT: call void @use32(i32 [[L1]])
; CHECK-NEXT: ret i32 [[X]]
;
@@ -276,12 +276,11 @@ define i32 @na_minus_max_bi_na_use2(i32 %A, i32 %Bi) {
define i8 @umin_not_sub(i8 %x, i8 %y) {
; CHECK-LABEL: @umin_not_sub(
-; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
-; CHECK-NEXT: [[NY:%.*]] = xor i8 [[Y:%.*]], -1
-; CHECK-NEXT: [[CMPXY:%.*]] = icmp ult i8 [[NX]], [[NY]]
-; CHECK-NEXT: [[MINXY:%.*]] = select i1 [[CMPXY]], i8 [[NX]], i8 [[NY]]
-; CHECK-NEXT: [[SUBX:%.*]] = sub i8 [[NX]], [[MINXY]]
-; CHECK-NEXT: [[SUBY:%.*]] = sub i8 [[NY]], [[MINXY]]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i8 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i8 [[X]], i8 [[Y]]
+; CHECK-NEXT: [[MINXY:%.*]] = xor i8 [[TMP2]], -1
+; CHECK-NEXT: [[SUBX:%.*]] = sub i8 [[TMP2]], [[X]]
+; CHECK-NEXT: [[SUBY:%.*]] = sub i8 [[TMP2]], [[Y]]
; CHECK-NEXT: call void @use8(i8 [[SUBX]])
; CHECK-NEXT: call void @use8(i8 [[SUBY]])
; CHECK-NEXT: ret i8 [[MINXY]]
@@ -299,12 +298,11 @@ define i8 @umin_not_sub(i8 %x, i8 %y) {
define i8 @umin_not_sub_rev(i8 %x, i8 %y) {
; CHECK-LABEL: @umin_not_sub_rev(
-; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
-; CHECK-NEXT: [[NY:%.*]] = xor i8 [[Y:%.*]], -1
-; CHECK-NEXT: [[CMPXY:%.*]] = icmp ult i8 [[NX]], [[NY]]
-; CHECK-NEXT: [[MINXY:%.*]] = select i1 [[CMPXY]], i8 [[NX]], i8 [[NY]]
-; CHECK-NEXT: [[SUBX:%.*]] = sub i8 [[MINXY]], [[NX]]
-; CHECK-NEXT: [[SUBY:%.*]] = sub i8 [[MINXY]], [[NY]]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i8 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i8 [[X]], i8 [[Y]]
+; CHECK-NEXT: [[MINXY:%.*]] = xor i8 [[TMP2]], -1
+; CHECK-NEXT: [[SUBX:%.*]] = sub i8 [[X]], [[TMP2]]
+; CHECK-NEXT: [[SUBY:%.*]] = sub i8 [[Y]], [[TMP2]]
; CHECK-NEXT: call void @use8(i8 [[SUBX]])
; CHECK-NEXT: call void @use8(i8 [[SUBY]])
; CHECK-NEXT: ret i8 [[MINXY]]
@@ -322,17 +320,15 @@ define i8 @umin_not_sub_rev(i8 %x, i8 %y) {
define void @umin3_not_all_ops_extra_uses_invert_subs(i8 %x, i8 %y, i8 %z) {
; CHECK-LABEL: @umin3_not_all_ops_extra_uses_invert_subs(
-; CHECK-NEXT: [[XN:%.*]] = xor i8 [[X:%.*]], -1
-; CHECK-NEXT: [[YN:%.*]] = xor i8 [[Y:%.*]], -1
-; CHECK-NEXT: [[ZN:%.*]] = xor i8 [[Z:%.*]], -1
-; CHECK-NEXT: [[CMPXZ:%.*]] = icmp ult i8 [[XN]], [[ZN]]
-; CHECK-NEXT: [[MINXZ:%.*]] = select i1 [[CMPXZ]], i8 [[XN]], i8 [[ZN]]
-; CHECK-NEXT: [[CMPXYZ:%.*]] = icmp ult i8 [[MINXZ]], [[YN]]
-; CHECK-NEXT: [[MINXYZ:%.*]] = select i1 [[CMPXYZ]], i8 [[MINXZ]], i8 [[YN]]
-; CHECK-NEXT: [[XMIN:%.*]] = sub i8 [[XN]], [[MINXYZ]]
-; CHECK-NEXT: [[YMIN:%.*]] = sub i8 [[YN]], [[MINXYZ]]
-; CHECK-NEXT: [[ZMIN:%.*]] = sub i8 [[ZN]], [[MINXYZ]]
-; CHECK-NEXT: call void @use8(i8 [[MINXYZ]])
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i8 [[X:%.*]], [[Z:%.*]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i8 [[X]], i8 [[Z]]
+; CHECK-NEXT: [[TMP3:%.*]] = icmp ugt i8 [[TMP2]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP3]], i8 [[TMP2]], i8 [[Y]]
+; CHECK-NEXT: [[TMP5:%.*]] = xor i8 [[TMP4]], -1
+; CHECK-NEXT: [[XMIN:%.*]] = sub i8 [[TMP4]], [[X]]
+; CHECK-NEXT: [[YMIN:%.*]] = sub i8 [[TMP4]], [[Y]]
+; CHECK-NEXT: [[ZMIN:%.*]] = sub i8 [[TMP4]], [[Z]]
+; CHECK-NEXT: call void @use8(i8 [[TMP5]])
; CHECK-NEXT: call void @use8(i8 [[XMIN]])
; CHECK-NEXT: call void @use8(i8 [[YMIN]])
; CHECK-NEXT: call void @use8(i8 [[ZMIN]])