#120: Add support for StartTLS protocol
* add starttls * remove starttls from default config Co-authored-by: Gopher Johns <gopher.johns28@gmail.com>
This commit is contained in:
		| @ -2,10 +2,14 @@ package client | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
|  | 	"crypto/x509" | ||||||
|  | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"net/smtp" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/go-ping/ping" | 	"github.com/go-ping/ping" | ||||||
| @ -74,6 +78,36 @@ func CanCreateTCPConnection(address string) bool { | |||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func CanPerformStartTls(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) { | ||||||
|  | 	tokens := strings.Split(address, ":") | ||||||
|  | 	if len(tokens) != 2 { | ||||||
|  | 		err = fmt.Errorf("invalid address for starttls, must HOST:PORT") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	tlsconfig := &tls.Config{ | ||||||
|  | 		InsecureSkipVerify: insecure, | ||||||
|  | 		ServerName:         tokens[0], | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c, err := smtp.Dial(address) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = c.StartTLS(tlsconfig) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if state, ok := c.TLSConnectionState(); ok { | ||||||
|  | 		certificate = state.PeerCertificates[0] | ||||||
|  | 	} else { | ||||||
|  | 		err = fmt.Errorf("could not get TLS connection state") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	connected = true | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
| // Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged | // Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged | ||||||
| // | // | ||||||
| // Note that this function takes at least 100ms, even if the address is 127.0.0.1 | // Note that this function takes at least 100ms, even if the address is 127.0.0.1 | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package client | package client | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"crypto/x509" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @ -49,3 +50,56 @@ func TestPing(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestCanPerformStartTls(t *testing.T) { | ||||||
|  | 	type args struct { | ||||||
|  | 		address  string | ||||||
|  | 		insecure bool | ||||||
|  | 	} | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name            string | ||||||
|  | 		args            args | ||||||
|  | 		wantConnected   bool | ||||||
|  | 		wantCertificate *x509.Certificate | ||||||
|  | 		wantErr         bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "invalid address", | ||||||
|  | 			args: args{ | ||||||
|  | 				address: "test", | ||||||
|  | 			}, | ||||||
|  | 			wantConnected:   false, | ||||||
|  | 			wantCertificate: nil, | ||||||
|  | 			wantErr:         true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "error dial", | ||||||
|  | 			args: args{ | ||||||
|  | 				address: "test:1234", | ||||||
|  | 			}, | ||||||
|  | 			wantConnected:   false, | ||||||
|  | 			wantCertificate: nil, | ||||||
|  | 			wantErr:         true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "valid starttls", | ||||||
|  | 			args: args{ | ||||||
|  | 				address: "smtp.gmail.com:587", | ||||||
|  | 			}, | ||||||
|  | 			wantConnected: true, | ||||||
|  | 			wantErr:       false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			gotConnected, _, err := CanPerformStartTls(tt.args.address, tt.args.insecure) | ||||||
|  | 			if (err != nil) != tt.wantErr { | ||||||
|  | 				t.Errorf("CanPerformStartTls() error = %v, wantErr %v", err, tt.wantErr) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			if gotConnected != tt.wantConnected { | ||||||
|  | 				t.Errorf("CanPerformStartTls() gotConnected = %v, want %v", gotConnected, tt.wantConnected) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ package core | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"crypto/x509" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| @ -178,10 +179,12 @@ func (service *Service) call(result *Result) { | |||||||
| 	var request *http.Request | 	var request *http.Request | ||||||
| 	var response *http.Response | 	var response *http.Response | ||||||
| 	var err error | 	var err error | ||||||
|  | 	var certificate *x509.Certificate | ||||||
| 	isServiceDNS := service.DNS != nil | 	isServiceDNS := service.DNS != nil | ||||||
| 	isServiceTCP := strings.HasPrefix(service.URL, "tcp://") | 	isServiceTCP := strings.HasPrefix(service.URL, "tcp://") | ||||||
| 	isServiceICMP := strings.HasPrefix(service.URL, "icmp://") | 	isServiceICMP := strings.HasPrefix(service.URL, "icmp://") | ||||||
| 	isServiceHTTP := !isServiceDNS && !isServiceTCP && !isServiceICMP | 	isServiceStartTLS := strings.HasPrefix(service.URL, "starttls://") | ||||||
|  | 	isServiceHTTP := !isServiceDNS && !isServiceTCP && !isServiceICMP && !isServiceStartTLS | ||||||
| 	if isServiceHTTP { | 	if isServiceHTTP { | ||||||
| 		request = service.buildHTTPRequest() | 		request = service.buildHTTPRequest() | ||||||
| 	} | 	} | ||||||
| @ -189,6 +192,14 @@ func (service *Service) call(result *Result) { | |||||||
| 	if isServiceDNS { | 	if isServiceDNS { | ||||||
| 		service.DNS.query(service.URL, result) | 		service.DNS.query(service.URL, result) | ||||||
| 		result.Duration = time.Since(startTime) | 		result.Duration = time.Since(startTime) | ||||||
|  | 	} else if isServiceStartTLS { | ||||||
|  | 		result.Connected, certificate, err = client.CanPerformStartTls(strings.TrimPrefix(service.URL, "starttls://"), service.Insecure) | ||||||
|  | 		if err != nil { | ||||||
|  | 			result.Errors = append(result.Errors, err.Error()) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		result.Duration = time.Since(startTime) | ||||||
|  | 		result.CertificateExpiration = time.Until(certificate.NotAfter) | ||||||
| 	} else if isServiceTCP { | 	} else if isServiceTCP { | ||||||
| 		result.Connected = client.CanCreateTCPConnection(strings.TrimPrefix(service.URL, "tcp://")) | 		result.Connected = client.CanCreateTCPConnection(strings.TrimPrefix(service.URL, "tcp://")) | ||||||
| 		result.Duration = time.Since(startTime) | 		result.Duration = time.Since(startTime) | ||||||
| @ -203,7 +214,7 @@ func (service *Service) call(result *Result) { | |||||||
| 		} | 		} | ||||||
| 		defer response.Body.Close() | 		defer response.Body.Close() | ||||||
| 		if response.TLS != nil && len(response.TLS.PeerCertificates) > 0 { | 		if response.TLS != nil && len(response.TLS.PeerCertificates) > 0 { | ||||||
| 			certificate := response.TLS.PeerCertificates[0] | 			certificate = response.TLS.PeerCertificates[0] | ||||||
| 			result.CertificateExpiration = time.Until(certificate.NotAfter) | 			result.CertificateExpiration = time.Until(certificate.NotAfter) | ||||||
| 		} | 		} | ||||||
| 		result.HTTPStatus = response.StatusCode | 		result.HTTPStatus = response.StatusCode | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user