Commit c4d0ac0e authored by Dave Cheney's avatar Dave Cheney Committed by Gustavo Niemeyer

exp/ssh: add Std{in,out,err}Pipe methods to Session

R=gustav.paul, cw, agl, rsc, n13m3y3r
CC=golang-dev
https://golang.org/cl/5433080
parent c0a53bbc
...@@ -54,9 +54,10 @@ type Session struct { ...@@ -54,9 +54,10 @@ type Session struct {
*clientChan // the channel backing this session *clientChan // the channel backing this session
started bool // true once a Shell or Run is invoked. started bool // true once Start, Run or Shell is invoked.
copyFuncs []func() error closeAfterWait []io.Closer
errch chan error // one send per copyFunc copyFuncs []func() error
errch chan error // one send per copyFunc
} }
// RFC 4254 Section 6.4. // RFC 4254 Section 6.4.
...@@ -231,7 +232,7 @@ func (s *Session) start() error { ...@@ -231,7 +232,7 @@ func (s *Session) start() error {
return nil return nil
} }
// Wait waits for the remote command to exit. // Wait waits for the remote command to exit.
func (s *Session) Wait() error { func (s *Session) Wait() error {
if !s.started { if !s.started {
return errors.New("ssh: session not started") return errors.New("ssh: session not started")
...@@ -244,11 +245,12 @@ func (s *Session) Wait() error { ...@@ -244,11 +245,12 @@ func (s *Session) Wait() error {
copyError = err copyError = err
} }
} }
for _, fd := range s.closeAfterWait {
fd.Close()
}
if waitErr != nil { if waitErr != nil {
return waitErr return waitErr
} }
return copyError return copyError
} }
...@@ -283,11 +285,15 @@ func (s *Session) stdin() error { ...@@ -283,11 +285,15 @@ func (s *Session) stdin() error {
s.Stdin = new(bytes.Buffer) s.Stdin = new(bytes.Buffer)
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(&chanWriter{ w := &chanWriter{
packetWriter: s, packetWriter: s,
peersId: s.peersId, peersId: s.peersId,
win: s.win, win: s.win,
}, s.Stdin) }
_, err := io.Copy(w, s.Stdin)
if err1 := w.Close(); err == nil {
err = err1
}
return err return err
}) })
return nil return nil
...@@ -298,11 +304,12 @@ func (s *Session) stdout() error { ...@@ -298,11 +304,12 @@ func (s *Session) stdout() error {
s.Stdout = ioutil.Discard s.Stdout = ioutil.Discard
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stdout, &chanReader{ r := &chanReader{
packetWriter: s, packetWriter: s,
peersId: s.peersId, peersId: s.peersId,
data: s.data, data: s.data,
}) }
_, err := io.Copy(s.Stdout, r)
return err return err
}) })
return nil return nil
...@@ -313,16 +320,72 @@ func (s *Session) stderr() error { ...@@ -313,16 +320,72 @@ func (s *Session) stderr() error {
s.Stderr = ioutil.Discard s.Stderr = ioutil.Discard
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stderr, &chanReader{ r := &chanReader{
packetWriter: s, packetWriter: s,
peersId: s.peersId, peersId: s.peersId,
data: s.dataExt, data: s.dataExt,
}) }
_, err := io.Copy(s.Stderr, r)
return err return err
}) })
return nil return nil
} }
// StdinPipe returns a pipe that will be connected to the
// remote command's standard input when the command starts.
func (s *Session) StdinPipe() (io.WriteCloser, error) {
if s.Stdin != nil {
return nil, errors.New("ssh: Stdin already set")
}
if s.started {
return nil, errors.New("ssh: StdinPipe after process started")
}
pr, pw := io.Pipe()
s.Stdin = pr
s.closeAfterWait = append(s.closeAfterWait, pr)
return pw, nil
}
// StdoutPipe returns a pipe that will be connected to the
// remote command's standard output when the command starts.
// There is a fixed amount of buffering that is shared between
// stdout and stderr streams. If the StdoutPipe reader is
// not serviced fast enought it may eventually cause the
// remote command to block.
func (s *Session) StdoutPipe() (io.ReadCloser, error) {
if s.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
if s.started {
return nil, errors.New("ssh: StdoutPipe after process started")
}
pr, pw := io.Pipe()
s.Stdout = pw
s.closeAfterWait = append(s.closeAfterWait, pw)
return pr, nil
}
// StderrPipe returns a pipe that will be connected to the
// remote command's standard error when the command starts.
// There is a fixed amount of buffering that is shared between
// stdout and stderr streams. If the StderrPipe reader is
// not serviced fast enought it may eventually cause the
// remote command to block.
func (s *Session) StderrPipe() (io.ReadCloser, error) {
if s.Stderr != nil {
return nil, errors.New("ssh: Stderr already set")
}
if s.started {
return nil, errors.New("ssh: StderrPipe after process started")
}
pr, pw := io.Pipe()
s.Stderr = pw
s.closeAfterWait = append(s.closeAfterWait, pw)
return pr, nil
}
// TODO(dfc) add Output and CombinedOutput helpers
// NewSession returns a new interactive session on the remote host. // NewSession returns a new interactive session on the remote host.
func (c *ClientConn) NewSession() (*Session, error) { func (c *ClientConn) NewSession() (*Session, error) {
ch := c.newChan(c.transport) ch := c.newChan(c.transport)
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session tests.
import (
"bytes"
"io"
"testing"
)
// dial constructs a new test server and returns a *ClientConn.
func dial(t *testing.T) *ClientConn {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
serverConfig.PubKeyCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
go func() {
defer l.Close()
conn, err := l.Accept()
if err != nil {
t.Errorf("Unable to accept: %v", err)
return
}
defer conn.Close()
if err := conn.Handshake(); err != nil {
t.Errorf("Unable to handshake: %v", err)
return
}
for {
ch, err := conn.Accept()
if err == io.EOF {
return
}
if err != nil {
t.Errorf("Unable to accept incoming channel request: %v", err)
return
}
if ch.ChannelType() != "session" {
ch.Reject(UnknownChannelType, "unknown channel type")
continue
}
ch.Accept()
go func() {
defer ch.Close()
// this string is returned to stdout
shell := NewServerShell(ch, "golang")
shell.ReadLine()
type exitMsg struct {
PeersId uint32
Request string
WantReply bool
Status uint32
}
// TODO(dfc) casting to the concrete type should not be
// necessary to send a packet.
msg := exitMsg{
PeersId: ch.(*channel).theirId,
Request: "exit-status",
WantReply: false,
Status: 0,
}
ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg))
}()
}
t.Log("done")
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(pw),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
return c
}
// Test a simple string is returned to session.Stdout.
func TestSessionShell(t *testing.T) {
conn := dial(t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %s", err)
}
defer session.Close()
stdout := new(bytes.Buffer)
session.Stdout = stdout
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %s", err)
}
actual := stdout.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
// Test a simple string is returned via StdoutPipe.
func TestSessionStdoutPipe(t *testing.T) {
conn := dial(t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %s", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("Unable to request StdoutPipe(): %v", err)
}
var buf bytes.Buffer
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
done := make(chan bool, 1)
go func() {
if _, err := io.Copy(&buf, stdout); err != nil {
t.Errorf("Copy of stdout failed: %v", err)
}
done <- true
}()
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %s", err)
}
<-done
actual := buf.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment