feat: add oauth middleware

This commit is contained in:
2024-11-07 09:45:09 +01:00
parent 59c0e59529
commit 946c40fda0
9 changed files with 246 additions and 4 deletions

2
go.mod
View File

@@ -20,8 +20,10 @@ require (
github.com/go-redis/redis v6.15.9+incompatible // indirect
github.com/go-redis/redis_rate v6.5.0+incompatible // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/copier v0.4.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.7.0 // indirect

6
go.sum
View File

@@ -17,8 +17,12 @@ github.com/jedib0t/go-pretty v4.3.0+incompatible h1:CGs8AVhEKg/n9YbUenWmNStRW2PH
github.com/jedib0t/go-pretty v4.3.0+incompatible/go.mod h1:XemHduiw8R651AF9Pt4FwCTKeG3oo7hrHJAoznj9nag=
github.com/jedib0t/go-pretty/v6 v6.6.1 h1:iJ65Xjb680rHcikRj6DSIbzCex2huitmc7bDtxYVWyc=
github.com/jedib0t/go-pretty/v6 v6.6.1/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
@@ -31,6 +35,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

View File

@@ -17,9 +17,11 @@ limitations under the License.
*/
import (
"fmt"
"github.com/jkaninda/goma-gateway/internal/middleware"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
"gopkg.in/yaml.v3"
"os"
)
@@ -175,6 +177,25 @@ func initConfig(configFile string) {
"/actuator/*",
},
},
{
Name: "oauth",
Type: OAuth,
Paths: []string{
"/protected",
"/example-of-oauth",
},
Rule: OauthRulerMiddleware{
ClientID: "",
ClientSecret: "",
RedirectURL: "",
Scopes: []string{"user"},
Endpoint: OauthEndpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
},
State: "randomStateString",
},
},
},
}
yamlData, err := yaml.Marshal(&conf)
@@ -250,3 +271,51 @@ func getBasicAuthMiddleware(input interface{}) (BasicRuleMiddleware, error) {
}
return *basicAuth, nil
}
// oAuthMiddleware returns OauthRulerMiddleware, error
func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) {
oauthRuler := new(OauthRulerMiddleware)
var bytes []byte
bytes, err := yaml.Marshal(input)
if err != nil {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
err = yaml.Unmarshal(bytes, oauthRuler)
if err != nil {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
if oauthRuler.ClientID == "" || oauthRuler.ClientSecret == "" || oauthRuler.RedirectURL == "" {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middleware", oauthRuler)
}
return *oauthRuler, nil
}
func oauth2Config(oauth OauthRulerMiddleware) *oauth2.Config {
return &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
},
}
}
func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware {
return &OauthRulerMiddleware{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
State: oauth.State,
Scopes: oauth.Scopes,
Endpoint: OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
},
}
}

View File

@@ -16,6 +16,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"context"
"encoding/json"
"github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/pkg/logger"
@@ -130,3 +131,30 @@ func allowedOrigin(origins []string, origin string) bool {
return false
}
func (oauth OauthRulerMiddleware) callbackHandler(w http.ResponseWriter, r *http.Request) {
oauthConfig := oauth2Config(oauth)
logger.Info("URL State: %s", r.URL.Query().Get("state"))
// Verify the state to protect against CSRF
if r.URL.Query().Get("state") != oauth.State {
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
// Exchange the authorization code for an access token
code := r.URL.Query().Get("code")
token, err := oauthConfig.Exchange(context.Background(), code)
if err != nil {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
return
}
// Save token to a cookie for simplicity
http.SetCookie(w, &http.Cookie{
Name: "oauth-token",
Value: token.AccessToken,
Path: oauth.CookiePath,
})
// Redirect to the home page or another protected route
http.Redirect(w, r, oauth.RedirectPath, http.StatusSeeOther)
}

View File

@@ -0,0 +1,56 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package middleware
import (
"github.com/jkaninda/goma-gateway/pkg/logger"
"golang.org/x/oauth2"
"net/http"
)
func oauth2Config(oauth Oauth) *oauth2.Config {
return &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
},
}
}
func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("%s: %s Oauth", getRealIP(r), r.URL.Path)
oauthConfig := oauth2Config(oauth)
// Check if the user is authenticated
_, err := r.Cookie("oauth-token")
if err != nil {
// If no token, redirect to OAuth provider
url := oauthConfig.AuthCodeURL(oauth.State)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return
}
//TODO: Check if the token stored in the cookie is valid
// Token exists, proceed with request
next.ServeHTTP(w, r)
})
}

View File

@@ -107,3 +107,24 @@ type responseRecorder struct {
statusCode int
body *bytes.Buffer
}
type Oauth struct {
// ClientID is the application's ID.
ClientID string
// ClientSecret is the application's secret.
ClientSecret string
// Endpoint contains the resource server's token endpoint
Endpoint OauthEndpoint
// RedirectURL is the URL to redirect users going through
// the OAuth flow, after the resource owner's URLs.
RedirectURL string
// Scope specifies optional requested permissions.
Scopes []string
// contains filtered or unexported fields
State string
Origins []string
}
type OauthEndpoint struct {
AuthURL string
TokenURL string
DeviceAuthURL string
}

View File

@@ -88,6 +88,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
cors: route.Cors,
}
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter()
//callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter()
//Check Authentication middleware
switch rMiddleware.Type {
case BasicAuth:
@@ -126,9 +127,40 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
case "OAuth":
logger.Error("OAuth is not yet implemented")
logger.Info("Auth middleware ignored")
case OAuth, "openid":
oauth, err := oAuthMiddleware(rMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
amw := middleware.Oauth{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL + route.Path,
Scopes: oauth.Scopes,
Endpoint: middleware.OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
},
State: oauth.State,
Origins: gateway.Cors.Origins,
}
oauthRuler := oauthRulerMiddleware(amw)
// Check if a cookie path is defined
if oauthRuler.CookiePath == "" {
oauthRuler.CookiePath = route.Path
}
// Check if a RedirectPath is defined
if oauthRuler.RedirectPath == "" {
oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, midPath)
}
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
// Callback route
r.HandleFunc("/callback"+route.Path, oauthRuler.callbackHandler).Methods("GET")
}
default:
if !doesExist(rMiddleware.Type) {
logger.Error("Unknown middleware type %s", rMiddleware.Type)

View File

@@ -72,6 +72,34 @@ type JWTRuleMiddleware struct {
//e.g: Header X-Auth-UserId to query userId
Params map[string]string `yaml:"params"`
}
type OauthRulerMiddleware struct {
// ClientID is the application's ID.
ClientID string `yaml:"clientId"`
// ClientSecret is the application's secret.
ClientSecret string `yaml:"clientSecret"`
// Endpoint contains the resource server's token endpoint
Endpoint OauthEndpoint `yaml:"endpoint"`
// RedirectURL is the URL to redirect users going through
// the OAuth flow, after the resource owner's URLs.
RedirectURL string `yaml:"redirectUrl"`
// RedirectPath is the PATH to redirect users after authentication, e.g: /my-protected-path/dashboard
RedirectPath string `yaml:"redirectPath"`
//CookiePath e.g: /my-protected-path or / || by default is applied on a route path
CookiePath string `yaml:"cookiePath"`
// Scope specifies optional requested permissions.
Scopes []string `yaml:"scopes"`
// contains filtered or unexported fields
State string `yaml:"state"`
}
type OauthEndpoint struct {
AuthURL string `yaml:"authUrl"`
TokenURL string `yaml:"tokenUrl"`
DeviceAuthURL string `yaml:"deviceAuthUrl"`
}
type RateLimiter struct {
// ipBased, tokenBased
Type string `yaml:"type"`

View File

@@ -7,4 +7,4 @@ const gatewayName = "Goma Gateway"
const AccessMiddleware = "access" // access middleware
const BasicAuth = "basic" // basic authentication middleware
const JWTAuth = "jwt" // JWT authentication middleware
const OAuth = "OAuth" // OAuth authentication middleware
const OAuth = "oauth" // OAuth authentication middleware