Commit 5d906469 authored by Rob Pike's avatar Rob Pike

gob: allow exchange of interface values

The implemetation describes each value as a string identifying the
concrete type of the value, followed by the usual encoding of that
value.  All types to be exchanged as contents of interface values
must be registered ahead of time with the new Register function.
Although this would not seem strictly necessary, the linker garbage
collects unused types so without some mechanism to guarantee
the type exists in the binary, there could be unpleasant surprises.
Moreover, the receiver needs a reflect.Type of the value to be
written in order to be able to save the data. A Register function
seems necessary.

The implementation may require defining types in the middle of
of sending a value.  The old code never did this. Therefore there
has been some refactoring to make the encoder and decoder
work recursively.

This change changes the internal type IDs. Existing gob archives
will break with this change.  Apologies for that. If this is a deal
breaker it should be possible to create a conversion tool.

Error handling is too complicated in this code. A subsequent
change should clean it up.

R=rsc
CC=golang-dev
https://golang.org/cl/2618042
parent 3478891d
......@@ -40,8 +40,7 @@ var encodeT = []EncodeT{
// Test basic encode/decode routines for unsigned integers
func TestUintCodec(t *testing.T) {
b := new(bytes.Buffer)
encState := new(encoderState)
encState.b = b
encState := newEncoderState(b)
for _, tt := range encodeT {
b.Reset()
encodeUint(encState, tt.x)
......@@ -52,7 +51,7 @@ func TestUintCodec(t *testing.T) {
t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes())
}
}
decState := newDecodeState(b)
decState := newDecodeState(&b)
for u := uint64(0); ; u = (u + 1) * 7 {
b.Reset()
encodeUint(encState, u)
......@@ -74,13 +73,12 @@ func TestUintCodec(t *testing.T) {
func verifyInt(i int64, t *testing.T) {
var b = new(bytes.Buffer)
encState := new(encoderState)
encState.b = b
encState := newEncoderState(b)
encodeInt(encState, i)
if encState.err != nil {
t.Error("encodeInt:", i, encState.err)
}
decState := newDecodeState(b)
decState := newDecodeState(&b)
decState.buf = make([]byte, 8)
j := decodeInt(decState)
if decState.err != nil {
......@@ -119,8 +117,7 @@ var bytesResult = []byte{0x07, 0x05, 'h', 'e', 'l', 'l', 'o'}
func newencoderState(b *bytes.Buffer) *encoderState {
b.Reset()
state := new(encoderState)
state.b = b
state := newEncoderState(b)
state.fieldnum = -1
return state
}
......@@ -335,7 +332,8 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un
}
func newDecodeStateFromData(data []byte) *decodeState {
state := newDecodeState(bytes.NewBuffer(data))
b := bytes.NewBuffer(data)
state := newDecodeState(&b)
state.fieldnum = -1
return state
}
......@@ -1020,18 +1018,20 @@ func TestIgnoredFields(t *testing.T) {
}
type Bad0 struct {
inter interface{}
c float
ch chan int
c float
}
var nilEncoder *Encoder
func TestInvalidField(t *testing.T) {
var bad0 Bad0
bad0.inter = 17
bad0.ch = make(chan int)
b := new(bytes.Buffer)
err := encode(b, reflect.NewValue(&bad0))
err := nilEncoder.encode(b, reflect.NewValue(&bad0))
if err == nil {
t.Error("expected error; got none")
} else if strings.Index(err.String(), "interface") < 0 {
} else if strings.Index(err.String(), "type") < 0 {
t.Error("expected type error; got", err)
}
}
......@@ -1103,3 +1103,130 @@ func TestIndirectSliceMapArray(t *testing.T) {
t.Errorf("direct to indirect: ****i.m is %v not %v", ****i.m, d.m)
}
}
// An interface with several implementations
type Squarer interface {
Square() int
}
type Int int
func (i Int) Square() int {
return int(i * i)
}
type Float float
func (f Float) Square() int {
return int(f * f)
}
type Vector []int
func (v Vector) Square() int {
sum := 0
for _, x := range v {
sum += x * x
}
return sum
}
// A struct with interfaces in it.
type InterfaceItem struct {
i int
sq1, sq2, sq3 Squarer
f float
sq []Squarer
}
// The same struct without interfaces
type NoInterfaceItem struct {
i int
f float
}
func TestInterface(t *testing.T) {
iVal := Int(3)
fVal := Float(5)
// Sending a Vector will require that the receiver define a type in the middle of
// receiving the value for item2.
vVal := Vector{1, 2, 3}
b := new(bytes.Buffer)
item1 := &InterfaceItem{1, iVal, fVal, vVal, 11.5, []Squarer{iVal, fVal, nil, vVal}}
// Register the types.
Register(Int(0))
Register(Float(0))
Register(Vector{})
err := NewEncoder(b).Encode(item1)
if err != nil {
t.Error("expected no encode error; got", err)
}
item2 := InterfaceItem{}
err = NewDecoder(b).Decode(&item2)
if err != nil {
t.Fatal("decode:", err)
}
if item2.i != item1.i {
t.Error("normal int did not decode correctly")
}
if item2.sq1 == nil || item2.sq1.Square() != iVal.Square() {
t.Error("Int did not decode correctly")
}
if item2.sq2 == nil || item2.sq2.Square() != fVal.Square() {
t.Error("Float did not decode correctly")
}
if item2.sq3 == nil || item2.sq3.Square() != vVal.Square() {
t.Error("Vector did not decode correctly")
}
if item2.f != item1.f {
t.Error("normal float did not decode correctly")
}
// Now check that we received a slice of Squarers correctly, including a nil element
if len(item1.sq) != len(item2.sq) {
t.Fatalf("[]Squarer length wrong: got %d; expected %d", len(item2.sq), len(item1.sq))
}
for i, v1 := range item1.sq {
v2 := item2.sq[i]
if v1 == nil || v2 == nil {
if v1 != nil || v2 != nil {
t.Errorf("item %d inconsistent nils", i)
}
continue
if v1.Square() != v2.Square() {
t.Errorf("item %d inconsistent values: %v %v", v1, v2)
}
}
}
}
func TestIgnoreInterface(t *testing.T) {
iVal := Int(3)
fVal := Float(5)
// Sending a Vector will require that the receiver define a type in the middle of
// receiving the value for item2.
vVal := Vector{1, 2, 3}
b := new(bytes.Buffer)
item1 := &InterfaceItem{1, iVal, fVal, vVal, 11.5, nil}
// Register the types.
Register(Int(0))
Register(Float(0))
Register(Vector{})
err := NewEncoder(b).Encode(item1)
if err != nil {
t.Error("expected no encode error; got", err)
}
item2 := NoInterfaceItem{}
err = NewDecoder(b).Decode(&item2)
if err != nil {
t.Fatal("decode:", err)
}
if item2.i != item1.i {
t.Error("normal int did not decode correctly")
}
if item2.f != item2.f {
t.Error("normal float did not decode correctly")
}
}
......@@ -22,15 +22,18 @@ var (
errRange = os.ErrorString("gob: internal error: field numbers out of bounds")
)
// The global execution state of an instance of the decoder.
// The execution state of an instance of the decoder. A new state
// is created for nested objects.
type decodeState struct {
b *bytes.Buffer
// The buffer is stored with an extra indirection because it may be replaced
// if we load a type during decode (when reading an interface value).
b **bytes.Buffer
err os.Error
fieldnum int // the last field number read.
buf []byte
}
func newDecodeState(b *bytes.Buffer) *decodeState {
func newDecodeState(b **bytes.Buffer) *decodeState {
d := new(decodeState)
d.b = b
d.buf = make([]byte, uint64Size)
......@@ -404,7 +407,7 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
return *(*uintptr)(up)
}
func decodeSingle(engine *decEngine, rtyp reflect.Type, b *bytes.Buffer, p uintptr, indir int) os.Error {
func decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) os.Error {
p = allocate(rtyp, p, indir)
state := newDecodeState(b)
state.fieldnum = singletonField
......@@ -423,12 +426,12 @@ func decodeSingle(engine *decEngine, rtyp reflect.Type, b *bytes.Buffer, p uintp
return state.err
}
func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error {
func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) os.Error {
p = allocate(rtyp, p, indir)
state := newDecodeState(b)
state.fieldnum = -1
basep := p
for state.err == nil {
for state.b.Len() > 0 && state.err == nil {
delta := int(decodeUint(state))
if delta < 0 {
state.err = os.ErrorString("gob decode: corrupted data: negative delta")
......@@ -453,10 +456,10 @@ func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer,
return state.err
}
func ignoreStruct(engine *decEngine, b *bytes.Buffer) os.Error {
func ignoreStruct(engine *decEngine, b **bytes.Buffer) os.Error {
state := newDecodeState(b)
state.fieldnum = -1
for state.err == nil {
for state.b.Len() > 0 && state.err == nil {
delta := int(decodeUint(state))
if delta < 0 {
state.err = os.ErrorString("gob ignore decode: corrupted data: negative delta")
......@@ -564,7 +567,6 @@ func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error {
return state.err
}
func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error {
n := int(uintptr(decodeUint(state)))
if indir > 0 {
......@@ -588,6 +590,78 @@ func ignoreSlice(state *decodeState, elemOp decOp) os.Error {
return ignoreArrayHelper(state, elemOp, int(decodeUint(state)))
}
// setInterfaceValue sets an interface value to a concrete value through
// reflection. If the concrete value does not implement the interface, the
// setting will panic. This routine turns the panic into an error return.
// This dance avoids manually checking that the value satisfies the
// interface.
// TODO(rsc): avoid panic+recover after fixing issue 327.
func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) (err os.Error) {
defer func() {
if e := recover(); e != nil {
err = e.(os.Error)
}
}()
ivalue.Set(value)
return nil
}
// decodeInterface receives the name of a concrete type followed by its value.
// If the name is empty, the value is nil and no value is sent.
func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) os.Error {
// Create an interface reflect.Value. We need one even for the nil case.
ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue)
// Read the name of the concrete type.
b := make([]byte, decodeUint(state))
state.b.Read(b)
name := string(b)
if name == "" {
// Copy the representation of the nil interface value to the target.
// This is horribly unsafe and special.
*(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get()
return state.err
}
// The concrete type must be registered.
typ, ok := nameToConcreteType[name]
if !ok {
state.err = os.ErrorString("gob: name not registered for interface: " + name)
return state.err
}
// Read the concrete value.
value := reflect.MakeZero(typ)
dec.decodeValueFromBuffer(value, false)
if dec.state.err != nil {
state.err = dec.state.err
return state.err
}
// Allocate the destination interface value.
if indir > 0 {
p = allocate(ityp, p, 1) // All but the last level has been allocated by dec.Indirect
}
// Assign the concrete value to the interface.
// Tread carefully; it might not satisfy the interface.
dec.state.err = setInterfaceValue(ivalue, value)
if dec.state.err != nil {
state.err = dec.state.err
return state.err
}
// Copy the representation of the interface value to the target.
// This is horribly unsafe and special.
*(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get()
return nil
}
func (dec *Decoder) ignoreInterface(state *decodeState) os.Error {
// Read the name of the concrete type.
b := make([]byte, decodeUint(state))
_, err := state.b.Read(b)
if err != nil {
dec.decodeValueFromBuffer(nil, true)
err = dec.state.err
}
return err
}
// Index by Go types.
var decOpMap = []decOp{
reflect.Bool: decBool,
......@@ -608,12 +682,13 @@ var decOpMap = []decOp{
// Indexed by gob types. tComplex will be added during type.init().
var decIgnoreOpMap = map[typeId]decOp{
tBool: ignoreUint,
tInt: ignoreUint,
tUint: ignoreUint,
tFloat: ignoreUint,
tBytes: ignoreUint8Array,
tString: ignoreUint8Array,
tBool: ignoreUint,
tInt: ignoreUint,
tUint: ignoreUint,
tFloat: ignoreUint,
tBytes: ignoreUint8Array,
tString: ignoreUint8Array,
tComplex: ignoreTwoUints,
}
// Return the decoding op for the base type under rt and
......@@ -687,7 +762,11 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
}
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs
state.err = decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir)
state.err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir)
}
case *reflect.InterfaceType:
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.err = dec.decodeInterface(t, state, uintptr(p), i.indir)
}
}
}
......@@ -701,6 +780,14 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
op, ok := decIgnoreOpMap[wireId]
if !ok {
if wireId == tInterface {
// Special case because it's a method: the ignored item might
// define types and we need to record their state in the decoder.
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.err = dec.ignoreInterface(state)
}
return op, nil
}
// Special cases
wire := dec.wireType[wireId]
switch {
......@@ -763,16 +850,10 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
// Answers the question for basic types, arrays, and slices.
// Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
for {
if pt, ok := fr.(*reflect.PtrType); ok {
fr = pt.Elem()
continue
}
break
}
fr, _ = indirect(fr)
switch t := fr.(type) {
default:
// interface, map, chan, etc: cannot handle.
// map, chan, etc: cannot handle.
return false
case *reflect.BoolType:
return fw == tBool
......@@ -786,6 +867,8 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
return fw == tComplex
case *reflect.StringType:
return fw == tString
case *reflect.InterfaceType:
return fw == tInterface
case *reflect.ArrayType:
wire, ok := dec.wireType[fw]
if !ok || wire.arrayT == nil {
......@@ -936,7 +1019,7 @@ func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error {
name := rt.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
}
return decodeStruct(engine, st, dec.state.b, uintptr(val.Addr()), indir)
return dec.decodeStruct(engine, st, dec.state.b, uintptr(val.Addr()), indir)
}
return decodeSingle(engine, rt, dec.state.b, uintptr(val.Addr()), indir)
}
......
......@@ -24,6 +24,7 @@ type Decoder struct {
countState *decodeState // reads counts from wire
buf []byte
countBuf [9]byte // counts may be uint64s (unlikely!), require 9 bytes
byteBuffer *bytes.Buffer
}
// NewDecoder returns a new decoder that reads from the io.Reader.
......@@ -31,13 +32,14 @@ func NewDecoder(r io.Reader) *Decoder {
dec := new(Decoder)
dec.r = r
dec.wireType = make(map[typeId]*wireType)
dec.state = newDecodeState(nil) // buffer set in Decode(); rest is unimportant
dec.state = newDecodeState(&dec.byteBuffer) // buffer set in Decode()
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
dec.ignorerCache = make(map[typeId]**decEngine)
return dec
}
// recvType loads the definition of a type and reloads the Decoder's buffer.
func (dec *Decoder) recvType(id typeId) {
// Have we already seen this type? That's an error
if dec.wireType[id] != nil {
......@@ -50,6 +52,9 @@ func (dec *Decoder) recvType(id typeId) {
dec.state.err = dec.decode(tWireType, reflect.NewValue(wire))
// Remember we've seen this type.
dec.wireType[id] = wire
// Load the next parcel.
dec.recv()
}
// Decode reads the next value from the connection and stores
......@@ -67,38 +72,36 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
return dec.DecodeValue(value)
}
// DecodeValue reads the next value from the connection and stores
// it in the data represented by the reflection value.
// The value must be the correct type for the next
// data item received.
func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
// Make sure we're single-threaded through here.
dec.mutex.Lock()
defer dec.mutex.Unlock()
dec.state.err = nil
for {
// Read a count.
var nbytes uint64
nbytes, dec.state.err = decodeUintReader(dec.r, dec.countBuf[0:])
if dec.state.err != nil {
break
}
// Allocate the buffer.
if nbytes > uint64(len(dec.buf)) {
dec.buf = make([]byte, nbytes+1000)
}
dec.state.b = bytes.NewBuffer(dec.buf[0:nbytes])
// recv reads the next count-delimited item from the input. It is the converse
// of Encoder.send.
func (dec *Decoder) recv() {
// Read a count.
var nbytes uint64
nbytes, dec.state.err = decodeUintReader(dec.r, dec.countBuf[0:])
if dec.state.err != nil {
return
}
// Allocate the buffer.
if nbytes > uint64(len(dec.buf)) {
dec.buf = make([]byte, nbytes+1000)
}
dec.byteBuffer = bytes.NewBuffer(dec.buf[0:nbytes])
// Read the data
_, dec.state.err = io.ReadFull(dec.r, dec.buf[0:nbytes])
if dec.state.err != nil {
if dec.state.err == os.EOF {
dec.state.err = io.ErrUnexpectedEOF
}
break
// Read the data
_, dec.state.err = io.ReadFull(dec.r, dec.buf[0:nbytes])
if dec.state.err != nil {
if dec.state.err == os.EOF {
dec.state.err = io.ErrUnexpectedEOF
}
return
}
}
// decodeValueFromBuffer grabs the next value from the input. The Decoder's
// buffer already contains data. If the next item in the buffer is a type
// descriptor, it may be necessary to reload the buffer, but recvType does that.
func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignore bool) {
for dec.state.b.Len() > 0 {
// Receive a type id.
id := typeId(decodeInt(dec.state))
if dec.state.err != nil {
......@@ -116,6 +119,10 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
}
// No, it's a value.
if ignore {
dec.byteBuffer.Reset()
break
}
// Make sure the type has been defined already or is a builtin type (for
// top-level singleton values).
if dec.wireType[id] == nil && builtinIdToType[id] == nil {
......@@ -125,5 +132,22 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
dec.state.err = dec.decode(id, value)
break
}
}
// DecodeValue reads the next value from the connection and stores
// it in the data represented by the reflection value.
// The value must be the correct type for the next
// data item received.
func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
// Make sure we're single-threaded through here.
dec.mutex.Lock()
defer dec.mutex.Unlock()
dec.state.err = nil
dec.recv()
if dec.state.err != nil {
return dec.state.err
}
dec.decodeValueFromBuffer(value, false)
return dec.state.err
}
......@@ -126,6 +126,16 @@ integer field 0 with value 7 is transmitted as unsigned delta = 1, unsigned valu
denotes the end of the struct. That mark is a delta=0 value, which has
representation (00).
Interface types are not checked for compatibility; all interface types are
treated, for transmission, as members of a single "interface" type, analogous to
int or []byte - in effect they're all treated as interface{}. Interface values
are transmitted as a string identifying the concrete type being sent (a name
that must be pre-defined by calling Register()), followed by the usual encoding
of concrete (dynamic) value stored in the interface value. (A nil interface
value is identified by the empty string and transmits no value.) Upon receipt,
the decoder verifies that the unpacked concrete item satisfies the interface of
the receiving variable.
The representation of types is described below. When a type is defined on a given
connection between an Encoder and Decoder, it is assigned a signed integer type
id. When Encoder.Encode(v) is called, it makes sure there is an id assigned for
......@@ -140,18 +150,32 @@ description, constructed from these types:
type wireType struct {
s structType;
}
type fieldType struct {
name string; // the name of the field.
id int; // the type id of the field, which must be already defined
type arrayType struct {
commonType
Elem typeId
Len int
}
type commonType {
name string; // the name of the struct type
id int; // the id of the type, repeated for so it's inside the type
}
type sliceType struct {
commonType
Elem typeId
}
type structType struct {
commonType;
field []fieldType; // the fields of the struct.
}
type fieldType struct {
name string; // the name of the field.
id int; // the type id of the field, which must be already defined
}
type mapType struct {
commonType
Key typeId
Elem typeId
}
If there are nested type ids, the types for all inner type ids must be defined
before the top-level type id is used to describe an encoded-v.
......@@ -159,16 +183,23 @@ before the top-level type id is used to describe an encoded-v.
For simplicity in setup, the connection is defined to understand these types a
priori, as well as the basic gob types int, uint, etc. Their ids are:
bool 1
int 2
uint 3
float 4
[]byte 5
string 6
wireType 7
structType 8
commonType 9
fieldType 10
bool 1
int 2
uint 3
float 4
[]byte 5
string 6
complex 7
interface 8
// gap for reserved ids.
wireType 16
arrayType 17
commonType 18
sliceType 19
structType 20
fieldType 21
// 22 is slice of fieldType.
mapType 23
In summary, a gob stream looks like
......
......@@ -27,6 +27,10 @@ type encoderState struct {
buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
}
func newEncoderState(b *bytes.Buffer) *encoderState {
return &encoderState{b: b}
}
// Unsigned integers have a two-state encoding. If the number is less
// than 128 (0 through 0x7F), its value is written directly.
// Otherwise the value is written in big-endian byte order preceded
......@@ -314,8 +318,7 @@ type encEngine struct {
const singletonField = 0
func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
state := new(encoderState)
state.b = b
state := newEncoderState(b)
state.fieldnum = singletonField
// There is no surrounding struct to frame the transmission, so we must
// generate data even if the item is zero. To do this, set sendZero.
......@@ -332,8 +335,7 @@ func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
}
func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
state := new(encoderState)
state.b = b
state := newEncoderState(b)
state.fieldnum = -1
for i := 0; i < len(engine.instr); i++ {
instr := &engine.instr[i]
......@@ -352,8 +354,7 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
}
func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) os.Error {
state := new(encoderState)
state.b = b
state := newEncoderState(b)
state.fieldnum = -1
state.sendZero = true
encodeUint(state, uint64(length))
......@@ -385,8 +386,7 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in
}
func encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error {
state := new(encoderState)
state.b = b
state := newEncoderState(b)
state.fieldnum = -1
state.sendZero = true
keys := mv.Keys()
......@@ -401,6 +401,41 @@ func encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIn
return state.err
}
// To send an interface, we send a string identifying the concrete type, followed
// by the type identifier (which might require defining that type right now), followed
// by the concrete value. A nil value gets sent as the empty string for the name,
// followed by no value.
func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) os.Error {
state := newEncoderState(b)
state.fieldnum = -1
state.sendZero = true
if iv.IsNil() {
encodeUint(state, 0)
return state.err
}
typ := iv.Elem().Type()
name, ok := concreteTypeToName[typ]
if !ok {
state.err = os.ErrorString("gob: type not registered for interface: " + typ.String())
return state.err
}
// Send the name.
encodeUint(state, uint64(len(name)))
_, state.err = io.WriteString(state.b, name)
if state.err != nil {
return state.err
}
// Send (and maybe first define) the type id.
enc.sendTypeDescriptor(typ)
if state.err != nil {
return state.err
}
// Send the value.
state.err = enc.encode(state.b, iv.Elem())
return state.err
}
var encOpMap = []encOp{
reflect.Bool: encBool,
reflect.Int: encInt,
......@@ -425,7 +460,7 @@ var encOpMap = []encOp{
// Return the encoding op for the base type under rt and
// the indirection count to reach it.
func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
typ, indir := indirect(rt)
var op encOp
k := typ.Kind()
......@@ -441,7 +476,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
break
}
// Slices have a header; we decode it to find the underlying array.
elemOp, indir, err := encOpFor(t.Elem())
elemOp, indir, err := enc.encOpFor(t.Elem())
if err != nil {
return nil, 0, err
}
......@@ -455,7 +490,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
}
case *reflect.ArrayType:
// True arrays have size in the type.
elemOp, indir, err := encOpFor(t.Elem())
elemOp, indir, err := enc.encOpFor(t.Elem())
if err != nil {
return nil, 0, err
}
......@@ -464,11 +499,11 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
}
case *reflect.MapType:
keyOp, keyIndir, err := encOpFor(t.Key())
keyOp, keyIndir, err := enc.encOpFor(t.Key())
if err != nil {
return nil, 0, err
}
elemOp, elemIndir, err := encOpFor(t.Elem())
elemOp, elemIndir, err := enc.encOpFor(t.Elem())
if err != nil {
return nil, 0, err
}
......@@ -486,7 +521,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
}
case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type.
_, err := getEncEngine(typ)
_, err := enc.getEncEngine(typ)
if err != nil {
return nil, 0, err
}
......@@ -496,6 +531,18 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
// indirect through info to delay evaluation for recursive structs
state.err = encodeStruct(info.encoder, state.b, uintptr(p))
}
case *reflect.InterfaceType:
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// Interfaces transmit the name and contents of the concrete
// value they contain.
v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p))))
iv := reflect.Indirect(v).(*reflect.InterfaceValue)
if !state.sendZero && (iv == nil || iv.IsNil()) {
return
}
state.update(i)
state.err = enc.encodeInterface(state.b, iv)
}
}
}
if op == nil {
......@@ -505,14 +552,14 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
}
// The local Type was compiled from the actual value, so we know it's compatible.
func compileEnc(rt reflect.Type) (*encEngine, os.Error) {
func (enc *Encoder) compileEnc(rt reflect.Type) (*encEngine, os.Error) {
srt, isStruct := rt.(*reflect.StructType)
engine := new(encEngine)
if isStruct {
engine.instr = make([]encInstr, srt.NumField()+1) // +1 for terminator
for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ {
f := srt.Field(fieldnum)
op, indir, err := encOpFor(f.Type)
op, indir, err := enc.encOpFor(f.Type)
if err != nil {
return nil, err
}
......@@ -521,7 +568,7 @@ func compileEnc(rt reflect.Type) (*encEngine, os.Error) {
engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0}
} else {
engine.instr = make([]encInstr, 1)
op, indir, err := encOpFor(rt)
op, indir, err := enc.encOpFor(rt)
if err != nil {
return nil, err
}
......@@ -532,7 +579,7 @@ func compileEnc(rt reflect.Type) (*encEngine, os.Error) {
// typeLock must be held (or we're in initialization and guaranteed single-threaded).
// The reflection type must have all its indirections processed out.
func getEncEngine(rt reflect.Type) (*encEngine, os.Error) {
func (enc *Encoder) getEncEngine(rt reflect.Type) (*encEngine, os.Error) {
info, err := getTypeInfo(rt)
if err != nil {
return nil, err
......@@ -540,19 +587,19 @@ func getEncEngine(rt reflect.Type) (*encEngine, os.Error) {
if info.encoder == nil {
// mark this engine as underway before compiling to handle recursive types.
info.encoder = new(encEngine)
info.encoder, err = compileEnc(rt)
info.encoder, err = enc.compileEnc(rt)
}
return info.encoder, err
}
func encode(b *bytes.Buffer, value reflect.Value) os.Error {
func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value) os.Error {
// Dereference down to the underlying object.
rt, indir := indirect(value.Type())
for i := 0; i < indir; i++ {
value = reflect.Indirect(value)
}
typeLock.Lock()
engine, err := getEncEngine(rt)
engine, err := enc.getEncEngine(rt)
typeLock.Unlock()
if err != nil {
return err
......
......@@ -28,10 +28,8 @@ func NewEncoder(w io.Writer) *Encoder {
enc := new(Encoder)
enc.w = w
enc.sent = make(map[reflect.Type]typeId)
enc.state = new(encoderState)
enc.state.b = new(bytes.Buffer) // the rest isn't important; all we need is buffer and writer
enc.countState = new(encoderState)
enc.countState.b = new(bytes.Buffer) // the rest isn't important; all we need is buffer and writer
enc.state = newEncoderState(new(bytes.Buffer))
enc.countState = newEncoderState(new(bytes.Buffer))
return enc
}
......@@ -74,7 +72,7 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
switch rt := rt.(type) {
default:
// Basic types do not need to be described.
// Basic types and interfaces do not need to be described.
return
case *reflect.SliceType:
// If it's []uint8, don't send; it's considered basic.
......@@ -92,7 +90,7 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
case *reflect.StructType:
// structs must be sent so we know their fields.
break
case *reflect.ChanType, *reflect.FuncType, *reflect.InterfaceType:
case *reflect.ChanType, *reflect.FuncType:
// Probably a bad field in a struct.
enc.badType(rt)
return
......@@ -115,7 +113,7 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
// Id:
encodeInt(enc.state, -int64(info.id))
// Type:
encode(enc.state.b, reflect.NewValue(info.wire))
enc.encode(enc.state.b, reflect.NewValue(info.wire))
enc.send()
if enc.state.err != nil {
return
......@@ -134,7 +132,7 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
case reflect.ArrayOrSliceType:
enc.sendType(st.Elem())
}
return
return true
}
// Encode transmits the data item represented by the empty interface value,
......@@ -143,30 +141,17 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
return enc.EncodeValue(reflect.NewValue(e))
}
// EncodeValue transmits the data item represented by the reflection value,
// guaranteeing that all necessary type information has been transmitted first.
func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
// Make sure we're single-threaded through here, so multiple
// goroutines can share an encoder.
enc.mutex.Lock()
defer enc.mutex.Unlock()
enc.state.err = nil
rt, _ := indirect(value.Type())
// Sanity check only: encoder should never come in with data present.
if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 {
enc.state.err = os.ErrorString("encoder: buffer not empty")
return enc.state.err
}
// sendTypeId makes sure the remote side knows about this type.
// It will send a descriptor if this is the first time the type has been
// sent. Regardless, it sends the id.
func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) {
// Make sure the type is known to the other side.
// First, have we already sent this type?
if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it.
sent := enc.sendType(rt)
if enc.state.err != nil {
return enc.state.err
return
}
// If the type info has still not been transmitted, it means we have
// a singleton basic type (int, []byte etc.) at top level. We don't
......@@ -177,7 +162,7 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
typeLock.Unlock()
if err != nil {
enc.setError(err)
return err
return
}
enc.sent[rt] = info.id
}
......@@ -185,9 +170,32 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
// Identify the type of this top-level value.
encodeInt(enc.state, int64(enc.sent[rt]))
}
// EncodeValue transmits the data item represented by the reflection value,
// guaranteeing that all necessary type information has been transmitted first.
func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
// Make sure we're single-threaded through here, so multiple
// goroutines can share an encoder.
enc.mutex.Lock()
defer enc.mutex.Unlock()
enc.state.err = nil
rt, _ := indirect(value.Type())
// Sanity check only: encoder should never come in with data present.
if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 {
enc.state.err = os.ErrorString("encoder: buffer not empty")
return enc.state.err
}
enc.sendTypeDescriptor(rt)
if enc.state.err != nil {
return enc.state.err
}
// Encode the object.
err := encode(enc.state.b, value)
err := enc.encode(enc.state.b, value)
if err != nil {
enc.setError(err)
} else {
......
......@@ -135,7 +135,6 @@ func TestBadData(t *testing.T) {
var unsupportedValues = []interface{}{
make(chan int),
func(a int) bool { return true },
new(interface{}),
}
func TestUnsupported(t *testing.T) {
......
......@@ -90,14 +90,22 @@ func (t *commonType) Name() string { return t.name }
var (
// Primordial types, needed during initialization.
tBool = bootstrapType("bool", false, 1)
tInt = bootstrapType("int", int(0), 2)
tUint = bootstrapType("uint", uint(0), 3)
tFloat = bootstrapType("float", float64(0), 4)
tBytes = bootstrapType("bytes", make([]byte, 0), 5)
tString = bootstrapType("string", "", 6)
// Types added to the language later, not needed during initialization.
tComplex typeId
tBool = bootstrapType("bool", false, 1)
tInt = bootstrapType("int", int(0), 2)
tUint = bootstrapType("uint", uint(0), 3)
tFloat = bootstrapType("float", float64(0), 4)
tBytes = bootstrapType("bytes", make([]byte, 0), 5)
tString = bootstrapType("string", "", 6)
tComplex = bootstrapType("complex", 0+0i, 7)
tInterface = bootstrapType("interface", interface{}(nil), 8)
// Reserve some Ids for compatible expansion
tReserved7 = bootstrapType("_reserved1", struct{ r7 int }{}, 9)
tReserved6 = bootstrapType("_reserved1", struct{ r6 int }{}, 10)
tReserved5 = bootstrapType("_reserved1", struct{ r5 int }{}, 11)
tReserved4 = bootstrapType("_reserved1", struct{ r4 int }{}, 12)
tReserved3 = bootstrapType("_reserved1", struct{ r3 int }{}, 13)
tReserved2 = bootstrapType("_reserved1", struct{ r2 int }{}, 14)
tReserved1 = bootstrapType("_reserved1", struct{ r1 int }{}, 15)
)
// Predefined because it's needed by the Decoder
......@@ -105,15 +113,13 @@ var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id
func init() {
// Some magic numbers to make sure there are no surprises.
checkId(7, tWireType)
checkId(9, mustGetTypeInfo(reflect.Typeof(commonType{})).id)
checkId(11, mustGetTypeInfo(reflect.Typeof(structType{})).id)
checkId(12, mustGetTypeInfo(reflect.Typeof(fieldType{})).id)
// Complex was added after gob was written, so appears after the
// fundamental types are built.
tComplex = bootstrapType("complex", 0+0i, 15)
decIgnoreOpMap[tComplex] = ignoreTwoUints
checkId(16, tWireType)
checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id)
checkId(18, mustGetTypeInfo(reflect.Typeof(commonType{})).id)
checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id)
checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id)
checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id)
checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id)
builtinIdToType = make(map[typeId]gobType)
for k, v := range idToType {
......@@ -234,7 +240,7 @@ func newStructType(name string) *structType {
}
// Step through the indirections on a type to discover the base type.
// Return the number of indirections.
// Return the base type and the number of indirections.
func indirect(t reflect.Type) (rt reflect.Type, count int) {
rt = t
for {
......@@ -269,6 +275,9 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
case *reflect.StringType:
return tString.gobType(), nil
case *reflect.InterfaceType:
return tInterface.gobType(), nil
case *reflect.ArrayType:
gt, err := getType("", t.Elem())
if err != nil {
......@@ -330,14 +339,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
// getType returns the Gob type describing the given reflect.Type.
// typeLock must be held.
func getType(name string, rt reflect.Type) (gobType, os.Error) {
// Flatten the data structure by collapsing out pointers
for {
pt, ok := rt.(*reflect.PtrType)
if !ok {
break
}
rt = pt.Elem()
}
rt, _ = indirect(rt)
typ, present := types[rt]
if present {
return typ, nil
......@@ -351,6 +353,7 @@ func getType(name string, rt reflect.Type) (gobType, os.Error) {
func checkId(want, got typeId) {
if want != got {
fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(want), int(got))
panic("bootstrap type wrong id: " + got.Name() + " " + got.string() + " not " + want.string())
}
}
......@@ -444,3 +447,54 @@ func mustGetTypeInfo(rt reflect.Type) *typeInfo {
}
return t
}
var (
nameToConcreteType = make(map[string]reflect.Type)
concreteTypeToName = make(map[reflect.Type]string)
)
// RegisterName is like Register but uses the provided name rather than the
// type's default.
func RegisterName(name string, value interface{}) {
if name == "" {
// reserved for nil
panic("attempt to register empty name")
}
rt, _ := indirect(reflect.Typeof(value))
// Check for incompatible duplicates.
if t, ok := nameToConcreteType[name]; ok && t != rt {
panic("gob: registering duplicate types for " + name)
}
if n, ok := concreteTypeToName[rt]; ok && n != name {
panic("gob: registering duplicate names for " + rt.String())
}
nameToConcreteType[name] = rt
concreteTypeToName[rt] = name
}
// Register records a type, identified by a value for that type, under its
// internal type name. That name will identify the concrete type of a value
// sent or received as an interface variable. Only types that will be
// transferred as implementations of interface values need to be registered.
// Expecting to be used only during initialization, it panics if the mapping
// between types and names is not a bijection.
func Register(value interface{}) {
// Default to printed representation for unnamed types
rt := reflect.Typeof(value)
name := rt.String()
// But for named types (or pointers to them), qualify with import path.
// Dereference one pointer looking for a named type.
star := ""
if rt.Name() == "" {
if pt, ok := rt.(*reflect.PtrType); ok {
star = "*"
rt = pt
}
}
if rt.Name() != "" {
name = star + rt.PkgPath() + "." + rt.Name()
}
RegisterName(name, value)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment