Commit 48e207d5 authored by Keith Randall's avatar Keith Randall Committed by Keith Randall

cmd/compile: fix mapassign_fast* routines for pointer keys

The signature of the mapassign_fast* routines need to distinguish
the pointerness of their key argument.  If the affected routines
suspend part way through, the object pointed to by the key might
get garbage collected because the key is typed as a uint{32,64}.

This is not a problem for mapaccess or mapdelete because the key
in those situations do not live beyond the call involved.  If the
object referenced by the key is garbage collected prematurely, the
code still works fine.  Even if that object is subsequently reallocated,
it can't be written to the map in time to affect the lookup/delete.

Fixes #22781

Change-Id: I0bbbc5e9883d5ce702faf4e655348be1191ee439
Reviewed-on: https://go-review.googlesource.com/79018
Run-TryBot: Keith Randall <khr@golang.org>
Reviewed-by: 's avatarAustin Clements <austin@google.com>
Reviewed-by: 's avatarMartin Möhrmann <moehrmann@google.com>
parent a2a1c173
...@@ -88,7 +88,9 @@ var runtimeDecls = [...]struct { ...@@ -88,7 +88,9 @@ var runtimeDecls = [...]struct {
{"mapaccess2_fat", funcTag, 70}, {"mapaccess2_fat", funcTag, 70},
{"mapassign", funcTag, 65}, {"mapassign", funcTag, 65},
{"mapassign_fast32", funcTag, 66}, {"mapassign_fast32", funcTag, 66},
{"mapassign_fast32ptr", funcTag, 66},
{"mapassign_fast64", funcTag, 66}, {"mapassign_fast64", funcTag, 66},
{"mapassign_fast64ptr", funcTag, 66},
{"mapassign_faststr", funcTag, 66}, {"mapassign_faststr", funcTag, 66},
{"mapiterinit", funcTag, 71}, {"mapiterinit", funcTag, 71},
{"mapdelete", funcTag, 71}, {"mapdelete", funcTag, 71},
......
...@@ -109,7 +109,9 @@ func mapaccess2_faststr(mapType *byte, hmap map[any]any, key any) (val *any, pre ...@@ -109,7 +109,9 @@ func mapaccess2_faststr(mapType *byte, hmap map[any]any, key any) (val *any, pre
func mapaccess2_fat(mapType *byte, hmap map[any]any, key *any, zero *byte) (val *any, pres bool) func mapaccess2_fat(mapType *byte, hmap map[any]any, key *any, zero *byte) (val *any, pres bool)
func mapassign(mapType *byte, hmap map[any]any, key *any) (val *any) func mapassign(mapType *byte, hmap map[any]any, key *any) (val *any)
func mapassign_fast32(mapType *byte, hmap map[any]any, key any) (val *any) func mapassign_fast32(mapType *byte, hmap map[any]any, key any) (val *any)
func mapassign_fast32ptr(mapType *byte, hmap map[any]any, key any) (val *any)
func mapassign_fast64(mapType *byte, hmap map[any]any, key any) (val *any) func mapassign_fast64(mapType *byte, hmap map[any]any, key any) (val *any)
func mapassign_fast64ptr(mapType *byte, hmap map[any]any, key any) (val *any)
func mapassign_faststr(mapType *byte, hmap map[any]any, key any) (val *any) func mapassign_faststr(mapType *byte, hmap map[any]any, key any) (val *any)
func mapiterinit(mapType *byte, hmap map[any]any, hiter *any) func mapiterinit(mapType *byte, hmap map[any]any, hiter *any)
func mapdelete(mapType *byte, hmap map[any]any, key *any) func mapdelete(mapType *byte, hmap map[any]any, key *any)
......
...@@ -2826,21 +2826,23 @@ func mapfndel(name string, t *types.Type) *Node { ...@@ -2826,21 +2826,23 @@ func mapfndel(name string, t *types.Type) *Node {
const ( const (
mapslow = iota mapslow = iota
mapfast32 mapfast32
mapfast32ptr
mapfast64 mapfast64
mapfast64ptr
mapfaststr mapfaststr
nmapfast nmapfast
) )
type mapnames [nmapfast]string type mapnames [nmapfast]string
func mkmapnames(base string) mapnames { func mkmapnames(base string, ptr string) mapnames {
return mapnames{base, base + "_fast32", base + "_fast64", base + "_faststr"} return mapnames{base, base + "_fast32", base + "_fast32" + ptr, base + "_fast64", base + "_fast64" + ptr, base + "_faststr"}
} }
var mapaccess1 = mkmapnames("mapaccess1") var mapaccess1 = mkmapnames("mapaccess1", "")
var mapaccess2 = mkmapnames("mapaccess2") var mapaccess2 = mkmapnames("mapaccess2", "")
var mapassign = mkmapnames("mapassign") var mapassign = mkmapnames("mapassign", "ptr")
var mapdelete = mkmapnames("mapdelete") var mapdelete = mkmapnames("mapdelete", "")
func mapfast(t *types.Type) int { func mapfast(t *types.Type) int {
// Check ../../runtime/hashmap.go:maxValueSize before changing. // Check ../../runtime/hashmap.go:maxValueSize before changing.
...@@ -2849,9 +2851,22 @@ func mapfast(t *types.Type) int { ...@@ -2849,9 +2851,22 @@ func mapfast(t *types.Type) int {
} }
switch algtype(t.Key()) { switch algtype(t.Key()) {
case AMEM32: case AMEM32:
return mapfast32 if !t.Key().HasHeapPointer() {
return mapfast32
}
if Widthptr == 4 {
return mapfast32ptr
}
Fatalf("small pointer %v", t.Key())
case AMEM64: case AMEM64:
return mapfast64 if !t.Key().HasHeapPointer() {
return mapfast64
}
if Widthptr == 8 {
return mapfast64ptr
}
// Two-word object, at least one of which is a pointer.
// Use the slow path.
case ASTRING: case ASTRING:
return mapfaststr return mapfaststr
} }
......
...@@ -420,11 +420,93 @@ again: ...@@ -420,11 +420,93 @@ again:
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*4) insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*4)
// store new key at insert position // store new key at insert position
if sys.PtrSize == 4 && t.key.kind&kindNoPointers == 0 && writeBarrier.enabled { *(*uint32)(insertk) = key
writebarrierptr((*uintptr)(insertk), uintptr(key))
} else { h.count++
*(*uint32)(insertk) = key
done:
val := add(unsafe.Pointer(insertb), dataOffset+bucketCnt*4+inserti*uintptr(t.valuesize))
if h.flags&hashWriting == 0 {
throw("concurrent map writes")
}
h.flags &^= hashWriting
return val
}
func mapassign_fast32ptr(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
if h == nil {
panic(plainError("assignment to entry in nil map"))
}
if raceenabled {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, funcPC(mapassign_fast32))
}
if h.flags&hashWriting != 0 {
throw("concurrent map writes")
} }
hash := t.key.alg.hash(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
// Set hashWriting after calling alg.hash for consistency with mapassign.
h.flags |= hashWriting
if h.buckets == nil {
h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
}
again:
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_fast32(t, h, bucket)
}
b := (*bmap)(unsafe.Pointer(uintptr(h.buckets) + bucket*uintptr(t.bucketsize)))
var insertb *bmap
var inserti uintptr
var insertk unsafe.Pointer
for {
for i := uintptr(0); i < bucketCnt; i++ {
if b.tophash[i] == empty {
if insertb == nil {
inserti = i
insertb = b
}
continue
}
k := *((*unsafe.Pointer)(add(unsafe.Pointer(b), dataOffset+i*4)))
if k != key {
continue
}
inserti = i
insertb = b
goto done
}
ovf := b.overflow(t)
if ovf == nil {
break
}
b = ovf
}
// Did not find mapping for key. Allocate new cell & add entry.
// If we hit the max load factor or we have too many overflow buckets,
// and we're not already in the middle of growing, start growing.
if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
hashGrow(t, h)
goto again // Growing the table invalidates everything, so try again
}
if insertb == nil {
// all current buckets are full, allocate a new one.
insertb = h.newoverflow(t, b)
inserti = 0 // not necessary, but avoids needlessly spilling inserti
}
insertb.tophash[inserti&(bucketCnt-1)] = tophash(hash) // mask inserti to avoid bounds checks
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*4)
// store new key at insert position
*(*unsafe.Pointer)(insertk) = key
h.count++ h.count++
...@@ -510,18 +592,94 @@ again: ...@@ -510,18 +592,94 @@ again:
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*8) insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*8)
// store new key at insert position // store new key at insert position
if t.key.kind&kindNoPointers == 0 && writeBarrier.enabled { *(*uint64)(insertk) = key
if sys.PtrSize == 8 {
writebarrierptr((*uintptr)(insertk), uintptr(key)) h.count++
} else {
// There are three ways to squeeze at least one 32 bit pointer into 64 bits. done:
// Give up and call typedmemmove. val := add(unsafe.Pointer(insertb), dataOffset+bucketCnt*8+inserti*uintptr(t.valuesize))
typedmemmove(t.key, insertk, unsafe.Pointer(&key)) if h.flags&hashWriting == 0 {
throw("concurrent map writes")
}
h.flags &^= hashWriting
return val
}
func mapassign_fast64ptr(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
if h == nil {
panic(plainError("assignment to entry in nil map"))
}
if raceenabled {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, funcPC(mapassign_fast64))
}
if h.flags&hashWriting != 0 {
throw("concurrent map writes")
}
hash := t.key.alg.hash(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
// Set hashWriting after calling alg.hash for consistency with mapassign.
h.flags |= hashWriting
if h.buckets == nil {
h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
}
again:
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_fast64(t, h, bucket)
}
b := (*bmap)(unsafe.Pointer(uintptr(h.buckets) + bucket*uintptr(t.bucketsize)))
var insertb *bmap
var inserti uintptr
var insertk unsafe.Pointer
for {
for i := uintptr(0); i < bucketCnt; i++ {
if b.tophash[i] == empty {
if insertb == nil {
insertb = b
inserti = i
}
continue
}
k := *((*unsafe.Pointer)(add(unsafe.Pointer(b), dataOffset+i*8)))
if k != key {
continue
}
insertb = b
inserti = i
goto done
} }
} else { ovf := b.overflow(t)
*(*uint64)(insertk) = key if ovf == nil {
break
}
b = ovf
}
// Did not find mapping for key. Allocate new cell & add entry.
// If we hit the max load factor or we have too many overflow buckets,
// and we're not already in the middle of growing, start growing.
if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
hashGrow(t, h)
goto again // Growing the table invalidates everything, so try again
} }
if insertb == nil {
// all current buckets are full, allocate a new one.
insertb = h.newoverflow(t, b)
inserti = 0 // not necessary, but avoids needlessly spilling inserti
}
insertb.tophash[inserti&(bucketCnt-1)] = tophash(hash) // mask inserti to avoid bounds checks
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*8)
// store new key at insert position
*(*unsafe.Pointer)(insertk) = key
h.count++ h.count++
done: done:
......
// run
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "runtime/debug"
type T struct {
// >= 16 bytes to avoid tiny alloc.
a, b int
}
func main() {
debug.SetGCPercent(1)
for i := 0; i < 100000; i++ {
m := make(map[*T]struct{}, 0)
for j := 0; j < 20; j++ {
// During the call to mapassign_fast64, the key argument
// was incorrectly treated as a uint64. If the stack was
// scanned during that call, the only pointer to k was
// missed, leading to *k being collected prematurely.
k := new(T)
m[k] = struct{}{}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment