Commit 99be9cc0 authored by Tim Cooper's avatar Tim Cooper Committed by Ian Lance Taylor

flag: add (*FlagSet).Name, (*FlagSet).ErrorHandling, export (*FlagSet).Output

Allows code that operates on a FlagSet to know the name and error
handling behavior of the FlagSet without having to call FlagSet.Init.

Fixes #17628
Fixes #21888

Change-Id: Ib0fe4c8885f9ccdacf5a7fb761d5ecb23f3bb055
Reviewed-on: https://go-review.googlesource.com/70391
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: 's avatarIan Lance Taylor <iant@golang.org>
parent 26e49e69
...@@ -308,13 +308,25 @@ func sortFlags(flags map[string]*Flag) []*Flag { ...@@ -308,13 +308,25 @@ func sortFlags(flags map[string]*Flag) []*Flag {
return result return result
} }
func (f *FlagSet) out() io.Writer { // Output returns the destination for usage and error messages. os.Stderr is returned if
// output was not set or was set to nil.
func (f *FlagSet) Output() io.Writer {
if f.output == nil { if f.output == nil {
return os.Stderr return os.Stderr
} }
return f.output return f.output
} }
// Name returns the name of the flag set.
func (f *FlagSet) Name() string {
return f.name
}
// ErrorHandling returns the error handling behavior of the flag set.
func (f *FlagSet) ErrorHandling() ErrorHandling {
return f.errorHandling
}
// SetOutput sets the destination for usage and error messages. // SetOutput sets the destination for usage and error messages.
// If output is nil, os.Stderr is used. // If output is nil, os.Stderr is used.
func (f *FlagSet) SetOutput(output io.Writer) { func (f *FlagSet) SetOutput(output io.Writer) {
...@@ -474,7 +486,7 @@ func (f *FlagSet) PrintDefaults() { ...@@ -474,7 +486,7 @@ func (f *FlagSet) PrintDefaults() {
s += fmt.Sprintf(" (default %v)", flag.DefValue) s += fmt.Sprintf(" (default %v)", flag.DefValue)
} }
} }
fmt.Fprint(f.out(), s, "\n") fmt.Fprint(f.Output(), s, "\n")
}) })
} }
...@@ -504,9 +516,9 @@ func PrintDefaults() { ...@@ -504,9 +516,9 @@ func PrintDefaults() {
// defaultUsage is the default function to print a usage message. // defaultUsage is the default function to print a usage message.
func (f *FlagSet) defaultUsage() { func (f *FlagSet) defaultUsage() {
if f.name == "" { if f.name == "" {
fmt.Fprintf(f.out(), "Usage:\n") fmt.Fprintf(f.Output(), "Usage:\n")
} else { } else {
fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) fmt.Fprintf(f.Output(), "Usage of %s:\n", f.name)
} }
f.PrintDefaults() f.PrintDefaults()
} }
...@@ -525,7 +537,7 @@ func (f *FlagSet) defaultUsage() { ...@@ -525,7 +537,7 @@ func (f *FlagSet) defaultUsage() {
// happens anyway as the command line's error handling strategy is set to // happens anyway as the command line's error handling strategy is set to
// ExitOnError. // ExitOnError.
var Usage = func() { var Usage = func() {
fmt.Fprintf(CommandLine.out(), "Usage of %s:\n", os.Args[0]) fmt.Fprintf(CommandLine.Output(), "Usage of %s:\n", os.Args[0])
PrintDefaults() PrintDefaults()
} }
...@@ -793,7 +805,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) { ...@@ -793,7 +805,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) {
} else { } else {
msg = fmt.Sprintf("%s flag redefined: %s", f.name, name) msg = fmt.Sprintf("%s flag redefined: %s", f.name, name)
} }
fmt.Fprintln(f.out(), msg) fmt.Fprintln(f.Output(), msg)
panic(msg) // Happens only if flags are declared with identical names panic(msg) // Happens only if flags are declared with identical names
} }
if f.formal == nil { if f.formal == nil {
...@@ -816,7 +828,7 @@ func Var(value Value, name string, usage string) { ...@@ -816,7 +828,7 @@ func Var(value Value, name string, usage string) {
// returns the error. // returns the error.
func (f *FlagSet) failf(format string, a ...interface{}) error { func (f *FlagSet) failf(format string, a ...interface{}) error {
err := fmt.Errorf(format, a...) err := fmt.Errorf(format, a...)
fmt.Fprintln(f.out(), err) fmt.Fprintln(f.Output(), err)
f.usage() f.usage()
return err return err
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"bytes" "bytes"
. "flag" . "flag"
"fmt" "fmt"
"io"
"os" "os"
"sort" "sort"
"strconv" "strconv"
...@@ -454,3 +455,36 @@ func TestUsageOutput(t *testing.T) { ...@@ -454,3 +455,36 @@ func TestUsageOutput(t *testing.T) {
t.Errorf("output = %q; want %q", got, want) t.Errorf("output = %q; want %q", got, want)
} }
} }
func TestGetters(t *testing.T) {
expectedName := "flag set"
expectedErrorHandling := ContinueOnError
expectedOutput := io.Writer(os.Stderr)
fs := NewFlagSet(expectedName, expectedErrorHandling)
if fs.Name() != expectedName {
t.Errorf("unexpected name: got %s, expected %s", fs.Name(), expectedName)
}
if fs.ErrorHandling() != expectedErrorHandling {
t.Errorf("unexpected ErrorHandling: got %d, expected %d", fs.ErrorHandling(), expectedErrorHandling)
}
if fs.Output() != expectedOutput {
t.Errorf("unexpected output: got %#v, expected %#v", fs.Output(), expectedOutput)
}
expectedName = "gopher"
expectedErrorHandling = ExitOnError
expectedOutput = os.Stdout
fs.Init(expectedName, expectedErrorHandling)
fs.SetOutput(expectedOutput)
if fs.Name() != expectedName {
t.Errorf("unexpected name: got %s, expected %s", fs.Name(), expectedName)
}
if fs.ErrorHandling() != expectedErrorHandling {
t.Errorf("unexpected ErrorHandling: got %d, expected %d", fs.ErrorHandling(), expectedErrorHandling)
}
if fs.Output() != expectedOutput {
t.Errorf("unexpected output: got %v, expected %v", fs.Output(), expectedOutput)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment