Commit 8bfb2171 authored by Dave Cheney's avatar Dave Cheney Committed by Adam Langley

exp/ssh: server cleanups

server.go/channel.go:
* rename Server to ServerConfig to match Client.
* rename ServerConnection to ServeConn to match Client.
* add Listen/Listener.
* ServerConn.Handshake(), general cleanups.

client.go:
* fix bug where fmt.Error was not assigned to err

R=rsc, agl
CC=golang-dev
https://golang.org/cl/5265049
parent 792a55f5
...@@ -68,7 +68,7 @@ type channel struct { ...@@ -68,7 +68,7 @@ type channel struct {
weClosed bool weClosed bool
dead bool dead bool
serverConn *ServerConnection serverConn *ServerConn
myId, theirId uint32 myId, theirId uint32
myWindow, theirWindow uint32 myWindow, theirWindow uint32
maxPacketSize uint32 maxPacketSize uint32
......
...@@ -115,7 +115,7 @@ func (c *ClientConn) handshake() os.Error { ...@@ -115,7 +115,7 @@ func (c *ClientConn) handshake() os.Error {
dhGroup14Once.Do(initDHGroup14) dhGroup14Once.Do(initDHGroup14)
H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
default: default:
fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo) err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
} }
if err != nil { if err != nil {
return err return err
......
...@@ -10,19 +10,23 @@ import ( ...@@ -10,19 +10,23 @@ import (
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
_ "crypto/sha1"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"io"
"net" "net"
"os" "os"
"sync" "sync"
) )
// Server represents an SSH server. A Server may have several ServerConnections. type ServerConfig struct {
type Server struct {
rsa *rsa.PrivateKey rsa *rsa.PrivateKey
rsaSerialized []byte rsaSerialized []byte
// Rand provides the source of entropy for key exchange. If Rand is
// nil, the cryptographic random reader in package crypto/rand will
// be used.
Rand io.Reader
// NoClientAuth is true if clients are allowed to connect without // NoClientAuth is true if clients are allowed to connect without
// authenticating. // authenticating.
NoClientAuth bool NoClientAuth bool
...@@ -38,11 +42,18 @@ type Server struct { ...@@ -38,11 +42,18 @@ type Server struct {
PubKeyCallback func(user, algo string, pubkey []byte) bool PubKeyCallback func(user, algo string, pubkey []byte) bool
} }
func (c *ServerConfig) rand() io.Reader {
if c.Rand == nil {
return rand.Reader
}
return c.Rand
}
// SetRSAPrivateKey sets the private key for a Server. A Server must have a // SetRSAPrivateKey sets the private key for a Server. A Server must have a
// private key configured in order to accept connections. The private key must // private key configured in order to accept connections. The private key must
// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa" // be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
// typically contains such a key. // typically contains such a key.
func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error { func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) os.Error {
block, _ := pem.Decode(pemBytes) block, _ := pem.Decode(pemBytes)
if block == nil { if block == nil {
return os.NewError("ssh: no key found") return os.NewError("ssh: no key found")
...@@ -109,7 +120,7 @@ func parseRSASig(in []byte) (sig []byte, ok bool) { ...@@ -109,7 +120,7 @@ func parseRSASig(in []byte) (sig []byte, ok bool) {
} }
// cachedPubKey contains the results of querying whether a public key is // cachedPubKey contains the results of querying whether a public key is
// acceptable for a user. The cache only applies to a single ServerConnection. // acceptable for a user. The cache only applies to a single ServerConn.
type cachedPubKey struct { type cachedPubKey struct {
user, algo string user, algo string
pubKey []byte pubKey []byte
...@@ -118,11 +129,10 @@ type cachedPubKey struct { ...@@ -118,11 +129,10 @@ type cachedPubKey struct {
const maxCachedPubKeys = 16 const maxCachedPubKeys = 16
// ServerConnection represents an incomming connection to a Server. // A ServerConn represents an incomming connection.
type ServerConnection struct { type ServerConn struct {
Server *Server
*transport *transport
config *ServerConfig
channels map[uint32]*channel channels map[uint32]*channel
nextChanId uint32 nextChanId uint32
...@@ -139,9 +149,20 @@ type ServerConnection struct { ...@@ -139,9 +149,20 @@ type ServerConnection struct {
cachedPubKeys []cachedPubKey cachedPubKeys []cachedPubKey
} }
// Server returns a new SSH server connection
// using c as the underlying transport.
func Server(c net.Conn, config *ServerConfig) *ServerConn {
conn := &ServerConn{
transport: newTransport(c, config.rand()),
channels: make(map[uint32]*channel),
config: config,
}
return conn
}
// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
// returned values are given the same names as in RFC 4253, section 8. // returned values are given the same names as in RFC 4253, section 8.
func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
packet, err := s.readPacket() packet, err := s.readPacket()
if err != nil { if err != nil {
return return
...@@ -155,7 +176,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h ...@@ -155,7 +176,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
return nil, nil, os.NewError("client DH parameter out of bounds") return nil, nil, os.NewError("client DH parameter out of bounds")
} }
y, err := rand.Int(rand.Reader, group.p) y, err := rand.Int(s.config.rand(), group.p)
if err != nil { if err != nil {
return return
} }
...@@ -166,7 +187,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h ...@@ -166,7 +187,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
var serializedHostKey []byte var serializedHostKey []byte
switch hostKeyAlgo { switch hostKeyAlgo {
case hostAlgoRSA: case hostAlgoRSA:
serializedHostKey = s.Server.rsaSerialized serializedHostKey = s.config.rsaSerialized
default: default:
return nil, nil, os.NewError("internal error") return nil, nil, os.NewError("internal error")
} }
...@@ -192,7 +213,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h ...@@ -192,7 +213,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
var sig []byte var sig []byte
switch hostKeyAlgo { switch hostKeyAlgo {
case hostAlgoRSA: case hostAlgoRSA:
sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh) sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh)
if err != nil { if err != nil {
return return
} }
...@@ -257,17 +278,18 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK ...@@ -257,17 +278,18 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
return ret return ret
} }
// Handshake performs an SSH transport and client authentication on the given ServerConnection. // Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConnection) Handshake(conn net.Conn) os.Error { func (s *ServerConn) Handshake() os.Error {
var magics handshakeMagics var magics handshakeMagics
s.transport = newTransport(conn, rand.Reader) if _, err := s.Write(serverVersion); err != nil {
return err
if _, err := conn.Write(serverVersion); err != nil { }
if err := s.Flush(); err != nil {
return err return err
} }
magics.serverVersion = serverVersion[:len(serverVersion)-2] magics.serverVersion = serverVersion[:len(serverVersion)-2]
version, err := readVersion(s.transport) version, err := readVersion(s)
if err != nil { if err != nil {
return err return err
} }
...@@ -310,8 +332,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { ...@@ -310,8 +332,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] { if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm, // The client sent a Kex message for the wrong algorithm,
// which we have to ignore. // which we have to ignore.
_, err := s.readPacket() if _, err := s.readPacket(); err != nil {
if err != nil {
return err return err
} }
} }
...@@ -324,32 +345,27 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { ...@@ -324,32 +345,27 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
dhGroup14Once.Do(initDHGroup14) dhGroup14Once.Do(initDHGroup14)
H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
default: default:
err = os.NewError("ssh: internal error") err = os.NewError("ssh: unexpected key exchange algorithm " + kexAlgo)
} }
if err != nil { if err != nil {
return err return err
} }
packet = []byte{msgNewKeys} if err = s.writePacket([]byte{msgNewKeys}); err != nil {
if err = s.writePacket(packet); err != nil {
return err return err
} }
if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
return err return err
} }
if packet, err = s.readPacket(); err != nil { if packet, err = s.readPacket(); err != nil {
return err return err
} }
if packet[0] != msgNewKeys { if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]} return UnexpectedMessageError{msgNewKeys, packet[0]}
} }
s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc) s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc)
if packet, err = s.readPacket(); err != nil {
packet, err = s.readPacket()
if err != nil {
return err return err
} }
...@@ -360,20 +376,16 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { ...@@ -360,20 +376,16 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
if serviceRequest.Service != serviceUserAuth { if serviceRequest.Service != serviceUserAuth {
return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating") return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
} }
serviceAccept := serviceAcceptMsg{ serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth, Service: serviceUserAuth,
} }
packet = marshal(msgServiceAccept, serviceAccept) if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
if err = s.writePacket(packet); err != nil {
return err return err
} }
if err = s.authenticate(H); err != nil { if err = s.authenticate(H); err != nil {
return err return err
} }
s.channels = make(map[uint32]*channel)
return nil return nil
} }
...@@ -382,8 +394,8 @@ func isAcceptableAlgo(algo string) bool { ...@@ -382,8 +394,8 @@ func isAcceptableAlgo(algo string) bool {
} }
// testPubKey returns true if the given public key is acceptable for the user. // testPubKey returns true if the given public key is acceptable for the user.
func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool {
if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) { if s.config.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
return false return false
} }
...@@ -393,7 +405,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { ...@@ -393,7 +405,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
} }
} }
result := s.Server.PubKeyCallback(user, algo, pubKey) result := s.config.PubKeyCallback(user, algo, pubKey)
if len(s.cachedPubKeys) < maxCachedPubKeys { if len(s.cachedPubKeys) < maxCachedPubKeys {
c := cachedPubKey{ c := cachedPubKey{
user: user, user: user,
...@@ -408,7 +420,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { ...@@ -408,7 +420,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
return result return result
} }
func (s *ServerConnection) authenticate(H []byte) os.Error { func (s *ServerConn) authenticate(H []byte) os.Error {
var userAuthReq userAuthRequestMsg var userAuthReq userAuthRequestMsg
var err os.Error var err os.Error
var packet []byte var packet []byte
...@@ -428,11 +440,11 @@ userAuthLoop: ...@@ -428,11 +440,11 @@ userAuthLoop:
switch userAuthReq.Method { switch userAuthReq.Method {
case "none": case "none":
if s.Server.NoClientAuth { if s.config.NoClientAuth {
break userAuthLoop break userAuthLoop
} }
case "password": case "password":
if s.Server.PasswordCallback == nil { if s.config.PasswordCallback == nil {
break break
} }
payload := userAuthReq.Payload payload := userAuthReq.Payload
...@@ -445,11 +457,11 @@ userAuthLoop: ...@@ -445,11 +457,11 @@ userAuthLoop:
return ParseError{msgUserAuthRequest} return ParseError{msgUserAuthRequest}
} }
if s.Server.PasswordCallback(userAuthReq.User, string(password)) { if s.config.PasswordCallback(userAuthReq.User, string(password)) {
break userAuthLoop break userAuthLoop
} }
case "publickey": case "publickey":
if s.Server.PubKeyCallback == nil { if s.config.PubKeyCallback == nil {
break break
} }
payload := userAuthReq.Payload payload := userAuthReq.Payload
...@@ -520,10 +532,10 @@ userAuthLoop: ...@@ -520,10 +532,10 @@ userAuthLoop:
} }
var failureMsg userAuthFailureMsg var failureMsg userAuthFailureMsg
if s.Server.PasswordCallback != nil { if s.config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password") failureMsg.Methods = append(failureMsg.Methods, "password")
} }
if s.Server.PubKeyCallback != nil { if s.config.PubKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey") failureMsg.Methods = append(failureMsg.Methods, "publickey")
} }
...@@ -546,9 +558,9 @@ userAuthLoop: ...@@ -546,9 +558,9 @@ userAuthLoop:
const defaultWindowSize = 32768 const defaultWindowSize = 32768
// Accept reads and processes messages on a ServerConnection. It must be called // Accept reads and processes messages on a ServerConn. It must be called
// in order to demultiplex messages to any resulting Channels. // in order to demultiplex messages to any resulting Channels.
func (s *ServerConnection) Accept() (Channel, os.Error) { func (s *ServerConn) Accept() (Channel, os.Error) {
if s.err != nil { if s.err != nil {
return nil, s.err return nil, s.err
} }
...@@ -643,3 +655,44 @@ func (s *ServerConnection) Accept() (Channel, os.Error) { ...@@ -643,3 +655,44 @@ func (s *ServerConnection) Accept() (Channel, os.Error) {
panic("unreachable") panic("unreachable")
} }
// A Listener implements a network listener (net.Listener) for SSH connections.
type Listener struct {
listener net.Listener
config *ServerConfig
}
// Accept waits for and returns the next incoming SSH connection.
// The receiver should call Handshake() in another goroutine
// to avoid blocking the accepter.
func (l *Listener) Accept() (*ServerConn, os.Error) {
c, err := l.listener.Accept()
if err != nil {
return nil, err
}
conn := Server(c, l.config)
return conn, nil
}
// Addr returns the listener's network address.
func (l *Listener) Addr() net.Addr {
return l.listener.Addr()
}
// Close closes the listener.
func (l *Listener) Close() os.Error {
return l.listener.Close()
}
// Listen creates an SSH listener accepting connections on
// the given network address using net.Listen.
func Listen(network, addr string, config *ServerConfig) (*Listener, os.Error) {
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return &Listener{
l,
config,
}, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment