Commit 30b1b9a3 authored by Rob Pike's avatar Rob Pike

Rework gobs to fix bad bug related to sharing of id's between encoder and decoder side.

Fix is to move all decoder state into the decoder object.

Fixes #215.

R=rsc
CC=golang-dev
https://golang.org/cl/155077
parent 50c04132
......@@ -37,7 +37,6 @@ var encodeT = []EncodeT{
EncodeT{1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
}
// Test basic encode/decode routines for unsigned integers
func TestUintCodec(t *testing.T) {
b := new(bytes.Buffer);
......@@ -592,9 +591,15 @@ func TestEndToEnd(t *testing.T) {
t: &T2{"this is T2"},
};
b := new(bytes.Buffer);
encode(b, t1);
err := NewEncoder(b).Encode(t1);
if err != nil {
t.Error("encode:", err)
}
var _t1 T1;
decode(b, getTypeInfoNoError(reflect.Typeof(_t1)).id, &_t1);
err = NewDecoder(b).Decode(&_t1);
if err != nil {
t.Fatal("decode:", err)
}
if !reflect.DeepEqual(t1, &_t1) {
t.Errorf("encode expected %v got %v", *t1, _t1)
}
......@@ -610,8 +615,9 @@ func TestOverflow(t *testing.T) {
}
var it inputT;
var err os.Error;
id := getTypeInfoNoError(reflect.Typeof(it)).id;
b := new(bytes.Buffer);
enc := NewEncoder(b);
dec := NewDecoder(b);
// int8
b.Reset();
......@@ -623,8 +629,8 @@ func TestOverflow(t *testing.T) {
mini int8;
}
var o1 outi8;
encode(b, it);
err = decode(b, id, &o1);
enc.Encode(it);
err = dec.Decode(&o1);
if err == nil || err.String() != `value for "maxi" out of range` {
t.Error("wrong overflow error for int8:", err)
}
......@@ -632,8 +638,8 @@ func TestOverflow(t *testing.T) {
mini: math.MinInt8 - 1,
};
b.Reset();
encode(b, it);
err = decode(b, id, &o1);
enc.Encode(it);
err = dec.Decode(&o1);
if err == nil || err.String() != `value for "mini" out of range` {
t.Error("wrong underflow error for int8:", err)
}
......@@ -648,8 +654,8 @@ func TestOverflow(t *testing.T) {
mini int16;
}
var o2 outi16;
encode(b, it);
err = decode(b, id, &o2);
enc.Encode(it);
err = dec.Decode(&o2);
if err == nil || err.String() != `value for "maxi" out of range` {
t.Error("wrong overflow error for int16:", err)
}
......@@ -657,8 +663,8 @@ func TestOverflow(t *testing.T) {
mini: math.MinInt16 - 1,
};
b.Reset();
encode(b, it);
err = decode(b, id, &o2);
enc.Encode(it);
err = dec.Decode(&o2);
if err == nil || err.String() != `value for "mini" out of range` {
t.Error("wrong underflow error for int16:", err)
}
......@@ -673,8 +679,8 @@ func TestOverflow(t *testing.T) {
mini int32;
}
var o3 outi32;
encode(b, it);
err = decode(b, id, &o3);
enc.Encode(it);
err = dec.Decode(&o3);
if err == nil || err.String() != `value for "maxi" out of range` {
t.Error("wrong overflow error for int32:", err)
}
......@@ -682,8 +688,8 @@ func TestOverflow(t *testing.T) {
mini: math.MinInt32 - 1,
};
b.Reset();
encode(b, it);
err = decode(b, id, &o3);
enc.Encode(it);
err = dec.Decode(&o3);
if err == nil || err.String() != `value for "mini" out of range` {
t.Error("wrong underflow error for int32:", err)
}
......@@ -697,8 +703,8 @@ func TestOverflow(t *testing.T) {
maxu uint8;
}
var o4 outu8;
encode(b, it);
err = decode(b, id, &o4);
enc.Encode(it);
err = dec.Decode(&o4);
if err == nil || err.String() != `value for "maxu" out of range` {
t.Error("wrong overflow error for uint8:", err)
}
......@@ -712,8 +718,8 @@ func TestOverflow(t *testing.T) {
maxu uint16;
}
var o5 outu16;
encode(b, it);
err = decode(b, id, &o5);
enc.Encode(it);
err = dec.Decode(&o5);
if err == nil || err.String() != `value for "maxu" out of range` {
t.Error("wrong overflow error for uint16:", err)
}
......@@ -727,8 +733,8 @@ func TestOverflow(t *testing.T) {
maxu uint32;
}
var o6 outu32;
encode(b, it);
err = decode(b, id, &o6);
enc.Encode(it);
err = dec.Decode(&o6);
if err == nil || err.String() != `value for "maxu" out of range` {
t.Error("wrong overflow error for uint32:", err)
}
......@@ -743,8 +749,8 @@ func TestOverflow(t *testing.T) {
minf float32;
}
var o7 outf32;
encode(b, it);
err = decode(b, id, &o7);
enc.Encode(it);
err = dec.Decode(&o7);
if err == nil || err.String() != `value for "maxf" out of range` {
t.Error("wrong overflow error for float32:", err)
}
......@@ -761,9 +767,13 @@ func TestNesting(t *testing.T) {
rt.next = new(RT);
rt.next.a = "level2";
b := new(bytes.Buffer);
encode(b, rt);
NewEncoder(b).Encode(rt);
var drt RT;
decode(b, getTypeInfoNoError(reflect.Typeof(drt)).id, &drt);
dec := NewDecoder(b);
err := dec.Decode(&drt);
if err != nil {
t.Errorf("decoder error:", err)
}
if drt.a != rt.a {
t.Errorf("nesting: encode expected %v got %v", *rt, drt)
}
......@@ -809,10 +819,11 @@ func TestAutoIndirection(t *testing.T) {
**t1.d = new(int);
***t1.d = 17777;
b := new(bytes.Buffer);
encode(b, t1);
enc := NewEncoder(b);
enc.Encode(t1);
dec := NewDecoder(b);
var t0 T0;
t0Id := getTypeInfoNoError(reflect.Typeof(t0)).id;
decode(b, t0Id, &t0);
dec.Decode(&t0);
if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0)
}
......@@ -830,9 +841,9 @@ func TestAutoIndirection(t *testing.T) {
**t2.a = new(int);
***t2.a = 17;
b.Reset();
encode(b, t2);
enc.Encode(t2);
t0 = T0{};
decode(b, t0Id, &t0);
dec.Decode(&t0);
if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0)
}
......@@ -840,32 +851,30 @@ func TestAutoIndirection(t *testing.T) {
// Now transfer t0 into t1
t0 = T0{17, 177, 1777, 17777};
b.Reset();
encode(b, t0);
enc.Encode(t0);
t1 = T1{};
t1Id := getTypeInfoNoError(reflect.Typeof(t1)).id;
decode(b, t1Id, &t1);
dec.Decode(&t1);
if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 {
t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d)
}
// Now transfer t0 into t2
b.Reset();
encode(b, t0);
enc.Encode(t0);
t2 = T2{};
t2Id := getTypeInfoNoError(reflect.Typeof(t2)).id;
decode(b, t2Id, &t2);
dec.Decode(&t2);
if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d)
}
// Now do t2 again but without pre-allocated pointers.
b.Reset();
encode(b, t0);
enc.Encode(t0);
***t2.a = 0;
**t2.b = 0;
*t2.c = 0;
t2.d = 0;
decode(b, t2Id, &t2);
dec.Decode(&t2);
if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d)
}
......@@ -889,11 +898,14 @@ func TestReorderedFields(t *testing.T) {
rt0.b = "hello";
rt0.c = 3.14159;
b := new(bytes.Buffer);
encode(b, rt0);
rt0Id := getTypeInfoNoError(reflect.Typeof(rt0)).id;
NewEncoder(b).Encode(rt0);
dec := NewDecoder(b);
var rt1 RT1;
// Wire type is RT0, local type is RT1.
decode(b, rt0Id, &rt1);
err := dec.Decode(&rt1);
if err != nil {
t.Error("decode error:", err)
}
if rt0.a != rt1.a || rt0.b != rt1.b || rt0.c != rt1.c {
t.Errorf("rt1->rt0: expected %v; got %v", rt0, rt1)
}
......@@ -927,11 +939,11 @@ func TestIgnoredFields(t *testing.T) {
it0.ignore_i = &RT1{3.1, "hi", 7, "hello"};
b := new(bytes.Buffer);
encode(b, it0);
rt0Id := getTypeInfoNoError(reflect.Typeof(it0)).id;
NewEncoder(b).Encode(it0);
dec := NewDecoder(b);
var rt1 RT1;
// Wire type is IT0, local type is RT1.
err := decode(b, rt0Id, &rt1);
err := dec.Decode(&rt1);
if err != nil {
t.Error("error: ", err)
}
......
......@@ -348,7 +348,7 @@ func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
// Execution engine
// The encoder engine is an array of instructions indexed by field number of the incoming
// data. It is executed with random access according to field number.
// decoder. It is executed with random access according to field number.
type decEngine struct {
instr []decInstr;
numInstr int; // the number of active instructions
......@@ -515,7 +515,7 @@ var decIgnoreOpMap = map[typeId]decOp{
// Return the decoding op for the base type under rt and
// the indirection count to reach it.
func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) {
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) {
typ, indir := indirect(rt);
op, ok := decOpMap[reflect.Typeof(typ)];
if !ok {
......@@ -528,7 +528,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
break;
}
elemId := wireId.gobType().(*sliceType).Elem;
elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name);
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name);
if err != nil {
return nil, 0, err
}
......@@ -540,7 +540,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
case *reflect.ArrayType:
name = "element of " + name;
elemId := wireId.gobType().(*arrayType).Elem;
elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name);
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name);
if err != nil {
return nil, 0, err
}
......@@ -551,7 +551,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type.
enginePtr, err := getDecEnginePtr(wireId, typ);
enginePtr, err := dec.getDecEnginePtr(wireId, typ);
if err != nil {
return nil, 0, err
}
......@@ -568,14 +568,14 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
}
// Return the decoding op for a field that has no destination.
func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
op, ok := decIgnoreOpMap[wireId];
if !ok {
// Special cases
switch t := wireId.gobType().(type) {
case *sliceType:
elemId := wireId.gobType().(*sliceType).Elem;
elemOp, err := decIgnoreOpFor(elemId);
elemOp, err := dec.decIgnoreOpFor(elemId);
if err != nil {
return nil, err
}
......@@ -585,7 +585,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
case *arrayType:
elemId := wireId.gobType().(*arrayType).Elem;
elemOp, err := decIgnoreOpFor(elemId);
elemOp, err := dec.decIgnoreOpFor(elemId);
if err != nil {
return nil, err
}
......@@ -595,7 +595,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
case *structType:
// Generate a closure that calls out to the engine for the nested type.
enginePtr, err := getIgnoreEnginePtr(wireId);
enginePtr, err := dec.getIgnoreEnginePtr(wireId);
if err != nil {
return nil, err
}
......@@ -676,11 +676,18 @@ func compatibleType(fr reflect.Type, fw typeId) bool {
return true;
}
func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
srt, ok1 := rt.(*reflect.StructType);
wireStruct, ok2 := wireId.gobType().(*structType);
if !ok1 || !ok2 {
return nil, errNotStruct
var wireStruct *structType;
// Builtin types can come from global pool; the rest must be defined by the decoder
if t, ok := builtinIdToType[remoteId]; ok {
wireStruct = t.(*structType)
} else {
w, ok2 := dec.wireType[remoteId];
if !ok1 || !ok2 {
return nil, errNotStruct
}
wireStruct = w.s;
}
engine = new(decEngine);
engine.instr = make([]decInstr, len(wireStruct.field));
......@@ -692,7 +699,7 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
ovfl := overflow(wireField.name);
// TODO(r): anonymous names
if !present {
op, err := decIgnoreOpFor(wireField.id);
op, err := dec.decIgnoreOpFor(wireField.id);
if err != nil {
return nil, err
}
......@@ -700,10 +707,10 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
continue;
}
if !compatibleType(localField.Type, wireField.id) {
details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + wireId.Name();
details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + remoteId.Name();
return nil, os.ErrorString("gob: wrong type for field " + wireField.name + details);
}
op, indir, err := decOpFor(wireField.id, localField.Type, localField.Name);
op, indir, err := dec.decOpFor(wireField.id, localField.Type, localField.Name);
if err != nil {
return nil, err
}
......@@ -713,23 +720,19 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
return;
}
var decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
var ignorerCache = make(map[typeId]**decEngine)
// typeLock must be held.
func getDecEnginePtr(wireId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
decoderMap, ok := decoderCache[rt];
func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
decoderMap, ok := dec.decoderCache[rt];
if !ok {
decoderMap = make(map[typeId]**decEngine);
decoderCache[rt] = decoderMap;
dec.decoderCache[rt] = decoderMap;
}
if enginePtr, ok = decoderMap[wireId]; !ok {
if enginePtr, ok = decoderMap[remoteId]; !ok {
// To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine);
decoderMap[wireId] = enginePtr;
*enginePtr, err = compileDec(wireId, rt);
decoderMap[remoteId] = enginePtr;
*enginePtr, err = dec.compileDec(remoteId, rt);
if err != nil {
decoderMap[wireId] = nil, false
decoderMap[remoteId] = nil, false
}
}
return;
......@@ -740,35 +743,28 @@ type emptyStruct struct{}
var emptyStructType = reflect.Typeof(emptyStruct{})
// typeLock must be held.
func getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
var ok bool;
if enginePtr, ok = ignorerCache[wireId]; !ok {
if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
// To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine);
ignorerCache[wireId] = enginePtr;
*enginePtr, err = compileDec(wireId, emptyStructType);
dec.ignorerCache[wireId] = enginePtr;
*enginePtr, err = dec.compileDec(wireId, emptyStructType);
if err != nil {
ignorerCache[wireId] = nil, false
dec.ignorerCache[wireId] = nil, false
}
}
return;
}
func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
// Dereference down to the underlying struct type.
rt, indir := indirect(reflect.Typeof(e));
st, ok := rt.(*reflect.StructType);
if !ok {
return os.ErrorString("gob: decode can't handle " + rt.String())
}
typeLock.Lock();
if _, ok := idToType[wireId]; !ok {
typeLock.Unlock();
return errBadType;
}
enginePtr, err := getDecEnginePtr(wireId, rt);
typeLock.Unlock();
enginePtr, err := dec.getDecEnginePtr(wireId, rt);
if err != nil {
return err
}
......@@ -777,7 +773,7 @@ func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
name := rt.Name();
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name);
}
return decodeStruct(engine, st, b, uintptr(reflect.NewValue(e).Addr()), indir);
return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir);
}
func init() {
......
......@@ -8,17 +8,20 @@ import (
"bytes";
"io";
"os";
"reflect";
"sync";
)
// A Decoder manages the receipt of type and data information read from the
// remote side of a connection.
type Decoder struct {
mutex sync.Mutex; // each item must be received atomically
r io.Reader; // source of the data
seen map[typeId]*wireType; // which types we've already seen described
state *decodeState; // reads data from in-memory buffer
countState *decodeState; // reads counts from wire
mutex sync.Mutex; // each item must be received atomically
r io.Reader; // source of the data
wireType map[typeId]*wireType; // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine; // cache of compiled engines
ignorerCache map[typeId]**decEngine; // ditto for ignored objects
state *decodeState; // reads data from in-memory buffer
countState *decodeState; // reads counts from wire
buf []byte;
oneByte []byte;
}
......@@ -27,8 +30,10 @@ type Decoder struct {
func NewDecoder(r io.Reader) *Decoder {
dec := new(Decoder);
dec.r = r;
dec.seen = make(map[typeId]*wireType);
dec.wireType = make(map[typeId]*wireType);
dec.state = newDecodeState(nil); // buffer set in Decode(); rest is unimportant
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine);
dec.ignorerCache = make(map[typeId]**decEngine);
dec.oneByte = make([]byte, 1);
return dec;
......@@ -36,16 +41,16 @@ func NewDecoder(r io.Reader) *Decoder {
func (dec *Decoder) recvType(id typeId) {
// Have we already seen this type? That's an error
if _, alreadySeen := dec.seen[id]; alreadySeen {
if _, alreadySeen := dec.wireType[id]; alreadySeen {
dec.state.err = os.ErrorString("gob: duplicate type received");
return;
}
// Type:
wire := new(wireType);
decode(dec.state.b, tWireType, wire);
dec.state.err = dec.decode(tWireType, wire);
// Remember we've seen this type.
dec.seen[id] = wire;
dec.wireType[id] = wire;
}
// Decode reads the next value from the connection and stores
......@@ -97,7 +102,13 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
}
// No, it's a value.
dec.state.err = decode(dec.state.b, id, e);
// Make sure the type has been defined already.
_, ok := dec.wireType[id];
if !ok {
dec.state.err = errBadType;
break;
}
dec.state.err = dec.decode(id, e);
break;
}
return dec.state.err;
......
......@@ -235,19 +235,15 @@ func (enc *Encoder) send() {
enc.w.Write(enc.buf[0:total]);
}
func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
func (enc *Encoder) sendType(origt reflect.Type) {
// Drill down to the base type.
rt, _ := indirect(origt);
// We only send structs - everything else is basic or an error
switch rt.(type) {
switch rt := rt.(type) {
default:
// Basic types do not need to be described, but if this is a top-level
// type, it's a user error, at least for now.
if topLevel {
enc.badType(rt)
}
return;
// Basic types do not need to be described.
return
case *reflect.StructType:
// Structs do need to be described.
break
......@@ -255,10 +251,9 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
// Probably a bad field in a struct.
enc.badType(rt);
return;
case *reflect.ArrayType, *reflect.SliceType:
// Array and slice types are not sent, only their element types.
// If we see one here it's user error; probably a bad top-level value.
enc.badType(rt);
// Array and slice types are not sent, only their element types.
case reflect.ArrayOrSliceType:
enc.sendType(rt.Elem());
return;
}
......@@ -289,7 +284,7 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
// Now send the inner types
st := rt.(*reflect.StructType);
for i := 0; i < st.NumField(); i++ {
enc.sendType(st.Field(i).Type, false)
enc.sendType(st.Field(i).Type)
}
return;
}
......@@ -301,6 +296,12 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
panicln("Encoder: buffer not empty")
}
rt, _ := indirect(reflect.Typeof(e));
// Must be a struct
if _, ok := rt.(*reflect.StructType); !ok {
enc.badType(rt);
return enc.state.err;
}
// Make sure we're single-threaded through here.
enc.mutex.Lock();
......@@ -310,7 +311,7 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
// First, have we already sent this type?
if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it.
enc.sendType(rt, true);
enc.sendType(rt);
if enc.state.err != nil {
enc.state.b.Reset();
enc.countState.b.Reset();
......
......@@ -32,122 +32,10 @@ type ET3 struct {
// Like ET1 but with a different type for a field
type ET4 struct {
a int;
et2 *ET1;
et2 float;
next int;
}
func TestBasicEncoder(t *testing.T) {
b := new(bytes.Buffer);
enc := NewEncoder(b);
et1 := new(ET1);
et1.a = 7;
et1.et2 = new(ET2);
enc.Encode(et1);
if enc.state.err != nil {
t.Error("encoder fail:", enc.state.err)
}
// Decode the result by hand to verify;
state := newDecodeState(b);
// The output should be:
// 0) The length, 38.
length := decodeUint(state);
if length != 38 {
t.Fatal("0. expected length 38; got", length)
}
// 1) -7: the type id of ET1
id1 := decodeInt(state);
if id1 >= 0 {
t.Fatal("expected ET1 negative id; got", id1)
}
// 2) The wireType for ET1
wire1 := new(wireType);
err := decode(b, tWireType, wire1);
if err != nil {
t.Fatal("error decoding ET1 type:", err)
}
info := getTypeInfoNoError(reflect.Typeof(ET1{}));
trueWire1 := &wireType{s: info.id.gobType().(*structType)};
if !reflect.DeepEqual(wire1, trueWire1) {
t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1)
}
// 3) The length, 21.
length = decodeUint(state);
if length != 21 {
t.Fatal("3. expected length 21; got", length)
}
// 4) -8: the type id of ET2
id2 := decodeInt(state);
if id2 >= 0 {
t.Fatal("expected ET2 negative id; got", id2)
}
// 5) The wireType for ET2
wire2 := new(wireType);
err = decode(b, tWireType, wire2);
if err != nil {
t.Fatal("error decoding ET2 type:", err)
}
info = getTypeInfoNoError(reflect.Typeof(ET2{}));
trueWire2 := &wireType{s: info.id.gobType().(*structType)};
if !reflect.DeepEqual(wire2, trueWire2) {
t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2)
}
// 6) The length, 6.
length = decodeUint(state);
if length != 6 {
t.Fatal("6. expected length 6; got", length)
}
// 7) The type id for the et1 value
newId1 := decodeInt(state);
if newId1 != -id1 {
t.Fatal("expected Et1 id", -id1, "got", newId1)
}
// 8) The value of et1
newEt1 := new(ET1);
et1Id := getTypeInfoNoError(reflect.Typeof(*newEt1)).id;
err = decode(b, et1Id, newEt1);
if err != nil {
t.Fatal("error decoding ET1 value:", err)
}
if !reflect.DeepEqual(et1, newEt1) {
t.Fatalf("invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1)
}
// 9) EOF
if b.Len() != 0 {
t.Error("not at eof;", b.Len(), "bytes left")
}
// Now do it again. This time we should see only the type id and value.
b.Reset();
enc.Encode(et1);
if enc.state.err != nil {
t.Error("2nd round: encoder fail:", enc.state.err)
}
// The length.
length = decodeUint(state);
if length != 6 {
t.Fatal("6. expected length 6; got", length)
}
// 5a) The type id for the et1 value
newId1 = decodeInt(state);
if newId1 != -id1 {
t.Fatal("2nd round: expected Et1 id", -id1, "got", newId1)
}
// 6a) The value of et1
newEt1 = new(ET1);
err = decode(b, et1Id, newEt1);
if err != nil {
t.Fatal("2nd round: error decoding ET1 value:", err)
}
if !reflect.DeepEqual(et1, newEt1) {
t.Fatalf("2nd round: invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1)
}
// 7a) EOF
if b.Len() != 0 {
t.Error("2nd round: not at eof;", b.Len(), "bytes left")
}
}
func TestEncoderDecoder(t *testing.T) {
b := new(bytes.Buffer);
enc := NewEncoder(b);
......@@ -215,7 +103,7 @@ func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) {
t.Error("expected error for", msg)
}
if !shouldFail && (dec.state.err != nil) {
t.Error("unexpected error for", msg)
t.Error("unexpected error for", msg, dec.state.err)
}
}
......
......@@ -48,6 +48,7 @@ type gobType interface {
var types = make(map[reflect.Type]gobType)
var idToType = make(map[typeId]gobType)
var builtinIdToType map[typeId]gobType // set in init() after builtins are established
func setTypeId(typ gobType) {
nextId++;
......@@ -104,6 +105,10 @@ func init() {
checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id);
checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id);
checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id);
builtinIdToType = make(map[typeId]gobType);
for k, v := range idToType {
builtinIdToType[k] = v
}
}
// Array type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment