Commit 75397e65 authored by Dmitriy Vyukov's avatar Dmitriy Vyukov

net/rpc: fix data race on Call.Error

+eliminates a possibility of sending a call to Done several times.
+fixes memory leak in case of temporal Write errors.
+fixes data race on Client.shutdown.
+fixes data race on Client.closing.
+fixes comments.
Fixes #2780.

R=r, rsc
CC=golang-dev, mpimenov
https://golang.org/cl/5571063
parent d5b7c515
...@@ -31,8 +31,7 @@ type Call struct { ...@@ -31,8 +31,7 @@ type Call struct {
Args interface{} // The argument to the function (*struct). Args interface{} // The argument to the function (*struct).
Reply interface{} // The reply from the function (*struct). Reply interface{} // The reply from the function (*struct).
Error error // After completion, the error status. Error error // After completion, the error status.
Done chan *Call // Strobes when call is complete; value is the error status. Done chan *Call // Strobes when call is complete.
seq uint64
} }
// Client represents an RPC Client. // Client represents an RPC Client.
...@@ -65,28 +64,33 @@ type ClientCodec interface { ...@@ -65,28 +64,33 @@ type ClientCodec interface {
Close() error Close() error
} }
func (client *Client) send(c *Call) { func (client *Client) send(call *Call) {
client.sending.Lock()
defer client.sending.Unlock()
// Register this call. // Register this call.
client.mutex.Lock() client.mutex.Lock()
if client.shutdown { if client.shutdown {
c.Error = ErrShutdown call.Error = ErrShutdown
client.mutex.Unlock() client.mutex.Unlock()
c.done() call.done()
return return
} }
c.seq = client.seq seq := client.seq
client.seq++ client.seq++
client.pending[c.seq] = c client.pending[seq] = call
client.mutex.Unlock() client.mutex.Unlock()
// Encode and send the request. // Encode and send the request.
client.sending.Lock() client.request.Seq = seq
defer client.sending.Unlock() client.request.ServiceMethod = call.ServiceMethod
client.request.Seq = c.seq err := client.codec.WriteRequest(&client.request, call.Args)
client.request.ServiceMethod = c.ServiceMethod if err != nil {
if err := client.codec.WriteRequest(&client.request, c.Args); err != nil { client.mutex.Lock()
c.Error = err delete(client.pending, seq)
c.done() client.mutex.Unlock()
call.Error = err
call.done()
} }
} }
...@@ -104,36 +108,39 @@ func (client *Client) input() { ...@@ -104,36 +108,39 @@ func (client *Client) input() {
} }
seq := response.Seq seq := response.Seq
client.mutex.Lock() client.mutex.Lock()
c := client.pending[seq] call := client.pending[seq]
delete(client.pending, seq) delete(client.pending, seq)
client.mutex.Unlock() client.mutex.Unlock()
if response.Error == "" { if response.Error == "" {
err = client.codec.ReadResponseBody(c.Reply) err = client.codec.ReadResponseBody(call.Reply)
if err != nil { if err != nil {
c.Error = errors.New("reading body " + err.Error()) call.Error = errors.New("reading body " + err.Error())
} }
} else { } else {
// We've got an error response. Give this to the request; // We've got an error response. Give this to the request;
// any subsequent requests will get the ReadResponseBody // any subsequent requests will get the ReadResponseBody
// error if there is one. // error if there is one.
c.Error = ServerError(response.Error) call.Error = ServerError(response.Error)
err = client.codec.ReadResponseBody(nil) err = client.codec.ReadResponseBody(nil)
if err != nil { if err != nil {
err = errors.New("reading error body: " + err.Error()) err = errors.New("reading error body: " + err.Error())
} }
} }
c.done() call.done()
} }
// Terminate pending calls. // Terminate pending calls.
client.sending.Lock()
client.mutex.Lock() client.mutex.Lock()
client.shutdown = true client.shutdown = true
closing := client.closing
for _, call := range client.pending { for _, call := range client.pending {
call.Error = err call.Error = err
call.done() call.done()
} }
client.mutex.Unlock() client.mutex.Unlock()
if err != io.EOF || !client.closing { client.sending.Unlock()
if err != io.EOF || !closing {
log.Println("rpc: client protocol error:", err) log.Println("rpc: client protocol error:", err)
} }
} }
...@@ -269,20 +276,12 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface ...@@ -269,20 +276,12 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
} }
} }
call.Done = done call.Done = done
if client.shutdown {
call.Error = ErrShutdown
call.done()
return call
}
client.send(call) client.send(call)
return call return call
} }
// Call invokes the named function, waits for it to complete, and returns its error status. // Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error { func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
if client.shutdown {
return ErrShutdown
}
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
return call.Error return call.Error
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment