Commit 42f07ff2 authored by Ross Light's avatar Ross Light Committed by Brad Fitzpatrick

os/user: add LookupGroup, LookupGroupId, and User.GroupIds functions

As part of local testing with a large group member list, I discovered
that the lookup functions don't resize their buffer if they receive
ERANGE.  I fixed this as a side-effect of this CL.

Thanks to @andrenth for the original CL.

Fixes #2617

Change-Id: Ie6aae2fe0a89eae5cce85786869a8acaa665ffe9
Reviewed-on: https://go-review.googlesource.com/19235Reviewed-by: 's avatarIan Lance Taylor <iant@golang.org>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent ff555f11
......@@ -12,11 +12,28 @@ func Current() (*User, error) {
// Lookup looks up a user by username. If the user cannot be found, the
// returned error is of type UnknownUserError.
func Lookup(username string) (*User, error) {
return lookup(username)
return lookupUser(username)
}
// LookupId looks up a user by userid. If the user cannot be found, the
// returned error is of type UnknownUserIdError.
func LookupId(uid string) (*User, error) {
return lookupId(uid)
return lookupUserId(uid)
}
// LookupGroup looks up a group by name. If the group cannot be found, the
// returned error is of type UnknownGroupError.
func LookupGroup(name string) (*Group, error) {
return lookupGroup(name)
}
// LookupGroupId looks up a group by groupid. If the group cannot be found, the
// returned error is of type UnknownGroupIdError.
func LookupGroupId(gid string) (*Group, error) {
return lookupGroupId(gid)
}
// GroupIds returns the list of group IDs that the user is a member of.
func (u *User) GroupIds() ([]string, error) {
return listGroups(u)
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build android
package user
import "errors"
func init() {
userImplemented = false
groupImplemented = false
}
func current() (*User, error) {
return nil, errors.New("user: Current not implemented on android")
}
func lookupUser(string) (*User, error) {
return nil, errors.New("user: Lookup not implemented on android")
}
func lookupUserId(string) (*User, error) {
return nil, errors.New("user: LookupId not implemented on android")
}
func lookupGroup(string) (*Group, error) {
return nil, errors.New("user: LookupGroup not implemented on android")
}
func lookupGroupId(string) (*Group, error) {
return nil, errors.New("user: LookupGroupId not implemented on android")
}
func listGroups(*User) ([]string, error) {
return nil, errors.New("user: GroupIds not implemented on android")
}
......@@ -18,6 +18,10 @@ const (
userFile = "/dev/user"
)
func init() {
groupImplemented = false
}
func current() (*User, error) {
ubytes, err := ioutil.ReadFile(userFile)
if err != nil {
......@@ -37,10 +41,22 @@ func current() (*User, error) {
return u, nil
}
func lookup(username string) (*User, error) {
func lookupUser(username string) (*User, error) {
return nil, syscall.EPLAN9
}
func lookupUserId(uid string) (*User, error) {
return nil, syscall.EPLAN9
}
func lookupGroup(groupname string) (*Group, error) {
return nil, syscall.EPLAN9
}
func lookupGroupId(string) (*Group, error) {
return nil, syscall.EPLAN9
}
func lookupId(uid string) (*User, error) {
func listGroups(*User) ([]string, error) {
return nil, syscall.EPLAN9
}
......@@ -2,27 +2,37 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !cgo,!windows,!plan9 android
// +build !cgo,!windows,!plan9,!android
package user
import (
"fmt"
"runtime"
)
import "errors"
func init() {
implemented = false
userImplemented = false
groupImplemented = false
}
func current() (*User, error) {
return nil, fmt.Errorf("user: Current not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
return nil, errors.New("user: Current requires cgo")
}
func lookup(username string) (*User, error) {
return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
func lookupUser(username string) (*User, error) {
return nil, errors.New("user: Lookup requires cgo")
}
func lookupId(uid string) (*User, error) {
return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
func lookupUserId(uid string) (*User, error) {
return nil, errors.New("user: LookupId requires cgo")
}
func lookupGroup(groupname string) (*Group, error) {
return nil, errors.New("user: LookupGroup requires cgo")
}
func lookupGroupId(string) (*Group, error) {
return nil, errors.New("user: LookupGroupId requires cgo")
}
func listGroups(*User) ([]string, error) {
return nil, errors.New("user: GroupIds requires cgo")
}
......@@ -20,6 +20,7 @@ import (
#include <unistd.h>
#include <sys/types.h>
#include <pwd.h>
#include <grp.h>
#include <stdlib.h>
static int mygetpwuid_r(int uid, struct passwd *pwd,
......@@ -31,76 +32,119 @@ static int mygetpwnam_r(const char *name, struct passwd *pwd,
char *buf, size_t buflen, struct passwd **result) {
return getpwnam_r(name, pwd, buf, buflen, result);
}
static int mygetgrgid_r(int gid, struct group *grp,
char *buf, size_t buflen, struct group **result) {
return getgrgid_r(gid, grp, buf, buflen, result);
}
static int mygetgrouplist(const char *user, gid_t group, gid_t *groups,
int *ngroups) {
return getgrouplist(user, group, groups, ngroups);
}
*/
import "C"
func current() (*User, error) {
return lookupUnix(syscall.Getuid(), "", false)
return lookupUnixUid(syscall.Getuid())
}
func lookup(username string) (*User, error) {
return lookupUnix(-1, username, true)
func lookupUser(username string) (*User, error) {
var pwd C.struct_passwd
var result *C.struct_passwd
nameC := C.CString(username)
defer C.free(unsafe.Pointer(nameC))
buf := alloc(userBuffer)
defer buf.free()
err := retryWithBuffer(buf, func() syscall.Errno {
// mygetpwnam_r is a wrapper around getpwnam_r to avoid
// passing a size_t to getpwnam_r, because for unknown
// reasons passing a size_t to getpwnam_r doesn't work on
// Solaris.
return syscall.Errno(C.mygetpwnam_r(nameC,
&pwd,
(*C.char)(buf.ptr),
C.size_t(buf.size),
&result))
})
if err != nil {
return nil, fmt.Errorf("user: lookup username %s: %v", username, err)
}
if result == nil {
return nil, UnknownUserError(username)
}
return buildUser(&pwd), err
}
func lookupId(uid string) (*User, error) {
func lookupUserId(uid string) (*User, error) {
i, e := strconv.Atoi(uid)
if e != nil {
return nil, e
}
return lookupUnix(i, "", false)
return lookupUnixUid(i)
}
func lookupUnix(uid int, username string, lookupByName bool) (*User, error) {
func lookupUnixUid(uid int) (*User, error) {
var pwd C.struct_passwd
var result *C.struct_passwd
bufSize := C.sysconf(C._SC_GETPW_R_SIZE_MAX)
if bufSize == -1 {
// DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX.
// Additionally, not all Linux systems have it, either. For
// example, the musl libc returns -1.
bufSize = 1024
}
if bufSize <= 0 || bufSize > 1<<20 {
return nil, fmt.Errorf("user: unreasonable _SC_GETPW_R_SIZE_MAX of %d", bufSize)
}
buf := C.malloc(C.size_t(bufSize))
defer C.free(buf)
var rv C.int
if lookupByName {
nameC := C.CString(username)
defer C.free(unsafe.Pointer(nameC))
// mygetpwnam_r is a wrapper around getpwnam_r to avoid
// passing a size_t to getpwnam_r, because for unknown
// reasons passing a size_t to getpwnam_r doesn't work on
// Solaris.
rv = C.mygetpwnam_r(nameC,
&pwd,
(*C.char)(buf),
C.size_t(bufSize),
&result)
if rv != 0 {
return nil, fmt.Errorf("user: lookup username %s: %s", username, syscall.Errno(rv))
}
if result == nil {
return nil, UnknownUserError(username)
}
} else {
buf := alloc(userBuffer)
defer buf.free()
err := retryWithBuffer(buf, func() syscall.Errno {
// mygetpwuid_r is a wrapper around getpwuid_r to
// to avoid using uid_t because C.uid_t(uid) for
// unknown reasons doesn't work on linux.
rv = C.mygetpwuid_r(C.int(uid),
return syscall.Errno(C.mygetpwuid_r(C.int(uid),
&pwd,
(*C.char)(buf),
C.size_t(bufSize),
&result)
if rv != 0 {
return nil, fmt.Errorf("user: lookup userid %d: %s", uid, syscall.Errno(rv))
(*C.char)(buf.ptr),
C.size_t(buf.size),
&result))
})
if err != nil {
return nil, fmt.Errorf("user: lookup userid %d: %v", uid, err)
}
if result == nil {
return nil, UnknownUserIdError(uid)
}
return buildUser(&pwd), nil
}
func listGroups(u *User) ([]string, error) {
ug, err := strconv.Atoi(u.Gid)
if err != nil {
return nil, fmt.Errorf("user: list groups for %s: invalid gid %q", u.Username, u.Gid)
}
userGID := C.gid_t(ug)
nameC := C.CString(u.Username)
defer C.free(unsafe.Pointer(nameC))
n := C.int(256)
gidsC := make([]C.gid_t, n)
rv := C.mygetgrouplist(nameC, userGID, &gidsC[0], &n)
if rv == -1 {
// More than initial buffer, but now n contains the correct size.
const maxGroups = 2048
if n > maxGroups {
return nil, fmt.Errorf("user: list groups for %s: member of more than %d groups", u.Username, maxGroups)
}
if result == nil {
return nil, UnknownUserIdError(uid)
gidsC = make([]C.gid_t, n)
rv := C.mygetgrouplist(nameC, userGID, &gidsC[0], &n)
if rv == -1 {
return nil, fmt.Errorf("user: list groups for %s failed (changed groups?)", u.Username)
}
}
gidsC = gidsC[:n]
gids := make([]string, 0, n)
for _, g := range gidsC[:n] {
gids = append(gids, strconv.Itoa(int(g)))
}
return gids, nil
}
func buildUser(pwd *C.struct_passwd) *User {
u := &User{
Uid: strconv.Itoa(int(pwd.pw_uid)),
Gid: strconv.Itoa(int(pwd.pw_gid)),
......@@ -115,5 +159,145 @@ func lookupUnix(uid int, username string, lookupByName bool) (*User, error) {
if i := strings.Index(u.Name, ","); i >= 0 {
u.Name = u.Name[:i]
}
return u, nil
return u
}
func currentGroup() (*Group, error) {
return lookupUnixGid(syscall.Getgid())
}
func lookupGroup(groupname string) (*Group, error) {
var grp C.struct_group
var result *C.struct_group
buf := alloc(groupBuffer)
defer buf.free()
cname := C.CString(groupname)
defer C.free(unsafe.Pointer(cname))
err := retryWithBuffer(buf, func() syscall.Errno {
return syscall.Errno(C.getgrnam_r(cname,
&grp,
(*C.char)(buf.ptr),
C.size_t(buf.size),
&result))
})
if err != nil {
return nil, fmt.Errorf("user: lookup groupname %s: %v", groupname, err)
}
if result == nil {
return nil, UnknownGroupError(groupname)
}
return buildGroup(&grp), nil
}
func lookupGroupId(gid string) (*Group, error) {
i, e := strconv.Atoi(gid)
if e != nil {
return nil, e
}
return lookupUnixGid(i)
}
func lookupUnixGid(gid int) (*Group, error) {
var grp C.struct_group
var result *C.struct_group
buf := alloc(groupBuffer)
defer buf.free()
err := retryWithBuffer(buf, func() syscall.Errno {
// mygetgrgid_r is a wrapper around getgrgid_r to
// to avoid using gid_t because C.gid_t(gid) for
// unknown reasons doesn't work on linux.
return syscall.Errno(C.mygetgrgid_r(C.int(gid),
&grp,
(*C.char)(buf.ptr),
C.size_t(buf.size),
&result))
})
if err != nil {
return nil, fmt.Errorf("user: lookup groupid %d: %v", gid, err)
}
if result == nil {
return nil, UnknownGroupIdError(gid)
}
return buildGroup(&grp), nil
}
func buildGroup(grp *C.struct_group) *Group {
g := &Group{
Gid: strconv.Itoa(int(grp.gr_gid)),
Name: C.GoString(grp.gr_name),
}
return g
}
type bufferKind C.int
const (
userBuffer = bufferKind(C._SC_GETPW_R_SIZE_MAX)
groupBuffer = bufferKind(C._SC_GETGR_R_SIZE_MAX)
)
func (k bufferKind) initialSize() C.size_t {
sz := C.sysconf(C.int(k))
if sz == -1 {
// DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX.
// Additionally, not all Linux systems have it, either. For
// example, the musl libc returns -1.
return 1024
}
if !isSizeReasonable(int64(sz)) {
// Truncate. If this truly isn't enough, retryWithBuffer will error on the first run.
return maxBufferSize
}
return C.size_t(sz)
}
type memBuffer struct {
ptr unsafe.Pointer
size C.size_t
}
func alloc(kind bufferKind) *memBuffer {
sz := kind.initialSize()
return &memBuffer{
ptr: C.malloc(sz),
size: sz,
}
}
func (mb *memBuffer) resize(newSize C.size_t) {
mb.ptr = C.realloc(mb.ptr, newSize)
mb.size = newSize
}
func (mb *memBuffer) free() {
C.free(mb.ptr)
}
// retryWithBuffer repeatedly calls f(), increasing the size of the
// buffer each time, until f succeeds, fails with a non-ERANGE error,
// or the buffer exceeds a reasonable limit.
func retryWithBuffer(buf *memBuffer, f func() syscall.Errno) error {
for {
errno := f()
if errno == 0 {
return nil
} else if errno != syscall.ERANGE {
return errno
}
newSize := buf.size * 2
if !isSizeReasonable(int64(newSize)) {
return fmt.Errorf("internal buffer exceeds %d bytes", maxBufferSize)
}
buf.resize(newSize)
}
}
const maxBufferSize = 1 << 20
func isSizeReasonable(sz int64) bool {
return sz > 0 && sz <= maxBufferSize
}
......@@ -5,11 +5,16 @@
package user
import (
"errors"
"fmt"
"syscall"
"unsafe"
)
func init() {
groupImplemented = false
}
func isDomainJoined() (bool, error) {
var domain *uint16
var status uint32
......@@ -129,7 +134,7 @@ func newUserFromSid(usid *syscall.SID) (*User, error) {
return newUser(usid, gid, dir)
}
func lookup(username string) (*User, error) {
func lookupUser(username string) (*User, error) {
sid, _, t, e := syscall.LookupSID("", username)
if e != nil {
return nil, e
......@@ -140,10 +145,22 @@ func lookup(username string) (*User, error) {
return newUserFromSid(sid)
}
func lookupId(uid string) (*User, error) {
func lookupUserId(uid string) (*User, error) {
sid, e := syscall.StringToSid(uid)
if e != nil {
return nil, e
}
return newUserFromSid(sid)
}
func lookupGroup(groupname string) (*Group, error) {
return nil, errors.New("user: LookupGroup not implemented on windows")
}
func lookupGroupId(string) (*Group, error) {
return nil, errors.New("user: LookupGroupId not implemented on windows")
}
func listGroups(*User) ([]string, error) {
return nil, errors.New("user: GroupIds not implemented on windows")
}
......@@ -9,23 +9,35 @@ import (
"strconv"
)
var implemented = true // set to false by lookup_stubs.go's init
var (
userImplemented = true // set to false by lookup_stubs.go's init
groupImplemented = true // set to false by lookup_stubs.go's init
)
// User represents a user account.
//
// On posix systems Uid and Gid contain a decimal number
// On POSIX systems Uid and Gid contain a decimal number
// representing uid and gid. On windows Uid and Gid
// contain security identifier (SID) in a string format.
// On Plan 9, Uid, Gid, Username, and Name will be the
// contents of /dev/user.
type User struct {
Uid string // user id
Gid string // primary group id
Uid string // user ID
Gid string // primary group ID
Username string
Name string
HomeDir string
}
// Group represents a grouping of users.
//
// On POSIX systems Gid contains a decimal number
// representing the group ID.
type Group struct {
Gid string // group ID
Name string // group name
}
// UnknownUserIdError is returned by LookupId when
// a user cannot be found.
type UnknownUserIdError int
......@@ -41,3 +53,19 @@ type UnknownUserError string
func (e UnknownUserError) Error() string {
return "user: unknown user " + string(e)
}
// UnknownGroupIdError is returned by LookupGroupId when
// a group cannot be found.
type UnknownGroupIdError string
func (e UnknownGroupIdError) Error() string {
return "group: unknown groupid " + string(e)
}
// UnknownGroupError is returned by LookupGroup when
// a group cannot be found.
type UnknownGroupError string
func (e UnknownGroupError) Error() string {
return "group: unknown group " + string(e)
}
......@@ -9,14 +9,14 @@ import (
"testing"
)
func check(t *testing.T) {
if !implemented {
func checkUser(t *testing.T) {
if !userImplemented {
t.Skip("user: not implemented; skipping tests")
}
}
func TestCurrent(t *testing.T) {
check(t)
checkUser(t)
u, err := Current()
if err != nil {
......@@ -53,7 +53,7 @@ func compare(t *testing.T, want, got *User) {
}
func TestLookup(t *testing.T) {
check(t)
checkUser(t)
if runtime.GOOS == "plan9" {
t.Skipf("Lookup not implemented on %q", runtime.GOOS)
......@@ -71,7 +71,7 @@ func TestLookup(t *testing.T) {
}
func TestLookupId(t *testing.T) {
check(t)
checkUser(t)
if runtime.GOOS == "plan9" {
t.Skipf("LookupId not implemented on %q", runtime.GOOS)
......@@ -87,3 +87,57 @@ func TestLookupId(t *testing.T) {
}
compare(t, want, got)
}
func checkGroup(t *testing.T) {
if !groupImplemented {
t.Skip("user: group not implemented; skipping test")
}
}
func TestLookupGroup(t *testing.T) {
checkGroup(t)
user, err := Current()
if err != nil {
t.Fatalf("Current(): %v", err)
}
g1, err := LookupGroupId(user.Gid)
if err != nil {
t.Fatalf("LookupGroupId(%q): %v", user.Gid, err)
}
if g1.Gid != user.Gid {
t.Errorf("LookupGroupId(%q).Gid = %s; want %s", user.Gid, g1.Gid, user.Gid)
}
g2, err := LookupGroup(g1.Name)
if err != nil {
t.Fatalf("LookupGroup(%q): %v", g1.Name, err)
}
if g1.Gid != g2.Gid || g1.Name != g2.Name {
t.Errorf("LookupGroup(%q) = %+v; want %+v", g1.Name, g2, g1)
}
}
func TestGroupIds(t *testing.T) {
checkGroup(t)
user, err := Current()
if err != nil {
t.Fatalf("Current(): %v", err)
}
gids, err := user.GroupIds()
if err != nil {
t.Fatalf("%+v.GroupIds(): %v", user, err)
}
if !containsID(gids, user.Gid) {
t.Errorf("%+v.GroupIds() = %v; does not contain user GID %s", user, gids, user.Gid)
}
}
func containsID(ids []string, id string) bool {
for _, x := range ids {
if x == id {
return true
}
}
return false
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment