Commit dcff8905 authored by Russ Cox's avatar Russ Cox

rpc: abstract client and server encodings

R=r
CC=golang-dev, rog
https://golang.org/cl/811046
parent 72f9b2eb
...@@ -33,13 +33,25 @@ type Client struct { ...@@ -33,13 +33,25 @@ type Client struct {
shutdown os.Error // non-nil if the client is shut down shutdown os.Error // non-nil if the client is shut down
sending sync.Mutex sending sync.Mutex
seq uint64 seq uint64
conn io.ReadWriteCloser codec ClientCodec
enc *gob.Encoder
dec *gob.Decoder
pending map[uint64]*Call pending map[uint64]*Call
closing bool closing bool
} }
// A ClientCodec implements writing of RPC requests and
// reading of RPC responses for the client side of an RPC session.
// The client calls WriteRequest to write a request to the connection
// and calls ReadResponseHeader and ReadResponseBody in pairs
// to read responses. The client calls Close when finished with the
// connection.
type ClientCodec interface {
WriteRequest(*Request, interface{}) os.Error
ReadResponseHeader(*Response) os.Error
ReadResponseBody(interface{}) os.Error
Close() os.Error
}
func (client *Client) send(c *Call) { func (client *Client) send(c *Call) {
// Register this call. // Register this call.
client.mutex.Lock() client.mutex.Lock()
...@@ -59,9 +71,7 @@ func (client *Client) send(c *Call) { ...@@ -59,9 +71,7 @@ func (client *Client) send(c *Call) {
client.sending.Lock() client.sending.Lock()
request.Seq = c.seq request.Seq = c.seq
request.ServiceMethod = c.ServiceMethod request.ServiceMethod = c.ServiceMethod
client.enc.Encode(request) if err := client.codec.WriteRequest(request, c.Args); err != nil {
err := client.enc.Encode(c.Args)
if err != nil {
panic("rpc: client encode error: " + err.String()) panic("rpc: client encode error: " + err.String())
} }
client.sending.Unlock() client.sending.Unlock()
...@@ -71,7 +81,7 @@ func (client *Client) input() { ...@@ -71,7 +81,7 @@ func (client *Client) input() {
var err os.Error var err os.Error
for err == nil { for err == nil {
response := new(Response) response := new(Response)
err = client.dec.Decode(response) err = client.codec.ReadResponseHeader(response)
if err != nil { if err != nil {
if err == os.EOF && !client.closing { if err == os.EOF && !client.closing {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
...@@ -83,7 +93,7 @@ func (client *Client) input() { ...@@ -83,7 +93,7 @@ func (client *Client) input() {
c := client.pending[seq] c := client.pending[seq]
client.pending[seq] = c, false client.pending[seq] = c, false
client.mutex.Unlock() client.mutex.Unlock()
err = client.dec.Decode(c.Reply) err = client.codec.ReadResponseBody(c.Reply)
// Empty strings should turn into nil os.Errors // Empty strings should turn into nil os.Errors
if response.Error != "" { if response.Error != "" {
c.Error = os.ErrorString(response.Error) c.Error = os.ErrorString(response.Error)
...@@ -110,17 +120,49 @@ func (client *Client) input() { ...@@ -110,17 +120,49 @@ func (client *Client) input() {
// NewClient returns a new Client to handle requests to the // NewClient returns a new Client to handle requests to the
// set of services at the other end of the connection. // set of services at the other end of the connection.
func NewClient(conn io.ReadWriteCloser) *Client { func NewClient(conn io.ReadWriteCloser) *Client {
client := new(Client) return NewClientWithCodec(&gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
client.conn = conn }
client.enc = gob.NewEncoder(conn)
client.dec = gob.NewDecoder(conn) // NewClientWithCodec is like NewClient but uses the specified
client.pending = make(map[uint64]*Call) // codec to encode requests and decode responses.
func NewClientWithCodec(codec ClientCodec) *Client {
client := &Client{
codec: codec,
pending: make(map[uint64]*Call),
}
go client.input() go client.input()
return client return client
} }
type gobClientCodec struct {
rwc io.ReadWriteCloser
dec *gob.Decoder
enc *gob.Encoder
}
func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) os.Error {
if err := c.enc.Encode(r); err != nil {
return err
}
return c.enc.Encode(body)
}
func (c *gobClientCodec) ReadResponseHeader(r *Response) os.Error {
return c.dec.Decode(r)
}
func (c *gobClientCodec) ReadResponseBody(body interface{}) os.Error {
return c.dec.Decode(body)
}
func (c *gobClientCodec) Close() os.Error {
return c.rwc.Close()
}
// DialHTTP connects to an HTTP RPC server at the specified network address. // DialHTTP connects to an HTTP RPC server at the specified network address.
func DialHTTP(network, address string) (*Client, os.Error) { func DialHTTP(network, address string) (*Client, os.Error) {
var err os.Error
conn, err := net.Dial(network, "", address) conn, err := net.Dial(network, "", address)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -156,7 +198,7 @@ func (client *Client) Close() os.Error { ...@@ -156,7 +198,7 @@ func (client *Client) Close() os.Error {
client.mutex.Lock() client.mutex.Lock()
client.closing = true client.closing = true
client.mutex.Unlock() client.mutex.Unlock()
return client.conn.Close() return client.codec.Close()
} }
// Go invokes the function asynchronously. It returns the Call structure representing // Go invokes the function asynchronously. It returns the Call structure representing
......
...@@ -272,7 +272,7 @@ func _new(t *reflect.PtrType) *reflect.PtrValue { ...@@ -272,7 +272,7 @@ func _new(t *reflect.PtrType) *reflect.PtrValue {
return v return v
} }
func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) { func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
resp := new(Response) resp := new(Response)
// Encode the response header // Encode the response header
resp.ServiceMethod = req.ServiceMethod resp.ServiceMethod = req.ServiceMethod
...@@ -281,13 +281,14 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob ...@@ -281,13 +281,14 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob
} }
resp.Seq = req.Seq resp.Seq = req.Seq
sending.Lock() sending.Lock()
enc.Encode(resp) err := codec.WriteResponse(resp, reply)
// Encode the reply value. if err != nil {
enc.Encode(reply) log.Stderr("rpc: writing response: ", err)
}
sending.Unlock() sending.Unlock()
} }
func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) { func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
mtype.Lock() mtype.Lock()
mtype.numCalls++ mtype.numCalls++
mtype.Unlock() mtype.Unlock()
...@@ -300,17 +301,40 @@ func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, arg ...@@ -300,17 +301,40 @@ func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, arg
if errInter != nil { if errInter != nil {
errmsg = errInter.(os.Error).String() errmsg = errInter.(os.Error).String()
} }
sendResponse(sending, req, replyv.Interface(), enc, errmsg) sendResponse(sending, req, replyv.Interface(), codec, errmsg)
}
type gobServerCodec struct {
rwc io.ReadWriteCloser
dec *gob.Decoder
enc *gob.Encoder
}
func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error {
return c.dec.Decode(r)
}
func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error {
return c.dec.Decode(body)
}
func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) os.Error {
if err := c.enc.Encode(r); err != nil {
return err
}
return c.enc.Encode(body)
} }
func (server *serverType) input(conn io.ReadWriteCloser) { func (c *gobServerCodec) Close() os.Error {
dec := gob.NewDecoder(conn) return c.rwc.Close()
enc := gob.NewEncoder(conn) }
func (server *serverType) input(codec ServerCodec) {
sending := new(sync.Mutex) sending := new(sync.Mutex)
for { for {
// Grab the request header. // Grab the request header.
req := new(Request) req := new(Request)
err := dec.Decode(req) err := codec.ReadRequestHeader(req)
if err != nil { if err != nil {
if err == os.EOF || err == io.ErrUnexpectedEOF { if err == os.EOF || err == io.ErrUnexpectedEOF {
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
...@@ -319,13 +343,13 @@ func (server *serverType) input(conn io.ReadWriteCloser) { ...@@ -319,13 +343,13 @@ func (server *serverType) input(conn io.ReadWriteCloser) {
break break
} }
s := "rpc: server cannot decode request: " + err.String() s := "rpc: server cannot decode request: " + err.String()
sendResponse(sending, req, invalidRequest, enc, s) sendResponse(sending, req, invalidRequest, codec, s)
continue break
} }
serviceMethod := strings.Split(req.ServiceMethod, ".", 0) serviceMethod := strings.Split(req.ServiceMethod, ".", 0)
if len(serviceMethod) != 2 { if len(serviceMethod) != 2 {
s := "rpc: service/method request ill:formed: " + req.ServiceMethod s := "rpc: service/method request ill-formed: " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, enc, s) sendResponse(sending, req, invalidRequest, codec, s)
continue continue
} }
// Look up the request. // Look up the request.
...@@ -334,27 +358,27 @@ func (server *serverType) input(conn io.ReadWriteCloser) { ...@@ -334,27 +358,27 @@ func (server *serverType) input(conn io.ReadWriteCloser) {
server.Unlock() server.Unlock()
if !ok { if !ok {
s := "rpc: can't find service " + req.ServiceMethod s := "rpc: can't find service " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, enc, s) sendResponse(sending, req, invalidRequest, codec, s)
continue continue
} }
mtype, ok := service.method[serviceMethod[1]] mtype, ok := service.method[serviceMethod[1]]
if !ok { if !ok {
s := "rpc: can't find method " + req.ServiceMethod s := "rpc: can't find method " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, enc, s) sendResponse(sending, req, invalidRequest, codec, s)
continue continue
} }
// Decode the argument value. // Decode the argument value.
argv := _new(mtype.argType) argv := _new(mtype.argType)
replyv := _new(mtype.replyType) replyv := _new(mtype.replyType)
err = dec.Decode(argv.Interface()) err = codec.ReadRequestBody(argv.Interface())
if err != nil { if err != nil {
log.Stderr("rpc: tearing down", serviceMethod[0], "connection:", err) log.Stderr("rpc: tearing down", serviceMethod[0], "connection:", err)
sendResponse(sending, req, replyv.Interface(), enc, err.String()) sendResponse(sending, req, replyv.Interface(), codec, err.String())
continue break
} }
go service.call(sending, mtype, req, argv, replyv, enc) go service.call(sending, mtype, req, argv, replyv, codec)
} }
conn.Close() codec.Close()
} }
func (server *serverType) accept(lis net.Listener) { func (server *serverType) accept(lis net.Listener) {
...@@ -363,7 +387,7 @@ func (server *serverType) accept(lis net.Listener) { ...@@ -363,7 +387,7 @@ func (server *serverType) accept(lis net.Listener) {
if err != nil { if err != nil {
log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit? log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit?
} }
go server.input(conn) go ServeConn(conn)
} }
} }
...@@ -376,10 +400,34 @@ func (server *serverType) accept(lis net.Listener) { ...@@ -376,10 +400,34 @@ func (server *serverType) accept(lis net.Listener) {
// suitable methods. // suitable methods.
func Register(rcvr interface{}) os.Error { return server.register(rcvr) } func Register(rcvr interface{}) os.Error { return server.register(rcvr) }
// ServeConn runs the server on a single connection. When the connection // A ServerCodec implements reading of RPC requests and writing of
// completes, service terminates. ServeConn blocks; the caller typically // RPC responses for the server side of an RPC session.
// invokes it in a go statement. // The server calls ReadRequestHeader and ReadRequestBody in pairs
func ServeConn(conn io.ReadWriteCloser) { server.input(conn) } // to read requests from the connection, and it calls WriteResponse to
// write a response back. The server calls Close when finished with the
// connection.
type ServerCodec interface {
ReadRequestHeader(*Request) os.Error
ReadRequestBody(interface{}) os.Error
WriteResponse(*Response, interface{}) os.Error
Close() os.Error
}
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
func ServeConn(conn io.ReadWriteCloser) {
ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func ServeCodec(codec ServerCodec) {
server.input(codec)
}
// Accept accepts connections on the listener and serves requests // Accept accepts connections on the listener and serves requests
// for each incoming connection. Accept blocks; the caller typically // for each incoming connection. Accept blocks; the caller typically
...@@ -404,7 +452,7 @@ func serveHTTP(c *http.Conn, req *http.Request) { ...@@ -404,7 +452,7 @@ func serveHTTP(c *http.Conn, req *http.Request) {
return return
} }
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.input(conn) ServeConn(conn)
} }
// HandleHTTP registers an HTTP handler for RPC messages. // HandleHTTP registers an HTTP handler for RPC messages.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment