diff --git a/client/client.go b/client/client.go index 931d5e2e..a07b675a 100644 --- a/client/client.go +++ b/client/client.go @@ -2,10 +2,14 @@ package client import ( "crypto/tls" + "crypto/x509" + "fmt" "net" "net/http" + "net/smtp" "os" "strconv" + "strings" "time" "github.com/go-ping/ping" @@ -74,6 +78,36 @@ func CanCreateTCPConnection(address string) bool { return true } +func CanPerformStartTls(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) { + tokens := strings.Split(address, ":") + if len(tokens) != 2 { + err = fmt.Errorf("invalid address for starttls, must HOST:PORT") + return + } + tlsconfig := &tls.Config{ + InsecureSkipVerify: insecure, + ServerName: tokens[0], + } + + c, err := smtp.Dial(address) + if err != nil { + return + } + + err = c.StartTLS(tlsconfig) + if err != nil { + return + } + if state, ok := c.TLSConnectionState(); ok { + certificate = state.PeerCertificates[0] + } else { + err = fmt.Errorf("could not get TLS connection state") + return + } + connected = true + return +} + // Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged // // Note that this function takes at least 100ms, even if the address is 127.0.0.1 diff --git a/client/client_test.go b/client/client_test.go index d05705ac..ecbc7150 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "crypto/x509" "testing" "time" ) @@ -49,3 +50,56 @@ func TestPing(t *testing.T) { } } } + +func TestCanPerformStartTls(t *testing.T) { + type args struct { + address string + insecure bool + } + tests := []struct { + name string + args args + wantConnected bool + wantCertificate *x509.Certificate + wantErr bool + }{ + { + name: "invalid address", + args: args{ + address: "test", + }, + wantConnected: false, + wantCertificate: nil, + wantErr: true, + }, + { + name: "error dial", + args: args{ + address: "test:1234", + }, + wantConnected: false, + wantCertificate: nil, + wantErr: true, + }, + { + name: "valid starttls", + args: args{ + address: "smtp.gmail.com:587", + }, + wantConnected: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotConnected, _, err := CanPerformStartTls(tt.args.address, tt.args.insecure) + if (err != nil) != tt.wantErr { + t.Errorf("CanPerformStartTls() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotConnected != tt.wantConnected { + t.Errorf("CanPerformStartTls() gotConnected = %v, want %v", gotConnected, tt.wantConnected) + } + }) + } +} diff --git a/core/service.go b/core/service.go index b974ee14..5e7fda20 100644 --- a/core/service.go +++ b/core/service.go @@ -2,6 +2,7 @@ package core import ( "bytes" + "crypto/x509" "encoding/json" "errors" "io/ioutil" @@ -178,10 +179,12 @@ func (service *Service) call(result *Result) { var request *http.Request var response *http.Response var err error + var certificate *x509.Certificate isServiceDNS := service.DNS != nil isServiceTCP := strings.HasPrefix(service.URL, "tcp://") isServiceICMP := strings.HasPrefix(service.URL, "icmp://") - isServiceHTTP := !isServiceDNS && !isServiceTCP && !isServiceICMP + isServiceStartTLS := strings.HasPrefix(service.URL, "starttls://") + isServiceHTTP := !isServiceDNS && !isServiceTCP && !isServiceICMP && !isServiceStartTLS if isServiceHTTP { request = service.buildHTTPRequest() } @@ -189,6 +192,14 @@ func (service *Service) call(result *Result) { if isServiceDNS { service.DNS.query(service.URL, result) result.Duration = time.Since(startTime) + } else if isServiceStartTLS { + result.Connected, certificate, err = client.CanPerformStartTls(strings.TrimPrefix(service.URL, "starttls://"), service.Insecure) + if err != nil { + result.Errors = append(result.Errors, err.Error()) + return + } + result.Duration = time.Since(startTime) + result.CertificateExpiration = time.Until(certificate.NotAfter) } else if isServiceTCP { result.Connected = client.CanCreateTCPConnection(strings.TrimPrefix(service.URL, "tcp://")) result.Duration = time.Since(startTime) @@ -203,7 +214,7 @@ func (service *Service) call(result *Result) { } defer response.Body.Close() if response.TLS != nil && len(response.TLS.PeerCertificates) > 0 { - certificate := response.TLS.PeerCertificates[0] + certificate = response.TLS.PeerCertificates[0] result.CertificateExpiration = time.Until(certificate.NotAfter) } result.HTTPStatus = response.StatusCode