Commit 6fd1680b authored by Adam Reese's avatar Adam Reese

ref(client): refactor url parsing

parent 361b4519
...@@ -22,7 +22,6 @@ func listCmd() cli.Command { ...@@ -22,7 +22,6 @@ func listCmd() cli.Command {
} }
func list(host string) error { func list(host string) error {
client := dm.NewClient(host) client := dm.NewClient(host).SetDebug(true)
client.Protocol = "http"
return client.ListDeployments() return client.ListDeployments()
} }
...@@ -6,8 +6,10 @@ import ( ...@@ -6,8 +6,10 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/ghodss/yaml" "github.com/ghodss/yaml"
...@@ -16,6 +18,9 @@ import ( ...@@ -16,6 +18,9 @@ import (
// The default HTTP timeout // The default HTTP timeout
var DefaultHTTPTimeout = time.Second * 10 var DefaultHTTPTimeout = time.Second * 10
// The default HTTP Protocol
var DefaultHTTPProtocol = "http"
// Client is a DM client. // Client is a DM client.
type Client struct { type Client struct {
// Timeout on HTTP connections. // Timeout on HTTP connections.
...@@ -26,29 +31,61 @@ type Client struct { ...@@ -26,29 +31,61 @@ type Client struct {
Protocol string Protocol string
// Transport // Transport
Transport http.RoundTripper Transport http.RoundTripper
// Debug enables http logging
Debug bool
// Base URL for remote service
baseURL *url.URL
} }
// NewClient creates a new DM client. Host name is required. // NewClient creates a new DM client. Host name is required.
func NewClient(host string) *Client { func NewClient(host string) *Client {
url, _ := DefaultServerURL(host)
return &Client{ return &Client{
HTTPTimeout: DefaultHTTPTimeout, HTTPTimeout: DefaultHTTPTimeout,
Protocol: "https", baseURL: url,
Host: host, Transport: http.DefaultTransport,
Transport: NewDebugTransport(nil), }
}
// SetDebug enables debug mode which logs http
func (c *Client) SetDebug(enable bool) *Client {
c.Debug = enable
return c
}
// transport wraps client transport if debug is enabled
func (c *Client) transport() http.RoundTripper {
if c.Debug {
return NewDebugTransport(c.Transport)
} }
return c.Transport
}
// SetTransport sets a custom Transport. Defaults to http.DefaultTransport
func (c *Client) SetTransport(tr http.RoundTripper) *Client {
c.Transport = tr
return c
} }
// url constructs the URL. // url constructs the URL.
func (c *Client) url(path string) string { func (c *Client) url(rawurl string) (string, error) {
// TODO: Switch to net.URL u, err := url.Parse(rawurl)
return c.Protocol + "://" + c.Host + "/" + path if err != nil {
return "", err
}
return c.baseURL.ResolveReference(u).String(), nil
} }
// CallService is a low-level function for making an API call. // CallService is a low-level function for making an API call.
// //
// This calls the service and then unmarshals the returned data into dest. // This calls the service and then unmarshals the returned data into dest.
func (c *Client) CallService(path, method, action string, dest interface{}, reader io.ReadCloser) error { func (c *Client) CallService(path, method, action string, dest interface{}, reader io.ReadCloser) error {
u := c.url(path) u, err := c.url(path)
if err != nil {
return err
}
resp, err := c.callHTTP(u, method, action, reader) resp, err := c.callHTTP(u, method, action, reader)
if err != nil { if err != nil {
...@@ -76,7 +113,7 @@ func (c *Client) callHTTP(path, method, action string, reader io.ReadCloser) (st ...@@ -76,7 +113,7 @@ func (c *Client) callHTTP(path, method, action string, reader io.ReadCloser) (st
client := http.Client{ client := http.Client{
Timeout: time.Duration(time.Duration(DefaultHTTPTimeout) * time.Second), Timeout: time.Duration(time.Duration(DefaultHTTPTimeout) * time.Second),
Transport: c.Transport, Transport: c.transport(),
} }
response, err := client.Do(request) response, err := client.Do(request)
...@@ -155,3 +192,27 @@ func (c *Client) DeployChart(filename, deployname string) error { ...@@ -155,3 +192,27 @@ func (c *Client) DeployChart(filename, deployname string) error {
return nil return nil
} }
// DefaultServerURL converts a host, host:port, or URL string to the default base server API path
// to use with a Client
func DefaultServerURL(host string) (*url.URL, error) {
if host == "" {
return nil, fmt.Errorf("host must be a URL or a host:port pair")
}
base := host
hostURL, err := url.Parse(base)
if err != nil {
return nil, err
}
if hostURL.Scheme == "" {
hostURL, err = url.Parse(DefaultHTTPProtocol + "://" + base)
if err != nil {
return nil, err
}
}
if len(hostURL.Path) > 0 && !strings.HasSuffix(hostURL.Path, "/") {
hostURL.Path = hostURL.Path + "/"
}
return hostURL, nil
}
package dm
import (
"testing"
)
func TestDefaultServerURL(t *testing.T) {
tt := []struct {
host string
url string
}{
{"127.0.0.1", "http://127.0.0.1"},
{"127.0.0.1:8080", "http://127.0.0.1:8080"},
{"foo.bar.com", "http://foo.bar.com"},
{"foo.bar.com/prefix", "http://foo.bar.com/prefix/"},
{"http://host/prefix", "http://host/prefix/"},
{"https://host/prefix", "https://host/prefix/"},
{"http://host", "http://host"},
{"http://host/other", "http://host/other/"},
}
for _, tc := range tt {
u, err := DefaultServerURL(tc.host)
if err != nil {
t.Fatal(err)
}
if tc.url != u.String() {
t.Errorf("%s, expected host %s, got %s", tc.host, tc.url, u.String())
}
}
}
func TestURL(t *testing.T) {
tt := []struct {
host string
path string
url string
}{
{"127.0.0.1", "foo", "http://127.0.0.1/foo"},
{"127.0.0.1:8080", "foo", "http://127.0.0.1:8080/foo"},
{"foo.bar.com", "foo", "http://foo.bar.com/foo"},
{"foo.bar.com/prefix", "foo", "http://foo.bar.com/prefix/foo"},
{"http://host/prefix", "foo", "http://host/prefix/foo"},
{"http://host", "foo", "http://host/foo"},
{"http://host/other", "/foo", "http://host/foo"},
}
for _, tc := range tt {
c := NewClient(tc.host)
p, err := c.url(tc.path)
if err != nil {
t.Fatal(err)
}
if tc.url != p {
t.Errorf("expected %s, got %s", tc.url, p)
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment