Commit 48e9c771 authored by Russ Cox's avatar Russ Cox

gofmt: accept program fragments on standard input

This makes it possible to grab a block of code
in an editor and pipe it through gofmt, instead of
having to pipe in the entire file.

R=gri
CC=golang-dev
https://golang.org/cl/4973074
parent 7944bbf2
...@@ -53,6 +53,12 @@ In the pattern, single-character lowercase identifiers serve as ...@@ -53,6 +53,12 @@ In the pattern, single-character lowercase identifiers serve as
wildcards matching arbitrary sub-expressions; those expressions wildcards matching arbitrary sub-expressions; those expressions
will be substituted for the same identifiers in the replacement. will be substituted for the same identifiers in the replacement.
When gofmt reads from standard input, it accepts either a full Go program
or a program fragment. A program fragment must be a syntactically
valid declaration list, statement list, or expression. When formatting
such a fragment, gofmt preserves leading indentation as well as leading
and trailing spaces, so that individual sections of a Go program can be
formatted by piping them through gofmt.
Examples Examples
......
...@@ -86,7 +86,7 @@ func isGoFile(f *os.FileInfo) bool { ...@@ -86,7 +86,7 @@ func isGoFile(f *os.FileInfo) bool {
} }
// If in == nil, the source is the contents of the file with the given filename. // If in == nil, the source is the contents of the file with the given filename.
func processFile(filename string, in io.Reader, out io.Writer) os.Error { func processFile(filename string, in io.Reader, out io.Writer, stdin bool) os.Error {
if in == nil { if in == nil {
f, err := os.Open(filename) f, err := os.Open(filename)
if err != nil { if err != nil {
...@@ -101,7 +101,7 @@ func processFile(filename string, in io.Reader, out io.Writer) os.Error { ...@@ -101,7 +101,7 @@ func processFile(filename string, in io.Reader, out io.Writer) os.Error {
return err return err
} }
file, err := parser.ParseFile(fset, filename, src, parserMode) file, adjust, err := parse(filename, src, stdin)
if err != nil { if err != nil {
return err return err
} }
...@@ -119,7 +119,7 @@ func processFile(filename string, in io.Reader, out io.Writer) os.Error { ...@@ -119,7 +119,7 @@ func processFile(filename string, in io.Reader, out io.Writer) os.Error {
if err != nil { if err != nil {
return err return err
} }
res := buf.Bytes() res := adjust(src, buf.Bytes())
if !bytes.Equal(src, res) { if !bytes.Equal(src, res) {
// formatting has changed // formatting has changed
...@@ -158,7 +158,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { ...@@ -158,7 +158,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { func (v fileVisitor) VisitFile(path string, f *os.FileInfo) {
if isGoFile(f) { if isGoFile(f) {
v <- nil // synchronize error handler v <- nil // synchronize error handler
if err := processFile(path, nil, os.Stdout); err != nil { if err := processFile(path, nil, os.Stdout, false); err != nil {
v <- err v <- err
} }
} }
...@@ -211,7 +211,7 @@ func gofmtMain() { ...@@ -211,7 +211,7 @@ func gofmtMain() {
initRewrite() initRewrite()
if flag.NArg() == 0 { if flag.NArg() == 0 {
if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil { if err := processFile("<standard input>", os.Stdin, os.Stdout, true); err != nil {
report(err) report(err)
} }
return return
...@@ -223,7 +223,7 @@ func gofmtMain() { ...@@ -223,7 +223,7 @@ func gofmtMain() {
case err != nil: case err != nil:
report(err) report(err)
case dir.IsRegular(): case dir.IsRegular():
if err := processFile(path, nil, os.Stdout); err != nil { if err := processFile(path, nil, os.Stdout, false); err != nil {
report(err) report(err)
} }
case dir.IsDirectory(): case dir.IsDirectory():
...@@ -259,3 +259,109 @@ func diff(b1, b2 []byte) (data []byte, err os.Error) { ...@@ -259,3 +259,109 @@ func diff(b1, b2 []byte) (data []byte, err os.Error) {
return return
} }
// parse parses src, which was read from filename,
// as a Go source file or statement list.
func parse(filename string, src []byte, stdin bool) (*ast.File, func(orig, src []byte) []byte, os.Error) {
// Try as whole source file.
file, err := parser.ParseFile(fset, filename, src, parserMode)
if err == nil {
adjust := func(orig, src []byte) []byte { return src }
return file, adjust, nil
}
// If the error is that the source file didn't begin with a
// package line and this is standard input, fall through to
// try as a source fragment. Stop and return on any other error.
if !stdin || !strings.Contains(err.String(), "expected 'package'") {
return nil, nil, err
}
// If this is a declaration list, make it a source file
// by inserting a package clause.
// Insert using a ;, not a newline, so that the line numbers
// in psrc match the ones in src.
psrc := append([]byte("package p;"), src...)
file, err = parser.ParseFile(fset, filename, psrc, parserMode)
if err == nil {
adjust := func(orig, src []byte) []byte {
// Remove the package clause.
// Gofmt has turned the ; into a \n.
src = src[len("package p\n"):]
return matchSpace(orig, src)
}
return file, adjust, nil
}
// If the error is that the source file didn't begin with a
// declaration, fall through to try as a statement list.
// Stop and return on any other error.
if !strings.Contains(err.String(), "expected declaration") {
return nil, nil, err
}
// If this is a statement list, make it a source file
// by inserting a package clause and turning the list
// into a function body. This handles expressions too.
// Insert using a ;, not a newline, so that the line numbers
// in fsrc match the ones in src.
fsrc := append(append([]byte("package p; func _() {"), src...), '}')
file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
if err == nil {
adjust := func(orig, src []byte) []byte {
// Remove the wrapping.
// Gofmt has turned the ; into a \n\n.
src = src[len("package p\n\nfunc _() {"):]
src = src[:len(src)-len("}\n")]
// Gofmt has also indented the function body one level.
// Remove that indent.
src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
return matchSpace(orig, src)
}
return file, adjust, nil
}
// Failed, and out of options.
return nil, nil, err
}
func cutSpace(b []byte) (before, middle, after []byte) {
i := 0
for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
i++
}
j := len(b)
for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
j--
}
return b[:i], b[i:j], b[j:]
}
// matchSpace reformats src to use the same space context as orig.
// 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src.
// 2) matchSpace copies the indentation of the first non-blank line in orig
// to every non-blank line in src.
// 3) matchSpace copies the trailing space from orig and uses it in place
// of src's trailing space.
func matchSpace(orig []byte, src []byte) []byte {
before, _, after := cutSpace(orig)
i := bytes.LastIndex(before, []byte{'\n'})
before, indent := before[:i+1], before[i+1:]
_, src, _ = cutSpace(src)
var b bytes.Buffer
b.Write(before)
for len(src) > 0 {
line := src
if i := bytes.IndexByte(line, '\n'); i >= 0 {
line, src = line[:i+1], line[i+1:]
} else {
src = nil
}
if len(line) > 0 && line[0] != '\n' { // not blank
b.Write(indent)
}
b.Write(line)
}
b.Write(after)
return b.Bytes()
}
...@@ -12,13 +12,11 @@ import ( ...@@ -12,13 +12,11 @@ import (
"testing" "testing"
) )
func runTest(t *testing.T, dirname, in, out, flags string) { func runTest(t *testing.T, in, out, flags string) {
in = filepath.Join(dirname, in)
out = filepath.Join(dirname, out)
// process flags // process flags
*simplifyAST = false *simplifyAST = false
*rewriteRule = "" *rewriteRule = ""
stdin := false
for _, flag := range strings.Split(flags, " ") { for _, flag := range strings.Split(flags, " ") {
elts := strings.SplitN(flag, "=", 2) elts := strings.SplitN(flag, "=", 2)
name := elts[0] name := elts[0]
...@@ -33,6 +31,9 @@ func runTest(t *testing.T, dirname, in, out, flags string) { ...@@ -33,6 +31,9 @@ func runTest(t *testing.T, dirname, in, out, flags string) {
*rewriteRule = value *rewriteRule = value
case "-s": case "-s":
*simplifyAST = true *simplifyAST = true
case "-stdin":
// fake flag - pretend input is from stdin
stdin = true
default: default:
t.Errorf("unrecognized flag name: %s", name) t.Errorf("unrecognized flag name: %s", name)
} }
...@@ -43,7 +44,7 @@ func runTest(t *testing.T, dirname, in, out, flags string) { ...@@ -43,7 +44,7 @@ func runTest(t *testing.T, dirname, in, out, flags string) {
initRewrite() initRewrite()
var buf bytes.Buffer var buf bytes.Buffer
err := processFile(in, nil, &buf) err := processFile(in, nil, &buf, stdin)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
...@@ -57,23 +58,43 @@ func runTest(t *testing.T, dirname, in, out, flags string) { ...@@ -57,23 +58,43 @@ func runTest(t *testing.T, dirname, in, out, flags string) {
if got := buf.Bytes(); bytes.Compare(got, expected) != 0 { if got := buf.Bytes(); bytes.Compare(got, expected) != 0 {
t.Errorf("(gofmt %s) != %s (see %s.gofmt)", in, out, in) t.Errorf("(gofmt %s) != %s (see %s.gofmt)", in, out, in)
d, err := diff(expected, got)
if err == nil {
t.Errorf("%s", d)
}
ioutil.WriteFile(in+".gofmt", got, 0666) ioutil.WriteFile(in+".gofmt", got, 0666)
} }
} }
// TODO(gri) Add more test cases! // TODO(gri) Add more test cases!
var tests = []struct { var tests = []struct {
dirname, in, out, flags string in, flags string
}{ }{
{".", "gofmt.go", "gofmt.go", ""}, {"gofmt.go", ""},
{".", "gofmt_test.go", "gofmt_test.go", ""}, {"gofmt_test.go", ""},
{"testdata", "composites.input", "composites.golden", "-s"}, {"testdata/composites.input", "-s"},
{"testdata", "rewrite1.input", "rewrite1.golden", "-r=Foo->Bar"}, {"testdata/rewrite1.input", "-r=Foo->Bar"},
{"testdata", "rewrite2.input", "rewrite2.golden", "-r=int->bool"}, {"testdata/rewrite2.input", "-r=int->bool"},
{"testdata/stdin*.input", "-stdin"},
} }
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
for _, test := range tests { for _, test := range tests {
runTest(t, test.dirname, test.in, test.out, test.flags) match, err := filepath.Glob(test.in)
if err != nil {
t.Error(err)
continue
}
for _, in := range match {
out := in
if strings.HasSuffix(in, ".input") {
out = in[:len(in)-len(".input")] + ".golden"
}
runTest(t, in, out, test.flags)
if in != out {
// Check idempotence.
runTest(t, out, out, test.flags)
}
}
} }
} }
var x int
func f() {
y := z
/* this is a comment */
// this is a comment too
}
var x int
func f() { y := z
/* this is a comment */
// this is a comment too
}
var x int
func f() {
y := z
/* this is a comment */
// this is a comment too
}
/* note: no newline at end of file */
for i := 0; i < 10; i++ {
s += i
}
\ No newline at end of file
/* note: no newline at end of file */
for i := 0; i < 10; i++ {
s += i
}
\ No newline at end of file
/* note: no newline at end of file */
for i := 0; i < 10; i++ { s += i }
\ No newline at end of file
/* note: no newline at end of file */
for i := 0; i < 10; i++ {
s += i
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment