Commit d91ec5bb authored by Ian Lance Taylor's avatar Ian Lance Taylor

cmd/cgo, runtime: recognize unsafe.Pointer(&s[0]) in cgo pointer checks

It's fairly common to call cgo functions with conversions to
unsafe.Pointer or other C types.  Apply the simpler checking of address
expressions when possible when the address expression occurs within a
type conversion.

Change-Id: I5187d4eb4d27a6542621c396cad9ee4b8647d1cd
Reviewed-on: https://go-review.googlesource.com/18391
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarRuss Cox <rsc@golang.org>
parent e84dad3f
...@@ -106,6 +106,16 @@ var ptrTests = []ptrTest{ ...@@ -106,6 +106,16 @@ var ptrTests = []ptrTest{
body: `i := 0; p := &S{p:&i, s:[]unsafe.Pointer{nil}}; C.f(&p.s[0])`, body: `i := 0; p := &S{p:&i, s:[]unsafe.Pointer{nil}}; C.f(&p.s[0])`,
fail: false, fail: false,
}, },
{
// Passing the address of a slice of an array that is
// an element in a struct, with a type conversion.
name: "slice-ok-3",
c: `void f(void* p) {}`,
imports: []string{"unsafe"},
support: `type S struct { p *int; a [4]byte }`,
body: `i := 0; p := &S{p:&i}; s := p.a[:]; C.f(unsafe.Pointer(&s[0]))`,
fail: false,
},
{ {
// Passing the address of a static variable with no // Passing the address of a static variable with no
// pointers doesn't matter. // pointers doesn't matter.
......
...@@ -626,9 +626,7 @@ func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) { ...@@ -626,9 +626,7 @@ func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) {
// Add optional additional arguments for an address // Add optional additional arguments for an address
// expression. // expression.
if u, ok := call.Args[i].(*ast.UnaryExpr); ok && u.Op == token.AND { c.Args = p.checkAddrArgs(f, c.Args, call.Args[i])
c.Args = p.checkAddrArgs(f, c.Args, u.X)
}
// _cgoCheckPointer returns interface{}. // _cgoCheckPointer returns interface{}.
// We need to type assert that to the type we want. // We need to type assert that to the type we want.
...@@ -773,7 +771,19 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool { ...@@ -773,7 +771,19 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool {
// only pass the slice or array if we can refer to it without side // only pass the slice or array if we can refer to it without side
// effects. // effects.
func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr { func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr {
index, ok := x.(*ast.IndexExpr) // Strip type conversions.
for {
c, ok := x.(*ast.CallExpr)
if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
break
}
x = c.Args[0]
}
u, ok := x.(*ast.UnaryExpr)
if !ok || u.Op != token.AND {
return args
}
index, ok := u.X.(*ast.IndexExpr)
if !ok { if !ok {
// This is the address of something that is not an // This is the address of something that is not an
// index expression. We only need to examine the // index expression. We only need to examine the
...@@ -804,6 +814,42 @@ func (p *Package) hasSideEffects(f *File, x ast.Expr) bool { ...@@ -804,6 +814,42 @@ func (p *Package) hasSideEffects(f *File, x ast.Expr) bool {
return found return found
} }
// isType returns whether the expression is definitely a type.
// This is conservative--it returns false for an unknown identifier.
func (p *Package) isType(t ast.Expr) bool {
switch t := t.(type) {
case *ast.SelectorExpr:
if t.Sel.Name != "Pointer" {
return false
}
id, ok := t.X.(*ast.Ident)
if !ok {
return false
}
return id.Name == "unsafe"
case *ast.Ident:
// TODO: This ignores shadowing.
switch t.Name {
case "unsafe.Pointer", "bool", "byte",
"complex64", "complex128",
"error",
"float32", "float64",
"int", "int8", "int16", "int32", "int64",
"rune", "string",
"uint", "uint8", "uint16", "uint32", "uint64", "uintptr":
return true
}
case *ast.StarExpr:
return p.isType(t.X)
case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType,
*ast.MapType, *ast.ChanType:
return true
}
return false
}
// unsafeCheckPointerName is given the Go version of a C type. If the // unsafeCheckPointerName is given the Go version of a C type. If the
// type uses unsafe.Pointer, we arrange to build a version of // type uses unsafe.Pointer, we arrange to build a version of
// _cgoCheckPointer that returns that type. This avoids using a type // _cgoCheckPointer that returns that type. This avoids using a type
...@@ -832,6 +878,8 @@ func (p *Package) unsafeCheckPointerName(t ast.Expr) string { ...@@ -832,6 +878,8 @@ func (p *Package) unsafeCheckPointerName(t ast.Expr) string {
func (p *Package) hasUnsafePointer(t ast.Expr) bool { func (p *Package) hasUnsafePointer(t ast.Expr) bool {
switch t := t.(type) { switch t := t.(type) {
case *ast.Ident: case *ast.Ident:
// We don't see a SelectorExpr for unsafe.Pointer;
// this is created by code in this file.
return t.Name == "unsafe.Pointer" return t.Name == "unsafe.Pointer"
case *ast.ArrayType: case *ast.ArrayType:
return p.hasUnsafePointer(t.Elt) return p.hasUnsafePointer(t.Elt)
......
...@@ -354,7 +354,7 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} { ...@@ -354,7 +354,7 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
t := ep._type t := ep._type
top := true top := true
if len(args) > 0 && t.kind&kindMask == kindPtr { if len(args) > 0 && (t.kind&kindMask == kindPtr || t.kind&kindMask == kindUnsafePointer) {
p := ep.data p := ep.data
if t.kind&kindDirectIface == 0 { if t.kind&kindDirectIface == 0 {
p = *(*unsafe.Pointer)(p) p = *(*unsafe.Pointer)(p)
...@@ -365,6 +365,10 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} { ...@@ -365,6 +365,10 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
aep := (*eface)(unsafe.Pointer(&args[0])) aep := (*eface)(unsafe.Pointer(&args[0]))
switch aep._type.kind & kindMask { switch aep._type.kind & kindMask {
case kindBool: case kindBool:
if t.kind&kindMask == kindUnsafePointer {
// We don't know the type of the element.
break
}
pt := (*ptrtype)(unsafe.Pointer(t)) pt := (*ptrtype)(unsafe.Pointer(t))
cgoCheckArg(pt.elem, p, true, false, cgoCheckPointerFail) cgoCheckArg(pt.elem, p, true, false, cgoCheckPointerFail)
return ptr return ptr
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment