From 4d5af131c2de1c867bc3ef1d9063f36398cc02bc Mon Sep 17 00:00:00 2001 From: Niko Abeler Date: Sun, 6 Nov 2022 16:27:35 +0100 Subject: [PATCH] PKCE implementation --- cmd/owl/web/auth_test.go | 128 ++++++++++++++++++++++++++++++++++++++- cmd/owl/web/handler.go | 44 +++++++++----- embed/auth.html | 2 + renderer.go | 16 ++--- user.go | 40 +++++++++--- 5 files changed, 197 insertions(+), 33 deletions(-) diff --git a/cmd/owl/web/auth_test.go b/cmd/owl/web/auth_test.go index 0bf6e76..67d443d 100644 --- a/cmd/owl/web/auth_test.go +++ b/cmd/owl/web/auth_test.go @@ -1,6 +1,8 @@ package web_test import ( + "crypto/sha256" + "encoding/base64" "encoding/json" main "h4kor/owl-blogs/cmd/owl/web" "h4kor/owl-blogs/test/assertions" @@ -73,7 +75,7 @@ func TestAuthPostCorrectPassword(t *testing.T) { func TestAuthPostWithIncorrectCode(t *testing.T) { repo, user := getSingleUserTestRepo() user.ResetPassword("testpassword") - user.GenerateAuthCode("http://example.com", "http://example.com/response") + user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "") // Create Request and Response form := url.Values{} @@ -95,7 +97,7 @@ func TestAuthPostWithIncorrectCode(t *testing.T) { func TestAuthPostWithCorrectCode(t *testing.T) { repo, user := getSingleUserTestRepo() user.ResetPassword("testpassword") - code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "") // Create Request and Response form := url.Values{} @@ -123,6 +125,128 @@ func TestAuthPostWithCorrectCode(t *testing.T) { } +func TestAuthPostWithCorrectCodeAndPKCE(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + + // Create Request and Response + code_verifier := "test_code_verifier" + // create code challenge + h := sha256.New() + h.Write([]byte(code_verifier)) + code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256") + + form := url.Values{} + form.Add("code", code) + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("grant_type", "authorization_code") + form.Add("code_verifier", code_verifier) + req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + req.Header.Add("Accept", "application/json") + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusOK) + // parse response as json + type responseType struct { + Me string `json:"me"` + } + var response responseType + json.Unmarshal(rr.Body.Bytes(), &response) + assertions.AssertEqual(t, response.Me, user.FullUrl()) + +} + +func TestAuthPostWithCorrectCodeAndWrongPKCE(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + + // Create Request and Response + code_verifier := "test_code_verifier" + // create code challenge + h := sha256.New() + h.Write([]byte(code_verifier + "wrong")) + code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256") + + form := url.Values{} + form.Add("code", code) + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("grant_type", "authorization_code") + form.Add("code_verifier", code_verifier) + req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + req.Header.Add("Accept", "application/json") + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusUnauthorized) +} + +func TestAuthPostWithCorrectCodePKCEPlain(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + + // Create Request and Response + code_verifier := "test_code_verifier" + code_challenge := code_verifier + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain") + + form := url.Values{} + form.Add("code", code) + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("grant_type", "authorization_code") + form.Add("code_verifier", code_verifier) + req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + req.Header.Add("Accept", "application/json") + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusOK) +} + +func TestAuthPostWithCorrectCodePKCEPlainWrong(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + + // Create Request and Response + code_verifier := "test_code_verifier" + code_challenge := code_verifier + "wrong" + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain") + + form := url.Values{} + form.Add("code", code) + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("grant_type", "authorization_code") + form.Add("code_verifier", code_verifier) + req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + req.Header.Add("Accept", "application/json") + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusUnauthorized) +} + func TestAuthRedirectUriNotSet(t *testing.T) { repo, user := getSingleUserTestRepo() repo.HttpClient = &mocks.MockHttpClient{} diff --git a/cmd/owl/web/handler.go b/cmd/owl/web/handler.go index 62ad48f..bad57c9 100644 --- a/cmd/owl/web/handler.go +++ b/cmd/owl/web/handler.go @@ -74,6 +74,8 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque redirectUri := r.URL.Query().Get("redirect_uri") state := r.URL.Query().Get("state") responseType := r.URL.Query().Get("response_type") + codeChallenge := r.URL.Query().Get("code_challenge") + codeChallengeMethod := r.URL.Query().Get("code_challenge_method") // check if request is valid missing_params := []string{} @@ -99,6 +101,11 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support).")) return } + if codeChallengeMethod != "" && (codeChallengeMethod != "S256" && codeChallengeMethod != "plain") { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Invalid code_challenge_method. Must be 'S256' or 'plain'.")) + return + } // check if redirect_uri is registered resp, _ := repo.HttpClient.Get(clientId) @@ -127,13 +134,15 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque http.SetCookie(w, &cookie) reqData := owl.AuthRequestData{ - Me: me, - ClientId: clientId, - RedirectUri: redirectUri, - State: state, - ResponseType: responseType, - User: user, - CsrfToken: csrfToken, + Me: me, + ClientId: clientId, + RedirectUri: redirectUri, + State: state, + ResponseType: responseType, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + User: user, + CsrfToken: csrfToken, } html, err := owl.RenderUserAuthPage(reqData) @@ -168,9 +177,10 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt code := r.Form.Get("code") client_id := r.Form.Get("client_id") redirect_uri := r.Form.Get("redirect_uri") + code_verifier := r.Form.Get("code_verifier") // check if request is valid - valid := user.VerifyAuthCode(code, client_id, redirect_uri) + valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier) if !valid { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("Invalid code")) @@ -225,11 +235,12 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http return } password := r.FormValue("password") - println("Password: ", password) client_id := r.FormValue("client_id") redirect_uri := r.FormValue("redirect_uri") response_type := r.FormValue("response_type") state := r.FormValue("state") + code_challenge := r.FormValue("code_challenge") + code_challenge_method := r.FormValue("code_challenge_method") // CSRF check formCsrfToken := r.FormValue("csrf_token") @@ -250,17 +261,22 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http password_valid := user.VerifyPassword(password) if !password_valid { + redirect := fmt.Sprintf( + "%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s", + user.AuthUrl(), client_id, redirect_uri, response_type, state, + ) + if code_challenge != "" { + redirect += fmt.Sprintf("&code_challenge=%s&code_challenge_method=%s", code_challenge, code_challenge_method) + } http.Redirect(w, r, - fmt.Sprintf( - "%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s", - user.AuthUrl(), client_id, redirect_uri, response_type, state, - ), + redirect, http.StatusFound, ) return } else { // password is valid, generate code - code, err := user.GenerateAuthCode(client_id, redirect_uri) + code, err := user.GenerateAuthCode( + client_id, redirect_uri, code_challenge, code_challenge_method) if err != nil { println("Error generating code: ", err.Error()) w.WriteHeader(http.StatusInternalServerError) diff --git a/embed/auth.html b/embed/auth.html index 03884ce..d2f322a 100644 --- a/embed/auth.html +++ b/embed/auth.html @@ -8,5 +8,7 @@ + + \ No newline at end of file diff --git a/renderer.go b/renderer.go index 03e5540..b72bf7d 100644 --- a/renderer.go +++ b/renderer.go @@ -21,13 +21,15 @@ type PostRenderData struct { } type AuthRequestData struct { - Me string - ClientId string - RedirectUri string - State string - ResponseType string - User User - CsrfToken string + Me string + ClientId string + RedirectUri string + State string + ResponseType string + CodeChallenge string + CodeChallengeMethod string + User User + CsrfToken string } func renderEmbedTemplate(templateFile string, data interface{}) (string, error) { diff --git a/user.go b/user.go index 4141814..8b042ef 100644 --- a/user.go +++ b/user.go @@ -1,6 +1,8 @@ package owl import ( + "crypto/sha256" + "encoding/base64" "fmt" "net/url" "os" @@ -32,10 +34,12 @@ type UserMe struct { } type AuthCode struct { - Code string `yaml:"code"` - ClientId string `yaml:"client_id"` - RedirectUri string `yaml:"redirect_uri"` - Created time.Time `yaml:"created"` + Code string `yaml:"code"` + ClientId string `yaml:"client_id"` + RedirectUri string `yaml:"redirect_uri"` + CodeChallenge string `yaml:"code_challenge"` + CodeChallengeMethod string `yaml:"code_challenge_method"` + Created time.Time `yaml:"created"` } func (user User) Dir() string { @@ -284,21 +288,37 @@ func (user User) addAuthCode(code AuthCode) error { return saveToYaml(user.AuthCodesFile(), codes) } -func (user User) GenerateAuthCode(client_id string, redirect_uri string) (string, error) { +func (user User) GenerateAuthCode( + client_id string, redirect_uri string, + code_challenge string, code_challenge_method string, +) (string, error) { // generate code code := GenerateRandomString(32) return code, user.addAuthCode(AuthCode{ - Code: code, - ClientId: client_id, - RedirectUri: redirect_uri, + Code: code, + ClientId: client_id, + RedirectUri: redirect_uri, + CodeChallenge: code_challenge, + CodeChallengeMethod: code_challenge_method, + Created: time.Now(), }) } -func (user User) VerifyAuthCode(code string, client_id string, redirect_uri string) bool { +func (user User) VerifyAuthCode( + code string, client_id string, redirect_uri string, code_verifier string, +) bool { codes := user.getAuthCodes() for _, c := range codes { if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri { - return true + if c.CodeChallengeMethod == "plain" { + return c.CodeChallenge == code_verifier + } else if c.CodeChallengeMethod == "S256" { + // hash code_verifier + hash := sha256.Sum256([]byte(code_verifier)) + return c.CodeChallenge == base64.RawURLEncoding.EncodeToString(hash[:]) + } else if c.CodeChallengeMethod == "" { + return true + } } } return false