Commit 14b6a477 authored by Rob Pike's avatar Rob Pike

gob: compute information about a user's type once.

Other than maybe cleaning the code up a bit, this has
little practical effect for now, but lays the foundation
for remembering the method set of a type, which can
be expensive.

R=rsc
CC=golang-dev
https://golang.org/cl/4193041
parent 556506e8
...@@ -984,7 +984,7 @@ func TestInvalidField(t *testing.T) { ...@@ -984,7 +984,7 @@ func TestInvalidField(t *testing.T) {
var bad0 Bad0 var bad0 Bad0
bad0.ch = make(chan int) bad0.ch = make(chan int)
b := new(bytes.Buffer) b := new(bytes.Buffer)
err := nilEncoder.encode(b, reflect.NewValue(&bad0)) err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
if err == nil { if err == nil {
t.Error("expected error; got none") t.Error("expected error; got none")
} else if strings.Index(err.String(), "type") < 0 { } else if strings.Index(err.String(), "type") < 0 {
......
...@@ -409,9 +409,9 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { ...@@ -409,9 +409,9 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
return *(*uintptr)(up) return *(*uintptr)(up)
} }
func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr, indir int) (err os.Error) { func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
defer catchError(&err) defer catchError(&err)
p = allocate(rtyp, p, indir) p = allocate(ut.base, p, ut.indir)
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField state.fieldnum = singletonField
basep := p basep := p
...@@ -428,9 +428,13 @@ func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr ...@@ -428,9 +428,13 @@ func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr
return nil return nil
} }
func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, p uintptr, indir int) (err os.Error) { // Indir is for the value, not the type. At the time of the call it may
// differ from ut.indir, which was computed when the engine was built.
// This state cannot arise for decodeSingle, which is called directly
// from the user's value, not from the innards of an engine.
func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) {
defer catchError(&err) defer catchError(&err)
p = allocate(rtyp, p, indir) p = allocate(ut.base.(*reflect.StructType), p, indir)
state := newDecodeState(dec, &dec.buf) state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1 state.fieldnum = -1
basep := p basep := p
...@@ -702,7 +706,9 @@ var decIgnoreOpMap = map[typeId]decOp{ ...@@ -702,7 +706,9 @@ var decIgnoreOpMap = map[typeId]decOp{
// Return the decoding op for the base type under rt and // Return the decoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) { func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) {
typ, indir := indirect(rt) ut := userType(rt)
typ := ut.base
indir := ut.indir
var op decOp var op decOp
k := typ.Kind() k := typ.Kind()
if int(k) < len(decOpMap) { if int(k) < len(decOpMap) {
...@@ -757,8 +763,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ...@@ -757,8 +763,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
error(err) error(err)
} }
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs // indirect through enginePtr to delay evaluation for recursive structs.
err = dec.decodeStruct(*enginePtr, t, uintptr(p), i.indir) err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir)
if err != nil { if err != nil {
error(err) error(err)
} }
...@@ -837,7 +843,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { ...@@ -837,7 +843,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
// Answers the question for basic types, arrays, and slices. // Answers the question for basic types, arrays, and slices.
// Structs are considered ok; fields will be checked later. // Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
fr, _ = indirect(fr) fr = userType(fr).base
switch t := fr.(type) { switch t := fr.(type) {
default: default:
// map, chan, etc: cannot handle. // map, chan, etc: cannot handle.
...@@ -882,7 +888,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { ...@@ -882,7 +888,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
} else { } else {
sw = dec.wireType[fw].SliceT sw = dec.wireType[fw].SliceT
} }
elem, _ := indirect(t.Elem()) elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem) return sw != nil && dec.compatibleType(elem, sw.Elem)
case *reflect.StructType: case *reflect.StructType:
return true return true
...@@ -1026,20 +1032,22 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error { ...@@ -1026,20 +1032,22 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error {
return dec.decodeIgnoredValue(wireId) return dec.decodeIgnoredValue(wireId)
} }
// Dereference down to the underlying struct type. // Dereference down to the underlying struct type.
rt, indir := indirect(val.Type()) ut := userType(val.Type())
enginePtr, err := dec.getDecEnginePtr(wireId, rt) base := ut.base
indir := ut.indir
enginePtr, err := dec.getDecEnginePtr(wireId, base)
if err != nil { if err != nil {
return err return err
} }
engine := *enginePtr engine := *enginePtr
if st, ok := rt.(*reflect.StructType); ok { if st, ok := base.(*reflect.StructType); ok {
if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 {
name := rt.Name() name := base.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
} }
return dec.decodeStruct(engine, st, uintptr(val.Addr()), indir) return dec.decodeStruct(engine, ut, uintptr(val.Addr()), indir)
} }
return dec.decodeSingle(engine, rt, uintptr(val.Addr()), indir) return dec.decodeSingle(engine, ut, uintptr(val.Addr()))
} }
func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error { func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error {
......
...@@ -384,10 +384,10 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) ...@@ -384,10 +384,10 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
return return
} }
typ, _ := indirect(iv.Elem().Type()) ut := userType(iv.Elem().Type())
name, ok := concreteTypeToName[typ] name, ok := concreteTypeToName[ut.base]
if !ok { if !ok {
errorf("gob: type not registered for interface: %s", typ) errorf("gob: type not registered for interface: %s", ut.base)
} }
// Send the name. // Send the name.
state.encodeUint(uint64(len(name))) state.encodeUint(uint64(len(name)))
...@@ -396,14 +396,14 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) ...@@ -396,14 +396,14 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
error(err) error(err)
} }
// Define the type id if necessary. // Define the type id if necessary.
enc.sendTypeDescriptor(enc.writer(), state, typ) enc.sendTypeDescriptor(enc.writer(), state, ut)
// Send the type id. // Send the type id.
enc.sendTypeId(state, typ) enc.sendTypeId(state, ut)
// Encode the value into a new buffer. Any nested type definitions // Encode the value into a new buffer. Any nested type definitions
// should be written to b, before the encoded value. // should be written to b, before the encoded value.
enc.pushWriter(b) enc.pushWriter(b)
data := new(bytes.Buffer) data := new(bytes.Buffer)
err = enc.encode(data, iv.Elem()) err = enc.encode(data, iv.Elem(), ut)
if err != nil { if err != nil {
error(err) error(err)
} }
...@@ -437,7 +437,9 @@ var encOpMap = []encOp{ ...@@ -437,7 +437,9 @@ var encOpMap = []encOp{
// Return the encoding op for the base type under rt and // Return the encoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
typ, indir := indirect(rt) ut := userType(rt)
typ := ut.base
indir := ut.indir
var op encOp var op encOp
k := typ.Kind() k := typ.Kind()
if int(k) < len(encOpMap) { if int(k) < len(encOpMap) {
...@@ -559,14 +561,12 @@ func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine { ...@@ -559,14 +561,12 @@ func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine {
return enc.getEncEngine(rt) return enc.getEncEngine(rt)
} }
func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value) (err os.Error) { func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) {
defer catchError(&err) defer catchError(&err)
// Dereference down to the underlying object. for i := 0; i < ut.indir; i++ {
rt, indir := indirect(value.Type())
for i := 0; i < indir; i++ {
value = reflect.Indirect(value) value = reflect.Indirect(value)
} }
engine := enc.lockAndGetEncEngine(rt) engine := enc.lockAndGetEncEngine(ut.base)
if value.Type().Kind() == reflect.Struct { if value.Type().Kind() == reflect.Struct {
enc.encodeStruct(b, engine, value.Addr()) enc.encodeStruct(b, engine, value.Addr())
} else { } else {
......
...@@ -80,7 +80,8 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) { ...@@ -80,7 +80,8 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
// Drill down to the base type. // Drill down to the base type.
rt, _ := indirect(origt) ut := userType(origt)
rt := ut.base
switch rt := rt.(type) { switch rt := rt.(type) {
default: default:
...@@ -125,7 +126,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ ...@@ -125,7 +126,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ
// Id: // Id:
state.encodeInt(-int64(info.id)) state.encodeInt(-int64(info.id))
// Type: // Type:
enc.encode(state.b, reflect.NewValue(info.wire)) enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b) enc.writeMessage(w, state.b)
if enc.err != nil { if enc.err != nil {
return return
...@@ -153,15 +154,16 @@ func (enc *Encoder) Encode(e interface{}) os.Error { ...@@ -153,15 +154,16 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
return enc.EncodeValue(reflect.NewValue(e)) return enc.EncodeValue(reflect.NewValue(e))
} }
// sendTypeId makes sure the remote side knows about this type. // sendTypeDescriptor makes sure the remote side knows about this type.
// It will send a descriptor if this is the first time the type has been // It will send a descriptor if this is the first time the type has been
// sent. // sent.
func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt reflect.Type) { func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
// Make sure the type is known to the other side. // Make sure the type is known to the other side.
// First, have we already sent this type? // First, have we already sent this (base) type?
if _, alreadySent := enc.sent[rt]; !alreadySent { base := ut.base
if _, alreadySent := enc.sent[base]; !alreadySent {
// No, so send it. // No, so send it.
sent := enc.sendType(w, state, rt) sent := enc.sendType(w, state, base)
if enc.err != nil { if enc.err != nil {
return return
} }
...@@ -170,21 +172,21 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt refl ...@@ -170,21 +172,21 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt refl
// need to send the type info but we do need to update enc.sent. // need to send the type info but we do need to update enc.sent.
if !sent { if !sent {
typeLock.Lock() typeLock.Lock()
info, err := getTypeInfo(rt) info, err := getTypeInfo(base)
typeLock.Unlock() typeLock.Unlock()
if err != nil { if err != nil {
enc.setError(err) enc.setError(err)
return return
} }
enc.sent[rt] = info.id enc.sent[base] = info.id
} }
} }
} }
// sendTypeId sends the id, which must have already been defined. // sendTypeId sends the id, which must have already been defined.
func (enc *Encoder) sendTypeId(state *encoderState, rt reflect.Type) { func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) {
// Identify the type of this top-level value. // Identify the type of this top-level value.
state.encodeInt(int64(enc.sent[rt])) state.encodeInt(int64(enc.sent[ut.base]))
} }
// EncodeValue transmits the data item represented by the reflection value, // EncodeValue transmits the data item represented by the reflection value,
...@@ -199,18 +201,18 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { ...@@ -199,18 +201,18 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
enc.w = enc.w[0:1] enc.w = enc.w[0:1]
enc.err = nil enc.err = nil
rt, _ := indirect(value.Type()) ut := userType(value.Type())
state := newEncoderState(enc, new(bytes.Buffer)) state := newEncoderState(enc, new(bytes.Buffer))
enc.sendTypeDescriptor(enc.writer(), state, rt) enc.sendTypeDescriptor(enc.writer(), state, ut)
enc.sendTypeId(state, rt) enc.sendTypeId(state, ut)
if enc.err != nil { if enc.err != nil {
return enc.err return enc.err
} }
// Encode the object. // Encode the object.
err := enc.encode(state.b, value) err := enc.encode(state.b, value, ut)
if err != nil { if err != nil {
enc.setError(err) enc.setError(err)
} else { } else {
......
...@@ -11,12 +11,43 @@ import ( ...@@ -11,12 +11,43 @@ import (
"sync" "sync"
) )
// Reflection types are themselves interface values holding structs // userTypeInfo stores the information associated with a type the user has handed
// describing the type. Each type has a different struct so that struct can // to the package. It's computed once and stored in a map keyed by reflection
// be the kind. For example, if typ is the reflect type for an int8, typ is // type.
// a pointer to a reflect.Int8Type struct; if typ is the reflect type for a type userTypeInfo struct {
// function, typ is a pointer to a reflect.FuncType struct; we use the type user reflect.Type // the type the user handed us
// of that pointer as the kind. base reflect.Type // the base type after all indirections
indir int // number of indirections to reach the base type
}
var (
// Protected by an RWMutex because we read it a lot and write
// it only when we see a new type, typically when compiling.
userTypeLock sync.RWMutex
userTypeCache = make(map[reflect.Type]*userTypeInfo)
)
// userType returns, and saves, the information associated with user-provided type rt
func userType(rt reflect.Type) *userTypeInfo {
userTypeLock.RLock()
ut := userTypeCache[rt]
userTypeLock.RUnlock()
if ut != nil {
return ut
}
// Now set the value under the write lock.
userTypeLock.Lock()
defer userTypeLock.Unlock()
if ut = userTypeCache[rt]; ut != nil {
// Lost the race; not a problem.
return ut
}
ut = new(userTypeInfo)
ut.user = rt
ut.base, ut.indir = indirect(rt)
userTypeCache[rt] = ut
return ut
}
// A typeId represents a gob Type as an integer that can be passed on the wire. // A typeId represents a gob Type as an integer that can be passed on the wire.
// Internally, typeIds are used as keys to a map to recover the underlying type info. // Internally, typeIds are used as keys to a map to recover the underlying type info.
...@@ -110,6 +141,7 @@ var ( ...@@ -110,6 +141,7 @@ var (
// Predefined because it's needed by the Decoder // Predefined because it's needed by the Decoder
var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id
var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType)
func init() { func init() {
// Some magic numbers to make sure there are no surprises. // Some magic numbers to make sure there are no surprises.
...@@ -133,6 +165,7 @@ func init() { ...@@ -133,6 +165,7 @@ func init() {
} }
nextId = firstUserId nextId = firstUserId
registerBasics() registerBasics()
wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil)))
} }
// Array type // Array type
...@@ -317,10 +350,10 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -317,10 +350,10 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
field := make([]*fieldType, t.NumField()) field := make([]*fieldType, t.NumField())
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
f := t.Field(i) f := t.Field(i)
typ, _ := indirect(f.Type) typ := userType(f.Type).base
tname := typ.Name() tname := typ.Name()
if tname == "" { if tname == "" {
t, _ := indirect(f.Type) t := userType(f.Type).base
tname = t.String() tname = t.String()
} }
gt, err := getType(tname, f.Type) gt, err := getType(tname, f.Type)
...@@ -341,7 +374,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { ...@@ -341,7 +374,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
// getType returns the Gob type describing the given reflect.Type. // getType returns the Gob type describing the given reflect.Type.
// typeLock must be held. // typeLock must be held.
func getType(name string, rt reflect.Type) (gobType, os.Error) { func getType(name string, rt reflect.Type) (gobType, os.Error) {
rt, _ = indirect(rt) rt = userType(rt).base
typ, present := types[rt] typ, present := types[rt]
if present { if present {
return typ, nil return typ, nil
...@@ -371,6 +404,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { ...@@ -371,6 +404,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
types[rt] = typ types[rt] = typ
setTypeId(typ) setTypeId(typ)
checkId(expect, nextId) checkId(expect, nextId)
userType(rt) // might as well cache it now
return nextId return nextId
} }
...@@ -473,18 +507,18 @@ func RegisterName(name string, value interface{}) { ...@@ -473,18 +507,18 @@ func RegisterName(name string, value interface{}) {
// reserved for nil // reserved for nil
panic("attempt to register empty name") panic("attempt to register empty name")
} }
rt, _ := indirect(reflect.Typeof(value)) base := userType(reflect.Typeof(value)).base
// Check for incompatible duplicates. // Check for incompatible duplicates.
if t, ok := nameToConcreteType[name]; ok && t != rt { if t, ok := nameToConcreteType[name]; ok && t != base {
panic("gob: registering duplicate types for " + name) panic("gob: registering duplicate types for " + name)
} }
if n, ok := concreteTypeToName[rt]; ok && n != name { if n, ok := concreteTypeToName[base]; ok && n != name {
panic("gob: registering duplicate names for " + rt.String()) panic("gob: registering duplicate names for " + base.String())
} }
// Store the name and type provided by the user.... // Store the name and type provided by the user....
nameToConcreteType[name] = reflect.Typeof(value) nameToConcreteType[name] = reflect.Typeof(value)
// but the flattened type in the type table, since that's what decode needs. // but the flattened type in the type table, since that's what decode needs.
concreteTypeToName[rt] = name concreteTypeToName[base] = name
} }
// Register records a type, identified by a value for that type, under its // Register records a type, identified by a value for that type, under its
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment