feat(client): Added client configuration option for using a custom DNS resolver (#284)

This commit is contained in:
Andre Bindewald
2022-06-13 00:45:08 +02:00
committed by GitHub
parent f23fcbedb8
commit 2cbb35fe3b
3 changed files with 128 additions and 1 deletions

View File

@ -4,7 +4,10 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"regexp"
"time"
"golang.org/x/oauth2"
@ -16,6 +19,7 @@ const (
)
var (
ErrInvalidDNSResolver = errors.New("invalid DNS resolver specified. Required format is {proto}://{ip}:{port}")
ErrInvalidClientOAuth2Config = errors.New("invalid OAuth2 configuration, all fields are required")
defaultConfig = Config{
@ -42,6 +46,10 @@ type Config struct {
// Timeout for the client
Timeout time.Duration `yaml:"timeout"`
// DNSResolver override for the HTTPClient
// Expected format is {protocol}://{host}:{port}
DNSResolver string `yaml:"dns-resolver,omitempty"`
// OAuth2Config is the OAuth2 configuration used for the client.
//
// If non-nil, the http.Client returned by getHTTPClient will automatically retrieve a token if necessary.
@ -51,6 +59,13 @@ type Config struct {
httpClient *http.Client
}
// DNSResolverConfig is the parsed configuration from the DNSResolver config string.
type DNSResolverConfig struct {
Protocol string
Host string
Port string
}
// OAuth2Config is the configuration for the OAuth2 client credentials flow
type OAuth2Config struct {
TokenURL string `yaml:"token-url"` // e.g. https://dev-12345678.okta.com/token
@ -64,12 +79,44 @@ func (c *Config) ValidateAndSetDefaults() error {
if c.Timeout < time.Millisecond {
c.Timeout = 10 * time.Second
}
if c.HasCustomDNSResolver() {
_, err := c.ParseDNSResolver()
if err != nil {
return ErrInvalidDNSResolver
}
}
if c.HasOAuth2Config() && !c.OAuth2Config.isValid() {
return ErrInvalidClientOAuth2Config
}
return nil
}
// Returns true if the DNSResolver is set in the configuration
func (c *Config) HasCustomDNSResolver() bool {
return len(c.DNSResolver) > 0
}
// Parses the DNSResolver configuration string into the DNSResolverConfig struct
func (c *Config) ParseDNSResolver() (DNSResolverConfig, error) {
re := regexp.MustCompile(`^(?P<proto>(.*))://(?P<host>[A-Za-z0-9\-\.]+):(?P<port>[0-9]+)?(.*)$`)
matches := re.FindStringSubmatch(c.DNSResolver)
if len(matches) == 0 {
return DNSResolverConfig{}, errors.New("ParseError")
}
r := make(map[string]string)
for i, k := range re.SubexpNames() {
if i != 0 && k != "" {
r[k] = matches[i]
}
}
return DNSResolverConfig{
Protocol: r["proto"],
Host: r["host"],
Port: r["port"],
}, nil
}
// HasOAuth2Config returns true if the client has OAuth2 configuration parameters
func (c *Config) HasOAuth2Config() bool {
return c.OAuth2Config != nil
@ -102,6 +149,22 @@ func (c *Config) getHTTPClient() *http.Client {
return nil
},
}
if c.HasCustomDNSResolver() {
dnsResolver, _ := c.ParseDNSResolver()
dialer := &net.Dialer{
Resolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, dnsResolver.Protocol, fmt.Sprintf("%s:%s", dnsResolver.Host, dnsResolver.Port))
},
},
}
dialCtx := func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, addr)
}
c.httpClient.Transport.(*http.Transport).DialContext = dialCtx
}
if c.HasOAuth2Config() {
c.httpClient = configureOAuth2(c.httpClient, *c.OAuth2Config)
}