diff --git a/cmd/owl/web/auth_test.go b/cmd/owl/web/auth_test.go
index 0bf6e76..67d443d 100644
--- a/cmd/owl/web/auth_test.go
+++ b/cmd/owl/web/auth_test.go
@@ -1,6 +1,8 @@
package web_test
import (
+ "crypto/sha256"
+ "encoding/base64"
"encoding/json"
main "h4kor/owl-blogs/cmd/owl/web"
"h4kor/owl-blogs/test/assertions"
@@ -73,7 +75,7 @@ func TestAuthPostCorrectPassword(t *testing.T) {
func TestAuthPostWithIncorrectCode(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
- user.GenerateAuthCode("http://example.com", "http://example.com/response")
+ user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
// Create Request and Response
form := url.Values{}
@@ -95,7 +97,7 @@ func TestAuthPostWithIncorrectCode(t *testing.T) {
func TestAuthPostWithCorrectCode(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
- code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response")
+ code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
// Create Request and Response
form := url.Values{}
@@ -123,6 +125,128 @@ func TestAuthPostWithCorrectCode(t *testing.T) {
}
+func TestAuthPostWithCorrectCodeAndPKCE(t *testing.T) {
+ repo, user := getSingleUserTestRepo()
+ user.ResetPassword("testpassword")
+
+ // Create Request and Response
+ code_verifier := "test_code_verifier"
+ // create code challenge
+ h := sha256.New()
+ h.Write([]byte(code_verifier))
+ code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
+ code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
+
+ form := url.Values{}
+ form.Add("code", code)
+ form.Add("client_id", "http://example.com")
+ form.Add("redirect_uri", "http://example.com/response")
+ form.Add("grant_type", "authorization_code")
+ form.Add("code_verifier", code_verifier)
+ req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
+ req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
+ req.Header.Add("Accept", "application/json")
+ assertions.AssertNoError(t, err, "Error creating request")
+ rr := httptest.NewRecorder()
+ router := main.SingleUserRouter(&repo)
+ router.ServeHTTP(rr, req)
+
+ assertions.AssertStatus(t, rr, http.StatusOK)
+ // parse response as json
+ type responseType struct {
+ Me string `json:"me"`
+ }
+ var response responseType
+ json.Unmarshal(rr.Body.Bytes(), &response)
+ assertions.AssertEqual(t, response.Me, user.FullUrl())
+
+}
+
+func TestAuthPostWithCorrectCodeAndWrongPKCE(t *testing.T) {
+ repo, user := getSingleUserTestRepo()
+ user.ResetPassword("testpassword")
+
+ // Create Request and Response
+ code_verifier := "test_code_verifier"
+ // create code challenge
+ h := sha256.New()
+ h.Write([]byte(code_verifier + "wrong"))
+ code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
+ code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
+
+ form := url.Values{}
+ form.Add("code", code)
+ form.Add("client_id", "http://example.com")
+ form.Add("redirect_uri", "http://example.com/response")
+ form.Add("grant_type", "authorization_code")
+ form.Add("code_verifier", code_verifier)
+ req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
+ req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
+ req.Header.Add("Accept", "application/json")
+ assertions.AssertNoError(t, err, "Error creating request")
+ rr := httptest.NewRecorder()
+ router := main.SingleUserRouter(&repo)
+ router.ServeHTTP(rr, req)
+
+ assertions.AssertStatus(t, rr, http.StatusUnauthorized)
+}
+
+func TestAuthPostWithCorrectCodePKCEPlain(t *testing.T) {
+ repo, user := getSingleUserTestRepo()
+ user.ResetPassword("testpassword")
+
+ // Create Request and Response
+ code_verifier := "test_code_verifier"
+ code_challenge := code_verifier
+ code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
+
+ form := url.Values{}
+ form.Add("code", code)
+ form.Add("client_id", "http://example.com")
+ form.Add("redirect_uri", "http://example.com/response")
+ form.Add("grant_type", "authorization_code")
+ form.Add("code_verifier", code_verifier)
+ req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
+ req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
+ req.Header.Add("Accept", "application/json")
+ assertions.AssertNoError(t, err, "Error creating request")
+ rr := httptest.NewRecorder()
+ router := main.SingleUserRouter(&repo)
+ router.ServeHTTP(rr, req)
+
+ assertions.AssertStatus(t, rr, http.StatusOK)
+}
+
+func TestAuthPostWithCorrectCodePKCEPlainWrong(t *testing.T) {
+ repo, user := getSingleUserTestRepo()
+ user.ResetPassword("testpassword")
+
+ // Create Request and Response
+ code_verifier := "test_code_verifier"
+ code_challenge := code_verifier + "wrong"
+ code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
+
+ form := url.Values{}
+ form.Add("code", code)
+ form.Add("client_id", "http://example.com")
+ form.Add("redirect_uri", "http://example.com/response")
+ form.Add("grant_type", "authorization_code")
+ form.Add("code_verifier", code_verifier)
+ req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
+ req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
+ req.Header.Add("Accept", "application/json")
+ assertions.AssertNoError(t, err, "Error creating request")
+ rr := httptest.NewRecorder()
+ router := main.SingleUserRouter(&repo)
+ router.ServeHTTP(rr, req)
+
+ assertions.AssertStatus(t, rr, http.StatusUnauthorized)
+}
+
func TestAuthRedirectUriNotSet(t *testing.T) {
repo, user := getSingleUserTestRepo()
repo.HttpClient = &mocks.MockHttpClient{}
diff --git a/cmd/owl/web/handler.go b/cmd/owl/web/handler.go
index 62ad48f..bad57c9 100644
--- a/cmd/owl/web/handler.go
+++ b/cmd/owl/web/handler.go
@@ -74,6 +74,8 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
redirectUri := r.URL.Query().Get("redirect_uri")
state := r.URL.Query().Get("state")
responseType := r.URL.Query().Get("response_type")
+ codeChallenge := r.URL.Query().Get("code_challenge")
+ codeChallengeMethod := r.URL.Query().Get("code_challenge_method")
// check if request is valid
missing_params := []string{}
@@ -99,6 +101,11 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support)."))
return
}
+ if codeChallengeMethod != "" && (codeChallengeMethod != "S256" && codeChallengeMethod != "plain") {
+ w.WriteHeader(http.StatusBadRequest)
+ w.Write([]byte("Invalid code_challenge_method. Must be 'S256' or 'plain'."))
+ return
+ }
// check if redirect_uri is registered
resp, _ := repo.HttpClient.Get(clientId)
@@ -127,13 +134,15 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
http.SetCookie(w, &cookie)
reqData := owl.AuthRequestData{
- Me: me,
- ClientId: clientId,
- RedirectUri: redirectUri,
- State: state,
- ResponseType: responseType,
- User: user,
- CsrfToken: csrfToken,
+ Me: me,
+ ClientId: clientId,
+ RedirectUri: redirectUri,
+ State: state,
+ ResponseType: responseType,
+ CodeChallenge: codeChallenge,
+ CodeChallengeMethod: codeChallengeMethod,
+ User: user,
+ CsrfToken: csrfToken,
}
html, err := owl.RenderUserAuthPage(reqData)
@@ -168,9 +177,10 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt
code := r.Form.Get("code")
client_id := r.Form.Get("client_id")
redirect_uri := r.Form.Get("redirect_uri")
+ code_verifier := r.Form.Get("code_verifier")
// check if request is valid
- valid := user.VerifyAuthCode(code, client_id, redirect_uri)
+ valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier)
if !valid {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Invalid code"))
@@ -225,11 +235,12 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
return
}
password := r.FormValue("password")
- println("Password: ", password)
client_id := r.FormValue("client_id")
redirect_uri := r.FormValue("redirect_uri")
response_type := r.FormValue("response_type")
state := r.FormValue("state")
+ code_challenge := r.FormValue("code_challenge")
+ code_challenge_method := r.FormValue("code_challenge_method")
// CSRF check
formCsrfToken := r.FormValue("csrf_token")
@@ -250,17 +261,22 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
password_valid := user.VerifyPassword(password)
if !password_valid {
+ redirect := fmt.Sprintf(
+ "%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s",
+ user.AuthUrl(), client_id, redirect_uri, response_type, state,
+ )
+ if code_challenge != "" {
+ redirect += fmt.Sprintf("&code_challenge=%s&code_challenge_method=%s", code_challenge, code_challenge_method)
+ }
http.Redirect(w, r,
- fmt.Sprintf(
- "%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s",
- user.AuthUrl(), client_id, redirect_uri, response_type, state,
- ),
+ redirect,
http.StatusFound,
)
return
} else {
// password is valid, generate code
- code, err := user.GenerateAuthCode(client_id, redirect_uri)
+ code, err := user.GenerateAuthCode(
+ client_id, redirect_uri, code_challenge, code_challenge_method)
if err != nil {
println("Error generating code: ", err.Error())
w.WriteHeader(http.StatusInternalServerError)
diff --git a/embed/auth.html b/embed/auth.html
index 03884ce..d2f322a 100644
--- a/embed/auth.html
+++ b/embed/auth.html
@@ -8,5 +8,7 @@
+
+
\ No newline at end of file
diff --git a/renderer.go b/renderer.go
index 03e5540..b72bf7d 100644
--- a/renderer.go
+++ b/renderer.go
@@ -21,13 +21,15 @@ type PostRenderData struct {
}
type AuthRequestData struct {
- Me string
- ClientId string
- RedirectUri string
- State string
- ResponseType string
- User User
- CsrfToken string
+ Me string
+ ClientId string
+ RedirectUri string
+ State string
+ ResponseType string
+ CodeChallenge string
+ CodeChallengeMethod string
+ User User
+ CsrfToken string
}
func renderEmbedTemplate(templateFile string, data interface{}) (string, error) {
diff --git a/user.go b/user.go
index 4141814..8b042ef 100644
--- a/user.go
+++ b/user.go
@@ -1,6 +1,8 @@
package owl
import (
+ "crypto/sha256"
+ "encoding/base64"
"fmt"
"net/url"
"os"
@@ -32,10 +34,12 @@ type UserMe struct {
}
type AuthCode struct {
- Code string `yaml:"code"`
- ClientId string `yaml:"client_id"`
- RedirectUri string `yaml:"redirect_uri"`
- Created time.Time `yaml:"created"`
+ Code string `yaml:"code"`
+ ClientId string `yaml:"client_id"`
+ RedirectUri string `yaml:"redirect_uri"`
+ CodeChallenge string `yaml:"code_challenge"`
+ CodeChallengeMethod string `yaml:"code_challenge_method"`
+ Created time.Time `yaml:"created"`
}
func (user User) Dir() string {
@@ -284,21 +288,37 @@ func (user User) addAuthCode(code AuthCode) error {
return saveToYaml(user.AuthCodesFile(), codes)
}
-func (user User) GenerateAuthCode(client_id string, redirect_uri string) (string, error) {
+func (user User) GenerateAuthCode(
+ client_id string, redirect_uri string,
+ code_challenge string, code_challenge_method string,
+) (string, error) {
// generate code
code := GenerateRandomString(32)
return code, user.addAuthCode(AuthCode{
- Code: code,
- ClientId: client_id,
- RedirectUri: redirect_uri,
+ Code: code,
+ ClientId: client_id,
+ RedirectUri: redirect_uri,
+ CodeChallenge: code_challenge,
+ CodeChallengeMethod: code_challenge_method,
+ Created: time.Now(),
})
}
-func (user User) VerifyAuthCode(code string, client_id string, redirect_uri string) bool {
+func (user User) VerifyAuthCode(
+ code string, client_id string, redirect_uri string, code_verifier string,
+) bool {
codes := user.getAuthCodes()
for _, c := range codes {
if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri {
- return true
+ if c.CodeChallengeMethod == "plain" {
+ return c.CodeChallenge == code_verifier
+ } else if c.CodeChallengeMethod == "S256" {
+ // hash code_verifier
+ hash := sha256.Sum256([]byte(code_verifier))
+ return c.CodeChallenge == base64.RawURLEncoding.EncodeToString(hash[:])
+ } else if c.CodeChallengeMethod == "" {
+ return true
+ }
}
}
return false