Commit 7e886740 authored by Russ Cox's avatar Russ Cox

encoding/json: support encoding.TextMarshaler, encoding.TextUnmarshaler

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12703043
parent 5822e784
......@@ -8,6 +8,7 @@
package json
import (
"encoding"
"encoding/base64"
"errors"
"fmt"
......@@ -293,7 +294,7 @@ func (d *decodeState) value(v reflect.Value) {
// until it gets to a non-pointer.
// if it encounters an Unmarshaler, indirect stops and returns that.
// if decodingNull is true, indirect stops at the last pointer so it can be set to nil.
func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) {
func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
......@@ -322,28 +323,38 @@ func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler,
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 {
if unmarshaler, ok := v.Interface().(Unmarshaler); ok {
return unmarshaler, reflect.Value{}
if u, ok := v.Interface().(Unmarshaler); ok {
return u, nil, reflect.Value{}
}
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
}
}
v = v.Elem()
}
return nil, v
return nil, nil, v
}
// array consumes an array from d.data[d.off-1:], decoding into the value v.
// the first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) {
// Check for unmarshaler.
unmarshaler, pv := d.indirect(v, false)
if unmarshaler != nil {
u, ut, pv := d.indirect(v, false)
if u != nil {
d.off--
err := unmarshaler.UnmarshalJSON(d.next())
err := u.UnmarshalJSON(d.next())
if err != nil {
d.error(err)
}
return
}
if ut != nil {
d.saveError(&UnmarshalTypeError{"array", v.Type()})
d.off--
d.next()
return
}
v = pv
// Check type of target.
......@@ -434,15 +445,21 @@ func (d *decodeState) array(v reflect.Value) {
// the first byte of the object ('{') has been read already.
func (d *decodeState) object(v reflect.Value) {
// Check for unmarshaler.
unmarshaler, pv := d.indirect(v, false)
if unmarshaler != nil {
u, ut, pv := d.indirect(v, false)
if u != nil {
d.off--
err := unmarshaler.UnmarshalJSON(d.next())
err := u.UnmarshalJSON(d.next())
if err != nil {
d.error(err)
}
return
}
if ut != nil {
d.saveError(&UnmarshalTypeError{"object", v.Type()})
d.off--
d.next() // skip over { } in input
return
}
v = pv
// Decoding into nil interface? Switch to non-reflect code.
......@@ -611,14 +628,37 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
return
}
wantptr := item[0] == 'n' // null
unmarshaler, pv := d.indirect(v, wantptr)
if unmarshaler != nil {
err := unmarshaler.UnmarshalJSON(item)
u, ut, pv := d.indirect(v, wantptr)
if u != nil {
err := u.UnmarshalJSON(item)
if err != nil {
d.error(err)
}
return
}
if ut != nil {
if item[0] != '"' {
if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.saveError(&UnmarshalTypeError{"string", v.Type()})
}
}
s, ok := unquoteBytes(item)
if !ok {
if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(errPhase)
}
}
err := ut.UnmarshalText(s)
if err != nil {
d.error(err)
}
return
}
v = pv
switch c := item[0]; c {
......
......@@ -6,6 +6,7 @@ package json
import (
"bytes"
"encoding"
"fmt"
"image"
"reflect"
......@@ -57,7 +58,7 @@ type unmarshaler struct {
}
func (u *unmarshaler) UnmarshalJSON(b []byte) error {
*u = unmarshaler{true} // All we need to see that UnmarshalJson is called.
*u = unmarshaler{true} // All we need to see that UnmarshalJSON is called.
return nil
}
......@@ -65,6 +66,26 @@ type ustruct struct {
M unmarshaler
}
type unmarshalerText struct {
T bool
}
// needed for re-marshaling tests
func (u *unmarshalerText) MarshalText() ([]byte, error) {
return []byte(""), nil
}
func (u *unmarshalerText) UnmarshalText(b []byte) error {
*u = unmarshalerText{true} // All we need to see that UnmarshalText is called.
return nil
}
var _ encoding.TextUnmarshaler = (*unmarshalerText)(nil)
type ustructText struct {
M unmarshalerText
}
var (
um0, um1 unmarshaler // target2 of unmarshaling
ump = &um1
......@@ -72,6 +93,13 @@ var (
umslice = []unmarshaler{{true}}
umslicep = new([]unmarshaler)
umstruct = ustruct{unmarshaler{true}}
um0T, um1T unmarshalerText // target2 of unmarshaling
umpT = &um1T
umtrueT = unmarshalerText{true}
umsliceT = []unmarshalerText{{true}}
umslicepT = new([]unmarshalerText)
umstructT = ustructText{unmarshalerText{true}}
)
// Test data structures for anonymous fields.
......@@ -261,6 +289,13 @@ var unmarshalTests = []unmarshalTest{
{in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
{in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct},
// UnmarshalText interface test
{in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called
{in: `"X"`, ptr: &umpT, out: &umtrueT},
{in: `["X"]`, ptr: &umsliceT, out: umsliceT},
{in: `["X"]`, ptr: &umslicepT, out: &umsliceT},
{in: `{"M":"X"}`, ptr: &umstructT, out: umstructT},
{
in: `{
"Level0": 1,
......@@ -505,7 +540,7 @@ func TestUnmarshal(t *testing.T) {
dec.UseNumber()
}
if err := dec.Decode(vv.Interface()); err != nil {
t.Errorf("#%d: error re-unmarshaling: %v", i, err)
t.Errorf("#%d: error re-unmarshaling %#q: %v", i, enc, err)
continue
}
if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) {
......@@ -979,15 +1014,20 @@ func TestRefUnmarshal(t *testing.T) {
// Ref is defined in encode_test.go.
R0 Ref
R1 *Ref
R2 RefText
R3 *RefText
}
want := S{
R0: 12,
R1: new(Ref),
R2: 13,
R3: new(RefText),
}
*want.R1 = 12
*want.R3 = 13
var got S
if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref"}`), &got); err != nil {
if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref","R2":"ref","R3":"ref"}`), &got); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if !reflect.DeepEqual(got, want) {
......
......@@ -12,6 +12,7 @@ package json
import (
"bytes"
"encoding"
"encoding/base64"
"math"
"reflect"
......@@ -361,17 +362,29 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
if !vx.IsValid() {
vx = reflect.New(t).Elem()
}
_, ok := vx.Interface().(Marshaler)
if ok {
return valueIsMarshallerEncoder
return marshalerEncoder
}
// T doesn't match the interface. Check against *T too.
if vx.Kind() != reflect.Ptr && vx.CanAddr() {
_, ok = vx.Addr().Interface().(Marshaler)
if ok {
return valueAddrIsMarshallerEncoder
return addrMarshalerEncoder
}
}
_, ok = vx.Interface().(encoding.TextMarshaler)
if ok {
return textMarshalerEncoder
}
if vx.Kind() != reflect.Ptr && vx.CanAddr() {
_, ok = vx.Addr().Interface().(encoding.TextMarshaler)
if ok {
return addrTextMarshalerEncoder
}
}
switch vx.Kind() {
case reflect.Bool:
return boolEncoder
......@@ -406,7 +419,7 @@ func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) {
e.WriteString("null")
}
func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
if v.Kind() == reflect.Ptr && v.IsNil() {
e.WriteString("null")
return
......@@ -422,9 +435,9 @@ func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
}
}
func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
va := v.Addr()
if va.Kind() == reflect.Ptr && va.IsNil() {
if va.IsNil() {
e.WriteString("null")
return
}
......@@ -439,6 +452,37 @@ func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool)
}
}
func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
if v.Kind() == reflect.Ptr && v.IsNil() {
e.WriteString("null")
return
}
m := v.Interface().(encoding.TextMarshaler)
b, err := m.MarshalText()
if err == nil {
_, err = e.stringBytes(b)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err})
}
}
func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(encoding.TextMarshaler)
b, err := m.MarshalText()
if err == nil {
_, err = e.stringBytes(b)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err})
}
}
func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
if quoted {
e.WriteByte('"')
......@@ -728,6 +772,7 @@ func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
func (sv stringValues) get(i int) string { return sv[i].String() }
// NOTE: keep in sync with stringBytes below.
func (e *encodeState) string(s string) (int, error) {
len0 := e.Len()
e.WriteByte('"')
......@@ -800,6 +845,79 @@ func (e *encodeState) string(s string) (int, error) {
return e.Len() - len0, nil
}
// NOTE: keep in sync with string above.
func (e *encodeState) stringBytes(s []byte) (int, error) {
len0 := e.Len()
e.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
i++
continue
}
if start < i {
e.Write(s[start:i])
}
switch b {
case '\\', '"':
e.WriteByte('\\')
e.WriteByte(b)
case '\n':
e.WriteByte('\\')
e.WriteByte('n')
case '\r':
e.WriteByte('\\')
e.WriteByte('r')
default:
// This encodes bytes < 0x20 except for \n and \r,
// as well as < and >. The latter are escaped because they
// can lead to security holes when user-controlled strings
// are rendered into JSON and served to some browsers.
e.WriteString(`\u00`)
e.WriteByte(hex[b>>4])
e.WriteByte(hex[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRune(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
e.Write(s[start:i])
}
e.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
e.Write(s[start:i])
}
e.WriteString(`\u202`)
e.WriteByte(hex[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
e.Write(s[start:])
}
e.WriteByte('"')
return e.Len() - len0, nil
}
// A field represents a single field found in a struct.
type field struct {
name string
......
......@@ -9,6 +9,7 @@ import (
"math"
"reflect"
"testing"
"unicode"
)
type Optionals struct {
......@@ -146,19 +147,46 @@ func (Val) MarshalJSON() ([]byte, error) {
return []byte(`"val"`), nil
}
// RefText has Marshaler and Unmarshaler methods with pointer receiver.
type RefText int
func (*RefText) MarshalText() ([]byte, error) {
return []byte(`"ref"`), nil
}
func (r *RefText) UnmarshalText([]byte) error {
*r = 13
return nil
}
// ValText has Marshaler methods with value receiver.
type ValText int
func (ValText) MarshalText() ([]byte, error) {
return []byte(`"val"`), nil
}
func TestRefValMarshal(t *testing.T) {
var s = struct {
R0 Ref
R1 *Ref
R2 RefText
R3 *RefText
V0 Val
V1 *Val
V2 ValText
V3 *ValText
}{
R0: 12,
R1: new(Ref),
R2: 14,
R3: new(RefText),
V0: 13,
V1: new(Val),
V2: 15,
V3: new(ValText),
}
const want = `{"R0":"ref","R1":"ref","V0":"val","V1":"val"}`
const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}`
b, err := Marshal(&s)
if err != nil {
t.Fatalf("Marshal: %v", err)
......@@ -175,15 +203,32 @@ func (C) MarshalJSON() ([]byte, error) {
return []byte(`"<&>"`), nil
}
// CText implements Marshaler and returns unescaped text.
type CText int
func (CText) MarshalText() ([]byte, error) {
return []byte(`"<&>"`), nil
}
func TestMarshalerEscaping(t *testing.T) {
var c C
const want = `"\u003c\u0026\u003e"`
want := `"\u003c\u0026\u003e"`
b, err := Marshal(c)
if err != nil {
t.Fatalf("Marshal: %v", err)
t.Fatalf("Marshal(c): %v", err)
}
if got := string(b); got != want {
t.Errorf("got %q, want %q", got, want)
t.Errorf("Marshal(c) = %#q, want %#q", got, want)
}
var ct CText
want = `"\"\u003c\u0026\u003e\""`
b, err = Marshal(ct)
if err != nil {
t.Fatalf("Marshal(ct): %v", err)
}
if got := string(b); got != want {
t.Errorf("Marshal(ct) = %#q, want %#q", got, want)
}
}
......@@ -310,3 +355,49 @@ func TestDuplicatedFieldDisappears(t *testing.T) {
t.Fatalf("Marshal: got %s want %s", got, want)
}
}
func TestStringBytes(t *testing.T) {
// Test that encodeState.stringBytes and encodeState.string use the same encoding.
es := &encodeState{}
var r []rune
for i := '\u0000'; i <= unicode.MaxRune; i++ {
r = append(r, i)
}
s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too
_, err := es.string(s)
if err != nil {
t.Fatal(err)
}
esBytes := &encodeState{}
_, err = esBytes.stringBytes([]byte(s))
if err != nil {
t.Fatal(err)
}
enc := es.Buffer.String()
encBytes := esBytes.Buffer.String()
if enc != encBytes {
i := 0
for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
i++
}
enc = enc[i:]
encBytes = encBytes[i:]
i = 0
for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
i++
}
enc = enc[:len(enc)-i]
encBytes = encBytes[:len(encBytes)-i]
if len(enc) > 20 {
enc = enc[:20] + "..."
}
if len(encBytes) > 20 {
encBytes = encBytes[:20] + "..."
}
t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment