Commit 1d18f66d authored by Ian Lance Taylor's avatar Ian Lance Taylor

cmd/cgo: write a string rather than building an AST

This generates the same code as before, but does so directly rather
than building an AST and printing that. This is in preparation for
later changes.

Change-Id: Ifec141120bcc74847f0bff8d3d47306bfe69b454
Reviewed-on: https://go-review.googlesource.com/c/142883
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent af951994
......@@ -744,16 +744,19 @@ func (p *Package) rewriteCalls(f *File) bool {
// argument and then calls the original function.
// This returns whether the package needs to import unsafe as _cgo_unsafe.
func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
params := name.FuncType.Params
args := call.Call.Args
// Avoid a crash if the number of arguments is
// less than the number of parameters.
// This will be caught when the generated file is compiled.
if len(call.Call.Args) < len(name.FuncType.Params) {
if len(args) < len(params) {
return false
}
any := false
for i, param := range name.FuncType.Params {
if p.needsPointerCheck(f, param.Go, call.Call.Args[i]) {
for i, param := range params {
if p.needsPointerCheck(f, param.Go, args[i]) {
any = true
break
}
......@@ -772,127 +775,108 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
// Using a function literal like this lets us do correct
// argument type checking, and works correctly if the call is
// deferred.
var sb bytes.Buffer
sb.WriteString("func(")
needsUnsafe := false
params := make([]*ast.Field, len(name.FuncType.Params))
nargs := make([]ast.Expr, len(name.FuncType.Params))
var stmts []ast.Stmt
for i, param := range name.FuncType.Params {
// params is going to become the parameters of the
// function literal.
// nargs is going to become the list of arguments made
// by the call within the function literal.
// nparam is the parameter of the function literal that
// corresponds to param.
origArg := call.Call.Args[i]
nparam := ast.NewIdent(fmt.Sprintf("_cgo%d", i))
nargs[i] = nparam
// The Go version of the C type might use unsafe.Pointer,
// but the file might not import unsafe.
// Rewrite the Go type if necessary to use _cgo_unsafe.
for i, param := range params {
if i > 0 {
sb.WriteString(", ")
}
fmt.Fprintf(&sb, "_cgo%d ", i)
ptype := p.rewriteUnsafe(param.Go)
if ptype != param.Go {
needsUnsafe = true
}
sb.WriteString(gofmtLine(ptype))
}
params[i] = &ast.Field{
Names: []*ast.Ident{nparam},
Type: ptype,
}
if !p.needsPointerCheck(f, param.Go, origArg) {
continue
}
sb.WriteString(")")
// Run the cgo pointer checks on nparam.
result := false
twoResults := false
// Change the function literal to call the real function
// with the parameter passed through _cgoCheckPointer.
c := &ast.CallExpr{
Fun: ast.NewIdent("_cgoCheckPointer"),
Args: []ast.Expr{
nparam,
},
// Check whether this call expects two results.
for _, ref := range f.Ref {
if ref.Expr != &call.Call.Fun {
continue
}
// Add optional additional arguments for an address
// expression.
c.Args = p.checkAddrArgs(f, c.Args, origArg)
stmt := &ast.ExprStmt{
X: c,
if ref.Context == ctxCall2 {
sb.WriteString(" (")
result = true
twoResults = true
}
stmts = append(stmts, stmt)
break
}
const cgoMarker = "__cgo__###__marker__"
fcall := &ast.CallExpr{
Fun: ast.NewIdent(cgoMarker),
Args: nargs,
}
ftype := &ast.FuncType{
Params: &ast.FieldList{
List: params,
},
}
// Add the result type, if any.
if name.FuncType.Result != nil {
rtype := p.rewriteUnsafe(name.FuncType.Result.Go)
if rtype != name.FuncType.Result.Go {
needsUnsafe = true
}
ftype.Results = &ast.FieldList{
List: []*ast.Field{
&ast.Field{
Type: rtype,
},
},
if !twoResults {
sb.WriteString(" ")
}
sb.WriteString(gofmtLine(rtype))
result = true
}
// If this call expects two results, we have to
// adjust the results of the function we generated.
for _, ref := range f.Ref {
if ref.Expr == &call.Call.Fun && ref.Context == ctxCall2 {
if ftype.Results == nil {
// An explicit void argument
// looks odd but it seems to
// be how cgo has worked historically.
ftype.Results = &ast.FieldList{
List: []*ast.Field{
&ast.Field{
Type: ast.NewIdent("_Ctype_void"),
},
},
}
}
ftype.Results.List = append(ftype.Results.List,
&ast.Field{
Type: ast.NewIdent("error"),
})
// Add the second result type, if any.
if twoResults {
if name.FuncType.Result == nil {
// An explicit void result looks odd but it
// seems to be how cgo has worked historically.
sb.WriteString("_Ctype_void")
}
sb.WriteString(", error)")
}
var fbody ast.Stmt
if ftype.Results == nil {
fbody = &ast.ExprStmt{
X: fcall,
sb.WriteString(" { ")
for i, param := range params {
arg := args[i]
if !p.needsPointerCheck(f, param.Go, arg) {
continue
}
} else {
fbody = &ast.ReturnStmt{
Results: []ast.Expr{fcall},
// Check for &a[i].
if p.checkIndex(&sb, f, arg, i) {
continue
}
// Check for &x.
if p.checkAddr(&sb, arg, i) {
continue
}
fmt.Fprintf(&sb, "_cgoCheckPointer(_cgo%d); ", i)
}
lit := &ast.FuncLit{
Type: ftype,
Body: &ast.BlockStmt{
List: append(stmts, fbody),
},
if result {
sb.WriteString("return ")
}
text := strings.Replace(gofmt(lit), "\n", ";", -1)
repl := strings.Split(text, cgoMarker)
f.Edit.Insert(f.offset(call.Call.Fun.Pos()), repl[0])
f.Edit.Insert(f.offset(call.Call.Fun.End()), repl[1])
// Now we are ready to call the C function.
// To work smoothly with rewriteRef we leave the call in place
// and just insert our new arguments between the function
// and the old arguments.
f.Edit.Insert(f.offset(call.Call.Fun.Pos()), sb.String())
sb.Reset()
sb.WriteString("(")
for i := range params {
if i > 0 {
sb.WriteString(", ")
}
fmt.Fprintf(&sb, "_cgo%d", i)
}
sb.WriteString("); }")
f.Edit.Insert(f.offset(call.Call.Lparen), sb.String())
return needsUnsafe
}
......@@ -1001,19 +985,13 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool {
}
}
// checkAddrArgs tries to add arguments to the call of
// _cgoCheckPointer when the argument is an address expression. We
// pass true to mean that the argument is an address operation of
// something other than a slice index, which means that it's only
// necessary to check the specific element pointed to, not the entire
// object. This is for &s.f, where f is a field in a struct. We can
// pass a slice or array, meaning that we should check the entire
// slice or array but need not check any other part of the object.
// This is for &s.a[i], where we need to check all of a. However, we
// only pass the slice or array if we can refer to it without side
// effects.
func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr {
// checkIndex checks whether arg the form &a[i], possibly inside type
// conversions. If so, and if a has no side effects, it writes
// _cgoCheckPointer(_cgoNN, a) to sb and returns true. This tells
// _cgoCheckPointer to check the complete contents of the slice.
func (p *Package) checkIndex(sb *bytes.Buffer, f *File, arg ast.Expr, i int) bool {
// Strip type conversions.
x := arg
for {
c, ok := x.(*ast.CallExpr)
if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
......@@ -1023,22 +1001,46 @@ func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr
}
u, ok := x.(*ast.UnaryExpr)
if !ok || u.Op != token.AND {
return args
return false
}
index, ok := u.X.(*ast.IndexExpr)
if !ok {
// This is the address of something that is not an
// index expression. We only need to examine the
// single value to which it points.
// TODO: what if true is shadowed?
return append(args, ast.NewIdent("true"))
}
if !p.hasSideEffects(f, index.X) {
// Examine the entire slice.
return append(args, index.X)
}
// Treat the pointer as unknown.
return args
return false
}
if p.hasSideEffects(f, index.X) {
return false
}
fmt.Fprintf(sb, "_cgoCheckPointer(_cgo%d, %s); ", i, gofmtLine(index.X))
return true
}
// checkAddr checks whether arg has the form &x, possibly inside type
// conversions. If so it writes _cgoCheckPointer(_cgoNN, true) to sb
// and returns true. This tells _cgoCheckPointer to check just the
// contents of the pointer being passed, not any other part of the
// memory allocation. This is run after checkIndex, which looks for
// the special case of &a[i], which requires different checks.
func (p *Package) checkAddr(sb *bytes.Buffer, arg ast.Expr, i int) bool {
// Strip type conversions.
px := &arg
for {
c, ok := (*px).(*ast.CallExpr)
if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
break
}
px = &c.Args[0]
}
if u, ok := (*px).(*ast.UnaryExpr); !ok || u.Op != token.AND {
return false
}
// Use "0 == 0" to do the right thing in the unlikely event
// that "true" is shadowed.
fmt.Fprintf(sb, "_cgoCheckPointer(_cgo%d, 0 == 0); ", i)
return true
}
// hasSideEffects returns whether the expression x has any side
......
......@@ -126,3 +126,9 @@ func gofmt(n interface{}) string {
}
return gofmtBuf.String()
}
// gofmtLine returns the gofmt-formatted string for an AST node,
// ensuring that it is on a single line.
func gofmtLine(n interface{}) string {
return strings.Replace(gofmt(n), "\n", ";", -1)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment