feat: add oauth token validity verification

This commit is contained in:
2024-11-08 12:03:52 +01:00
parent d6e7791cb4
commit bd20895306
12 changed files with 326 additions and 119 deletions

20
go.mod
View File

@@ -3,28 +3,24 @@ module github.com/jkaninda/goma-gateway
go 1.23.2 go 1.23.2
require ( require (
github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/gorilla/mux v1.8.1 github.com/gorilla/mux v1.8.1
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1
golang.org/x/oauth2 v0.24.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
github.com/jedib0t/go-pretty/v6 v6.6.1 // indirect github.com/jedib0t/go-pretty/v6 v6.6.1
github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect
golang.org/x/sys v0.17.0 // indirect golang.org/x/sys v0.27.0 // indirect
) )
require ( require (
github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/go-redis/redis v6.15.9+incompatible // indirect
github.com/go-redis/redis_rate v6.5.0+incompatible // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/copier v0.4.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.7.0 // indirect
) )

43
go.sum
View File

@@ -1,53 +1,46 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be h1:J5BL2kskAlV9ckgEsNQXscjIaLiOYiZ75d4e94E6dcQ= github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be h1:J5BL2kskAlV9ckgEsNQXscjIaLiOYiZ75d4e94E6dcQ=
github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be/go.mod h1:mk5IQ+Y0ZeO87b858TlA645sVcEcbiX6YqP98kt+7+w= github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be/go.mod h1:mk5IQ+Y0ZeO87b858TlA645sVcEcbiX6YqP98kt+7+w=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/getsentry/sentry-go v0.29.1 h1:DyZuChN8Hz3ARxGVV8ePaNXh1dQ7d76AiB117xcREwA= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/getsentry/sentry-go v0.29.1/go.mod h1:x3AtIzN01d6SiWkderzaH28Tm0lgkafpJ5Bm3li39O0= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/go-redis/redis_rate v6.5.0+incompatible h1:K/G+KaoJgO3kbkLLbfdg0kzJsHhhk0gVGTMgstKgbsM=
github.com/go-redis/redis_rate v6.5.0+incompatible/go.mod h1:Jxe7BhQuVncH6fUQ2rwoAkc8SesjCGIWkm6fNRQo4Qg=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jedib0t/go-pretty v4.3.0+incompatible h1:CGs8AVhEKg/n9YbUenWmNStRW2PHJzaeDodcfvRAbIo=
github.com/jedib0t/go-pretty v4.3.0+incompatible/go.mod h1:XemHduiw8R651AF9Pt4FwCTKeG3oo7hrHJAoznj9nag=
github.com/jedib0t/go-pretty/v6 v6.6.1 h1:iJ65Xjb680rHcikRj6DSIbzCex2huitmc7bDtxYVWyc= github.com/jedib0t/go-pretty/v6 v6.6.1 h1:iJ65Xjb680rHcikRj6DSIbzCex2huitmc7bDtxYVWyc=
github.com/jedib0t/go-pretty/v6 v6.6.1/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= github.com/jedib0t/go-pretty/v6 v6.6.1/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -22,6 +22,11 @@ import (
"github.com/jkaninda/goma-gateway/util" "github.com/jkaninda/goma-gateway/util"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/amazon"
"golang.org/x/oauth2/facebook"
"golang.org/x/oauth2/github"
"golang.org/x/oauth2/gitlab"
"golang.org/x/oauth2/google"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"os" "os"
) )
@@ -179,7 +184,7 @@ func initConfig(configFile string) {
"/example-of-jwt", "/example-of-jwt",
}, },
Rule: JWTRuleMiddleware{ Rule: JWTRuleMiddleware{
URL: "https://www.googleapis.com/auth/userinfo.email", URL: "https://example.com/auth/userinfo",
RequiredHeaders: []string{ RequiredHeaders: []string{
"Authorization", "Authorization",
}, },
@@ -199,20 +204,41 @@ func initConfig(configFile string) {
}, },
}, },
{ {
Name: "oauth", Name: "oauth-google",
Type: OAuth, Type: OAuth,
Paths: []string{ Paths: []string{
"/protected", "/protected",
"/example-of-oauth", "/example-of-oauth",
}, },
Rule: OauthRulerMiddleware{ Rule: OauthRulerMiddleware{
ClientID: "", ClientID: "xxx",
ClientSecret: "", ClientSecret: "xxx",
RedirectURL: "", Provider: "google",
Scopes: []string{"user"}, JWTSecret: "your-strong-jwt-secret | It's optional",
RedirectURL: "http://localhost:8080/callback",
Scopes: []string{"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile"},
Endpoint: OauthEndpoint{},
State: "randomStateString",
},
},
{
Name: "oauth-authentik",
Type: OAuth,
Paths: []string{
"/protected",
"/example-of-oauth",
},
Rule: OauthRulerMiddleware{
ClientID: "xxx",
ClientSecret: "xxx",
RedirectURL: "http://localhost:8080/callback",
Scopes: []string{"email", "openid"},
JWTSecret: "your-strong-jwt-secret | It's optional",
Endpoint: OauthEndpoint{ Endpoint: OauthEndpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth", AuthURL: "https://authentik.example.com/application/o/authorize/",
TokenURL: "https://oauth2.googleapis.com/token", TokenURL: "https://authentik.example.com/application/o/token/",
UserInfoURL: "https://authentik.example.com/application/o/userinfo/",
}, },
State: "randomStateString", State: "randomStateString",
}, },
@@ -311,21 +337,6 @@ func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) {
} }
return *oauthRuler, nil return *oauthRuler, nil
} }
func oauth2Config(oauth OauthRulerMiddleware) *oauth2.Config {
return &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
},
}
}
func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware { func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware {
return &OauthRulerMiddleware{ return &OauthRulerMiddleware{
ClientID: oauth.ClientID, ClientID: oauth.ClientID,
@@ -333,10 +344,51 @@ func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware {
RedirectURL: oauth.RedirectURL, RedirectURL: oauth.RedirectURL,
State: oauth.State, State: oauth.State,
Scopes: oauth.Scopes, Scopes: oauth.Scopes,
JWTSecret: oauth.JWTSecret,
Provider: oauth.Provider,
Endpoint: OauthEndpoint{ Endpoint: OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL, AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL, TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, UserInfoURL: oauth.Endpoint.UserInfoURL,
}, },
} }
} }
func oauth2Config(oauth *OauthRulerMiddleware) *oauth2.Config {
conf := &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
},
}
switch oauth.Provider {
case "google":
conf.Endpoint = google.Endpoint
if oauth.Endpoint.UserInfoURL == "" {
oauth.Endpoint.UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
}
case "amazon":
conf.Endpoint = amazon.Endpoint
case "facebook":
conf.Endpoint = facebook.Endpoint
if oauth.Endpoint.UserInfoURL == "" {
oauth.Endpoint.UserInfoURL = "https://graph.facebook.com/me"
}
case "github":
conf.Endpoint = github.Endpoint
if oauth.Endpoint.UserInfoURL == "" {
oauth.Endpoint.UserInfoURL = "https://api.github.com/user/repo"
}
case "gitlab":
conf.Endpoint = gitlab.Endpoint
default:
if oauth.Provider != "custom" {
logger.Error("Unknown provider: %s", oauth.Provider)
}
}
return conf
}

View File

@@ -55,13 +55,8 @@ func CORSHandler(cors Cors) mux.MiddlewareFunc {
// ProxyErrorHandler catches backend errors and returns a custom response // ProxyErrorHandler catches backend errors and returns a custom response
func ProxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { func ProxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("Proxy error: %v", err) logger.Error("Proxy error: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadGateway) w.WriteHeader(http.StatusBadGateway)
err = json.NewEncoder(w).Encode(map[string]interface{}{ _, err = w.Write([]byte("Bad Gateway"))
"success": false,
"code": http.StatusBadGateway,
"message": "The service is currently unavailable. Please try again later.",
})
if err != nil { if err != nil {
return return
} }
@@ -131,27 +126,42 @@ func allowedOrigin(origins []string, origin string) bool {
return false return false
} }
func (oauth OauthRulerMiddleware) callbackHandler(w http.ResponseWriter, r *http.Request) {
// callbackHandler handles oauth callback
func (oauth *OauthRulerMiddleware) callbackHandler(w http.ResponseWriter, r *http.Request) {
oauthConfig := oauth2Config(oauth) oauthConfig := oauth2Config(oauth)
logger.Info("URL State: %s", r.URL.Query().Get("state"))
// Verify the state to protect against CSRF // Verify the state to protect against CSRF
if r.URL.Query().Get("state") != oauth.State { if r.URL.Query().Get("state") != oauth.State {
http.Error(w, "Invalid state", http.StatusBadRequest) http.Error(w, "Invalid state", http.StatusBadRequest)
return return
} }
// Exchange the authorization code for an access token // Exchange the authorization code for an access token
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
token, err := oauthConfig.Exchange(context.Background(), code) token, err := oauthConfig.Exchange(context.Background(), code)
if err != nil { if err != nil {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) logger.Error("Failed to exchange token: %v", err.Error())
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
return return
} }
// Get user info from the token
userInfo, err := oauth.getUserInfo(token)
if err != nil {
logger.Error("Error getting user info: %v", err)
http.Error(w, "Error getting user info: ", http.StatusInternalServerError)
return
}
// Generate JWT with user's email
jwtToken, err := createJWT(userInfo.Email, oauth.JWTSecret)
if err != nil {
logger.Error("Error creating JWT: %v", err)
http.Error(w, "Error creating JWT ", http.StatusInternalServerError)
return
}
// Save token to a cookie for simplicity // Save token to a cookie for simplicity
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: "oauth-token", Name: "goma.JWT",
Value: token.AccessToken, Value: jwtToken,
Path: oauth.CookiePath, Path: oauth.CookiePath,
}) })

View File

@@ -10,11 +10,16 @@ You may get a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
*/ */
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/json"
"fmt" "fmt"
"github.com/golang-jwt/jwt"
"github.com/jedib0t/go-pretty/v6/table" "github.com/jedib0t/go-pretty/v6/table"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"golang.org/x/oauth2"
"net/http" "net/http"
"time"
) )
// printRoute prints routes // printRoute prints routes
@@ -53,3 +58,40 @@ func loadTLS(cert, key string) (*tls.Config, error) {
} }
return tlsConfig, nil return tlsConfig, nil
} }
func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, error) {
oauthConfig := oauth2Config(oauth)
// Call the user info endpoint with the token
client := oauthConfig.Client(context.Background(), token)
resp, err := client.Get(oauth.Endpoint.UserInfoURL)
if err != nil {
return UserInfo{}, err
}
defer resp.Body.Close()
// Parse the user info
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return UserInfo{}, err
}
return userInfo, nil
}
func createJWT(email, jwtSecret string) (string, error) {
// Define JWT claims
claims := jwt.MapClaims{
"email": email,
"exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration
"iss": "Goma-Gateway", // Issuer claim
}
// Create a new token with HS256 signing method
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Sign the token with a secret
signedToken, err := token.SignedString([]byte(jwtSecret))
if err != nil {
return "", err
}
return signedToken, nil
}

View File

@@ -0,0 +1,59 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package middleware
import (
"github.com/jkaninda/goma-gateway/pkg/logger"
"golang.org/x/oauth2"
"golang.org/x/oauth2/amazon"
"golang.org/x/oauth2/facebook"
"golang.org/x/oauth2/github"
"golang.org/x/oauth2/gitlab"
"golang.org/x/oauth2/google"
)
func oauth2Config(oauth Oauth) *oauth2.Config {
config := &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
},
}
switch oauth.Provider {
case "google":
config.Endpoint = google.Endpoint
case "amazon":
config.Endpoint = amazon.Endpoint
case "facebook":
config.Endpoint = facebook.Endpoint
case "github":
config.Endpoint = github.Endpoint
case "gitlab":
config.Endpoint = gitlab.Endpoint
default:
if oauth.Provider != "custom" {
logger.Error("Unknown provider: %s", oauth.Provider)
}
}
return config
}

View File

@@ -18,39 +18,71 @@
package middleware package middleware
import ( import (
"fmt"
"github.com/golang-jwt/jwt"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"golang.org/x/oauth2"
"net/http" "net/http"
"time"
) )
func oauth2Config(oauth Oauth) *oauth2.Config {
return &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
},
}
}
func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("%s: %s Oauth", getRealIP(r), r.URL.Path) logger.Info("%s: %s Oauth", getRealIP(r), r.URL.Path)
oauthConfig := oauth2Config(oauth) oauthConf := oauth2Config(oauth)
// Check if the user is authenticated // Check if the user is authenticated
_, err := r.Cookie("oauth-token") token, err := r.Cookie("goma.JWT")
if err != nil { if err != nil {
// If no token, redirect to OAuth provider // If no token, redirect to OAuth provider
url := oauthConfig.AuthCodeURL(oauth.State) url := oauthConf.AuthCodeURL(oauth.State)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return
}
ok, err := validateJWT(token.Value, oauth)
if err != nil {
// If no token, redirect to OAuth provider
url := oauthConf.AuthCodeURL(oauth.State)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return
}
if !ok {
// If no token, redirect to OAuth provider
url := oauthConf.AuthCodeURL(oauth.State)
http.Redirect(w, r, url, http.StatusTemporaryRedirect) http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return return
} }
//TODO: Check if the token stored in the cookie is valid
// Token exists, proceed with request // Token exists, proceed with request
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
func validateJWT(signedToken string, oauth Oauth) (bool, error) {
// Parse the JWT token and provide the key function
token, err := jwt.Parse(signedToken, func(token *jwt.Token) (interface{}, error) {
// Ensure the signing method is HMAC and specifically HS256
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Return the shared secret key for validation
return []byte(oauth.JWTSecret), nil
})
// If there's an error or token is invalid, return false
if err != nil || !token.Valid {
return false, fmt.Errorf("token is invalid: %v", err)
}
// Check if token claims are valid
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// Optional: Check token expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Unix(int64(exp), 0).Before(time.Now()) {
return false, fmt.Errorf("token has expired")
}
}
// Token is valid and not expired
return true, nil
}
return false, fmt.Errorf("token is invalid or missing claims")
}

View File

@@ -120,11 +120,13 @@ type Oauth struct {
// Scope specifies optional requested permissions. // Scope specifies optional requested permissions.
Scopes []string Scopes []string
// contains filtered or unexported fields // contains filtered or unexported fields
State string State string
Origins []string Origins []string
JWTSecret string
Provider string
} }
type OauthEndpoint struct { type OauthEndpoint struct {
AuthURL string AuthURL string
TokenURL string TokenURL string
DeviceAuthURL string UserInfoURL string
} }

View File

@@ -16,7 +16,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"net/http" "net/http"
@@ -44,18 +43,12 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
w.Header().Set(accessControlAllowOrigin, r.Header.Get("Origin")) w.Header().Set(accessControlAllowOrigin, r.Header.Get("Origin"))
} }
} }
// Parse the target backend URL // Parse the target backend URL
targetURL, err := url.Parse(proxyRoute.destination) targetURL, err := url.Parse(proxyRoute.destination)
if err != nil { if err != nil {
logger.Error("Error parsing backend URL: %s", err) logger.Error("Error parsing backend URL: %s", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
err := json.NewEncoder(w).Encode(ErrorResponse{ _, err := w.Write([]byte("Internal Server Error"))
Message: "Internal server error",
Code: http.StatusInternalServerError,
Success: false,
})
if err != nil { if err != nil {
return return
} }

View File

@@ -132,18 +132,24 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if err != nil { if err != nil {
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
redirectURL := "/callback" + route.Path
if oauth.RedirectURL != "" {
redirectURL = oauth.RedirectURL
}
amw := middleware.Oauth{ amw := middleware.Oauth{
ClientID: oauth.ClientID, ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret, ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL + route.Path, RedirectURL: redirectURL,
Scopes: oauth.Scopes, Scopes: oauth.Scopes,
Endpoint: middleware.OauthEndpoint{ Endpoint: middleware.OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL, AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL, TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, UserInfoURL: oauth.Endpoint.UserInfoURL,
}, },
State: oauth.State, State: oauth.State,
Origins: gateway.Cors.Origins, Origins: gateway.Cors.Origins,
JWTSecret: oauth.JWTSecret,
Provider: oauth.Provider,
} }
oauthRuler := oauthRulerMiddleware(amw) oauthRuler := oauthRulerMiddleware(amw)
// Check if a cookie path is defined // Check if a cookie path is defined
@@ -154,12 +160,15 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if oauthRuler.RedirectPath == "" { if oauthRuler.RedirectPath == "" {
oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, midPath) oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, midPath)
} }
if oauthRuler.Provider == "" {
oauthRuler.Provider = "custom"
}
secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors)) secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
// Callback route // Callback route
r.HandleFunc("/callback"+route.Path, oauthRuler.callbackHandler).Methods("GET") r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET")
} }
default: default:
if !doesExist(rMiddleware.Type) { if !doesExist(rMiddleware.Type) {

View File

@@ -78,7 +78,8 @@ type OauthRulerMiddleware struct {
// ClientSecret is the application's secret. // ClientSecret is the application's secret.
ClientSecret string `yaml:"clientSecret"` ClientSecret string `yaml:"clientSecret"`
// oauth provider google, gitlab, github, amazon, facebook, custom
Provider string `yaml:"provider"`
// Endpoint contains the resource server's token endpoint // Endpoint contains the resource server's token endpoint
Endpoint OauthEndpoint `yaml:"endpoint"` Endpoint OauthEndpoint `yaml:"endpoint"`
@@ -93,12 +94,13 @@ type OauthRulerMiddleware struct {
// Scope specifies optional requested permissions. // Scope specifies optional requested permissions.
Scopes []string `yaml:"scopes"` Scopes []string `yaml:"scopes"`
// contains filtered or unexported fields // contains filtered or unexported fields
State string `yaml:"state"` State string `yaml:"state"`
JWTSecret string `yaml:"jwtSecret"`
} }
type OauthEndpoint struct { type OauthEndpoint struct {
AuthURL string `yaml:"authUrl"` AuthURL string `yaml:"authUrl"`
TokenURL string `yaml:"tokenUrl"` TokenURL string `yaml:"tokenUrl"`
DeviceAuthURL string `yaml:"deviceAuthUrl"` UserInfoURL string `yaml:"userInfoUrl"`
} }
type RateLimiter struct { type RateLimiter struct {
// ipBased, tokenBased // ipBased, tokenBased
@@ -242,3 +244,11 @@ type HealthCheckRouteResponse struct {
Status string `json:"status"` Status string `json:"status"`
Error string `json:"error"` Error string `json:"error"`
} }
type UserInfo struct {
Email string `json:"email"`
}
type JWTSecret struct {
ISS string `yaml:"iss"`
Secret string `yaml:"secret"`
}

View File

@@ -10,6 +10,7 @@ You may get a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
*/ */
import ( import (
"net/url"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@@ -96,3 +97,11 @@ func ParseRoutePath(path, blockedPath string) string {
return basePath + blockedPath return basePath + blockedPath
} }
} }
func UrlParsePath(uri string) string {
parse, err := url.Parse(uri)
if err != nil {
return ""
}
return parse.Path
}