Commit 7e886740 authored by Russ Cox's avatar Russ Cox

encoding/json: support encoding.TextMarshaler, encoding.TextUnmarshaler

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12703043
parent 5822e784
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
package json package json
import ( import (
"encoding"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
...@@ -293,7 +294,7 @@ func (d *decodeState) value(v reflect.Value) { ...@@ -293,7 +294,7 @@ func (d *decodeState) value(v reflect.Value) {
// until it gets to a non-pointer. // until it gets to a non-pointer.
// if it encounters an Unmarshaler, indirect stops and returns that. // if it encounters an Unmarshaler, indirect stops and returns that.
// if decodingNull is true, indirect stops at the last pointer so it can be set to nil. // if decodingNull is true, indirect stops at the last pointer so it can be set to nil.
func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// If v is a named type and is addressable, // If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods, // start with its address, so that if the type has pointer methods,
// we find them. // we find them.
...@@ -322,28 +323,38 @@ func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ...@@ -322,28 +323,38 @@ func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler,
v.Set(reflect.New(v.Type().Elem())) v.Set(reflect.New(v.Type().Elem()))
} }
if v.Type().NumMethod() > 0 { if v.Type().NumMethod() > 0 {
if unmarshaler, ok := v.Interface().(Unmarshaler); ok { if u, ok := v.Interface().(Unmarshaler); ok {
return unmarshaler, reflect.Value{} return u, nil, reflect.Value{}
}
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
} }
} }
v = v.Elem() v = v.Elem()
} }
return nil, v return nil, nil, v
} }
// array consumes an array from d.data[d.off-1:], decoding into the value v. // array consumes an array from d.data[d.off-1:], decoding into the value v.
// the first byte of the array ('[') has been read already. // the first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) { func (d *decodeState) array(v reflect.Value) {
// Check for unmarshaler. // Check for unmarshaler.
unmarshaler, pv := d.indirect(v, false) u, ut, pv := d.indirect(v, false)
if unmarshaler != nil { if u != nil {
d.off-- d.off--
err := unmarshaler.UnmarshalJSON(d.next()) err := u.UnmarshalJSON(d.next())
if err != nil { if err != nil {
d.error(err) d.error(err)
} }
return return
} }
if ut != nil {
d.saveError(&UnmarshalTypeError{"array", v.Type()})
d.off--
d.next()
return
}
v = pv v = pv
// Check type of target. // Check type of target.
...@@ -434,15 +445,21 @@ func (d *decodeState) array(v reflect.Value) { ...@@ -434,15 +445,21 @@ func (d *decodeState) array(v reflect.Value) {
// the first byte of the object ('{') has been read already. // the first byte of the object ('{') has been read already.
func (d *decodeState) object(v reflect.Value) { func (d *decodeState) object(v reflect.Value) {
// Check for unmarshaler. // Check for unmarshaler.
unmarshaler, pv := d.indirect(v, false) u, ut, pv := d.indirect(v, false)
if unmarshaler != nil { if u != nil {
d.off-- d.off--
err := unmarshaler.UnmarshalJSON(d.next()) err := u.UnmarshalJSON(d.next())
if err != nil { if err != nil {
d.error(err) d.error(err)
} }
return return
} }
if ut != nil {
d.saveError(&UnmarshalTypeError{"object", v.Type()})
d.off--
d.next() // skip over { } in input
return
}
v = pv v = pv
// Decoding into nil interface? Switch to non-reflect code. // Decoding into nil interface? Switch to non-reflect code.
...@@ -611,14 +628,37 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool ...@@ -611,14 +628,37 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
return return
} }
wantptr := item[0] == 'n' // null wantptr := item[0] == 'n' // null
unmarshaler, pv := d.indirect(v, wantptr) u, ut, pv := d.indirect(v, wantptr)
if unmarshaler != nil { if u != nil {
err := unmarshaler.UnmarshalJSON(item) err := u.UnmarshalJSON(item)
if err != nil {
d.error(err)
}
return
}
if ut != nil {
if item[0] != '"' {
if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.saveError(&UnmarshalTypeError{"string", v.Type()})
}
}
s, ok := unquoteBytes(item)
if !ok {
if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(errPhase)
}
}
err := ut.UnmarshalText(s)
if err != nil { if err != nil {
d.error(err) d.error(err)
} }
return return
} }
v = pv v = pv
switch c := item[0]; c { switch c := item[0]; c {
......
...@@ -6,6 +6,7 @@ package json ...@@ -6,6 +6,7 @@ package json
import ( import (
"bytes" "bytes"
"encoding"
"fmt" "fmt"
"image" "image"
"reflect" "reflect"
...@@ -57,7 +58,7 @@ type unmarshaler struct { ...@@ -57,7 +58,7 @@ type unmarshaler struct {
} }
func (u *unmarshaler) UnmarshalJSON(b []byte) error { func (u *unmarshaler) UnmarshalJSON(b []byte) error {
*u = unmarshaler{true} // All we need to see that UnmarshalJson is called. *u = unmarshaler{true} // All we need to see that UnmarshalJSON is called.
return nil return nil
} }
...@@ -65,6 +66,26 @@ type ustruct struct { ...@@ -65,6 +66,26 @@ type ustruct struct {
M unmarshaler M unmarshaler
} }
type unmarshalerText struct {
T bool
}
// needed for re-marshaling tests
func (u *unmarshalerText) MarshalText() ([]byte, error) {
return []byte(""), nil
}
func (u *unmarshalerText) UnmarshalText(b []byte) error {
*u = unmarshalerText{true} // All we need to see that UnmarshalText is called.
return nil
}
var _ encoding.TextUnmarshaler = (*unmarshalerText)(nil)
type ustructText struct {
M unmarshalerText
}
var ( var (
um0, um1 unmarshaler // target2 of unmarshaling um0, um1 unmarshaler // target2 of unmarshaling
ump = &um1 ump = &um1
...@@ -72,6 +93,13 @@ var ( ...@@ -72,6 +93,13 @@ var (
umslice = []unmarshaler{{true}} umslice = []unmarshaler{{true}}
umslicep = new([]unmarshaler) umslicep = new([]unmarshaler)
umstruct = ustruct{unmarshaler{true}} umstruct = ustruct{unmarshaler{true}}
um0T, um1T unmarshalerText // target2 of unmarshaling
umpT = &um1T
umtrueT = unmarshalerText{true}
umsliceT = []unmarshalerText{{true}}
umslicepT = new([]unmarshalerText)
umstructT = ustructText{unmarshalerText{true}}
) )
// Test data structures for anonymous fields. // Test data structures for anonymous fields.
...@@ -261,6 +289,13 @@ var unmarshalTests = []unmarshalTest{ ...@@ -261,6 +289,13 @@ var unmarshalTests = []unmarshalTest{
{in: `[{"T":false}]`, ptr: &umslicep, out: &umslice}, {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
{in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct}, {in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct},
// UnmarshalText interface test
{in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called
{in: `"X"`, ptr: &umpT, out: &umtrueT},
{in: `["X"]`, ptr: &umsliceT, out: umsliceT},
{in: `["X"]`, ptr: &umslicepT, out: &umsliceT},
{in: `{"M":"X"}`, ptr: &umstructT, out: umstructT},
{ {
in: `{ in: `{
"Level0": 1, "Level0": 1,
...@@ -505,7 +540,7 @@ func TestUnmarshal(t *testing.T) { ...@@ -505,7 +540,7 @@ func TestUnmarshal(t *testing.T) {
dec.UseNumber() dec.UseNumber()
} }
if err := dec.Decode(vv.Interface()); err != nil { if err := dec.Decode(vv.Interface()); err != nil {
t.Errorf("#%d: error re-unmarshaling: %v", i, err) t.Errorf("#%d: error re-unmarshaling %#q: %v", i, enc, err)
continue continue
} }
if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) { if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) {
...@@ -979,15 +1014,20 @@ func TestRefUnmarshal(t *testing.T) { ...@@ -979,15 +1014,20 @@ func TestRefUnmarshal(t *testing.T) {
// Ref is defined in encode_test.go. // Ref is defined in encode_test.go.
R0 Ref R0 Ref
R1 *Ref R1 *Ref
R2 RefText
R3 *RefText
} }
want := S{ want := S{
R0: 12, R0: 12,
R1: new(Ref), R1: new(Ref),
R2: 13,
R3: new(RefText),
} }
*want.R1 = 12 *want.R1 = 12
*want.R3 = 13
var got S var got S
if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref"}`), &got); err != nil { if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref","R2":"ref","R3":"ref"}`), &got); err != nil {
t.Fatalf("Unmarshal: %v", err) t.Fatalf("Unmarshal: %v", err)
} }
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
......
...@@ -12,6 +12,7 @@ package json ...@@ -12,6 +12,7 @@ package json
import ( import (
"bytes" "bytes"
"encoding"
"encoding/base64" "encoding/base64"
"math" "math"
"reflect" "reflect"
...@@ -361,17 +362,29 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { ...@@ -361,17 +362,29 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
if !vx.IsValid() { if !vx.IsValid() {
vx = reflect.New(t).Elem() vx = reflect.New(t).Elem()
} }
_, ok := vx.Interface().(Marshaler) _, ok := vx.Interface().(Marshaler)
if ok { if ok {
return valueIsMarshallerEncoder return marshalerEncoder
} }
// T doesn't match the interface. Check against *T too.
if vx.Kind() != reflect.Ptr && vx.CanAddr() { if vx.Kind() != reflect.Ptr && vx.CanAddr() {
_, ok = vx.Addr().Interface().(Marshaler) _, ok = vx.Addr().Interface().(Marshaler)
if ok { if ok {
return valueAddrIsMarshallerEncoder return addrMarshalerEncoder
} }
} }
_, ok = vx.Interface().(encoding.TextMarshaler)
if ok {
return textMarshalerEncoder
}
if vx.Kind() != reflect.Ptr && vx.CanAddr() {
_, ok = vx.Addr().Interface().(encoding.TextMarshaler)
if ok {
return addrTextMarshalerEncoder
}
}
switch vx.Kind() { switch vx.Kind() {
case reflect.Bool: case reflect.Bool:
return boolEncoder return boolEncoder
...@@ -406,7 +419,7 @@ func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) { ...@@ -406,7 +419,7 @@ func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) {
e.WriteString("null") e.WriteString("null")
} }
func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) { func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
if v.Kind() == reflect.Ptr && v.IsNil() { if v.Kind() == reflect.Ptr && v.IsNil() {
e.WriteString("null") e.WriteString("null")
return return
...@@ -422,9 +435,9 @@ func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) { ...@@ -422,9 +435,9 @@ func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
} }
} }
func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) { func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
va := v.Addr() va := v.Addr()
if va.Kind() == reflect.Ptr && va.IsNil() { if va.IsNil() {
e.WriteString("null") e.WriteString("null")
return return
} }
...@@ -439,6 +452,37 @@ func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) ...@@ -439,6 +452,37 @@ func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool)
} }
} }
func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
if v.Kind() == reflect.Ptr && v.IsNil() {
e.WriteString("null")
return
}
m := v.Interface().(encoding.TextMarshaler)
b, err := m.MarshalText()
if err == nil {
_, err = e.stringBytes(b)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err})
}
}
func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(encoding.TextMarshaler)
b, err := m.MarshalText()
if err == nil {
_, err = e.stringBytes(b)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err})
}
}
func boolEncoder(e *encodeState, v reflect.Value, quoted bool) { func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
if quoted { if quoted {
e.WriteByte('"') e.WriteByte('"')
...@@ -728,6 +772,7 @@ func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } ...@@ -728,6 +772,7 @@ func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) } func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
func (sv stringValues) get(i int) string { return sv[i].String() } func (sv stringValues) get(i int) string { return sv[i].String() }
// NOTE: keep in sync with stringBytes below.
func (e *encodeState) string(s string) (int, error) { func (e *encodeState) string(s string) (int, error) {
len0 := e.Len() len0 := e.Len()
e.WriteByte('"') e.WriteByte('"')
...@@ -800,6 +845,79 @@ func (e *encodeState) string(s string) (int, error) { ...@@ -800,6 +845,79 @@ func (e *encodeState) string(s string) (int, error) {
return e.Len() - len0, nil return e.Len() - len0, nil
} }
// NOTE: keep in sync with string above.
func (e *encodeState) stringBytes(s []byte) (int, error) {
len0 := e.Len()
e.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
i++
continue
}
if start < i {
e.Write(s[start:i])
}
switch b {
case '\\', '"':
e.WriteByte('\\')
e.WriteByte(b)
case '\n':
e.WriteByte('\\')
e.WriteByte('n')
case '\r':
e.WriteByte('\\')
e.WriteByte('r')
default:
// This encodes bytes < 0x20 except for \n and \r,
// as well as < and >. The latter are escaped because they
// can lead to security holes when user-controlled strings
// are rendered into JSON and served to some browsers.
e.WriteString(`\u00`)
e.WriteByte(hex[b>>4])
e.WriteByte(hex[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRune(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
e.Write(s[start:i])
}
e.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
e.Write(s[start:i])
}
e.WriteString(`\u202`)
e.WriteByte(hex[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
e.Write(s[start:])
}
e.WriteByte('"')
return e.Len() - len0, nil
}
// A field represents a single field found in a struct. // A field represents a single field found in a struct.
type field struct { type field struct {
name string name string
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"math" "math"
"reflect" "reflect"
"testing" "testing"
"unicode"
) )
type Optionals struct { type Optionals struct {
...@@ -146,19 +147,46 @@ func (Val) MarshalJSON() ([]byte, error) { ...@@ -146,19 +147,46 @@ func (Val) MarshalJSON() ([]byte, error) {
return []byte(`"val"`), nil return []byte(`"val"`), nil
} }
// RefText has Marshaler and Unmarshaler methods with pointer receiver.
type RefText int
func (*RefText) MarshalText() ([]byte, error) {
return []byte(`"ref"`), nil
}
func (r *RefText) UnmarshalText([]byte) error {
*r = 13
return nil
}
// ValText has Marshaler methods with value receiver.
type ValText int
func (ValText) MarshalText() ([]byte, error) {
return []byte(`"val"`), nil
}
func TestRefValMarshal(t *testing.T) { func TestRefValMarshal(t *testing.T) {
var s = struct { var s = struct {
R0 Ref R0 Ref
R1 *Ref R1 *Ref
R2 RefText
R3 *RefText
V0 Val V0 Val
V1 *Val V1 *Val
V2 ValText
V3 *ValText
}{ }{
R0: 12, R0: 12,
R1: new(Ref), R1: new(Ref),
R2: 14,
R3: new(RefText),
V0: 13, V0: 13,
V1: new(Val), V1: new(Val),
V2: 15,
V3: new(ValText),
} }
const want = `{"R0":"ref","R1":"ref","V0":"val","V1":"val"}` const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}`
b, err := Marshal(&s) b, err := Marshal(&s)
if err != nil { if err != nil {
t.Fatalf("Marshal: %v", err) t.Fatalf("Marshal: %v", err)
...@@ -175,15 +203,32 @@ func (C) MarshalJSON() ([]byte, error) { ...@@ -175,15 +203,32 @@ func (C) MarshalJSON() ([]byte, error) {
return []byte(`"<&>"`), nil return []byte(`"<&>"`), nil
} }
// CText implements Marshaler and returns unescaped text.
type CText int
func (CText) MarshalText() ([]byte, error) {
return []byte(`"<&>"`), nil
}
func TestMarshalerEscaping(t *testing.T) { func TestMarshalerEscaping(t *testing.T) {
var c C var c C
const want = `"\u003c\u0026\u003e"` want := `"\u003c\u0026\u003e"`
b, err := Marshal(c) b, err := Marshal(c)
if err != nil { if err != nil {
t.Fatalf("Marshal: %v", err) t.Fatalf("Marshal(c): %v", err)
} }
if got := string(b); got != want { if got := string(b); got != want {
t.Errorf("got %q, want %q", got, want) t.Errorf("Marshal(c) = %#q, want %#q", got, want)
}
var ct CText
want = `"\"\u003c\u0026\u003e\""`
b, err = Marshal(ct)
if err != nil {
t.Fatalf("Marshal(ct): %v", err)
}
if got := string(b); got != want {
t.Errorf("Marshal(ct) = %#q, want %#q", got, want)
} }
} }
...@@ -310,3 +355,49 @@ func TestDuplicatedFieldDisappears(t *testing.T) { ...@@ -310,3 +355,49 @@ func TestDuplicatedFieldDisappears(t *testing.T) {
t.Fatalf("Marshal: got %s want %s", got, want) t.Fatalf("Marshal: got %s want %s", got, want)
} }
} }
func TestStringBytes(t *testing.T) {
// Test that encodeState.stringBytes and encodeState.string use the same encoding.
es := &encodeState{}
var r []rune
for i := '\u0000'; i <= unicode.MaxRune; i++ {
r = append(r, i)
}
s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too
_, err := es.string(s)
if err != nil {
t.Fatal(err)
}
esBytes := &encodeState{}
_, err = esBytes.stringBytes([]byte(s))
if err != nil {
t.Fatal(err)
}
enc := es.Buffer.String()
encBytes := esBytes.Buffer.String()
if enc != encBytes {
i := 0
for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
i++
}
enc = enc[i:]
encBytes = encBytes[i:]
i = 0
for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
i++
}
enc = enc[:len(enc)-i]
encBytes = encBytes[:len(encBytes)-i]
if len(enc) > 20 {
enc = enc[:20] + "..."
}
if len(encBytes) > 20 {
encBytes = encBytes[:20] + "..."
}
t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment