refactor(client): Clean up client dns resolver
This commit is contained in:
parent
fea95b8479
commit
326ea1c3d1
@ -56,52 +56,6 @@ func TestPing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSResolverConfig(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
dnsResolver string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid resolver",
|
|
||||||
args: args{
|
|
||||||
dnsResolver: "tcp://1.1.1.1:53",
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid resolver port",
|
|
||||||
args: args{
|
|
||||||
dnsResolver: "tcp://127.0.0.1:99999",
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid resolver format",
|
|
||||||
args: args{
|
|
||||||
dnsResolver: "foobar",
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := &Config{
|
|
||||||
DNSResolver: tt.args.dnsResolver,
|
|
||||||
}
|
|
||||||
client := GetHTTPClient(cfg)
|
|
||||||
_, err := client.Get("https://example.org")
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("TestDNSResolverConfig err=%v, wantErr=%v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCanPerformStartTLS(t *testing.T) {
|
func TestCanPerformStartTLS(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
address string
|
address string
|
||||||
@ -221,7 +175,6 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mockHttpClient := &http.Client{
|
mockHttpClient := &http.Client{
|
||||||
Transport: test.MockRoundTripper(func(r *http.Request) *http.Response {
|
Transport: test.MockRoundTripper(func(r *http.Request) *http.Response {
|
||||||
|
|
||||||
// if the mock HTTP client tries to get a token from the `token-server`
|
// if the mock HTTP client tries to get a token from the `token-server`
|
||||||
// we provide the expected token response
|
// we provide the expected token response
|
||||||
if r.Host == "token-server.local" {
|
if r.Host == "token-server.local" {
|
||||||
|
@ -4,10 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@ -20,6 +21,7 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidDNSResolver = errors.New("invalid DNS resolver specified. Required format is {proto}://{ip}:{port}")
|
ErrInvalidDNSResolver = errors.New("invalid DNS resolver specified. Required format is {proto}://{ip}:{port}")
|
||||||
|
ErrInvalidDNSResolverPort = errors.New("invalid DNS resolver port")
|
||||||
ErrInvalidClientOAuth2Config = errors.New("invalid OAuth2 configuration, all fields are required")
|
ErrInvalidClientOAuth2Config = errors.New("invalid OAuth2 configuration, all fields are required")
|
||||||
|
|
||||||
defaultConfig = Config{
|
defaultConfig = Config{
|
||||||
@ -46,8 +48,8 @@ type Config struct {
|
|||||||
// Timeout for the client
|
// Timeout for the client
|
||||||
Timeout time.Duration `yaml:"timeout"`
|
Timeout time.Duration `yaml:"timeout"`
|
||||||
|
|
||||||
// DNSResolver override for the HTTPClient
|
// DNSResolver override for the HTTP client
|
||||||
// Expected format is {protocol}://{host}:{port}
|
// Expected format is {protocol}://{host}:{port}, e.g. tcp://1.1.1.1:53
|
||||||
DNSResolver string `yaml:"dns-resolver,omitempty"`
|
DNSResolver string `yaml:"dns-resolver,omitempty"`
|
||||||
|
|
||||||
// OAuth2Config is the OAuth2 configuration used for the client.
|
// OAuth2Config is the OAuth2 configuration used for the client.
|
||||||
@ -80,9 +82,9 @@ func (c *Config) ValidateAndSetDefaults() error {
|
|||||||
c.Timeout = 10 * time.Second
|
c.Timeout = 10 * time.Second
|
||||||
}
|
}
|
||||||
if c.HasCustomDNSResolver() {
|
if c.HasCustomDNSResolver() {
|
||||||
_, err := c.ParseDNSResolver()
|
// Validate the DNS resolver now to make sure it will not return an error later.
|
||||||
if err != nil {
|
if _, err := c.parseDNSResolver(); err != nil {
|
||||||
return ErrInvalidDNSResolver
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if c.HasOAuth2Config() && !c.OAuth2Config.isValid() {
|
if c.HasOAuth2Config() && !c.OAuth2Config.isValid() {
|
||||||
@ -91,17 +93,17 @@ func (c *Config) ValidateAndSetDefaults() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the DNSResolver is set in the configuration
|
// HasCustomDNSResolver returns whether a custom DNSResolver is configured
|
||||||
func (c *Config) HasCustomDNSResolver() bool {
|
func (c *Config) HasCustomDNSResolver() bool {
|
||||||
return len(c.DNSResolver) > 0
|
return len(c.DNSResolver) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parses the DNSResolver configuration string into the DNSResolverConfig struct
|
// parseDNSResolver parses the DNS resolver into the DNSResolverConfig struct
|
||||||
func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) {
|
func (c *Config) parseDNSResolver() (*DNSResolverConfig, error) {
|
||||||
re := regexp.MustCompile(`^(?P<proto>(.*))://(?P<host>[A-Za-z0-9\-\.]+):(?P<port>[0-9]+)?(.*)$`)
|
re := regexp.MustCompile(`^(?P<proto>(.*))://(?P<host>[A-Za-z0-9\-\.]+):(?P<port>[0-9]+)?(.*)$`)
|
||||||
matches := re.FindStringSubmatch(c.DNSResolver)
|
matches := re.FindStringSubmatch(c.DNSResolver)
|
||||||
if len(matches) == 0 {
|
if len(matches) == 0 {
|
||||||
return DNSResolverConfig{}, errors.New("ParseError")
|
return nil, ErrInvalidDNSResolver
|
||||||
}
|
}
|
||||||
r := make(map[string]string)
|
r := make(map[string]string)
|
||||||
for i, k := range re.SubexpNames() {
|
for i, k := range re.SubexpNames() {
|
||||||
@ -109,8 +111,14 @@ func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) {
|
|||||||
r[k] = matches[i]
|
r[k] = matches[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
port, err := strconv.Atoi(r["port"])
|
||||||
return DNSResolverConfig{
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if port < 1 || port > 65535 {
|
||||||
|
return nil, ErrInvalidDNSResolverPort
|
||||||
|
}
|
||||||
|
return &DNSResolverConfig{
|
||||||
Protocol: r["proto"],
|
Protocol: r["proto"],
|
||||||
Host: r["host"],
|
Host: r["host"],
|
||||||
Port: r["port"],
|
Port: r["port"],
|
||||||
@ -150,20 +158,25 @@ func (c *Config) getHTTPClient() *http.Client {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
if c.HasCustomDNSResolver() {
|
if c.HasCustomDNSResolver() {
|
||||||
dnsResolver, _ := c.ParseDNSResolver()
|
dnsResolver, err := c.parseDNSResolver()
|
||||||
|
if err != nil {
|
||||||
|
// We're ignoring the error, because it should have been validated on startup ValidateAndSetDefaults.
|
||||||
|
// It shouldn't happen, but if it does, we'll log it... Better safe than sorry ;)
|
||||||
|
log.Println("[client][getHTTPClient] THIS SHOULD NOT HAPPEN. Silently ignoring invalid DNS resolver due to error:", err.Error())
|
||||||
|
} else {
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Resolver: &net.Resolver{
|
Resolver: &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
d := net.Dialer{}
|
d := net.Dialer{}
|
||||||
return d.DialContext(ctx, dnsResolver.Protocol, fmt.Sprintf("%s:%s", dnsResolver.Host, dnsResolver.Port))
|
return d.DialContext(ctx, dnsResolver.Protocol, dnsResolver.Host+":"+dnsResolver.Port)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dialCtx := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
c.httpClient.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
return dialer.DialContext(ctx, network, addr)
|
return dialer.DialContext(ctx, network, addr)
|
||||||
}
|
}
|
||||||
c.httpClient.Transport.(*http.Transport).DialContext = dialCtx
|
}
|
||||||
}
|
}
|
||||||
if c.HasOAuth2Config() {
|
if c.HasOAuth2Config() {
|
||||||
c.httpClient = configureOAuth2(c.httpClient, *c.OAuth2Config)
|
c.httpClient = configureOAuth2(c.httpClient, *c.OAuth2Config)
|
||||||
|
@ -35,3 +35,47 @@ func TestConfig_getHTTPClient(t *testing.T) {
|
|||||||
t.Error("expected Config.IgnoreRedirect set to true to cause the HTTP client's CheckRedirect to return http.ErrUseLastResponse")
|
t.Error("expected Config.IgnoreRedirect set to true to cause the HTTP client's CheckRedirect to return http.ErrUseLastResponse")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_ValidateAndSetDefaults_withCustomDNSResolver(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
dnsResolver string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with-valid-resolver",
|
||||||
|
args: args{
|
||||||
|
dnsResolver: "tcp://1.1.1.1:53",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with-invalid-resolver-port",
|
||||||
|
args: args{
|
||||||
|
dnsResolver: "tcp://127.0.0.1:99999",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with-invalid-resolver-format",
|
||||||
|
args: args{
|
||||||
|
dnsResolver: "foobar",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
DNSResolver: tt.args.dnsResolver,
|
||||||
|
}
|
||||||
|
err := cfg.ValidateAndSetDefaults()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ValidateAndSetDefaults() error=%v, wantErr=%v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user