Commit b3a5f9e5 authored by Robert Griesemer's avatar Robert Griesemer

go/doc: don't show methods of exported anonymous fields

Added flag AllMethods: if not set (future default), embedded
methods of exported (and thus visible) embedded fields are not
shown in the final package documentation

The actual change for AllMethods is just in sortedFuncs. All
other changes are simplifications of the existing logic (mostly
deletion of code): Because method conflicts due to embedding
must always be detected, remove any premature elimination of
types and methods. Instead collect all named types and all
methods and do the filtering at the end.

Miscellaneous:
- renamed baseType -> namedType
- streamline logic for recording embedded types
- record embedded types via a map (simpler data structures)

AllMethods is set by default; so the output is unchanged and
the tests pass. The next CL will enable the AllMethods flag
and have adjusted tests (and fix issue 2791).

R=rsc
CC=golang-dev
https://golang.org/cl/5572076
parent a0d0ed20
...@@ -68,12 +68,17 @@ const ( ...@@ -68,12 +68,17 @@ const (
// extract documentation for all package-level declarations, // extract documentation for all package-level declarations,
// not just exported ones // not just exported ones
AllDecls Mode = 1 << iota AllDecls Mode = 1 << iota
// show all embedded methods, not just the ones of
// invisible (unexported) anonymous fields
AllMethods
) )
// New computes the package documentation for the given package AST. // New computes the package documentation for the given package AST.
// New takes ownership of the AST pkg and may edit or overwrite it. // New takes ownership of the AST pkg and may edit or overwrite it.
// //
func New(pkg *ast.Package, importPath string, mode Mode) *Package { func New(pkg *ast.Package, importPath string, mode Mode) *Package {
mode |= AllMethods // TODO(gri) remove this to enable flag
var r reader var r reader
r.readPackage(pkg, mode) r.readPackage(pkg, mode)
r.computeMethodSets() r.computeMethodSets()
...@@ -86,8 +91,8 @@ func New(pkg *ast.Package, importPath string, mode Mode) *Package { ...@@ -86,8 +91,8 @@ func New(pkg *ast.Package, importPath string, mode Mode) *Package {
Filenames: r.filenames, Filenames: r.filenames,
Bugs: r.bugs, Bugs: r.bugs,
Consts: sortedValues(r.values, token.CONST), Consts: sortedValues(r.values, token.CONST),
Types: sortedTypes(r.types), Types: sortedTypes(r.types, mode&AllMethods != 0),
Vars: sortedValues(r.values, token.VAR), Vars: sortedValues(r.values, token.VAR),
Funcs: sortedFuncs(r.funcs), Funcs: sortedFuncs(r.funcs, true),
} }
} }
...@@ -23,11 +23,11 @@ func filterIdentList(list []*ast.Ident) []*ast.Ident { ...@@ -23,11 +23,11 @@ func filterIdentList(list []*ast.Ident) []*ast.Ident {
} }
// filterFieldList removes unexported fields (field names) from the field list // filterFieldList removes unexported fields (field names) from the field list
// in place and returns true if fields were removed. Removed fields that are // in place and returns true if fields were removed. Anonymous fields are
// anonymous (embedded) fields are added as embedded types to base. filterType // recorded with the parent type. filterType is called with the types of
// is called with the types of all remaining fields. // all remaining fields.
// //
func (r *reader) filterFieldList(base *baseType, fields *ast.FieldList) (removedFields bool) { func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList) (removedFields bool) {
if fields == nil { if fields == nil {
return return
} }
...@@ -37,18 +37,9 @@ func (r *reader) filterFieldList(base *baseType, fields *ast.FieldList) (removed ...@@ -37,18 +37,9 @@ func (r *reader) filterFieldList(base *baseType, fields *ast.FieldList) (removed
keepField := false keepField := false
if n := len(field.Names); n == 0 { if n := len(field.Names); n == 0 {
// anonymous field // anonymous field
name, imp := baseTypeName(field.Type) name := r.recordAnonymousField(parent, field.Type)
if ast.IsExported(name) { if ast.IsExported(name) {
// we keep the field - in this case r.readDecl
// will take care of adding the embedded type
keepField = true keepField = true
} else if base != nil && !imp {
// we don't keep the field - add it as an embedded
// type so we won't loose its methods, if any
if embedded := r.lookupType(name); embedded != nil {
_, ptr := field.Type.(*ast.StarExpr)
base.addEmbeddedType(embedded, ptr)
}
} }
} else { } else {
field.Names = filterIdentList(field.Names) field.Names = filterIdentList(field.Names)
...@@ -86,7 +77,7 @@ func (r *reader) filterParamList(fields *ast.FieldList) { ...@@ -86,7 +77,7 @@ func (r *reader) filterParamList(fields *ast.FieldList) {
// in place. If fields (or methods) have been removed, the corresponding // in place. If fields (or methods) have been removed, the corresponding
// struct or interface type has the Incomplete field set to true. // struct or interface type has the Incomplete field set to true.
// //
func (r *reader) filterType(base *baseType, typ ast.Expr) { func (r *reader) filterType(parent *namedType, typ ast.Expr) {
switch t := typ.(type) { switch t := typ.(type) {
case *ast.Ident: case *ast.Ident:
// nothing to do // nothing to do
...@@ -95,14 +86,14 @@ func (r *reader) filterType(base *baseType, typ ast.Expr) { ...@@ -95,14 +86,14 @@ func (r *reader) filterType(base *baseType, typ ast.Expr) {
case *ast.ArrayType: case *ast.ArrayType:
r.filterType(nil, t.Elt) r.filterType(nil, t.Elt)
case *ast.StructType: case *ast.StructType:
if r.filterFieldList(base, t.Fields) { if r.filterFieldList(parent, t.Fields) {
t.Incomplete = true t.Incomplete = true
} }
case *ast.FuncType: case *ast.FuncType:
r.filterParamList(t.Params) r.filterParamList(t.Params)
r.filterParamList(t.Results) r.filterParamList(t.Results)
case *ast.InterfaceType: case *ast.InterfaceType:
if r.filterFieldList(base, t.Methods) { if r.filterFieldList(parent, t.Methods) {
t.Incomplete = true t.Incomplete = true
} }
case *ast.MapType: case *ast.MapType:
......
...@@ -89,7 +89,7 @@ func (mset methodSet) add(m *Func) { ...@@ -89,7 +89,7 @@ func (mset methodSet) add(m *Func) {
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Base types // Named types
// baseTypeName returns the name of the base type of x (or "") // baseTypeName returns the name of the base type of x (or "")
// and whether the type is imported or not. // and whether the type is imported or not.
...@@ -110,31 +110,23 @@ func baseTypeName(x ast.Expr) (name string, imported bool) { ...@@ -110,31 +110,23 @@ func baseTypeName(x ast.Expr) (name string, imported bool) {
return return
} }
// embeddedType describes the type of an anonymous field. // A namedType represents a named unqualified (package local, or possibly
// predeclared) type. The namedType for a type name is always found via
// reader.lookupType.
// //
type embeddedType struct { type namedType struct {
typ *baseType // the corresponding base type
ptr bool // if set, the anonymous field type is a pointer
}
type baseType struct {
doc string // doc comment for type doc string // doc comment for type
name string // local type name (excluding package qualifier) name string // type name
decl *ast.GenDecl // nil if declaration hasn't been seen yet decl *ast.GenDecl // nil if declaration hasn't been seen yet
isEmbedded bool // true if this type is embedded
isStruct bool // true if this type is a struct
embedded map[*namedType]bool // true if the embedded type is a pointer
// associated declarations // associated declarations
values []*Value // consts and vars values []*Value // consts and vars
funcs methodSet funcs methodSet
methods methodSet methods methodSet
isEmbedded bool // true if this type is embedded
isStruct bool // true if this type is a struct
embedded []embeddedType // list of embedded types
}
func (typ *baseType) addEmbeddedType(e *baseType, isPtr bool) {
e.isEmbedded = true
typ.embedded = append(typ.embedded, embeddedType{e, isPtr})
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
...@@ -158,22 +150,16 @@ type reader struct { ...@@ -158,22 +150,16 @@ type reader struct {
// declarations // declarations
imports map[string]int imports map[string]int
values []*Value // consts and vars values []*Value // consts and vars
types map[string]*baseType types map[string]*namedType
funcs methodSet funcs methodSet
} }
// isVisible reports whether name is visible in the documentation.
//
func (r *reader) isVisible(name string) bool {
return r.mode&AllDecls != 0 || ast.IsExported(name)
}
// lookupType returns the base type with the given name. // lookupType returns the base type with the given name.
// If the base type has not been encountered yet, a new // If the base type has not been encountered yet, a new
// type with the given name but no associated declaration // type with the given name but no associated declaration
// is added to the type map. // is added to the type map.
// //
func (r *reader) lookupType(name string) *baseType { func (r *reader) lookupType(name string) *namedType {
if name == "" || name == "_" { if name == "" || name == "_" {
return nil // no type docs for anonymous types return nil // no type docs for anonymous types
} }
...@@ -181,8 +167,9 @@ func (r *reader) lookupType(name string) *baseType { ...@@ -181,8 +167,9 @@ func (r *reader) lookupType(name string) *baseType {
return typ return typ
} }
// type not found - add one without declaration // type not found - add one without declaration
typ := &baseType{ typ := &namedType{
name: name, name: name,
embedded: make(map[*namedType]bool),
funcs: make(methodSet), funcs: make(methodSet),
methods: make(methodSet), methods: make(methodSet),
} }
...@@ -190,6 +177,24 @@ func (r *reader) lookupType(name string) *baseType { ...@@ -190,6 +177,24 @@ func (r *reader) lookupType(name string) *baseType {
return typ return typ
} }
// recordAnonymousField registers fieldType as the type of an
// anonymous field in the parent type. If the field is imported
// (qualified name) or the parent is nil, the field is ignored.
// The function returns the field name.
//
func (r *reader) recordAnonymousField(parent *namedType, fieldType ast.Expr) (fname string) {
fname, imp := baseTypeName(fieldType)
if parent == nil || imp {
return
}
if ftype := r.lookupType(fname); ftype != nil {
ftype.isEmbedded = true
_, ptr := fieldType.(*ast.StarExpr)
parent.embedded[ftype] = ptr
}
return
}
func (r *reader) readDoc(comment *ast.CommentGroup) { func (r *reader) readDoc(comment *ast.CommentGroup) {
// By convention there should be only one package comment // By convention there should be only one package comment
// but collect all of them if there are more then one. // but collect all of them if there are more then one.
...@@ -232,7 +237,7 @@ func (r *reader) readValue(decl *ast.GenDecl) { ...@@ -232,7 +237,7 @@ func (r *reader) readValue(decl *ast.GenDecl) {
switch { switch {
case s.Type != nil: case s.Type != nil:
// a type is present; determine its name // a type is present; determine its name
if n, imp := baseTypeName(s.Type); !imp && r.isVisible(n) { if n, imp := baseTypeName(s.Type); !imp {
name = n name = n
} }
case decl.Tok == token.CONST: case decl.Tok == token.CONST:
...@@ -267,8 +272,7 @@ func (r *reader) readValue(decl *ast.GenDecl) { ...@@ -267,8 +272,7 @@ func (r *reader) readValue(decl *ast.GenDecl) {
const threshold = 0.75 const threshold = 0.75
if domName != "" && domFreq >= int(float64(len(decl.Specs))*threshold) { if domName != "" && domFreq >= int(float64(len(decl.Specs))*threshold) {
// typed entries are sufficiently frequent // typed entries are sufficiently frequent
typ := r.lookupType(domName) if typ := r.lookupType(domName); typ != nil {
if typ != nil {
values = &typ.values // associate with that type values = &typ.values // associate with that type
} }
} }
...@@ -321,22 +325,14 @@ func (r *reader) readType(decl *ast.GenDecl, spec *ast.TypeSpec) { ...@@ -321,22 +325,14 @@ func (r *reader) readType(decl *ast.GenDecl, spec *ast.TypeSpec) {
decl.Doc = nil // doc consumed - remove from AST decl.Doc = nil // doc consumed - remove from AST
typ.doc = doc.Text() typ.doc = doc.Text()
// look for anonymous fields that might contribute methods // record anonymous fields (they may contribute methods)
// (some fields may have been recorded already when filtering
// exports, but that's ok)
var list []*ast.Field var list []*ast.Field
list, typ.isStruct = fields(spec.Type) list, typ.isStruct = fields(spec.Type)
for _, field := range list { for _, field := range list {
if len(field.Names) == 0 { if len(field.Names) == 0 {
// anonymous field - add corresponding field type to typ r.recordAnonymousField(typ, field.Type)
n, imp := baseTypeName(field.Type)
if imp {
// imported type - we don't handle this case
// at the moment
return
}
if embedded := r.lookupType(n); embedded != nil {
_, ptr := field.Type.(*ast.StarExpr)
typ.addEmbeddedType(embedded, ptr)
}
} }
} }
} }
...@@ -356,24 +352,10 @@ func (r *reader) readFunc(fun *ast.FuncDecl) { ...@@ -356,24 +352,10 @@ func (r *reader) readFunc(fun *ast.FuncDecl) {
// don't show this method // don't show this method
return return
} }
var typ *baseType if typ := r.lookupType(recvTypeName); typ != nil {
if r.isVisible(recvTypeName) {
// visible recv type: if not found, add it to r.types
typ = r.lookupType(recvTypeName)
} else {
// invisible recv type: if not found, do not add it
// (invisible embedded types are added before this
// phase, so if the type doesn't exist yet, we don't
// care about this method)
typ = r.types[recvTypeName]
}
if typ != nil {
// associate method with the type
// (if the type is not exported, it may be embedded
// somewhere so we need to collect the method anyway)
typ.methods.set(fun) typ.methods.set(fun)
} }
// otherwise don't show the method // otherwise ignore the method
// TODO(gri): There may be exported methods of non-exported types // TODO(gri): There may be exported methods of non-exported types
// that can be called because of exported values (consts, vars, or // that can be called because of exported values (consts, vars, or
// function results) of that type. Could determine if that is the // function results) of that type. Could determine if that is the
...@@ -389,7 +371,7 @@ func (r *reader) readFunc(fun *ast.FuncDecl) { ...@@ -389,7 +371,7 @@ func (r *reader) readFunc(fun *ast.FuncDecl) {
// exactly one (named or anonymous) result associated // exactly one (named or anonymous) result associated
// with the first type in result signature (there may // with the first type in result signature (there may
// be more than one result) // be more than one result)
if n, imp := baseTypeName(res.Type); !imp && r.isVisible(n) { if n, imp := baseTypeName(res.Type); !imp {
if typ := r.lookupType(n); typ != nil { if typ := r.lookupType(n); typ != nil {
// associate Func with typ // associate Func with typ
typ.funcs.set(fun) typ.funcs.set(fun)
...@@ -476,7 +458,7 @@ func (r *reader) readPackage(pkg *ast.Package, mode Mode) { ...@@ -476,7 +458,7 @@ func (r *reader) readPackage(pkg *ast.Package, mode Mode) {
r.filenames = make([]string, len(pkg.Files)) r.filenames = make([]string, len(pkg.Files))
r.imports = make(map[string]int) r.imports = make(map[string]int)
r.mode = mode r.mode = mode
r.types = make(map[string]*baseType) r.types = make(map[string]*namedType)
r.funcs = make(methodSet) r.funcs = make(methodSet)
// sort package files before reading them so that the // sort package files before reading them so that the
...@@ -554,24 +536,23 @@ func customizeRecv(f *Func, recvTypeName string, embeddedIsPtr bool, level int) ...@@ -554,24 +536,23 @@ func customizeRecv(f *Func, recvTypeName string, embeddedIsPtr bool, level int)
return &newF return &newF
} }
// collectEmbeddedMethods collects the embedded methods from // collectEmbeddedMethods collects the embedded methods of typ in mset.
// all processed embedded types found in info in mset.
// //
func collectEmbeddedMethods(mset methodSet, typ *baseType, recvTypeName string, embeddedIsPtr bool, level int) { func (r *reader) collectEmbeddedMethods(mset methodSet, typ *namedType, recvTypeName string, embeddedIsPtr bool, level int) {
for _, e := range typ.embedded { for embedded, isPtr := range typ.embedded {
// Once an embedded type is embedded as a pointer type // Once an embedded type is embedded as a pointer type
// all embedded types in those types are treated like // all embedded types in those types are treated like
// pointer types for the purpose of the receiver type // pointer types for the purpose of the receiver type
// computation; i.e., embeddedIsPtr is sticky for this // computation; i.e., embeddedIsPtr is sticky for this
// embedding hierarchy. // embedding hierarchy.
thisEmbeddedIsPtr := embeddedIsPtr || e.ptr thisEmbeddedIsPtr := embeddedIsPtr || isPtr
for _, m := range e.typ.methods { for _, m := range embedded.methods {
// only top-level methods are embedded // only top-level methods are embedded
if m.Level == 0 { if m.Level == 0 {
mset.add(customizeRecv(m, recvTypeName, thisEmbeddedIsPtr, level)) mset.add(customizeRecv(m, recvTypeName, thisEmbeddedIsPtr, level))
} }
} }
collectEmbeddedMethods(mset, e.typ, recvTypeName, thisEmbeddedIsPtr, level+1) r.collectEmbeddedMethods(mset, embedded, recvTypeName, thisEmbeddedIsPtr, level+1)
} }
} }
...@@ -582,7 +563,7 @@ func (r *reader) computeMethodSets() { ...@@ -582,7 +563,7 @@ func (r *reader) computeMethodSets() {
// collect embedded methods for t // collect embedded methods for t
if t.isStruct { if t.isStruct {
// struct // struct
collectEmbeddedMethods(t.methods, t, t.name, false, 1) r.collectEmbeddedMethods(t.methods, t, t.name, false, 1)
} else { } else {
// interface // interface
// TODO(gri) fix this // TODO(gri) fix this
...@@ -597,7 +578,7 @@ func (r *reader) computeMethodSets() { ...@@ -597,7 +578,7 @@ func (r *reader) computeMethodSets() {
// //
func (r *reader) cleanupTypes() { func (r *reader) cleanupTypes() {
for _, t := range r.types { for _, t := range r.types {
visible := r.isVisible(t.name) visible := r.mode&AllDecls != 0 || ast.IsExported(t.name)
if t.decl == nil && (predeclaredTypes[t.name] || t.isEmbedded && visible) { if t.decl == nil && (predeclaredTypes[t.name] || t.isEmbedded && visible) {
// t.name is a predeclared type (and was not redeclared in this package), // t.name is a predeclared type (and was not redeclared in this package),
// or it was embedded somewhere but its declaration is missing (because // or it was embedded somewhere but its declaration is missing (because
...@@ -607,6 +588,8 @@ func (r *reader) cleanupTypes() { ...@@ -607,6 +588,8 @@ func (r *reader) cleanupTypes() {
r.values = append(r.values, t.values...) r.values = append(r.values, t.values...)
// 2) move factory functions // 2) move factory functions
for name, f := range t.funcs { for name, f := range t.funcs {
// in a correct AST, package-level function names
// are all different - no need to check for conflicts
r.funcs[name] = f r.funcs[name] = f
} }
// 3) move methods // 3) move methods
...@@ -689,7 +672,7 @@ func sortedValues(m []*Value, tok token.Token) []*Value { ...@@ -689,7 +672,7 @@ func sortedValues(m []*Value, tok token.Token) []*Value {
return list return list
} }
func sortedTypes(m map[string]*baseType) []*Type { func sortedTypes(m map[string]*namedType, allMethods bool) []*Type {
list := make([]*Type, len(m)) list := make([]*Type, len(m))
i := 0 i := 0
for _, t := range m { for _, t := range m {
...@@ -699,8 +682,8 @@ func sortedTypes(m map[string]*baseType) []*Type { ...@@ -699,8 +682,8 @@ func sortedTypes(m map[string]*baseType) []*Type {
Decl: t.decl, Decl: t.decl,
Consts: sortedValues(t.values, token.CONST), Consts: sortedValues(t.values, token.CONST),
Vars: sortedValues(t.values, token.VAR), Vars: sortedValues(t.values, token.VAR),
Funcs: sortedFuncs(t.funcs), Funcs: sortedFuncs(t.funcs, true),
Methods: sortedFuncs(t.methods), Methods: sortedFuncs(t.methods, allMethods),
} }
i++ i++
} }
...@@ -714,12 +697,19 @@ func sortedTypes(m map[string]*baseType) []*Type { ...@@ -714,12 +697,19 @@ func sortedTypes(m map[string]*baseType) []*Type {
return list return list
} }
func sortedFuncs(m methodSet) []*Func { func removeStar(s string) string {
if len(s) > 0 && s[0] == '*' {
return s[1:]
}
return s
}
func sortedFuncs(m methodSet, allMethods bool) []*Func {
list := make([]*Func, len(m)) list := make([]*Func, len(m))
i := 0 i := 0
for _, m := range m { for _, m := range m {
// exclude conflict entries // exclude conflict entries
if m.Decl != nil { if m.Decl != nil && (allMethods || ast.IsExported(removeStar(m.Orig))) {
list[i] = m list[i] = m
i++ i++
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment