From c20c09251c37c60356e8457a3c7cb632c30b69b1 Mon Sep 17 00:00:00 2001 From: Evan Shaw Date: Tue, 3 Jan 2012 12:30:18 +1100 Subject: [PATCH] encoding/json: don't marshal special float values R=golang-dev, adg CC=golang-dev https://golang.org/cl/5500084 --- src/pkg/encoding/json/encode.go | 16 +++++++++++++++- src/pkg/encoding/json/encode_test.go | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/pkg/encoding/json/encode.go b/src/pkg/encoding/json/encode.go index 3d2f4fc316..033da2d0ad 100644 --- a/src/pkg/encoding/json/encode.go +++ b/src/pkg/encoding/json/encode.go @@ -12,6 +12,7 @@ package json import ( "bytes" "encoding/base64" + "math" "reflect" "runtime" "sort" @@ -170,6 +171,15 @@ func (e *UnsupportedTypeError) Error() string { return "json: unsupported type: " + e.Type.String() } +type UnsupportedValueError struct { + Value reflect.Value + Str string +} + +func (e *UnsupportedValueError) Error() string { + return "json: unsupported value: " + e.Str +} + type InvalidUTF8Error struct { S string } @@ -290,7 +300,11 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) { e.Write(b) } case reflect.Float32, reflect.Float64: - b := strconv.AppendFloat(e.scratch[:0], v.Float(), 'g', -1, v.Type().Bits()) + f := v.Float() + if math.IsInf(f, 0) || math.IsNaN(f) { + e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, v.Type().Bits())}) + } + b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, v.Type().Bits()) if quoted { writeString(e, string(b)) } else { diff --git a/src/pkg/encoding/json/encode_test.go b/src/pkg/encoding/json/encode_test.go index 9366589f25..0e39559a46 100644 --- a/src/pkg/encoding/json/encode_test.go +++ b/src/pkg/encoding/json/encode_test.go @@ -6,6 +6,7 @@ package json import ( "bytes" + "math" "reflect" "testing" ) @@ -107,3 +108,21 @@ func TestEncodeRenamedByteSlice(t *testing.T) { t.Errorf(" got %s want %s", result, expect) } } + +var unsupportedValues = []interface{}{ + math.NaN(), + math.Inf(-1), + math.Inf(1), +} + +func TestUnsupportedValues(t *testing.T) { + for _, v := range unsupportedValues { + if _, err := Marshal(v); err != nil { + if _, ok := err.(*UnsupportedValueError); !ok { + t.Errorf("for %v, got %T want UnsupportedValueError", v, err) + } + } else { + t.Errorf("for %v, expected error", v) + } + } +}