diff --git a/src/encoding/json/encode.go b/src/encoding/json/encode.go index 0088f25ab8..d8c779869b 100644 --- a/src/encoding/json/encode.go +++ b/src/encoding/json/encode.go @@ -49,6 +49,7 @@ import ( // The angle brackets "<" and ">" are escaped to "\u003c" and "\u003e" // to keep some browsers from misinterpreting JSON output as HTML. // Ampersand "&" is also escaped to "\u0026" for the same reason. +// This escaping can be disabled using an Encoder with DisableHTMLEscaping. // // Array and slice values encode as JSON arrays, except that // []byte encodes as a base64-encoded string, and a nil slice @@ -136,7 +137,7 @@ import ( // func Marshal(v interface{}) ([]byte, error) { e := &encodeState{} - err := e.marshal(v) + err := e.marshal(v, encOpts{escapeHTML: true}) if err != nil { return nil, err } @@ -259,7 +260,7 @@ func newEncodeState() *encodeState { return new(encodeState) } -func (e *encodeState) marshal(v interface{}) (err error) { +func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) { defer func() { if r := recover(); r != nil { if _, ok := r.(runtime.Error); ok { @@ -271,7 +272,7 @@ func (e *encodeState) marshal(v interface{}) (err error) { err = r.(error) } }() - e.reflectValue(reflect.ValueOf(v)) + e.reflectValue(reflect.ValueOf(v), opts) return nil } @@ -297,11 +298,18 @@ func isEmptyValue(v reflect.Value) bool { return false } -func (e *encodeState) reflectValue(v reflect.Value) { - valueEncoder(v)(e, v, false) +func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) { + valueEncoder(v)(e, v, opts) } -type encoderFunc func(e *encodeState, v reflect.Value, quoted bool) +type encOpts struct { + // quoted causes primitive fields to be encoded inside JSON strings. + quoted bool + // escapeHTML causes '<', '>', and '&' to be escaped in JSON strings. + escapeHTML bool +} + +type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts) var encoderCache struct { sync.RWMutex @@ -333,9 +341,9 @@ func typeEncoder(t reflect.Type) encoderFunc { } var wg sync.WaitGroup wg.Add(1) - encoderCache.m[t] = func(e *encodeState, v reflect.Value, quoted bool) { + encoderCache.m[t] = func(e *encodeState, v reflect.Value, opts encOpts) { wg.Wait() - f(e, v, quoted) + f(e, v, opts) } encoderCache.Unlock() @@ -405,11 +413,11 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { } } -func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) { +func invalidValueEncoder(e *encodeState, v reflect.Value, _ encOpts) { e.WriteString("null") } -func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { +func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { if v.Kind() == reflect.Ptr && v.IsNil() { e.WriteString("null") return @@ -418,14 +426,14 @@ func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { b, err := m.MarshalJSON() if err == nil { // copy JSON into buffer, checking validity. - err = compact(&e.Buffer, b, true) + err = compact(&e.Buffer, b, opts.escapeHTML) } if err != nil { e.error(&MarshalerError{v.Type(), err}) } } -func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { +func addrMarshalerEncoder(e *encodeState, v reflect.Value, _ encOpts) { va := v.Addr() if va.IsNil() { e.WriteString("null") @@ -442,7 +450,7 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { } } -func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { +func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { if v.Kind() == reflect.Ptr && v.IsNil() { e.WriteString("null") return @@ -452,10 +460,10 @@ func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { if err != nil { e.error(&MarshalerError{v.Type(), err}) } - e.stringBytes(b) + e.stringBytes(b, opts.escapeHTML) } -func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { +func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { va := v.Addr() if va.IsNil() { e.WriteString("null") @@ -466,11 +474,11 @@ func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { if err != nil { e.error(&MarshalerError{v.Type(), err}) } - e.stringBytes(b) + e.stringBytes(b, opts.escapeHTML) } -func boolEncoder(e *encodeState, v reflect.Value, quoted bool) { - if quoted { +func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if opts.quoted { e.WriteByte('"') } if v.Bool() { @@ -478,46 +486,46 @@ func boolEncoder(e *encodeState, v reflect.Value, quoted bool) { } else { e.WriteString("false") } - if quoted { + if opts.quoted { e.WriteByte('"') } } -func intEncoder(e *encodeState, v reflect.Value, quoted bool) { +func intEncoder(e *encodeState, v reflect.Value, opts encOpts) { b := strconv.AppendInt(e.scratch[:0], v.Int(), 10) - if quoted { + if opts.quoted { e.WriteByte('"') } e.Write(b) - if quoted { + if opts.quoted { e.WriteByte('"') } } -func uintEncoder(e *encodeState, v reflect.Value, quoted bool) { +func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10) - if quoted { + if opts.quoted { e.WriteByte('"') } e.Write(b) - if quoted { + if opts.quoted { e.WriteByte('"') } } type floatEncoder int // number of bits -func (bits floatEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { +func (bits floatEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { f := v.Float() if math.IsInf(f, 0) || math.IsNaN(f) { e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))}) } b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, int(bits)) - if quoted { + if opts.quoted { e.WriteByte('"') } e.Write(b) - if quoted { + if opts.quoted { e.WriteByte('"') } } @@ -527,7 +535,7 @@ var ( float64Encoder = (floatEncoder(64)).encode ) -func stringEncoder(e *encodeState, v reflect.Value, quoted bool) { +func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) { if v.Type() == numberType { numStr := v.String() // In Go1.5 the empty string encodes to "0", while this is not a valid number literal @@ -541,26 +549,26 @@ func stringEncoder(e *encodeState, v reflect.Value, quoted bool) { e.WriteString(numStr) return } - if quoted { + if opts.quoted { sb, err := Marshal(v.String()) if err != nil { e.error(err) } - e.string(string(sb)) + e.string(string(sb), opts.escapeHTML) } else { - e.string(v.String()) + e.string(v.String(), opts.escapeHTML) } } -func interfaceEncoder(e *encodeState, v reflect.Value, quoted bool) { +func interfaceEncoder(e *encodeState, v reflect.Value, opts encOpts) { if v.IsNil() { e.WriteString("null") return } - e.reflectValue(v.Elem()) + e.reflectValue(v.Elem(), opts) } -func unsupportedTypeEncoder(e *encodeState, v reflect.Value, quoted bool) { +func unsupportedTypeEncoder(e *encodeState, v reflect.Value, _ encOpts) { e.error(&UnsupportedTypeError{v.Type()}) } @@ -569,7 +577,7 @@ type structEncoder struct { fieldEncs []encoderFunc } -func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { +func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { e.WriteByte('{') first := true for i, f := range se.fields { @@ -582,9 +590,10 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { } else { e.WriteByte(',') } - e.string(f.name) + e.string(f.name, opts.escapeHTML) e.WriteByte(':') - se.fieldEncs[i](e, fv, f.quoted) + opts.quoted = f.quoted + se.fieldEncs[i](e, fv, opts) } e.WriteByte('}') } @@ -605,7 +614,7 @@ type mapEncoder struct { elemEnc encoderFunc } -func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) { +func (me *mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { if v.IsNil() { e.WriteString("null") return @@ -627,9 +636,9 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) { if i > 0 { e.WriteByte(',') } - e.string(kv.s) + e.string(kv.s, opts.escapeHTML) e.WriteByte(':') - me.elemEnc(e, v.MapIndex(kv.v), false) + me.elemEnc(e, v.MapIndex(kv.v), opts) } e.WriteByte('}') } @@ -642,7 +651,7 @@ func newMapEncoder(t reflect.Type) encoderFunc { return me.encode } -func encodeByteSlice(e *encodeState, v reflect.Value, _ bool) { +func encodeByteSlice(e *encodeState, v reflect.Value, _ encOpts) { if v.IsNil() { e.WriteString("null") return @@ -669,12 +678,12 @@ type sliceEncoder struct { arrayEnc encoderFunc } -func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) { +func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { if v.IsNil() { e.WriteString("null") return } - se.arrayEnc(e, v, false) + se.arrayEnc(e, v, opts) } func newSliceEncoder(t reflect.Type) encoderFunc { @@ -692,14 +701,14 @@ type arrayEncoder struct { elemEnc encoderFunc } -func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) { +func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { e.WriteByte('[') n := v.Len() for i := 0; i < n; i++ { if i > 0 { e.WriteByte(',') } - ae.elemEnc(e, v.Index(i), false) + ae.elemEnc(e, v.Index(i), opts) } e.WriteByte(']') } @@ -713,12 +722,12 @@ type ptrEncoder struct { elemEnc encoderFunc } -func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { +func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { if v.IsNil() { e.WriteString("null") return } - pe.elemEnc(e, v.Elem(), quoted) + pe.elemEnc(e, v.Elem(), opts) } func newPtrEncoder(t reflect.Type) encoderFunc { @@ -730,11 +739,11 @@ type condAddrEncoder struct { canAddrEnc, elseEnc encoderFunc } -func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { +func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { if v.CanAddr() { - ce.canAddrEnc(e, v, quoted) + ce.canAddrEnc(e, v, opts) } else { - ce.elseEnc(e, v, quoted) + ce.elseEnc(e, v, opts) } } @@ -812,13 +821,14 @@ func (sv byString) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } func (sv byString) Less(i, j int) bool { return sv[i].s < sv[j].s } // NOTE: keep in sync with stringBytes below. -func (e *encodeState) string(s string) int { +func (e *encodeState) string(s string, escapeHTML bool) int { len0 := e.Len() e.WriteByte('"') start := 0 for i := 0; i < len(s); { if b := s[i]; b < utf8.RuneSelf { - if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' { + if 0x20 <= b && b != '\\' && b != '"' && + (!escapeHTML || b != '<' && b != '>' && b != '&') { i++ continue } @@ -839,10 +849,11 @@ func (e *encodeState) string(s string) int { e.WriteByte('\\') e.WriteByte('t') default: - // This encodes bytes < 0x20 except for \n and \r, - // as well as <, > and &. The latter are escaped because they - // can lead to security holes when user-controlled strings - // are rendered into JSON and served to some browsers. + // This encodes bytes < 0x20 except for \t, \n and \r. + // If escapeHTML is set, it also escapes <, >, and & + // because they can lead to security holes when + // user-controlled strings are rendered into JSON + // and served to some browsers. e.WriteString(`\u00`) e.WriteByte(hex[b>>4]) e.WriteByte(hex[b&0xF]) @@ -888,13 +899,14 @@ func (e *encodeState) string(s string) int { } // NOTE: keep in sync with string above. -func (e *encodeState) stringBytes(s []byte) int { +func (e *encodeState) stringBytes(s []byte, escapeHTML bool) int { len0 := e.Len() e.WriteByte('"') start := 0 for i := 0; i < len(s); { if b := s[i]; b < utf8.RuneSelf { - if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' { + if 0x20 <= b && b != '\\' && b != '"' && + (!escapeHTML || b != '<' && b != '>' && b != '&') { i++ continue } @@ -915,10 +927,11 @@ func (e *encodeState) stringBytes(s []byte) int { e.WriteByte('\\') e.WriteByte('t') default: - // This encodes bytes < 0x20 except for \n and \r, - // as well as <, >, and &. The latter are escaped because they - // can lead to security holes when user-controlled strings - // are rendered into JSON and served to some browsers. + // This encodes bytes < 0x20 except for \t, \n and \r. + // If escapeHTML is set, it also escapes <, >, and & + // because they can lead to security holes when + // user-controlled strings are rendered into JSON + // and served to some browsers. e.WriteString(`\u00`) e.WriteByte(hex[b>>4]) e.WriteByte(hex[b&0xF]) diff --git a/src/encoding/json/encode_test.go b/src/encoding/json/encode_test.go index eee59ccb49..b484022a70 100644 --- a/src/encoding/json/encode_test.go +++ b/src/encoding/json/encode_test.go @@ -376,41 +376,45 @@ func TestDuplicatedFieldDisappears(t *testing.T) { func TestStringBytes(t *testing.T) { // Test that encodeState.stringBytes and encodeState.string use the same encoding. - es := &encodeState{} var r []rune for i := '\u0000'; i <= unicode.MaxRune; i++ { r = append(r, i) } s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too - es.string(s) - esBytes := &encodeState{} - esBytes.stringBytes([]byte(s)) + for _, escapeHTML := range []bool{true, false} { + es := &encodeState{} + es.string(s, escapeHTML) - enc := es.Buffer.String() - encBytes := esBytes.Buffer.String() - if enc != encBytes { - i := 0 - for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] { - i++ - } - enc = enc[i:] - encBytes = encBytes[i:] - i = 0 - for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] { - i++ - } - enc = enc[:len(enc)-i] - encBytes = encBytes[:len(encBytes)-i] + esBytes := &encodeState{} + esBytes.stringBytes([]byte(s), escapeHTML) - if len(enc) > 20 { - enc = enc[:20] + "..." - } - if len(encBytes) > 20 { - encBytes = encBytes[:20] + "..." - } + enc := es.Buffer.String() + encBytes := esBytes.Buffer.String() + if enc != encBytes { + i := 0 + for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] { + i++ + } + enc = enc[i:] + encBytes = encBytes[i:] + i = 0 + for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] { + i++ + } + enc = enc[:len(enc)-i] + encBytes = encBytes[:len(encBytes)-i] - t.Errorf("encodings differ at %#q vs %#q", enc, encBytes) + if len(enc) > 20 { + enc = enc[:20] + "..." + } + if len(encBytes) > 20 { + encBytes = encBytes[:20] + "..." + } + + t.Errorf("with escapeHTML=%t, encodings differ at %#q vs %#q", + escapeHTML, enc, encBytes) + } } } diff --git a/src/encoding/json/stream.go b/src/encoding/json/stream.go index 422837bb63..d6b2992e9b 100644 --- a/src/encoding/json/stream.go +++ b/src/encoding/json/stream.go @@ -166,8 +166,9 @@ func nonSpace(b []byte) bool { // An Encoder writes JSON values to an output stream. type Encoder struct { - w io.Writer - err error + w io.Writer + err error + escapeHTML bool indentBuf *bytes.Buffer indentPrefix string @@ -176,7 +177,7 @@ type Encoder struct { // NewEncoder returns a new encoder that writes to w. func NewEncoder(w io.Writer) *Encoder { - return &Encoder{w: w} + return &Encoder{w: w, escapeHTML: true} } // Encode writes the JSON encoding of v to the stream, @@ -189,7 +190,7 @@ func (enc *Encoder) Encode(v interface{}) error { return enc.err } e := newEncodeState() - err := e.marshal(v) + err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML}) if err != nil { return err } @@ -225,6 +226,12 @@ func (enc *Encoder) Indent(prefix, indent string) { enc.indentValue = indent } +// DisableHTMLEscaping causes the encoder not to escape angle brackets +// ("<" and ">") or ampersands ("&") in JSON strings. +func (enc *Encoder) DisableHTMLEscaping() { + enc.escapeHTML = false +} + // RawMessage is a raw encoded JSON value. // It implements Marshaler and Unmarshaler and can // be used to delay JSON decoding or precompute a JSON encoding. diff --git a/src/encoding/json/stream_test.go b/src/encoding/json/stream_test.go index db25708f4c..3516ac3b83 100644 --- a/src/encoding/json/stream_test.go +++ b/src/encoding/json/stream_test.go @@ -87,6 +87,39 @@ func TestEncoderIndent(t *testing.T) { } } +func TestEncoderDisableHTMLEscaping(t *testing.T) { + var c C + var ct CText + for _, tt := range []struct { + name string + v interface{} + wantEscape string + want string + }{ + {"c", c, `"\u003c\u0026\u003e"`, `"<&>"`}, + {"ct", ct, `"\"\u003c\u0026\u003e\""`, `"\"<&>\""`}, + {`"<&>"`, "<&>", `"\u003c\u0026\u003e"`, `"<&>"`}, + } { + var buf bytes.Buffer + enc := NewEncoder(&buf) + if err := enc.Encode(tt.v); err != nil { + t.Fatalf("Encode(%s): %s", tt.name, err) + } + if got := strings.TrimSpace(buf.String()); got != tt.wantEscape { + t.Errorf("Encode(%s) = %#q, want %#q", tt.name, got, tt.wantEscape) + } + buf.Reset() + enc.DisableHTMLEscaping() + if err := enc.Encode(tt.v); err != nil { + t.Fatalf("DisableHTMLEscaping Encode(%s): %s", tt.name, err) + } + if got := strings.TrimSpace(buf.String()); got != tt.want { + t.Errorf("DisableHTMLEscaping Encode(%s) = %#q, want %#q", + tt.name, got, tt.want) + } + } +} + func TestDecoder(t *testing.T) { for i := 0; i <= len(streamTest); i++ { // Use stream without newlines as input,