Commit ce3ba126 authored by Russ Cox's avatar Russ Cox

encoding/gob: support new generic interfaces in package encoding

R=r
CC=golang-dev
https://golang.org/cl/12681044
parent 48b90bbc
...@@ -415,6 +415,16 @@ func (deb *debugger) typeDefinition(indent tab, id typeId) { ...@@ -415,6 +415,16 @@ func (deb *debugger) typeDefinition(indent tab, id typeId) {
deb.delta(1) deb.delta(1)
com := deb.common() com := deb.common()
wire.GobEncoderT = &gobEncoderType{com} wire.GobEncoderT = &gobEncoderType{com}
case 5: // BinaryMarshaler type, one field of {{Common}}
// Field number 0 is CommonType
deb.delta(1)
com := deb.common()
wire.BinaryMarshalerT = &gobEncoderType{com}
case 6: // TextMarshaler type, one field of {{Common}}
// Field number 0 is CommonType
deb.delta(1)
com := deb.common()
wire.TextMarshalerT = &gobEncoderType{com}
default: default:
errorf("bad field in type %d", fieldNum) errorf("bad field in type %d", fieldNum)
} }
......
...@@ -9,6 +9,7 @@ package gob ...@@ -9,6 +9,7 @@ package gob
import ( import (
"bytes" "bytes"
"encoding"
"errors" "errors"
"io" "io"
"math" "math"
...@@ -767,15 +768,22 @@ func (dec *Decoder) ignoreInterface(state *decoderState) { ...@@ -767,15 +768,22 @@ func (dec *Decoder) ignoreInterface(state *decoderState) {
// decodeGobDecoder decodes something implementing the GobDecoder interface. // decodeGobDecoder decodes something implementing the GobDecoder interface.
// The data is encoded as a byte slice. // The data is encoded as a byte slice.
func (dec *Decoder) decodeGobDecoder(state *decoderState, v reflect.Value) { func (dec *Decoder) decodeGobDecoder(ut *userTypeInfo, state *decoderState, v reflect.Value) {
// Read the bytes for the value. // Read the bytes for the value.
b := make([]byte, state.decodeUint()) b := make([]byte, state.decodeUint())
_, err := state.b.Read(b) _, err := state.b.Read(b)
if err != nil { if err != nil {
error_(err) error_(err)
} }
// We know it's a GobDecoder, so just call the method directly. // We know it's one of these.
switch ut.externalDec {
case xGob:
err = v.Interface().(GobDecoder).GobDecode(b) err = v.Interface().(GobDecoder).GobDecode(b)
case xBinary:
err = v.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(b)
case xText:
err = v.Interface().(encoding.TextUnmarshaler).UnmarshalText(b)
}
if err != nil { if err != nil {
error_(err) error_(err)
} }
...@@ -825,9 +833,10 @@ var decIgnoreOpMap = map[typeId]decOp{ ...@@ -825,9 +833,10 @@ var decIgnoreOpMap = map[typeId]decOp{
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) { func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
ut := userType(rt) ut := userType(rt)
// If the type implements GobEncoder, we handle it without further processing. // If the type implements GobEncoder, we handle it without further processing.
if ut.isGobDecoder { if ut.externalDec != 0 {
return dec.gobDecodeOpFor(ut) return dec.gobDecodeOpFor(ut)
} }
// If this type is already in progress, it's a recursive type (e.g. map[string]*T). // If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building. // Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil { if opPtr := inProgress[rt]; opPtr != nil {
...@@ -954,7 +963,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { ...@@ -954,7 +963,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
state.dec.ignoreStruct(*enginePtr) state.dec.ignoreStruct(*enginePtr)
} }
case wire.GobEncoderT != nil: case wire.GobEncoderT != nil, wire.BinaryMarshalerT != nil, wire.TextMarshalerT != nil:
op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.ignoreGobDecoder(state) state.dec.ignoreGobDecoder(state)
} }
...@@ -993,7 +1002,7 @@ func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { ...@@ -993,7 +1002,7 @@ func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) {
} else { } else {
v = reflect.NewAt(rcvrType, p).Elem() v = reflect.NewAt(rcvrType, p).Elem()
} }
state.dec.decodeGobDecoder(state, v) state.dec.decodeGobDecoder(ut, state, v)
} }
return &op, int(ut.indir) return &op, int(ut.indir)
...@@ -1010,12 +1019,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re ...@@ -1010,12 +1019,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re
inProgress[fr] = fw inProgress[fr] = fw
ut := userType(fr) ut := userType(fr)
wire, ok := dec.wireType[fw] wire, ok := dec.wireType[fw]
// If fr is a GobDecoder, the wire type must be GobEncoder. // If wire was encoded with an encoding method, fr must have that method.
// And if fr is not a GobDecoder, the wire type must not be either. // And if not, it must not.
if ut.isGobDecoder != (ok && wire.GobEncoderT != nil) { // the parentheses look odd but are correct. // At most one of the booleans in ut is set.
// We could possibly relax this constraint in the future in order to
// choose the decoding method using the data in the wireType.
// The parentheses look odd but are correct.
if (ut.externalDec == xGob) != (ok && wire.GobEncoderT != nil) ||
(ut.externalDec == xBinary) != (ok && wire.BinaryMarshalerT != nil) ||
(ut.externalDec == xText) != (ok && wire.TextMarshalerT != nil) {
return false return false
} }
if ut.isGobDecoder { // This test trumps all others. if ut.externalDec != 0 { // This test trumps all others.
return true return true
} }
switch t := ut.base; t.Kind() { switch t := ut.base; t.Kind() {
...@@ -1114,8 +1129,7 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err ...@@ -1114,8 +1129,7 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err
func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) { func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) {
rt := ut.base rt := ut.base
srt := rt srt := rt
if srt.Kind() != reflect.Struct || if srt.Kind() != reflect.Struct || ut.externalDec != 0 {
ut.isGobDecoder {
return dec.compileSingle(remoteId, ut) return dec.compileSingle(remoteId, ut)
} }
var wireStruct *structType var wireStruct *structType
...@@ -1223,7 +1237,7 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) { ...@@ -1223,7 +1237,7 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) {
return return
} }
engine := *enginePtr engine := *enginePtr
if st := base; st.Kind() == reflect.Struct && !ut.isGobDecoder { if st := base; st.Kind() == reflect.Struct && ut.externalDec == 0 {
if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 {
name := base.Name() name := base.Name()
errorf("type mismatch: no fields matched compiling decoder for %s", name) errorf("type mismatch: no fields matched compiling decoder for %s", name)
......
...@@ -67,19 +67,28 @@ point values may be received into any floating point variable. However, ...@@ -67,19 +67,28 @@ point values may be received into any floating point variable. However,
the destination variable must be able to represent the value or the decode the destination variable must be able to represent the value or the decode
operation will fail. operation will fail.
Structs, arrays and slices are also supported. Structs encode and Structs, arrays and slices are also supported. Structs encode and decode only
decode only exported fields. Strings and arrays of bytes are supported exported fields. Strings and arrays of bytes are supported with a special,
with a special, efficient representation (see below). When a slice efficient representation (see below). When a slice is decoded, if the existing
is decoded, if the existing slice has capacity the slice will be slice has capacity the slice will be extended in place; if not, a new array is
extended in place; if not, a new array is allocated. Regardless, allocated. Regardless, the length of the resulting slice reports the number of
the length of the resulting slice reports the number of elements elements decoded.
decoded.
Functions and channels cannot be sent in a gob. Attempting to encode a value
Functions and channels cannot be sent in a gob. Attempting that contains one will fail.
to encode a value that contains one will fail.
Gob can encode a value of any type implementing the GobEncoder,
The rest of this comment documents the encoding, details that are not important encoding.BinaryMarshaler, or encoding.TextMarshaler interfaces by calling the
for most users. Details are presented bottom-up. corresponding method, in that order of preference.
Gob can decode a value of any type implementing the GobDecoder,
encoding.BinaryUnmarshaler, or encoding.TextUnmarshaler interfaces by calling
the corresponding method, again in that order of preference.
Encoding Details
This section documents the encoding, details that are not important for most
users. Details are presented bottom-up.
An unsigned integer is sent one of two ways. If it is less than 128, it is sent An unsigned integer is sent one of two ways. If it is less than 128, it is sent
as a byte with that value. Otherwise it is sent as a minimal-length big-endian as a byte with that value. Otherwise it is sent as a minimal-length big-endian
......
...@@ -6,6 +6,7 @@ package gob ...@@ -6,6 +6,7 @@ package gob
import ( import (
"bytes" "bytes"
"encoding"
"math" "math"
"reflect" "reflect"
"unsafe" "unsafe"
...@@ -511,10 +512,20 @@ func isZero(val reflect.Value) bool { ...@@ -511,10 +512,20 @@ func isZero(val reflect.Value) bool {
// encGobEncoder encodes a value that implements the GobEncoder interface. // encGobEncoder encodes a value that implements the GobEncoder interface.
// The data is sent as a byte array. // The data is sent as a byte array.
func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, v reflect.Value) { func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, ut *userTypeInfo, v reflect.Value) {
// TODO: should we catch panics from the called method? // TODO: should we catch panics from the called method?
// We know it's a GobEncoder, so just call the method directly.
data, err := v.Interface().(GobEncoder).GobEncode() var data []byte
var err error
// We know it's one of these.
switch ut.externalEnc {
case xGob:
data, err = v.Interface().(GobEncoder).GobEncode()
case xBinary:
data, err = v.Interface().(encoding.BinaryMarshaler).MarshalBinary()
case xText:
data, err = v.Interface().(encoding.TextMarshaler).MarshalText()
}
if err != nil { if err != nil {
error_(err) error_(err)
} }
...@@ -550,7 +561,7 @@ var encOpTable = [...]encOp{ ...@@ -550,7 +561,7 @@ var encOpTable = [...]encOp{
func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) { func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) {
ut := userType(rt) ut := userType(rt)
// If the type implements GobEncoder, we handle it without further processing. // If the type implements GobEncoder, we handle it without further processing.
if ut.isGobEncoder { if ut.externalEnc != 0 {
return enc.gobEncodeOpFor(ut) return enc.gobEncodeOpFor(ut)
} }
// If this type is already in progress, it's a recursive type (e.g. map[string]*T). // If this type is already in progress, it's a recursive type (e.g. map[string]*T).
...@@ -661,7 +672,7 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { ...@@ -661,7 +672,7 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) {
return return
} }
state.update(i) state.update(i)
state.enc.encodeGobEncoder(state.b, v) state.enc.encodeGobEncoder(state.b, ut, v)
} }
return &op, int(ut.encIndir) // encIndir: op will get called with p == address of receiver. return &op, int(ut.encIndir) // encIndir: op will get called with p == address of receiver.
} }
...@@ -672,10 +683,10 @@ func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine { ...@@ -672,10 +683,10 @@ func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine {
engine := new(encEngine) engine := new(encEngine)
seen := make(map[reflect.Type]*encOp) seen := make(map[reflect.Type]*encOp)
rt := ut.base rt := ut.base
if ut.isGobEncoder { if ut.externalEnc != 0 {
rt = ut.user rt = ut.user
} }
if !ut.isGobEncoder && if ut.externalEnc == 0 &&
srt.Kind() == reflect.Struct { srt.Kind() == reflect.Struct {
for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ { for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ {
f := srt.Field(fieldNum) f := srt.Field(fieldNum)
...@@ -736,13 +747,13 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf ...@@ -736,13 +747,13 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf
defer catchError(&enc.err) defer catchError(&enc.err)
engine := enc.lockAndGetEncEngine(ut) engine := enc.lockAndGetEncEngine(ut)
indir := ut.indir indir := ut.indir
if ut.isGobEncoder { if ut.externalEnc != 0 {
indir = int(ut.encIndir) indir = int(ut.encIndir)
} }
for i := 0; i < indir; i++ { for i := 0; i < indir; i++ {
value = reflect.Indirect(value) value = reflect.Indirect(value)
} }
if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct { if ut.externalEnc == 0 && value.Type().Kind() == reflect.Struct {
enc.encodeStruct(b, engine, unsafeAddr(value)) enc.encodeStruct(b, engine, unsafeAddr(value))
} else { } else {
enc.encodeSingle(b, engine, unsafeAddr(value)) enc.encodeSingle(b, engine, unsafeAddr(value))
......
...@@ -135,7 +135,7 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp ...@@ -135,7 +135,7 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp
// sendType sends the type info to the other side, if necessary. // sendType sends the type info to the other side, if necessary.
func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
ut := userType(origt) ut := userType(origt)
if ut.isGobEncoder { if ut.externalEnc != 0 {
// The rules are different: regardless of the underlying type's representation, // The rules are different: regardless of the underlying type's representation,
// we need to tell the other side that the base type is a GobEncoder. // we need to tell the other side that the base type is a GobEncoder.
return enc.sendActualType(w, state, ut, ut.base) return enc.sendActualType(w, state, ut, ut.base)
...@@ -184,7 +184,7 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use ...@@ -184,7 +184,7 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use
// Make sure the type is known to the other side. // Make sure the type is known to the other side.
// First, have we already sent this type? // First, have we already sent this type?
rt := ut.base rt := ut.base
if ut.isGobEncoder { if ut.externalEnc != 0 {
rt = ut.user rt = ut.user
} }
if _, alreadySent := enc.sent[rt]; !alreadySent { if _, alreadySent := enc.sent[rt]; !alreadySent {
......
...@@ -34,6 +34,14 @@ type Gobber int ...@@ -34,6 +34,14 @@ type Gobber int
type ValueGobber string // encodes with a value, decodes with a pointer. type ValueGobber string // encodes with a value, decodes with a pointer.
type BinaryGobber int
type BinaryValueGobber string
type TextGobber int
type TextValueGobber string
// The relevant methods // The relevant methods
func (g *ByteStruct) GobEncode() ([]byte, error) { func (g *ByteStruct) GobEncode() ([]byte, error) {
...@@ -101,6 +109,24 @@ func (g *Gobber) GobDecode(data []byte) error { ...@@ -101,6 +109,24 @@ func (g *Gobber) GobDecode(data []byte) error {
return err return err
} }
func (g *BinaryGobber) MarshalBinary() ([]byte, error) {
return []byte(fmt.Sprintf("VALUE=%d", *g)), nil
}
func (g *BinaryGobber) UnmarshalBinary(data []byte) error {
_, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g))
return err
}
func (g *TextGobber) MarshalText() ([]byte, error) {
return []byte(fmt.Sprintf("VALUE=%d", *g)), nil
}
func (g *TextGobber) UnmarshalText(data []byte) error {
_, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g))
return err
}
func (v ValueGobber) GobEncode() ([]byte, error) { func (v ValueGobber) GobEncode() ([]byte, error) {
return []byte(fmt.Sprintf("VALUE=%s", v)), nil return []byte(fmt.Sprintf("VALUE=%s", v)), nil
} }
...@@ -110,6 +136,24 @@ func (v *ValueGobber) GobDecode(data []byte) error { ...@@ -110,6 +136,24 @@ func (v *ValueGobber) GobDecode(data []byte) error {
return err return err
} }
func (v BinaryValueGobber) MarshalBinary() ([]byte, error) {
return []byte(fmt.Sprintf("VALUE=%s", v)), nil
}
func (v *BinaryValueGobber) UnmarshalBinary(data []byte) error {
_, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v))
return err
}
func (v TextValueGobber) MarshalText() ([]byte, error) {
return []byte(fmt.Sprintf("VALUE=%s", v)), nil
}
func (v *TextValueGobber) UnmarshalText(data []byte) error {
_, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v))
return err
}
// Structs that include GobEncodable fields. // Structs that include GobEncodable fields.
type GobTest0 struct { type GobTest0 struct {
...@@ -130,28 +174,42 @@ type GobTest2 struct { ...@@ -130,28 +174,42 @@ type GobTest2 struct {
type GobTest3 struct { type GobTest3 struct {
X int // guarantee we have something in common with GobTest* X int // guarantee we have something in common with GobTest*
G *Gobber G *Gobber
B *BinaryGobber
T *TextGobber
} }
type GobTest4 struct { type GobTest4 struct {
X int // guarantee we have something in common with GobTest* X int // guarantee we have something in common with GobTest*
V ValueGobber V ValueGobber
BV BinaryValueGobber
TV TextValueGobber
} }
type GobTest5 struct { type GobTest5 struct {
X int // guarantee we have something in common with GobTest* X int // guarantee we have something in common with GobTest*
V *ValueGobber V *ValueGobber
BV *BinaryValueGobber
TV *TextValueGobber
} }
type GobTest6 struct { type GobTest6 struct {
X int // guarantee we have something in common with GobTest* X int // guarantee we have something in common with GobTest*
V ValueGobber V ValueGobber
W *ValueGobber W *ValueGobber
BV BinaryValueGobber
BW *BinaryValueGobber
TV TextValueGobber
TW *TextValueGobber
} }
type GobTest7 struct { type GobTest7 struct {
X int // guarantee we have something in common with GobTest* X int // guarantee we have something in common with GobTest*
V *ValueGobber V *ValueGobber
W ValueGobber W ValueGobber
BV *BinaryValueGobber
BW BinaryValueGobber
TV *TextValueGobber
TW TextValueGobber
} }
type GobTestIgnoreEncoder struct { type GobTestIgnoreEncoder struct {
...@@ -198,7 +256,9 @@ func TestGobEncoderField(t *testing.T) { ...@@ -198,7 +256,9 @@ func TestGobEncoderField(t *testing.T) {
// Now a field that's not a structure. // Now a field that's not a structure.
b.Reset() b.Reset()
gobber := Gobber(23) gobber := Gobber(23)
err = enc.Encode(GobTest3{17, &gobber}) bgobber := BinaryGobber(24)
tgobber := TextGobber(25)
err = enc.Encode(GobTest3{17, &gobber, &bgobber, &tgobber})
if err != nil { if err != nil {
t.Fatal("encode error:", err) t.Fatal("encode error:", err)
} }
...@@ -207,7 +267,7 @@ func TestGobEncoderField(t *testing.T) { ...@@ -207,7 +267,7 @@ func TestGobEncoderField(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("decode error:", err) t.Fatal("decode error:", err)
} }
if *y.G != 23 { if *y.G != 23 || *y.B != 24 || *y.T != 25 {
t.Errorf("expected '23 got %d", *y.G) t.Errorf("expected '23 got %d", *y.G)
} }
} }
...@@ -357,7 +417,7 @@ func TestGobEncoderValueEncoder(t *testing.T) { ...@@ -357,7 +417,7 @@ func TestGobEncoderValueEncoder(t *testing.T) {
// first, string in field to byte in field // first, string in field to byte in field
b := new(bytes.Buffer) b := new(bytes.Buffer)
enc := NewEncoder(b) enc := NewEncoder(b)
err := enc.Encode(GobTest4{17, ValueGobber("hello")}) err := enc.Encode(GobTest4{17, ValueGobber("hello"), BinaryValueGobber("Καλημέρα"), TextValueGobber("こんにちは")})
if err != nil { if err != nil {
t.Fatal("encode error:", err) t.Fatal("encode error:", err)
} }
...@@ -367,7 +427,7 @@ func TestGobEncoderValueEncoder(t *testing.T) { ...@@ -367,7 +427,7 @@ func TestGobEncoderValueEncoder(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("decode error:", err) t.Fatal("decode error:", err)
} }
if *x.V != "hello" { if *x.V != "hello" || *x.BV != "Καλημέρα" || *x.TV != "こんにちは" {
t.Errorf("expected `hello` got %s", x.V) t.Errorf("expected `hello` got %s", x.V)
} }
} }
...@@ -377,13 +437,17 @@ func TestGobEncoderValueEncoder(t *testing.T) { ...@@ -377,13 +437,17 @@ func TestGobEncoderValueEncoder(t *testing.T) {
func TestGobEncoderValueThenPointer(t *testing.T) { func TestGobEncoderValueThenPointer(t *testing.T) {
v := ValueGobber("forty-two") v := ValueGobber("forty-two")
w := ValueGobber("six-by-nine") w := ValueGobber("six-by-nine")
bv := BinaryValueGobber("1nanocentury")
bw := BinaryValueGobber("πseconds")
tv := TextValueGobber("gravitationalacceleration")
tw := TextValueGobber("π²ft/s²")
// this was a bug: encoding a GobEncoder by value before a GobEncoder // this was a bug: encoding a GobEncoder by value before a GobEncoder
// pointer would cause duplicate type definitions to be sent. // pointer would cause duplicate type definitions to be sent.
b := new(bytes.Buffer) b := new(bytes.Buffer)
enc := NewEncoder(b) enc := NewEncoder(b)
if err := enc.Encode(GobTest6{42, v, &w}); err != nil { if err := enc.Encode(GobTest6{42, v, &w, bv, &bw, tv, &tw}); err != nil {
t.Fatal("encode error:", err) t.Fatal("encode error:", err)
} }
dec := NewDecoder(b) dec := NewDecoder(b)
...@@ -391,6 +455,7 @@ func TestGobEncoderValueThenPointer(t *testing.T) { ...@@ -391,6 +455,7 @@ func TestGobEncoderValueThenPointer(t *testing.T) {
if err := dec.Decode(x); err != nil { if err := dec.Decode(x); err != nil {
t.Fatal("decode error:", err) t.Fatal("decode error:", err)
} }
if got, want := x.V, v; got != want { if got, want := x.V, v; got != want {
t.Errorf("v = %q, want %q", got, want) t.Errorf("v = %q, want %q", got, want)
} }
...@@ -399,6 +464,24 @@ func TestGobEncoderValueThenPointer(t *testing.T) { ...@@ -399,6 +464,24 @@ func TestGobEncoderValueThenPointer(t *testing.T) {
} else if *got != want { } else if *got != want {
t.Errorf("w = %q, want %q", *got, want) t.Errorf("w = %q, want %q", *got, want)
} }
if got, want := x.BV, bv; got != want {
t.Errorf("bv = %q, want %q", got, want)
}
if got, want := x.BW, bw; got == nil {
t.Errorf("bw = nil, want %q", want)
} else if *got != want {
t.Errorf("bw = %q, want %q", *got, want)
}
if got, want := x.TV, tv; got != want {
t.Errorf("tv = %q, want %q", got, want)
}
if got, want := x.TW, tw; got == nil {
t.Errorf("tw = nil, want %q", want)
} else if *got != want {
t.Errorf("tw = %q, want %q", *got, want)
}
} }
// Test that we can use a pointer then a value type of a GobEncoder // Test that we can use a pointer then a value type of a GobEncoder
...@@ -406,10 +489,14 @@ func TestGobEncoderValueThenPointer(t *testing.T) { ...@@ -406,10 +489,14 @@ func TestGobEncoderValueThenPointer(t *testing.T) {
func TestGobEncoderPointerThenValue(t *testing.T) { func TestGobEncoderPointerThenValue(t *testing.T) {
v := ValueGobber("forty-two") v := ValueGobber("forty-two")
w := ValueGobber("six-by-nine") w := ValueGobber("six-by-nine")
bv := BinaryValueGobber("1nanocentury")
bw := BinaryValueGobber("πseconds")
tv := TextValueGobber("gravitationalacceleration")
tw := TextValueGobber("π²ft/s²")
b := new(bytes.Buffer) b := new(bytes.Buffer)
enc := NewEncoder(b) enc := NewEncoder(b)
if err := enc.Encode(GobTest7{42, &v, w}); err != nil { if err := enc.Encode(GobTest7{42, &v, w, &bv, bw, &tv, tw}); err != nil {
t.Fatal("encode error:", err) t.Fatal("encode error:", err)
} }
dec := NewDecoder(b) dec := NewDecoder(b)
...@@ -417,14 +504,33 @@ func TestGobEncoderPointerThenValue(t *testing.T) { ...@@ -417,14 +504,33 @@ func TestGobEncoderPointerThenValue(t *testing.T) {
if err := dec.Decode(x); err != nil { if err := dec.Decode(x); err != nil {
t.Fatal("decode error:", err) t.Fatal("decode error:", err)
} }
if got, want := x.V, v; got == nil { if got, want := x.V, v; got == nil {
t.Errorf("v = nil, want %q", want) t.Errorf("v = nil, want %q", want)
} else if *got != want { } else if *got != want {
t.Errorf("v = %q, want %q", got, want) t.Errorf("v = %q, want %q", *got, want)
} }
if got, want := x.W, w; got != want { if got, want := x.W, w; got != want {
t.Errorf("w = %q, want %q", got, want) t.Errorf("w = %q, want %q", got, want)
} }
if got, want := x.BV, bv; got == nil {
t.Errorf("bv = nil, want %q", want)
} else if *got != want {
t.Errorf("bv = %q, want %q", *got, want)
}
if got, want := x.BW, bw; got != want {
t.Errorf("bw = %q, want %q", got, want)
}
if got, want := x.TV, tv; got == nil {
t.Errorf("tv = nil, want %q", want)
} else if *got != want {
t.Errorf("tv = %q, want %q", *got, want)
}
if got, want := x.TW, tw; got != want {
t.Errorf("tw = %q, want %q", got, want)
}
} }
func TestGobEncoderFieldTypeError(t *testing.T) { func TestGobEncoderFieldTypeError(t *testing.T) {
...@@ -521,7 +627,9 @@ func TestGobEncoderIgnoreNonStructField(t *testing.T) { ...@@ -521,7 +627,9 @@ func TestGobEncoderIgnoreNonStructField(t *testing.T) {
// First a field that's a structure. // First a field that's a structure.
enc := NewEncoder(b) enc := NewEncoder(b)
gobber := Gobber(23) gobber := Gobber(23)
err := enc.Encode(GobTest3{17, &gobber}) bgobber := BinaryGobber(24)
tgobber := TextGobber(25)
err := enc.Encode(GobTest3{17, &gobber, &bgobber, &tgobber})
if err != nil { if err != nil {
t.Fatal("encode error:", err) t.Fatal("encode error:", err)
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package gob package gob
import ( import (
"encoding"
"errors" "errors"
"fmt" "fmt"
"os" "os"
...@@ -21,12 +22,19 @@ type userTypeInfo struct { ...@@ -21,12 +22,19 @@ type userTypeInfo struct {
user reflect.Type // the type the user handed us user reflect.Type // the type the user handed us
base reflect.Type // the base type after all indirections base reflect.Type // the base type after all indirections
indir int // number of indirections to reach the base type indir int // number of indirections to reach the base type
isGobEncoder bool // does the type implement GobEncoder? externalEnc int // xGob, xBinary, or xText
isGobDecoder bool // does the type implement GobDecoder? externalDec int // xGob, xBinary or xText
encIndir int8 // number of indirections to reach the receiver type; may be negative encIndir int8 // number of indirections to reach the receiver type; may be negative
decIndir int8 // number of indirections to reach the receiver type; may be negative decIndir int8 // number of indirections to reach the receiver type; may be negative
} }
// externalEncoding bits
const (
xGob = 1 + iota // GobEncoder or GobDecoder
xBinary // encoding.BinaryMarshaler or encoding.BinaryUnmarshaler
xText // encoding.TextMarshaler or encoding.TextUnmarshaler
)
var ( var (
// Protected by an RWMutex because we read it a lot and write // Protected by an RWMutex because we read it a lot and write
// it only when we see a new type, typically when compiling. // it only when we see a new type, typically when compiling.
...@@ -75,8 +83,23 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err error) { ...@@ -75,8 +83,23 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err error) {
} }
ut.indir++ ut.indir++
} }
ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderInterfaceType)
ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderInterfaceType) if ok, indir := implementsInterface(ut.user, gobEncoderInterfaceType); ok {
ut.externalEnc, ut.encIndir = xGob, indir
} else if ok, indir := implementsInterface(ut.user, binaryMarshalerInterfaceType); ok {
ut.externalEnc, ut.encIndir = xBinary, indir
} else if ok, indir := implementsInterface(ut.user, textMarshalerInterfaceType); ok {
ut.externalEnc, ut.encIndir = xText, indir
}
if ok, indir := implementsInterface(ut.user, gobDecoderInterfaceType); ok {
ut.externalDec, ut.decIndir = xGob, indir
} else if ok, indir := implementsInterface(ut.user, binaryUnmarshalerInterfaceType); ok {
ut.externalDec, ut.decIndir = xBinary, indir
} else if ok, indir := implementsInterface(ut.user, textUnmarshalerInterfaceType); ok {
ut.externalDec, ut.decIndir = xText, indir
}
userTypeCache[rt] = ut userTypeCache[rt] = ut
return return
} }
...@@ -84,6 +107,10 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err error) { ...@@ -84,6 +107,10 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err error) {
var ( var (
gobEncoderInterfaceType = reflect.TypeOf((*GobEncoder)(nil)).Elem() gobEncoderInterfaceType = reflect.TypeOf((*GobEncoder)(nil)).Elem()
gobDecoderInterfaceType = reflect.TypeOf((*GobDecoder)(nil)).Elem() gobDecoderInterfaceType = reflect.TypeOf((*GobDecoder)(nil)).Elem()
binaryMarshalerInterfaceType = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem()
binaryUnmarshalerInterfaceType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
textMarshalerInterfaceType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshalerInterfaceType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
) )
// implementsInterface reports whether the type implements the // implementsInterface reports whether the type implements the
...@@ -412,7 +439,7 @@ func newStructType(name string) *structType { ...@@ -412,7 +439,7 @@ func newStructType(name string) *structType {
// works through typeIds and userTypeInfos alone. // works through typeIds and userTypeInfos alone.
func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) { func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) {
// Does this type implement GobEncoder? // Does this type implement GobEncoder?
if ut.isGobEncoder { if ut.externalEnc != 0 {
return newGobEncoderType(name), nil return newGobEncoderType(name), nil
} }
var err error var err error
...@@ -598,6 +625,8 @@ type wireType struct { ...@@ -598,6 +625,8 @@ type wireType struct {
StructT *structType StructT *structType
MapT *mapType MapT *mapType
GobEncoderT *gobEncoderType GobEncoderT *gobEncoderType
BinaryMarshalerT *gobEncoderType
TextMarshalerT *gobEncoderType
} }
func (w *wireType) string() string { func (w *wireType) string() string {
...@@ -616,6 +645,10 @@ func (w *wireType) string() string { ...@@ -616,6 +645,10 @@ func (w *wireType) string() string {
return w.MapT.Name return w.MapT.Name
case w.GobEncoderT != nil: case w.GobEncoderT != nil:
return w.GobEncoderT.Name return w.GobEncoderT.Name
case w.BinaryMarshalerT != nil:
return w.BinaryMarshalerT.Name
case w.TextMarshalerT != nil:
return w.TextMarshalerT.Name
} }
return unknown return unknown
} }
...@@ -631,7 +664,7 @@ var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock ...@@ -631,7 +664,7 @@ var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock
// typeLock must be held. // typeLock must be held.
func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) { func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) {
rt := ut.base rt := ut.base
if ut.isGobEncoder { if ut.externalEnc != 0 {
// We want the user type, not the base type. // We want the user type, not the base type.
rt = ut.user rt = ut.user
} }
...@@ -646,12 +679,20 @@ func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) { ...@@ -646,12 +679,20 @@ func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) {
} }
info.id = gt.id() info.id = gt.id()
if ut.isGobEncoder { if ut.externalEnc != 0 {
userType, err := getType(rt.Name(), ut, rt) userType, err := getType(rt.Name(), ut, rt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)} gt := userType.id().gobType().(*gobEncoderType)
switch ut.externalEnc {
case xGob:
info.wire = &wireType{GobEncoderT: gt}
case xBinary:
info.wire = &wireType{BinaryMarshalerT: gt}
case xText:
info.wire = &wireType{TextMarshalerT: gt}
}
typeInfoMap[ut.user] = info typeInfoMap[ut.user] = info
return info, nil return info, nil
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment