diff --git a/src/cmd/compile/internal/arm64/ssa.go b/src/cmd/compile/internal/arm64/ssa.go index 736823b719..2a58738ffe 100644 --- a/src/cmd/compile/internal/arm64/ssa.go +++ b/src/cmd/compile/internal/arm64/ssa.go @@ -559,7 +559,11 @@ func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) { ssa.OpARM64RBIT, ssa.OpARM64RBITW, ssa.OpARM64CLZ, - ssa.OpARM64CLZW: + ssa.OpARM64CLZW, + ssa.OpARM64FRINTAD, + ssa.OpARM64FRINTMD, + ssa.OpARM64FRINTPD, + ssa.OpARM64FRINTZD: p := s.Prog(v.Op.Asm()) p.From.Type = obj.TYPE_REG p.From.Reg = v.Args[0].Reg() diff --git a/src/cmd/compile/internal/gc/asm_test.go b/src/cmd/compile/internal/gc/asm_test.go index 0bea56ea3c..0f41f3044d 100644 --- a/src/cmd/compile/internal/gc/asm_test.go +++ b/src/cmd/compile/internal/gc/asm_test.go @@ -248,7 +248,7 @@ var allAsmTests = []*asmTests{ { arch: "arm64", os: "linux", - imports: []string{"encoding/binary", "math/bits"}, + imports: []string{"encoding/binary", "math", "math/bits"}, tests: linuxARM64Tests, }, { @@ -2849,6 +2849,47 @@ var linuxARM64Tests = []*asmTest{ pos: []string{"\tMOVHU\t\\(R[0-9]+\\)"}, neg: []string{"ORR\tR[0-9]+<<8\t"}, }, + // Intrinsic tests for math. + { + fn: ` + func sqrt(x float64) float64 { + return math.Sqrt(x) + } + `, + pos: []string{"FSQRTD"}, + }, + { + fn: ` + func ceil(x float64) float64 { + return math.Ceil(x) + } + `, + pos: []string{"FRINTPD"}, + }, + { + fn: ` + func floor(x float64) float64 { + return math.Floor(x) + } + `, + pos: []string{"FRINTMD"}, + }, + { + fn: ` + func round(x float64) float64 { + return math.Round(x) + } + `, + pos: []string{"FRINTAD"}, + }, + { + fn: ` + func trunc(x float64) float64 { + return math.Trunc(x) + } + `, + pos: []string{"FRINTZD"}, + }, } var linuxMIPSTests = []*asmTest{ diff --git a/src/cmd/compile/internal/gc/ssa.go b/src/cmd/compile/internal/gc/ssa.go index e369e0c5d0..542acb6da2 100644 --- a/src/cmd/compile/internal/gc/ssa.go +++ b/src/cmd/compile/internal/gc/ssa.go @@ -2918,22 +2918,22 @@ func init() { func(s *state, n *Node, args []*ssa.Value) *ssa.Value { return s.newValue1(ssa.OpTrunc, types.Types[TFLOAT64], args[0]) }, - sys.PPC64, sys.S390X) + sys.ARM64, sys.PPC64, sys.S390X) addF("math", "Ceil", func(s *state, n *Node, args []*ssa.Value) *ssa.Value { return s.newValue1(ssa.OpCeil, types.Types[TFLOAT64], args[0]) }, - sys.PPC64, sys.S390X) + sys.ARM64, sys.PPC64, sys.S390X) addF("math", "Floor", func(s *state, n *Node, args []*ssa.Value) *ssa.Value { return s.newValue1(ssa.OpFloor, types.Types[TFLOAT64], args[0]) }, - sys.PPC64, sys.S390X) + sys.ARM64, sys.PPC64, sys.S390X) addF("math", "Round", func(s *state, n *Node, args []*ssa.Value) *ssa.Value { return s.newValue1(ssa.OpRound, types.Types[TFLOAT64], args[0]) }, - sys.S390X) + sys.ARM64, sys.S390X) addF("math", "RoundToEven", func(s *state, n *Node, args []*ssa.Value) *ssa.Value { return s.newValue1(ssa.OpRoundToEven, types.Types[TFLOAT64], args[0]) diff --git a/src/cmd/compile/internal/ssa/gen/ARM64.rules b/src/cmd/compile/internal/ssa/gen/ARM64.rules index 44c1c84288..646b983a64 100644 --- a/src/cmd/compile/internal/ssa/gen/ARM64.rules +++ b/src/cmd/compile/internal/ssa/gen/ARM64.rules @@ -81,7 +81,12 @@ (Com16 x) -> (MVN x) (Com8 x) -> (MVN x) +// math package intrinsics (Sqrt x) -> (FSQRTD x) +(Ceil x) -> (FRINTPD x) +(Floor x) -> (FRINTMD x) +(Round x) -> (FRINTAD x) +(Trunc x) -> (FRINTZD x) (Ctz64 x) -> (CLZ (RBIT x)) (Ctz32 x) -> (CLZW (RBITW x)) diff --git a/src/cmd/compile/internal/ssa/gen/ARM64Ops.go b/src/cmd/compile/internal/ssa/gen/ARM64Ops.go index f053109fb3..bed0fb3ccf 100644 --- a/src/cmd/compile/internal/ssa/gen/ARM64Ops.go +++ b/src/cmd/compile/internal/ssa/gen/ARM64Ops.go @@ -323,6 +323,12 @@ func init() { {name: "FCVTSD", argLength: 1, reg: fp11, asm: "FCVTSD"}, // float32 -> float64 {name: "FCVTDS", argLength: 1, reg: fp11, asm: "FCVTDS"}, // float64 -> float32 + // floating-point round to integral + {name: "FRINTAD", argLength: 1, reg: fp11, asm: "FRINTAD"}, + {name: "FRINTMD", argLength: 1, reg: fp11, asm: "FRINTMD"}, + {name: "FRINTPD", argLength: 1, reg: fp11, asm: "FRINTPD"}, + {name: "FRINTZD", argLength: 1, reg: fp11, asm: "FRINTZD"}, + // conditional instructions {name: "CSELULT", argLength: 3, reg: gp2flags1, asm: "CSEL"}, // returns arg0 if flags indicates unsigned LT, arg1 otherwise, arg2=flags {name: "CSELULT0", argLength: 2, reg: gp1flags1, asm: "CSEL"}, // returns arg0 if flags indicates unsigned LT, 0 otherwise, arg1=flags diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index ee95e77d5a..fad17c2acd 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -1093,6 +1093,10 @@ const ( OpARM64FCVTZUD OpARM64FCVTSD OpARM64FCVTDS + OpARM64FRINTAD + OpARM64FRINTMD + OpARM64FRINTPD + OpARM64FRINTZD OpARM64CSELULT OpARM64CSELULT0 OpARM64CALLstatic @@ -13971,6 +13975,58 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "FRINTAD", + argLen: 1, + asm: arm64.AFRINTAD, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FRINTMD", + argLen: 1, + asm: arm64.AFRINTMD, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FRINTPD", + argLen: 1, + asm: arm64.AFRINTPD, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FRINTZD", + argLen: 1, + asm: arm64.AFRINTZD, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, { name: "CSELULT", argLen: 3, diff --git a/src/cmd/compile/internal/ssa/rewriteARM64.go b/src/cmd/compile/internal/ssa/rewriteARM64.go index 546ce9067e..db30fd3ba5 100644 --- a/src/cmd/compile/internal/ssa/rewriteARM64.go +++ b/src/cmd/compile/internal/ssa/rewriteARM64.go @@ -289,6 +289,8 @@ func rewriteValueARM64(v *Value) bool { return rewriteValueARM64_OpBswap32_0(v) case OpBswap64: return rewriteValueARM64_OpBswap64_0(v) + case OpCeil: + return rewriteValueARM64_OpCeil_0(v) case OpClosureCall: return rewriteValueARM64_OpClosureCall_0(v) case OpCom16: @@ -393,6 +395,8 @@ func rewriteValueARM64(v *Value) bool { return rewriteValueARM64_OpEqB_0(v) case OpEqPtr: return rewriteValueARM64_OpEqPtr_0(v) + case OpFloor: + return rewriteValueARM64_OpFloor_0(v) case OpGeq16: return rewriteValueARM64_OpGeq16_0(v) case OpGeq16U: @@ -607,6 +611,8 @@ func rewriteValueARM64(v *Value) bool { return rewriteValueARM64_OpPopCount32_0(v) case OpPopCount64: return rewriteValueARM64_OpPopCount64_0(v) + case OpRound: + return rewriteValueARM64_OpRound_0(v) case OpRound32F: return rewriteValueARM64_OpRound32F_0(v) case OpRound64F: @@ -709,6 +715,8 @@ func rewriteValueARM64(v *Value) bool { return rewriteValueARM64_OpSub8_0(v) case OpSubPtr: return rewriteValueARM64_OpSubPtr_0(v) + case OpTrunc: + return rewriteValueARM64_OpTrunc_0(v) case OpTrunc16to8: return rewriteValueARM64_OpTrunc16to8_0(v) case OpTrunc32to16: @@ -11318,6 +11326,17 @@ func rewriteValueARM64_OpBswap64_0(v *Value) bool { return true } } +func rewriteValueARM64_OpCeil_0(v *Value) bool { + // match: (Ceil x) + // cond: + // result: (FRINTPD x) + for { + x := v.Args[0] + v.reset(OpARM64FRINTPD) + v.AddArg(x) + return true + } +} func rewriteValueARM64_OpClosureCall_0(v *Value) bool { // match: (ClosureCall [argwid] entry closure mem) // cond: @@ -12044,6 +12063,17 @@ func rewriteValueARM64_OpEqPtr_0(v *Value) bool { return true } } +func rewriteValueARM64_OpFloor_0(v *Value) bool { + // match: (Floor x) + // cond: + // result: (FRINTMD x) + for { + x := v.Args[0] + v.reset(OpARM64FRINTMD) + v.AddArg(x) + return true + } +} func rewriteValueARM64_OpGeq16_0(v *Value) bool { b := v.Block _ = b @@ -14717,6 +14747,17 @@ func rewriteValueARM64_OpPopCount64_0(v *Value) bool { return true } } +func rewriteValueARM64_OpRound_0(v *Value) bool { + // match: (Round x) + // cond: + // result: (FRINTAD x) + for { + x := v.Args[0] + v.reset(OpARM64FRINTAD) + v.AddArg(x) + return true + } +} func rewriteValueARM64_OpRound32F_0(v *Value) bool { // match: (Round32F x) // cond: @@ -16079,6 +16120,17 @@ func rewriteValueARM64_OpSubPtr_0(v *Value) bool { return true } } +func rewriteValueARM64_OpTrunc_0(v *Value) bool { + // match: (Trunc x) + // cond: + // result: (FRINTZD x) + for { + x := v.Args[0] + v.reset(OpARM64FRINTZD) + v.AddArg(x) + return true + } +} func rewriteValueARM64_OpTrunc16to8_0(v *Value) bool { // match: (Trunc16to8 x) // cond: