refactor(client): Clean up client dns resolver

This commit is contained in:
TwiN 2022-06-13 19:16:34 -04:00
parent fea95b8479
commit 326ea1c3d1
3 changed files with 81 additions and 71 deletions

View File

@ -56,52 +56,6 @@ func TestPing(t *testing.T) {
} }
} }
func TestDNSResolverConfig(t *testing.T) {
type args struct {
dnsResolver string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid resolver",
args: args{
dnsResolver: "tcp://1.1.1.1:53",
},
wantErr: false,
},
{
name: "invalid resolver port",
args: args{
dnsResolver: "tcp://127.0.0.1:99999",
},
wantErr: true,
},
{
name: "invalid resolver format",
args: args{
dnsResolver: "foobar",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
DNSResolver: tt.args.dnsResolver,
}
client := GetHTTPClient(cfg)
_, err := client.Get("https://example.org")
if (err != nil) != tt.wantErr {
t.Errorf("TestDNSResolverConfig err=%v, wantErr=%v", err, tt.wantErr)
return
}
})
}
}
func TestCanPerformStartTLS(t *testing.T) { func TestCanPerformStartTLS(t *testing.T) {
type args struct { type args struct {
address string address string
@ -221,7 +175,6 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
} }
mockHttpClient := &http.Client{ mockHttpClient := &http.Client{
Transport: test.MockRoundTripper(func(r *http.Request) *http.Response { Transport: test.MockRoundTripper(func(r *http.Request) *http.Response {
// if the mock HTTP client tries to get a token from the `token-server` // if the mock HTTP client tries to get a token from the `token-server`
// we provide the expected token response // we provide the expected token response
if r.Host == "token-server.local" { if r.Host == "token-server.local" {

View File

@ -4,10 +4,11 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "log"
"net" "net"
"net/http" "net/http"
"regexp" "regexp"
"strconv"
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -20,6 +21,7 @@ const (
var ( var (
ErrInvalidDNSResolver = errors.New("invalid DNS resolver specified. Required format is {proto}://{ip}:{port}") ErrInvalidDNSResolver = errors.New("invalid DNS resolver specified. Required format is {proto}://{ip}:{port}")
ErrInvalidDNSResolverPort = errors.New("invalid DNS resolver port")
ErrInvalidClientOAuth2Config = errors.New("invalid OAuth2 configuration, all fields are required") ErrInvalidClientOAuth2Config = errors.New("invalid OAuth2 configuration, all fields are required")
defaultConfig = Config{ defaultConfig = Config{
@ -46,8 +48,8 @@ type Config struct {
// Timeout for the client // Timeout for the client
Timeout time.Duration `yaml:"timeout"` Timeout time.Duration `yaml:"timeout"`
// DNSResolver override for the HTTPClient // DNSResolver override for the HTTP client
// Expected format is {protocol}://{host}:{port} // Expected format is {protocol}://{host}:{port}, e.g. tcp://1.1.1.1:53
DNSResolver string `yaml:"dns-resolver,omitempty"` DNSResolver string `yaml:"dns-resolver,omitempty"`
// OAuth2Config is the OAuth2 configuration used for the client. // OAuth2Config is the OAuth2 configuration used for the client.
@ -80,9 +82,9 @@ func (c *Config) ValidateAndSetDefaults() error {
c.Timeout = 10 * time.Second c.Timeout = 10 * time.Second
} }
if c.HasCustomDNSResolver() { if c.HasCustomDNSResolver() {
_, err := c.ParseDNSResolver() // Validate the DNS resolver now to make sure it will not return an error later.
if err != nil { if _, err := c.parseDNSResolver(); err != nil {
return ErrInvalidDNSResolver return err
} }
} }
if c.HasOAuth2Config() && !c.OAuth2Config.isValid() { if c.HasOAuth2Config() && !c.OAuth2Config.isValid() {
@ -91,17 +93,17 @@ func (c *Config) ValidateAndSetDefaults() error {
return nil return nil
} }
// Returns true if the DNSResolver is set in the configuration // HasCustomDNSResolver returns whether a custom DNSResolver is configured
func (c *Config) HasCustomDNSResolver() bool { func (c *Config) HasCustomDNSResolver() bool {
return len(c.DNSResolver) > 0 return len(c.DNSResolver) > 0
} }
// Parses the DNSResolver configuration string into the DNSResolverConfig struct // parseDNSResolver parses the DNS resolver into the DNSResolverConfig struct
func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) { func (c *Config) parseDNSResolver() (*DNSResolverConfig, error) {
re := regexp.MustCompile(`^(?P<proto>(.*))://(?P<host>[A-Za-z0-9\-\.]+):(?P<port>[0-9]+)?(.*)$`) re := regexp.MustCompile(`^(?P<proto>(.*))://(?P<host>[A-Za-z0-9\-\.]+):(?P<port>[0-9]+)?(.*)$`)
matches := re.FindStringSubmatch(c.DNSResolver) matches := re.FindStringSubmatch(c.DNSResolver)
if len(matches) == 0 { if len(matches) == 0 {
return DNSResolverConfig{}, errors.New("ParseError") return nil, ErrInvalidDNSResolver
} }
r := make(map[string]string) r := make(map[string]string)
for i, k := range re.SubexpNames() { for i, k := range re.SubexpNames() {
@ -109,8 +111,14 @@ func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) {
r[k] = matches[i] r[k] = matches[i]
} }
} }
port, err := strconv.Atoi(r["port"])
return DNSResolverConfig{ if err != nil {
return nil, err
}
if port < 1 || port > 65535 {
return nil, ErrInvalidDNSResolverPort
}
return &DNSResolverConfig{
Protocol: r["proto"], Protocol: r["proto"],
Host: r["host"], Host: r["host"],
Port: r["port"], Port: r["port"],
@ -150,20 +158,25 @@ func (c *Config) getHTTPClient() *http.Client {
}, },
} }
if c.HasCustomDNSResolver() { if c.HasCustomDNSResolver() {
dnsResolver, _ := c.ParseDNSResolver() dnsResolver, err := c.parseDNSResolver()
if err != nil {
// We're ignoring the error, because it should have been validated on startup ValidateAndSetDefaults.
// It shouldn't happen, but if it does, we'll log it... Better safe than sorry ;)
log.Println("[client][getHTTPClient] THIS SHOULD NOT HAPPEN. Silently ignoring invalid DNS resolver due to error:", err.Error())
} else {
dialer := &net.Dialer{ dialer := &net.Dialer{
Resolver: &net.Resolver{ Resolver: &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{} d := net.Dialer{}
return d.DialContext(ctx, dnsResolver.Protocol, fmt.Sprintf("%s:%s", dnsResolver.Host, dnsResolver.Port)) return d.DialContext(ctx, dnsResolver.Protocol, dnsResolver.Host+":"+dnsResolver.Port)
}, },
}, },
} }
dialCtx := func(ctx context.Context, network, addr string) (net.Conn, error) { c.httpClient.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, addr) return dialer.DialContext(ctx, network, addr)
} }
c.httpClient.Transport.(*http.Transport).DialContext = dialCtx }
} }
if c.HasOAuth2Config() { if c.HasOAuth2Config() {
c.httpClient = configureOAuth2(c.httpClient, *c.OAuth2Config) c.httpClient = configureOAuth2(c.httpClient, *c.OAuth2Config)

View File

@ -35,3 +35,47 @@ func TestConfig_getHTTPClient(t *testing.T) {
t.Error("expected Config.IgnoreRedirect set to true to cause the HTTP client's CheckRedirect to return http.ErrUseLastResponse") t.Error("expected Config.IgnoreRedirect set to true to cause the HTTP client's CheckRedirect to return http.ErrUseLastResponse")
} }
} }
func TestConfig_ValidateAndSetDefaults_withCustomDNSResolver(t *testing.T) {
type args struct {
dnsResolver string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "with-valid-resolver",
args: args{
dnsResolver: "tcp://1.1.1.1:53",
},
wantErr: false,
},
{
name: "with-invalid-resolver-port",
args: args{
dnsResolver: "tcp://127.0.0.1:99999",
},
wantErr: true,
},
{
name: "with-invalid-resolver-format",
args: args{
dnsResolver: "foobar",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
DNSResolver: tt.args.dnsResolver,
}
err := cfg.ValidateAndSetDefaults()
if (err != nil) != tt.wantErr {
t.Errorf("ValidateAndSetDefaults() error=%v, wantErr=%v", err, tt.wantErr)
}
})
}
}