Commit 250ac873 authored by Andrew Gerrand's avatar Andrew Gerrand

rpc: expose Server type to allow multiple RPC Server instances

R=r, rsc, msolo, sougou
CC=golang-dev
https://golang.org/cl/2696041
parent 904adfdc
...@@ -162,14 +162,21 @@ func (c *gobClientCodec) Close() os.Error { ...@@ -162,14 +162,21 @@ func (c *gobClientCodec) Close() os.Error {
} }
// DialHTTP connects to an HTTP RPC server at the specified network address. // DialHTTP connects to an HTTP RPC server at the specified network address
// listening on the default HTTP RPC path.
func DialHTTP(network, address string) (*Client, os.Error) { func DialHTTP(network, address string) (*Client, os.Error) {
return DialHTTPPath(network, address, DefaultRPCPath)
}
// DialHTTPPath connects to an HTTP RPC server
// at the specified network address and path.
func DialHTTPPath(network, address, path string) (*Client, os.Error) {
var err os.Error var err os.Error
conn, err := net.Dial(network, "", address) conn, err := net.Dial(network, "", address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
io.WriteString(conn, "CONNECT "+rpcPath+" HTTP/1.0\n\n") io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
// Require successful HTTP response // Require successful HTTP response
// before switching to RPC protocol. // before switching to RPC protocol.
......
...@@ -61,8 +61,12 @@ func (m methodArray) Len() int { return len(m) } ...@@ -61,8 +61,12 @@ func (m methodArray) Len() int { return len(m) }
func (m methodArray) Less(i, j int) bool { return m[i].name < m[j].name } func (m methodArray) Less(i, j int) bool { return m[i].name < m[j].name }
func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] } func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
type debugHTTP struct {
*Server
}
// Runs at /debug/rpc // Runs at /debug/rpc
func debugHTTP(w http.ResponseWriter, req *http.Request) { func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Build a sorted version of the data. // Build a sorted version of the data.
var services = make(serviceArray, len(server.serviceMap)) var services = make(serviceArray, len(server.serviceMap))
i := 0 i := 0
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
/* /*
The rpc package provides access to the public methods of an object across a The rpc package provides access to the exported methods of an object across a
network or other I/O connection. A server registers an object, making it visible network or other I/O connection. A server registers an object, making it visible
as a service with the name of the type of the object. After registration, public as a service with the name of the type of the object. After registration, exported
methods of the object will be accessible remotely. A server may register multiple methods of the object will be accessible remotely. A server may register multiple
objects (services) of different types but it is an error to register multiple objects (services) of different types but it is an error to register multiple
objects of the same type. objects of the same type.
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
Only methods that satisfy these criteria will be made available for remote access; Only methods that satisfy these criteria will be made available for remote access;
other methods will be ignored: other methods will be ignored:
- the method receiver and name are publicly visible, that is, begin with an upper case letter. - the method receiver and name are exported, that is, begin with an upper case letter.
- the method has two arguments, both pointers to publicly visible types. - the method has two arguments, both pointers to exported types.
- the method has return type os.Error. - the method has return type os.Error.
The method's first argument represents the arguments provided by the caller; the The method's first argument represents the arguments provided by the caller; the
...@@ -123,6 +123,12 @@ import ( ...@@ -123,6 +123,12 @@ import (
"utf8" "utf8"
) )
const (
// Defaults used by HandleHTTP
DefaultRPCPath = "/_goRPC_"
DefaultDebugPath = "/debug/rpc"
)
// Precompute the reflect type for os.Error. Can't use os.Error directly // Precompute the reflect type for os.Error. Can't use os.Error directly
// because Typeof takes an empty interface value. This is annoying. // because Typeof takes an empty interface value. This is annoying.
var unusedError *os.Error var unusedError *os.Error
...@@ -166,23 +172,34 @@ type ClientInfo struct { ...@@ -166,23 +172,34 @@ type ClientInfo struct {
RemoteAddr string RemoteAddr string
} }
type serverType struct { // Server represents an RPC Server.
type Server struct {
sync.Mutex // protects the serviceMap sync.Mutex // protects the serviceMap
serviceMap map[string]*service serviceMap map[string]*service
} }
// This variable is a global whose "public" methods are really private methods // NewServer returns a new Server.
// called from the global functions of this package: rpc.Register, rpc.ServeConn, etc. func NewServer() *Server {
// For example, rpc.Register() calls server.add(). return &Server{serviceMap: make(map[string]*service)}
var server = &serverType{serviceMap: make(map[string]*service)} }
// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer()
// Is this a publicly visible - upper case - name? // Is this an exported - upper case - name?
func isPublic(name string) bool { func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name) rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune) return unicode.IsUpper(rune)
} }
func (server *serverType) register(rcvr interface{}) os.Error { // Register publishes in the server the set of methods of the
// receiver value that satisfy the following conditions:
// - exported method
// - two arguments, both pointers to exported structs
// - one return value, of type os.Error
// It returns an error if the receiver is not an exported type or has no
// suitable methods.
func (server *Server) Register(rcvr interface{}) os.Error {
server.Lock() server.Lock()
defer server.Unlock() defer server.Unlock()
if server.serviceMap == nil { if server.serviceMap == nil {
...@@ -195,8 +212,8 @@ func (server *serverType) register(rcvr interface{}) os.Error { ...@@ -195,8 +212,8 @@ func (server *serverType) register(rcvr interface{}) os.Error {
if sname == "" { if sname == "" {
log.Exit("rpc: no service name for type", s.typ.String()) log.Exit("rpc: no service name for type", s.typ.String())
} }
if s.typ.PkgPath() != "" && !isPublic(sname) { if s.typ.PkgPath() != "" && !isExported(sname) {
s := "rpc Register: type " + sname + " is not public" s := "rpc Register: type " + sname + " is not exported"
log.Print(s) log.Print(s)
return os.ErrorString(s) return os.ErrorString(s)
} }
...@@ -211,7 +228,7 @@ func (server *serverType) register(rcvr interface{}) os.Error { ...@@ -211,7 +228,7 @@ func (server *serverType) register(rcvr interface{}) os.Error {
method := s.typ.Method(m) method := s.typ.Method(m)
mtype := method.Type mtype := method.Type
mname := method.Name mname := method.Name
if mtype.PkgPath() != "" || !isPublic(mname) { if mtype.PkgPath() != "" || !isExported(mname) {
continue continue
} }
// Method needs three ins: receiver, *args, *reply. // Method needs three ins: receiver, *args, *reply.
...@@ -229,12 +246,12 @@ func (server *serverType) register(rcvr interface{}) os.Error { ...@@ -229,12 +246,12 @@ func (server *serverType) register(rcvr interface{}) os.Error {
log.Println(mname, "reply type not a pointer:", mtype.In(2)) log.Println(mname, "reply type not a pointer:", mtype.In(2))
continue continue
} }
if argType.Elem().PkgPath() != "" && !isPublic(argType.Elem().Name()) { if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) {
log.Println(mname, "argument type not public:", argType) log.Println(mname, "argument type not exported:", argType)
continue continue
} }
if replyType.Elem().PkgPath() != "" && !isPublic(replyType.Elem().Name()) { if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) {
log.Println(mname, "reply type not public:", replyType) log.Println(mname, "reply type not exported:", replyType)
continue continue
} }
if mtype.NumIn() == 4 { if mtype.NumIn() == 4 {
...@@ -257,7 +274,7 @@ func (server *serverType) register(rcvr interface{}) os.Error { ...@@ -257,7 +274,7 @@ func (server *serverType) register(rcvr interface{}) os.Error {
} }
if len(s.method) == 0 { if len(s.method) == 0 {
s := "rpc Register: type " + sname + " has no public methods of suitable type" s := "rpc Register: type " + sname + " has no exported methods of suitable type"
log.Print(s) log.Print(s)
return os.ErrorString(s) return os.ErrorString(s)
} }
...@@ -335,7 +352,19 @@ func (c *gobServerCodec) Close() os.Error { ...@@ -335,7 +352,19 @@ func (c *gobServerCodec) Close() os.Error {
return c.rwc.Close() return c.rwc.Close()
} }
func (server *serverType) input(codec ServerCodec) {
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
server.ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex) sending := new(sync.Mutex)
for { for {
// Grab the request header. // Grab the request header.
...@@ -387,24 +416,27 @@ func (server *serverType) input(codec ServerCodec) { ...@@ -387,24 +416,27 @@ func (server *serverType) input(codec ServerCodec) {
codec.Close() codec.Close()
} }
func (server *serverType) accept(lis net.Listener) { // Accept accepts connections on the listener and serves requests
// for each incoming connection. Accept blocks; the caller typically
// invokes it in a go statement.
func (server *Server) Accept(lis net.Listener) {
for { for {
conn, err := lis.Accept() conn, err := lis.Accept()
if err != nil { if err != nil {
log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit? log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit?
} }
go ServeConn(conn) go server.ServeConn(conn)
} }
} }
// Register publishes in the server the set of methods of the // Register publishes in the DefaultServer the set of methods
// receiver value that satisfy the following conditions: // of the receiver value that satisfy the following conditions:
// - public method // - exported method
// - two arguments, both pointers to public structs // - two arguments, both pointers to exported structs
// - one return value of type os.Error // - one return value, of type os.Error
// It returns an error if the receiver is not public or has no // It returns an error if the receiver is not an exported type or has no
// suitable methods. // suitable methods.
func Register(rcvr interface{}) os.Error { return server.register(rcvr) } func Register(rcvr interface{}) os.Error { return DefaultServer.Register(rcvr) }
// A ServerCodec implements reading of RPC requests and writing of // A ServerCodec implements reading of RPC requests and writing of
// RPC responses for the server side of an RPC session. // RPC responses for the server side of an RPC session.
...@@ -420,36 +452,35 @@ type ServerCodec interface { ...@@ -420,36 +452,35 @@ type ServerCodec interface {
Close() os.Error Close() os.Error
} }
// ServeConn runs the server on a single connection. // ServeConn runs the DefaultServer on a single connection.
// ServeConn blocks, serving the connection until the client hangs up. // ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement. // The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the // ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec. // connection. To use an alternate codec, use ServeCodec.
func ServeConn(conn io.ReadWriteCloser) { func ServeConn(conn io.ReadWriteCloser) {
ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}) DefaultServer.ServeConn(conn)
} }
// ServeCodec is like ServeConn but uses the specified codec to // ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses. // decode requests and encode responses.
func ServeCodec(codec ServerCodec) { func ServeCodec(codec ServerCodec) {
server.input(codec) DefaultServer.ServeCodec(codec)
} }
// Accept accepts connections on the listener and serves requests // Accept accepts connections on the listener and serves requests
// for each incoming connection. Accept blocks; the caller typically // to DefaultServer for each incoming connection.
// invokes it in a go statement. // Accept blocks; the caller typically invokes it in a go statement.
func Accept(lis net.Listener) { server.accept(lis) } func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
// Can connect to RPC service using HTTP CONNECT to rpcPath. // Can connect to RPC service using HTTP CONNECT to rpcPath.
var rpcPath string = "/_goRPC_"
var debugPath string = "/debug/rpc"
var connected = "200 Connected to Go RPC" var connected = "200 Connected to Go RPC"
func serveHTTP(w http.ResponseWriter, req *http.Request) { // ServeHTTP implements an http.Handler that answers RPC requests.
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" { if req.Method != "CONNECT" {
w.SetHeader("Content-Type", "text/plain; charset=utf-8") w.SetHeader("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
io.WriteString(w, "405 must CONNECT to "+rpcPath+"\n") io.WriteString(w, "405 must CONNECT\n")
return return
} }
conn, _, err := w.Hijack() conn, _, err := w.Hijack()
...@@ -458,12 +489,20 @@ func serveHTTP(w http.ResponseWriter, req *http.Request) { ...@@ -458,12 +489,20 @@ func serveHTTP(w http.ResponseWriter, req *http.Request) {
return return
} }
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
ServeConn(conn) server.ServeConn(conn)
}
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
// and a debugging handler on debugPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func (server *Server) HandleHTTP(rpcPath, debugPath string) {
http.Handle(rpcPath, server)
http.Handle(debugPath, debugHTTP{server})
} }
// HandleHTTP registers an HTTP handler for RPC messages. // HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
// It is still necessary to invoke http.Serve(), typically in a go statement. // It is still necessary to invoke http.Serve(), typically in a go statement.
func HandleHTTP() { func HandleHTTP() {
http.Handle(rpcPath, http.HandlerFunc(serveHTTP)) DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
http.Handle(debugPath, http.HandlerFunc(debugHTTP))
} }
...@@ -15,12 +15,16 @@ import ( ...@@ -15,12 +15,16 @@ import (
"testing" "testing"
) )
var serverAddr string var (
var httpServerAddr string serverAddr, newServerAddr string
var once sync.Once httpServerAddr string
once, newOnce, httpOnce sync.Once
const second = 1e9 )
const (
second = 1e9
newHttpPath = "/foo"
)
type Args struct { type Args struct {
A, B int A, B int
...@@ -64,23 +68,42 @@ func (t *Arith) Error(args *Args, reply *Reply) os.Error { ...@@ -64,23 +68,42 @@ func (t *Arith) Error(args *Args, reply *Reply) os.Error {
panic("ERROR") panic("ERROR")
} }
func startServer() { func listenTCP() (net.Listener, string) {
Register(new(Arith))
l, e := net.Listen("tcp", "127.0.0.1:0") // any available address l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
if e != nil { if e != nil {
log.Exitf("net.Listen tcp :0: %v", e) log.Exitf("net.Listen tcp :0: %v", e)
} }
serverAddr = l.Addr().String() return l, l.Addr().String()
}
func startServer() {
Register(new(Arith))
var l net.Listener
l, serverAddr = listenTCP()
log.Println("Test RPC server listening on", serverAddr) log.Println("Test RPC server listening on", serverAddr)
go Accept(l) go Accept(l)
HandleHTTP() HandleHTTP()
l, e = net.Listen("tcp", "127.0.0.1:0") // any available address httpOnce.Do(startHttpServer)
if e != nil { }
log.Printf("net.Listen tcp :0: %v", e)
os.Exit(1) func startNewServer() {
} s := NewServer()
s.Register(new(Arith))
var l net.Listener
l, newServerAddr = listenTCP()
log.Println("NewServer test RPC server listening on", newServerAddr)
go Accept(l)
s.HandleHTTP(newHttpPath, "/bar")
httpOnce.Do(startHttpServer)
}
func startHttpServer() {
var l net.Listener
l, httpServerAddr = listenTCP()
httpServerAddr = l.Addr().String() httpServerAddr = l.Addr().String()
log.Println("Test HTTP RPC server listening on", httpServerAddr) log.Println("Test HTTP RPC server listening on", httpServerAddr)
go http.Serve(l, nil) go http.Serve(l, nil)
...@@ -88,8 +111,13 @@ func startServer() { ...@@ -88,8 +111,13 @@ func startServer() {
func TestRPC(t *testing.T) { func TestRPC(t *testing.T) {
once.Do(startServer) once.Do(startServer)
testRPC(t, serverAddr)
newOnce.Do(startNewServer)
testRPC(t, newServerAddr)
}
client, err := Dial("tcp", serverAddr) func testRPC(t *testing.T, addr string) {
client, err := Dial("tcp", addr)
if err != nil { if err != nil {
t.Fatal("dialing", err) t.Fatal("dialing", err)
} }
...@@ -175,8 +203,19 @@ func TestRPC(t *testing.T) { ...@@ -175,8 +203,19 @@ func TestRPC(t *testing.T) {
func TestHTTPRPC(t *testing.T) { func TestHTTPRPC(t *testing.T) {
once.Do(startServer) once.Do(startServer)
testHTTPRPC(t, "")
newOnce.Do(startNewServer)
testHTTPRPC(t, newHttpPath)
}
client, err := DialHTTP("tcp", httpServerAddr) func testHTTPRPC(t *testing.T, path string) {
var client *Client
var err os.Error
if path == "" {
client, err = DialHTTP("tcp", httpServerAddr)
} else {
client, err = DialHTTPPath("tcp", httpServerAddr, path)
}
if err != nil { if err != nil {
t.Fatal("dialing", err) t.Fatal("dialing", err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment