diff --git a/src/cmd/compile/internal/arm64/ssa.go b/src/cmd/compile/internal/arm64/ssa.go index 22b28a9308..43588511ab 100644 --- a/src/cmd/compile/internal/arm64/ssa.go +++ b/src/cmd/compile/internal/arm64/ssa.go @@ -1054,7 +1054,11 @@ func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) { ssa.OpARM64LessThanF, ssa.OpARM64LessEqualF, ssa.OpARM64GreaterThanF, - ssa.OpARM64GreaterEqualF: + ssa.OpARM64GreaterEqualF, + ssa.OpARM64NotLessThanF, + ssa.OpARM64NotLessEqualF, + ssa.OpARM64NotGreaterThanF, + ssa.OpARM64NotGreaterEqualF: // generate boolean values using CSET p := s.Prog(arm64.ACSET) p.From.Type = obj.TYPE_REG // assembler encodes conditional bits in Reg @@ -1098,10 +1102,16 @@ var condBits = map[ssa.Op]int16{ ssa.OpARM64GreaterThanU: arm64.COND_HI, ssa.OpARM64GreaterEqual: arm64.COND_GE, ssa.OpARM64GreaterEqualU: arm64.COND_HS, - ssa.OpARM64LessThanF: arm64.COND_MI, - ssa.OpARM64LessEqualF: arm64.COND_LS, - ssa.OpARM64GreaterThanF: arm64.COND_GT, - ssa.OpARM64GreaterEqualF: arm64.COND_GE, + ssa.OpARM64LessThanF: arm64.COND_MI, // Less than + ssa.OpARM64LessEqualF: arm64.COND_LS, // Less than or equal to + ssa.OpARM64GreaterThanF: arm64.COND_GT, // Greater than + ssa.OpARM64GreaterEqualF: arm64.COND_GE, // Greater than or equal to + + // The following condition codes have unordered to handle comparisons related to NaN. + ssa.OpARM64NotLessThanF: arm64.COND_PL, // Greater than, equal to, or unordered + ssa.OpARM64NotLessEqualF: arm64.COND_HI, // Greater than or unordered + ssa.OpARM64NotGreaterThanF: arm64.COND_LE, // Less than, equal to or unordered + ssa.OpARM64NotGreaterEqualF: arm64.COND_LT, // Less than or unordered } var blockJump = map[ssa.BlockKind]struct { diff --git a/src/cmd/compile/internal/ssa/gen/ARM64Ops.go b/src/cmd/compile/internal/ssa/gen/ARM64Ops.go index 87db2b7c9d..b0bc9c78ff 100644 --- a/src/cmd/compile/internal/ssa/gen/ARM64Ops.go +++ b/src/cmd/compile/internal/ssa/gen/ARM64Ops.go @@ -478,20 +478,24 @@ func init() { // pseudo-ops {name: "LoweredNilCheck", argLength: 2, reg: regInfo{inputs: []regMask{gpg}}, nilCheck: true, faultOnNilArg0: true}, // panic if arg0 is nil. arg1=mem. - {name: "Equal", argLength: 1, reg: readflags}, // bool, true flags encode x==y false otherwise. - {name: "NotEqual", argLength: 1, reg: readflags}, // bool, true flags encode x!=y false otherwise. - {name: "LessThan", argLength: 1, reg: readflags}, // bool, true flags encode signed xy false otherwise. - {name: "GreaterEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x>=y false otherwise. - {name: "LessThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned xy false otherwise. - {name: "GreaterEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>=y false otherwise. - {name: "LessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point xy false otherwise. - {name: "GreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y false otherwise. + {name: "Equal", argLength: 1, reg: readflags}, // bool, true flags encode x==y false otherwise. + {name: "NotEqual", argLength: 1, reg: readflags}, // bool, true flags encode x!=y false otherwise. + {name: "LessThan", argLength: 1, reg: readflags}, // bool, true flags encode signed xy false otherwise. + {name: "GreaterEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x>=y false otherwise. + {name: "LessThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned xy false otherwise. + {name: "GreaterEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>=y false otherwise. + {name: "LessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point xy false otherwise. + {name: "GreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y false otherwise. + {name: "NotLessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y || x is unordered with y, false otherwise. + {name: "NotLessEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>y || x is unordered with y, false otherwise. + {name: "NotGreaterThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<=y || x is unordered with y, false otherwise. + {name: "NotGreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x NotEqual or LessThan -> GreaterEqual +// for example !Equal -> NotEqual or !LessThan -> GreaterEqual // -// TODO: add floating-point conditions +// For floating point, it's more subtle because NaN is unordered. We do +// !LessThanF -> NotLessThanF, the latter takes care of NaNs. func arm64Negate(op Op) Op { switch op { case OpARM64LessThan: @@ -1000,13 +1001,21 @@ func arm64Negate(op Op) Op { case OpARM64NotEqual: return OpARM64Equal case OpARM64LessThanF: - return OpARM64GreaterEqualF - case OpARM64GreaterThanF: - return OpARM64LessEqualF + return OpARM64NotLessThanF + case OpARM64NotLessThanF: + return OpARM64LessThanF case OpARM64LessEqualF: + return OpARM64NotLessEqualF + case OpARM64NotLessEqualF: + return OpARM64LessEqualF + case OpARM64GreaterThanF: + return OpARM64NotGreaterThanF + case OpARM64NotGreaterThanF: return OpARM64GreaterThanF case OpARM64GreaterEqualF: - return OpARM64LessThanF + return OpARM64NotGreaterEqualF + case OpARM64NotGreaterEqualF: + return OpARM64GreaterEqualF default: panic("unreachable") } @@ -1017,8 +1026,6 @@ func arm64Negate(op Op) Op { // that the same result would be produced if the arguments // to the flag-generating instruction were reversed, e.g. // (InvertFlags (CMP x y)) -> (CMP y x) -// -// TODO: add floating-point conditions func arm64Invert(op Op) Op { switch op { case OpARM64LessThan: @@ -1047,6 +1054,14 @@ func arm64Invert(op Op) Op { return OpARM64GreaterEqualF case OpARM64GreaterEqualF: return OpARM64LessEqualF + case OpARM64NotLessThanF: + return OpARM64NotGreaterThanF + case OpARM64NotGreaterThanF: + return OpARM64NotLessThanF + case OpARM64NotLessEqualF: + return OpARM64NotGreaterEqualF + case OpARM64NotGreaterEqualF: + return OpARM64NotLessEqualF default: panic("unreachable") } diff --git a/test/fixedbugs/issue43619.go b/test/fixedbugs/issue43619.go new file mode 100644 index 0000000000..3e667851a4 --- /dev/null +++ b/test/fixedbugs/issue43619.go @@ -0,0 +1,119 @@ +// run + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "math" +) + +//go:noinline +func fcmplt(a, b float64, x uint64) uint64 { + if a < b { + x = 0 + } + return x +} + +//go:noinline +func fcmple(a, b float64, x uint64) uint64 { + if a <= b { + x = 0 + } + return x +} + +//go:noinline +func fcmpgt(a, b float64, x uint64) uint64 { + if a > b { + x = 0 + } + return x +} + +//go:noinline +func fcmpge(a, b float64, x uint64) uint64 { + if a >= b { + x = 0 + } + return x +} + +//go:noinline +func fcmpeq(a, b float64, x uint64) uint64 { + if a == b { + x = 0 + } + return x +} + +//go:noinline +func fcmpne(a, b float64, x uint64) uint64 { + if a != b { + x = 0 + } + return x +} + +func main() { + type fn func(a, b float64, x uint64) uint64 + + type testCase struct { + f fn + a, b float64 + x, want uint64 + } + NaN := math.NaN() + for _, t := range []testCase{ + {fcmplt, 1.0, 1.0, 123, 123}, + {fcmple, 1.0, 1.0, 123, 0}, + {fcmpgt, 1.0, 1.0, 123, 123}, + {fcmpge, 1.0, 1.0, 123, 0}, + {fcmpeq, 1.0, 1.0, 123, 0}, + {fcmpne, 1.0, 1.0, 123, 123}, + + {fcmplt, 1.0, 2.0, 123, 0}, + {fcmple, 1.0, 2.0, 123, 0}, + {fcmpgt, 1.0, 2.0, 123, 123}, + {fcmpge, 1.0, 2.0, 123, 123}, + {fcmpeq, 1.0, 2.0, 123, 123}, + {fcmpne, 1.0, 2.0, 123, 0}, + + {fcmplt, 2.0, 1.0, 123, 123}, + {fcmple, 2.0, 1.0, 123, 123}, + {fcmpgt, 2.0, 1.0, 123, 0}, + {fcmpge, 2.0, 1.0, 123, 0}, + {fcmpeq, 2.0, 1.0, 123, 123}, + {fcmpne, 2.0, 1.0, 123, 0}, + + {fcmplt, 1.0, NaN, 123, 123}, + {fcmple, 1.0, NaN, 123, 123}, + {fcmpgt, 1.0, NaN, 123, 123}, + {fcmpge, 1.0, NaN, 123, 123}, + {fcmpeq, 1.0, NaN, 123, 123}, + {fcmpne, 1.0, NaN, 123, 0}, + + {fcmplt, NaN, 1.0, 123, 123}, + {fcmple, NaN, 1.0, 123, 123}, + {fcmpgt, NaN, 1.0, 123, 123}, + {fcmpge, NaN, 1.0, 123, 123}, + {fcmpeq, NaN, 1.0, 123, 123}, + {fcmpne, NaN, 1.0, 123, 0}, + + {fcmplt, NaN, NaN, 123, 123}, + {fcmple, NaN, NaN, 123, 123}, + {fcmpgt, NaN, NaN, 123, 123}, + {fcmpge, NaN, NaN, 123, 123}, + {fcmpeq, NaN, NaN, 123, 123}, + {fcmpne, NaN, NaN, 123, 0}, + } { + got := t.f(t.a, t.b, t.x) + if got != t.want { + panic(fmt.Sprintf("want %v, got %v", t.want, got)) + } + } +}