@ -2,10 +2,13 @@ package security
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
g8 "github.com/TwiN/g8/v2"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||
"github.com/gofiber/fiber/v2/middleware/basicauth"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
@ -29,20 +32,20 @@ func (c *Config) IsValid() bool {
|
||||
}
|
||||
|
||||
// RegisterHandlers registers all handlers required based on the security configuration
|
||||
func (c *Config) RegisterHandlers(router *mux.Router) error {
|
||||
func (c *Config) RegisterHandlers(router fiber.Router) error {
|
||||
if c.OIDC != nil {
|
||||
if err := c.OIDC.initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
router.HandleFunc("/oidc/login", c.OIDC.loginHandler)
|
||||
router.HandleFunc("/authorization-code/callback", c.OIDC.callbackHandler)
|
||||
router.All("/oidc/login", c.OIDC.loginHandler)
|
||||
router.All("/authorization-code/callback", adaptor.HTTPHandlerFunc(c.OIDC.callbackHandler))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplySecurityMiddleware applies an authentication middleware to the router passed.
|
||||
// The router passed should be a subrouter in charge of handlers that require authentication.
|
||||
func (c *Config) ApplySecurityMiddleware(api *mux.Router) error {
|
||||
// The router passed should be a sub-router in charge of handlers that require authentication.
|
||||
func (c *Config) ApplySecurityMiddleware(router fiber.Router) error {
|
||||
if c.OIDC != nil {
|
||||
// We're going to use g8 for session handling
|
||||
clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
|
||||
@ -61,7 +64,7 @@ func (c *Config) ApplySecurityMiddleware(api *mux.Router) error {
|
||||
// TODO: g8: Add a way to update cookie after? would need the writer
|
||||
authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider)
|
||||
c.gate = g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
|
||||
api.Use(c.gate.Protect)
|
||||
router.Use(adaptor.HTTPMiddleware(c.gate.Protect))
|
||||
} else if c.Basic != nil {
|
||||
var decodedBcryptHash []byte
|
||||
if len(c.Basic.PasswordBcryptHashBase64Encoded) > 0 {
|
||||
@ -71,29 +74,35 @@ func (c *Config) ApplySecurityMiddleware(api *mux.Router) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
api.Use(func(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
usernameEntered, passwordEntered, ok := r.BasicAuth()
|
||||
router.Use(basicauth.New(basicauth.Config{
|
||||
Authorizer: func(username, password string) bool {
|
||||
if len(c.Basic.PasswordBcryptHashBase64Encoded) > 0 {
|
||||
if !ok || usernameEntered != c.Basic.Username || bcrypt.CompareHashAndPassword(decodedBcryptHash, []byte(passwordEntered)) != nil {
|
||||
w.Header().Set("WWW-Authenticate", "Basic")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("Unauthorized"))
|
||||
return
|
||||
if username != c.Basic.Username || bcrypt.CompareHashAndPassword(decodedBcryptHash, []byte(password)) != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
return true
|
||||
},
|
||||
Unauthorized: func(ctx *fiber.Ctx) error {
|
||||
ctx.Set("WWW-Authenticate", "Basic")
|
||||
return ctx.Status(401).SendString("Unauthorized")
|
||||
},
|
||||
}))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAuthenticated checks whether the user is authenticated
|
||||
// If the Config does not warrant authentication, it will always return true.
|
||||
func (c *Config) IsAuthenticated(r *http.Request) bool {
|
||||
func (c *Config) IsAuthenticated(ctx *fiber.Ctx) bool {
|
||||
if c.gate != nil {
|
||||
token := c.gate.ExtractTokenFromRequest(r)
|
||||
// TODO: Update g8 to support fasthttp natively? (see g8's fasthttp branch)
|
||||
request, err := adaptor.ConvertRequest(ctx, false)
|
||||
if err != nil {
|
||||
log.Printf("[IsAuthenticated] Unexpected error converting request: %v", err)
|
||||
return false
|
||||
}
|
||||
token := c.gate.ExtractTokenFromRequest(request)
|
||||
_, hasSession := sessions.Get(token)
|
||||
return hasSession
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -23,83 +23,96 @@ func TestConfig_ApplySecurityMiddleware(t *testing.T) {
|
||||
///////////
|
||||
// BASIC //
|
||||
///////////
|
||||
// Bcrypt
|
||||
c := &Config{Basic: &BasicConfig{
|
||||
Username: "john.doe",
|
||||
PasswordBcryptHashBase64Encoded: "JDJhJDA4JDFoRnpPY1hnaFl1OC9ISlFsa21VS09wOGlPU1ZOTDlHZG1qeTFvb3dIckRBUnlHUmNIRWlT",
|
||||
}}
|
||||
api := mux.NewRouter()
|
||||
api.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
t.Run("basic", func(t *testing.T) {
|
||||
// Bcrypt
|
||||
c := &Config{Basic: &BasicConfig{
|
||||
Username: "john.doe",
|
||||
PasswordBcryptHashBase64Encoded: "JDJhJDA4JDFoRnpPY1hnaFl1OC9ISlFsa21VS09wOGlPU1ZOTDlHZG1qeTFvb3dIckRBUnlHUmNIRWlT",
|
||||
}}
|
||||
app := fiber.New()
|
||||
if err := c.ApplySecurityMiddleware(app); err != nil {
|
||||
t.Error("expected no error, got", err)
|
||||
}
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
// Try to access the route without basic auth
|
||||
request := httptest.NewRequest("GET", "/test", http.NoBody)
|
||||
response, err := app.Test(request)
|
||||
if err != nil {
|
||||
t.Fatal("expected no error, got", err)
|
||||
}
|
||||
if response.StatusCode != 401 {
|
||||
t.Error("expected code to be 401, but was", response.StatusCode)
|
||||
}
|
||||
// Try again, but with basic auth
|
||||
request = httptest.NewRequest("GET", "/test", http.NoBody)
|
||||
request.SetBasicAuth("john.doe", "hunter2")
|
||||
response, err = app.Test(request)
|
||||
if err != nil {
|
||||
t.Fatal("expected no error, got", err)
|
||||
}
|
||||
if response.StatusCode != 200 {
|
||||
t.Error("expected code to be 200, but was", response.StatusCode)
|
||||
}
|
||||
})
|
||||
if err := c.ApplySecurityMiddleware(api); err != nil {
|
||||
t.Error("expected no error, but was", err)
|
||||
}
|
||||
// Try to access the route without basic auth
|
||||
request, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
api.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Error("expected code to be 401, but was", responseRecorder.Code)
|
||||
}
|
||||
// Try again, but with basic auth
|
||||
request, _ = http.NewRequest("GET", "/test", http.NoBody)
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
request.SetBasicAuth("john.doe", "hunter2")
|
||||
api.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Error("expected code to be 200, but was", responseRecorder.Code)
|
||||
}
|
||||
//////////
|
||||
// OIDC //
|
||||
//////////
|
||||
api = mux.NewRouter()
|
||||
api.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
t.Run("oidc", func(t *testing.T) {
|
||||
c := &Config{OIDC: &OIDCConfig{
|
||||
IssuerURL: "https://sso.gatus.io/",
|
||||
RedirectURL: "http://localhost:80/authorization-code/callback",
|
||||
Scopes: []string{"openid"},
|
||||
AllowedSubjects: []string{"user1@example.com"},
|
||||
oauth2Config: oauth2.Config{},
|
||||
verifier: nil,
|
||||
}}
|
||||
app := fiber.New()
|
||||
if err := c.ApplySecurityMiddleware(app); err != nil {
|
||||
t.Error("expected no error, got", err)
|
||||
}
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
// Try without any session cookie
|
||||
request := httptest.NewRequest("GET", "/test", http.NoBody)
|
||||
response, err := app.Test(request)
|
||||
if err != nil {
|
||||
t.Fatal("expected no error, got", err)
|
||||
}
|
||||
if response.StatusCode != 401 {
|
||||
t.Error("expected code to be 401, but was", response.StatusCode)
|
||||
}
|
||||
// Try with a session cookie
|
||||
request = httptest.NewRequest("GET", "/test", http.NoBody)
|
||||
request.AddCookie(&http.Cookie{Name: "session", Value: "123"})
|
||||
response, err = app.Test(request)
|
||||
if err != nil {
|
||||
t.Fatal("expected no error, got", err)
|
||||
}
|
||||
if response.StatusCode != 401 {
|
||||
t.Error("expected code to be 401, but was", response.StatusCode)
|
||||
}
|
||||
})
|
||||
c.OIDC = &OIDCConfig{
|
||||
IssuerURL: "https://sso.gatus.io/",
|
||||
RedirectURL: "http://localhost:80/authorization-code/callback",
|
||||
Scopes: []string{"openid"},
|
||||
AllowedSubjects: []string{"user1@example.com"},
|
||||
oauth2Config: oauth2.Config{},
|
||||
verifier: nil,
|
||||
}
|
||||
c.Basic = nil
|
||||
if err := c.ApplySecurityMiddleware(api); err != nil {
|
||||
t.Error("expected no error, but was", err)
|
||||
}
|
||||
// Try without any session cookie
|
||||
request, _ = http.NewRequest("GET", "/test", http.NoBody)
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
api.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Error("expected code to be 401, but was", responseRecorder.Code)
|
||||
}
|
||||
// Try with a session cookie
|
||||
request, _ = http.NewRequest("GET", "/test", http.NoBody)
|
||||
request.AddCookie(&http.Cookie{Name: "session", Value: "123"})
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
api.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Error("expected code to be 401, but was", responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_RegisterHandlers(t *testing.T) {
|
||||
c := &Config{}
|
||||
router := mux.NewRouter()
|
||||
c.RegisterHandlers(router)
|
||||
app := fiber.New()
|
||||
c.RegisterHandlers(app)
|
||||
// Try to access the OIDC handler. This should fail, because the security config doesn't have OIDC
|
||||
request, _ := http.NewRequest("GET", "/oidc/login", http.NoBody)
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusNotFound {
|
||||
t.Error("expected code to be 404, but was", responseRecorder.Code)
|
||||
request := httptest.NewRequest("GET", "/oidc/login", http.NoBody)
|
||||
response, err := app.Test(request)
|
||||
if err != nil {
|
||||
t.Fatal("expected no error, got", err)
|
||||
}
|
||||
if response.StatusCode != 404 {
|
||||
t.Error("expected code to be 404, but was", response.StatusCode)
|
||||
}
|
||||
// Set an empty OIDC config. This should fail, because the IssuerURL is required.
|
||||
c.OIDC = &OIDCConfig{}
|
||||
if err := c.RegisterHandlers(router); err == nil {
|
||||
if err := c.RegisterHandlers(app); err == nil {
|
||||
t.Fatal("expected an error, but got none")
|
||||
}
|
||||
// Set the OIDC config and try again
|
||||
@ -109,13 +122,15 @@ func TestConfig_RegisterHandlers(t *testing.T) {
|
||||
Scopes: []string{"openid"},
|
||||
AllowedSubjects: []string{"user1@example.com"},
|
||||
}
|
||||
if err := c.RegisterHandlers(router); err != nil {
|
||||
if err := c.RegisterHandlers(app); err != nil {
|
||||
t.Fatal("expected no error, but got", err)
|
||||
}
|
||||
request, _ = http.NewRequest("GET", "/oidc/login", http.NoBody)
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusFound {
|
||||
t.Error("expected code to be 302, but was", responseRecorder.Code)
|
||||
request = httptest.NewRequest("GET", "/oidc/login", http.NoBody)
|
||||
response, err = app.Test(request)
|
||||
if err != nil {
|
||||
t.Fatal("expected no error, got", err)
|
||||
}
|
||||
if response.StatusCode != 302 {
|
||||
t.Error("expected code to be 302, but was", response.StatusCode)
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
@ -47,28 +48,28 @@ func (c *OIDCConfig) initialize() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCConfig) loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (c *OIDCConfig) loginHandler(ctx *fiber.Ctx) error {
|
||||
state, nonce := uuid.NewString(), uuid.NewString()
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
ctx.Cookie(&fiber.Cookie{
|
||||
Name: cookieNameState,
|
||||
Value: state,
|
||||
Path: "/",
|
||||
MaxAge: int(time.Hour.Seconds()),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
HttpOnly: true,
|
||||
SameSite: "lax",
|
||||
HTTPOnly: true,
|
||||
})
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
ctx.Cookie(&fiber.Cookie{
|
||||
Name: cookieNameNonce,
|
||||
Value: nonce,
|
||||
Path: "/",
|
||||
MaxAge: int(time.Hour.Seconds()),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
HttpOnly: true,
|
||||
SameSite: "lax",
|
||||
HTTPOnly: true,
|
||||
})
|
||||
http.Redirect(w, r, c.oauth2Config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
|
||||
return ctx.Redirect(c.oauth2Config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
|
||||
}
|
||||
|
||||
func (c *OIDCConfig) callbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (c *OIDCConfig) callbackHandler(w http.ResponseWriter, r *http.Request) { // TODO: Migrate to a native fiber handler
|
||||
// Check if there's an error
|
||||
if len(r.URL.Query().Get("error")) > 0 {
|
||||
http.Error(w, r.URL.Query().Get("error")+": "+r.URL.Query().Get("error_description"), http.StatusBadRequest)
|
||||
|
Reference in New Issue
Block a user