From bd2089530668f85f8a9bdee22a191b7c8fa06e81 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Fri, 8 Nov 2024 12:03:52 +0100 Subject: [PATCH] feat: add oauth token validity verification --- go.mod | 20 ++--- go.sum | 43 ++++------ internal/config.go | 104 ++++++++++++++++++------ internal/handler.go | 34 +++++--- internal/helpers.go | 42 ++++++++++ internal/middleware/config.go | 59 ++++++++++++++ internal/middleware/oauth-middleware.go | 70 +++++++++++----- internal/middleware/types.go | 12 +-- internal/proxy.go | 9 +- internal/route.go | 23 ++++-- internal/types.go | 20 +++-- util/helpers.go | 9 ++ 12 files changed, 326 insertions(+), 119 deletions(-) create mode 100644 internal/middleware/config.go diff --git a/go.mod b/go.mod index b188526..f53fad8 100644 --- a/go.mod +++ b/go.mod @@ -3,28 +3,24 @@ module github.com/jkaninda/goma-gateway go 1.23.2 require ( + github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/gorilla/mux v1.8.1 github.com/spf13/cobra v1.8.1 + golang.org/x/oauth2 v0.24.0 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/jedib0t/go-pretty/v6 v6.6.1 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/rivo/uniseg v0.2.0 // indirect - golang.org/x/sys v0.17.0 // indirect + github.com/jedib0t/go-pretty/v6 v6.6.1 + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + golang.org/x/sys v0.27.0 // indirect ) require ( - github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be // indirect - github.com/go-redis/redis v6.15.9+incompatible // indirect - github.com/go-redis/redis_rate v6.5.0+incompatible // indirect + cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jinzhu/copier v0.4.0 // indirect - github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/pflag v1.0.5 // indirect - golang.org/x/oauth2 v0.23.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/time v0.7.0 // indirect ) diff --git a/go.sum b/go.sum index 15c56be..0721c90 100644 --- a/go.sum +++ b/go.sum @@ -1,53 +1,46 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be h1:J5BL2kskAlV9ckgEsNQXscjIaLiOYiZ75d4e94E6dcQ= github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be/go.mod h1:mk5IQ+Y0ZeO87b858TlA645sVcEcbiX6YqP98kt+7+w= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/getsentry/sentry-go v0.29.1 h1:DyZuChN8Hz3ARxGVV8ePaNXh1dQ7d76AiB117xcREwA= -github.com/getsentry/sentry-go v0.29.1/go.mod h1:x3AtIzN01d6SiWkderzaH28Tm0lgkafpJ5Bm3li39O0= -github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= -github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= -github.com/go-redis/redis_rate v6.5.0+incompatible h1:K/G+KaoJgO3kbkLLbfdg0kzJsHhhk0gVGTMgstKgbsM= -github.com/go-redis/redis_rate v6.5.0+incompatible/go.mod h1:Jxe7BhQuVncH6fUQ2rwoAkc8SesjCGIWkm6fNRQo4Qg= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/jedib0t/go-pretty v4.3.0+incompatible h1:CGs8AVhEKg/n9YbUenWmNStRW2PHJzaeDodcfvRAbIo= -github.com/jedib0t/go-pretty v4.3.0+incompatible/go.mod h1:XemHduiw8R651AF9Pt4FwCTKeG3oo7hrHJAoznj9nag= github.com/jedib0t/go-pretty/v6 v6.6.1 h1:iJ65Xjb680rHcikRj6DSIbzCex2huitmc7bDtxYVWyc= github.com/jedib0t/go-pretty/v6 v6.6.1/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= -github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= -github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= -github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= +golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= -golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config.go b/internal/config.go index 9d63d56..5de1c37 100644 --- a/internal/config.go +++ b/internal/config.go @@ -22,6 +22,11 @@ import ( "github.com/jkaninda/goma-gateway/util" "github.com/spf13/cobra" "golang.org/x/oauth2" + "golang.org/x/oauth2/amazon" + "golang.org/x/oauth2/facebook" + "golang.org/x/oauth2/github" + "golang.org/x/oauth2/gitlab" + "golang.org/x/oauth2/google" "gopkg.in/yaml.v3" "os" ) @@ -179,7 +184,7 @@ func initConfig(configFile string) { "/example-of-jwt", }, Rule: JWTRuleMiddleware{ - URL: "https://www.googleapis.com/auth/userinfo.email", + URL: "https://example.com/auth/userinfo", RequiredHeaders: []string{ "Authorization", }, @@ -199,20 +204,41 @@ func initConfig(configFile string) { }, }, { - Name: "oauth", + Name: "oauth-google", Type: OAuth, Paths: []string{ "/protected", "/example-of-oauth", }, Rule: OauthRulerMiddleware{ - ClientID: "", - ClientSecret: "", - RedirectURL: "", - Scopes: []string{"user"}, + ClientID: "xxx", + ClientSecret: "xxx", + Provider: "google", + JWTSecret: "your-strong-jwt-secret | It's optional", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile"}, + Endpoint: OauthEndpoint{}, + State: "randomStateString", + }, + }, + { + Name: "oauth-authentik", + Type: OAuth, + Paths: []string{ + "/protected", + "/example-of-oauth", + }, + Rule: OauthRulerMiddleware{ + ClientID: "xxx", + ClientSecret: "xxx", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"email", "openid"}, + JWTSecret: "your-strong-jwt-secret | It's optional", Endpoint: OauthEndpoint{ - AuthURL: "https://accounts.google.com/o/oauth2/auth", - TokenURL: "https://oauth2.googleapis.com/token", + AuthURL: "https://authentik.example.com/application/o/authorize/", + TokenURL: "https://authentik.example.com/application/o/token/", + UserInfoURL: "https://authentik.example.com/application/o/userinfo/", }, State: "randomStateString", }, @@ -311,21 +337,6 @@ func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) { } return *oauthRuler, nil } - -func oauth2Config(oauth OauthRulerMiddleware) *oauth2.Config { - return &oauth2.Config{ - ClientID: oauth.ClientID, - ClientSecret: oauth.ClientSecret, - RedirectURL: oauth.RedirectURL, - Scopes: oauth.Scopes, - Endpoint: oauth2.Endpoint{ - AuthURL: oauth.Endpoint.AuthURL, - TokenURL: oauth.Endpoint.TokenURL, - DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, - }, - } -} - func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware { return &OauthRulerMiddleware{ ClientID: oauth.ClientID, @@ -333,10 +344,51 @@ func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware { RedirectURL: oauth.RedirectURL, State: oauth.State, Scopes: oauth.Scopes, + JWTSecret: oauth.JWTSecret, + Provider: oauth.Provider, Endpoint: OauthEndpoint{ - AuthURL: oauth.Endpoint.AuthURL, - TokenURL: oauth.Endpoint.TokenURL, - DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + UserInfoURL: oauth.Endpoint.UserInfoURL, }, } } +func oauth2Config(oauth *OauthRulerMiddleware) *oauth2.Config { + conf := &oauth2.Config{ + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: oauth.RedirectURL, + Scopes: oauth.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + }, + } + switch oauth.Provider { + case "google": + conf.Endpoint = google.Endpoint + if oauth.Endpoint.UserInfoURL == "" { + oauth.Endpoint.UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo" + } + case "amazon": + conf.Endpoint = amazon.Endpoint + case "facebook": + conf.Endpoint = facebook.Endpoint + if oauth.Endpoint.UserInfoURL == "" { + oauth.Endpoint.UserInfoURL = "https://graph.facebook.com/me" + } + case "github": + conf.Endpoint = github.Endpoint + if oauth.Endpoint.UserInfoURL == "" { + oauth.Endpoint.UserInfoURL = "https://api.github.com/user/repo" + } + case "gitlab": + conf.Endpoint = gitlab.Endpoint + default: + if oauth.Provider != "custom" { + logger.Error("Unknown provider: %s", oauth.Provider) + } + + } + return conf +} diff --git a/internal/handler.go b/internal/handler.go index 84e690d..d2b311d 100644 --- a/internal/handler.go +++ b/internal/handler.go @@ -55,13 +55,8 @@ func CORSHandler(cors Cors) mux.MiddlewareFunc { // ProxyErrorHandler catches backend errors and returns a custom response func ProxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { logger.Error("Proxy error: %v", err) - w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadGateway) - err = json.NewEncoder(w).Encode(map[string]interface{}{ - "success": false, - "code": http.StatusBadGateway, - "message": "The service is currently unavailable. Please try again later.", - }) + _, err = w.Write([]byte("Bad Gateway")) if err != nil { return } @@ -131,27 +126,42 @@ func allowedOrigin(origins []string, origin string) bool { return false } -func (oauth OauthRulerMiddleware) callbackHandler(w http.ResponseWriter, r *http.Request) { + +// callbackHandler handles oauth callback +func (oauth *OauthRulerMiddleware) callbackHandler(w http.ResponseWriter, r *http.Request) { oauthConfig := oauth2Config(oauth) - logger.Info("URL State: %s", r.URL.Query().Get("state")) // Verify the state to protect against CSRF if r.URL.Query().Get("state") != oauth.State { http.Error(w, "Invalid state", http.StatusBadRequest) return } - // Exchange the authorization code for an access token code := r.URL.Query().Get("code") token, err := oauthConfig.Exchange(context.Background(), code) if err != nil { - http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) + logger.Error("Failed to exchange token: %v", err.Error()) + http.Error(w, "Failed to exchange token", http.StatusInternalServerError) return } + // Get user info from the token + userInfo, err := oauth.getUserInfo(token) + if err != nil { + logger.Error("Error getting user info: %v", err) + http.Error(w, "Error getting user info: ", http.StatusInternalServerError) + return + } + // Generate JWT with user's email + jwtToken, err := createJWT(userInfo.Email, oauth.JWTSecret) + if err != nil { + logger.Error("Error creating JWT: %v", err) + http.Error(w, "Error creating JWT ", http.StatusInternalServerError) + return + } // Save token to a cookie for simplicity http.SetCookie(w, &http.Cookie{ - Name: "oauth-token", - Value: token.AccessToken, + Name: "goma.JWT", + Value: jwtToken, Path: oauth.CookiePath, }) diff --git a/internal/helpers.go b/internal/helpers.go index e8fcbaa..eb224f6 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -10,11 +10,16 @@ You may get a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 */ import ( + "context" "crypto/tls" + "encoding/json" "fmt" + "github.com/golang-jwt/jwt" "github.com/jedib0t/go-pretty/v6/table" "github.com/jkaninda/goma-gateway/pkg/logger" + "golang.org/x/oauth2" "net/http" + "time" ) // printRoute prints routes @@ -53,3 +58,40 @@ func loadTLS(cert, key string) (*tls.Config, error) { } return tlsConfig, nil } +func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, error) { + oauthConfig := oauth2Config(oauth) + // Call the user info endpoint with the token + client := oauthConfig.Client(context.Background(), token) + resp, err := client.Get(oauth.Endpoint.UserInfoURL) + if err != nil { + return UserInfo{}, err + } + defer resp.Body.Close() + + // Parse the user info + var userInfo UserInfo + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return UserInfo{}, err + } + + return userInfo, nil +} +func createJWT(email, jwtSecret string) (string, error) { + // Define JWT claims + claims := jwt.MapClaims{ + "email": email, + "exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration + "iss": "Goma-Gateway", // Issuer claim + } + + // Create a new token with HS256 signing method + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + // Sign the token with a secret + signedToken, err := token.SignedString([]byte(jwtSecret)) + if err != nil { + return "", err + } + + return signedToken, nil +} diff --git a/internal/middleware/config.go b/internal/middleware/config.go new file mode 100644 index 0000000..c26f218 --- /dev/null +++ b/internal/middleware/config.go @@ -0,0 +1,59 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package middleware + +import ( + "github.com/jkaninda/goma-gateway/pkg/logger" + "golang.org/x/oauth2" + "golang.org/x/oauth2/amazon" + "golang.org/x/oauth2/facebook" + "golang.org/x/oauth2/github" + "golang.org/x/oauth2/gitlab" + "golang.org/x/oauth2/google" +) + +func oauth2Config(oauth Oauth) *oauth2.Config { + config := &oauth2.Config{ + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: oauth.RedirectURL, + Scopes: oauth.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + }, + } + switch oauth.Provider { + case "google": + config.Endpoint = google.Endpoint + case "amazon": + config.Endpoint = amazon.Endpoint + case "facebook": + config.Endpoint = facebook.Endpoint + case "github": + config.Endpoint = github.Endpoint + case "gitlab": + config.Endpoint = gitlab.Endpoint + default: + if oauth.Provider != "custom" { + logger.Error("Unknown provider: %s", oauth.Provider) + } + + } + return config +} diff --git a/internal/middleware/oauth-middleware.go b/internal/middleware/oauth-middleware.go index 16c977c..e3c6e5e 100644 --- a/internal/middleware/oauth-middleware.go +++ b/internal/middleware/oauth-middleware.go @@ -18,39 +18,71 @@ package middleware import ( + "fmt" + "github.com/golang-jwt/jwt" "github.com/jkaninda/goma-gateway/pkg/logger" - "golang.org/x/oauth2" "net/http" + "time" ) -func oauth2Config(oauth Oauth) *oauth2.Config { - return &oauth2.Config{ - ClientID: oauth.ClientID, - ClientSecret: oauth.ClientSecret, - RedirectURL: oauth.RedirectURL, - Scopes: oauth.Scopes, - Endpoint: oauth2.Endpoint{ - AuthURL: oauth.Endpoint.AuthURL, - TokenURL: oauth.Endpoint.TokenURL, - DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, - }, - } -} func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Info("%s: %s Oauth", getRealIP(r), r.URL.Path) - oauthConfig := oauth2Config(oauth) + oauthConf := oauth2Config(oauth) // Check if the user is authenticated - _, err := r.Cookie("oauth-token") + token, err := r.Cookie("goma.JWT") if err != nil { // If no token, redirect to OAuth provider - url := oauthConfig.AuthCodeURL(oauth.State) + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + ok, err := validateJWT(token.Value, oauth) + if err != nil { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + if !ok { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) http.Redirect(w, r, url, http.StatusTemporaryRedirect) return } - //TODO: Check if the token stored in the cookie is valid - // Token exists, proceed with request next.ServeHTTP(w, r) }) } + +func validateJWT(signedToken string, oauth Oauth) (bool, error) { + // Parse the JWT token and provide the key function + token, err := jwt.Parse(signedToken, func(token *jwt.Token) (interface{}, error) { + // Ensure the signing method is HMAC and specifically HS256 + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + // Return the shared secret key for validation + return []byte(oauth.JWTSecret), nil + }) + + // If there's an error or token is invalid, return false + if err != nil || !token.Valid { + return false, fmt.Errorf("token is invalid: %v", err) + } + + // Check if token claims are valid + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + // Optional: Check token expiration + if exp, ok := claims["exp"].(float64); ok { + if time.Unix(int64(exp), 0).Before(time.Now()) { + return false, fmt.Errorf("token has expired") + } + } + + // Token is valid and not expired + return true, nil + } + + return false, fmt.Errorf("token is invalid or missing claims") +} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index 94b1d71..54bebee 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -120,11 +120,13 @@ type Oauth struct { // Scope specifies optional requested permissions. Scopes []string // contains filtered or unexported fields - State string - Origins []string + State string + Origins []string + JWTSecret string + Provider string } type OauthEndpoint struct { - AuthURL string - TokenURL string - DeviceAuthURL string + AuthURL string + TokenURL string + UserInfoURL string } diff --git a/internal/proxy.go b/internal/proxy.go index 793d81a..0680a8f 100644 --- a/internal/proxy.go +++ b/internal/proxy.go @@ -16,7 +16,6 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( - "encoding/json" "fmt" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" @@ -44,18 +43,12 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { w.Header().Set(accessControlAllowOrigin, r.Header.Get("Origin")) } } - // Parse the target backend URL targetURL, err := url.Parse(proxyRoute.destination) if err != nil { logger.Error("Error parsing backend URL: %s", err) - w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) - err := json.NewEncoder(w).Encode(ErrorResponse{ - Message: "Internal server error", - Code: http.StatusInternalServerError, - Success: false, - }) + _, err := w.Write([]byte("Internal Server Error")) if err != nil { return } diff --git a/internal/route.go b/internal/route.go index 7c70188..5497d05 100644 --- a/internal/route.go +++ b/internal/route.go @@ -132,18 +132,24 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { if err != nil { logger.Error("Error: %s", err.Error()) } else { + redirectURL := "/callback" + route.Path + if oauth.RedirectURL != "" { + redirectURL = oauth.RedirectURL + } amw := middleware.Oauth{ ClientID: oauth.ClientID, ClientSecret: oauth.ClientSecret, - RedirectURL: oauth.RedirectURL + route.Path, + RedirectURL: redirectURL, Scopes: oauth.Scopes, Endpoint: middleware.OauthEndpoint{ - AuthURL: oauth.Endpoint.AuthURL, - TokenURL: oauth.Endpoint.TokenURL, - DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + UserInfoURL: oauth.Endpoint.UserInfoURL, }, - State: oauth.State, - Origins: gateway.Cors.Origins, + State: oauth.State, + Origins: gateway.Cors.Origins, + JWTSecret: oauth.JWTSecret, + Provider: oauth.Provider, } oauthRuler := oauthRulerMiddleware(amw) // Check if a cookie path is defined @@ -154,12 +160,15 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { if oauthRuler.RedirectPath == "" { oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, midPath) } + if oauthRuler.Provider == "" { + oauthRuler.Provider = "custom" + } secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(CORSHandler(route.Cors)) secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler // Callback route - r.HandleFunc("/callback"+route.Path, oauthRuler.callbackHandler).Methods("GET") + r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET") } default: if !doesExist(rMiddleware.Type) { diff --git a/internal/types.go b/internal/types.go index f36ea97..cf2a468 100644 --- a/internal/types.go +++ b/internal/types.go @@ -78,7 +78,8 @@ type OauthRulerMiddleware struct { // ClientSecret is the application's secret. ClientSecret string `yaml:"clientSecret"` - + // oauth provider google, gitlab, github, amazon, facebook, custom + Provider string `yaml:"provider"` // Endpoint contains the resource server's token endpoint Endpoint OauthEndpoint `yaml:"endpoint"` @@ -93,12 +94,13 @@ type OauthRulerMiddleware struct { // Scope specifies optional requested permissions. Scopes []string `yaml:"scopes"` // contains filtered or unexported fields - State string `yaml:"state"` + State string `yaml:"state"` + JWTSecret string `yaml:"jwtSecret"` } type OauthEndpoint struct { - AuthURL string `yaml:"authUrl"` - TokenURL string `yaml:"tokenUrl"` - DeviceAuthURL string `yaml:"deviceAuthUrl"` + AuthURL string `yaml:"authUrl"` + TokenURL string `yaml:"tokenUrl"` + UserInfoURL string `yaml:"userInfoUrl"` } type RateLimiter struct { // ipBased, tokenBased @@ -242,3 +244,11 @@ type HealthCheckRouteResponse struct { Status string `json:"status"` Error string `json:"error"` } +type UserInfo struct { + Email string `json:"email"` +} + +type JWTSecret struct { + ISS string `yaml:"iss"` + Secret string `yaml:"secret"` +} diff --git a/util/helpers.go b/util/helpers.go index 5fb0fa5..6daf98d 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -10,6 +10,7 @@ You may get a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 */ import ( + "net/url" "os" "strconv" "strings" @@ -96,3 +97,11 @@ func ParseRoutePath(path, blockedPath string) string { return basePath + blockedPath } } + +func UrlParsePath(uri string) string { + parse, err := url.Parse(uri) + if err != nil { + return "" + } + return parse.Path +}