PKCE implementation

This commit is contained in:
Niko Abeler 2022-11-06 16:27:35 +01:00
parent ae387f3d7d
commit 4d5af131c2
5 changed files with 197 additions and 33 deletions

View File

@ -1,6 +1,8 @@
package web_test package web_test
import ( import (
"crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
main "h4kor/owl-blogs/cmd/owl/web" main "h4kor/owl-blogs/cmd/owl/web"
"h4kor/owl-blogs/test/assertions" "h4kor/owl-blogs/test/assertions"
@ -73,7 +75,7 @@ func TestAuthPostCorrectPassword(t *testing.T) {
func TestAuthPostWithIncorrectCode(t *testing.T) { func TestAuthPostWithIncorrectCode(t *testing.T) {
repo, user := getSingleUserTestRepo() repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword") user.ResetPassword("testpassword")
user.GenerateAuthCode("http://example.com", "http://example.com/response") user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
// Create Request and Response // Create Request and Response
form := url.Values{} form := url.Values{}
@ -95,7 +97,7 @@ func TestAuthPostWithIncorrectCode(t *testing.T) {
func TestAuthPostWithCorrectCode(t *testing.T) { func TestAuthPostWithCorrectCode(t *testing.T) {
repo, user := getSingleUserTestRepo() repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword") user.ResetPassword("testpassword")
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response") code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
// Create Request and Response // Create Request and Response
form := url.Values{} form := url.Values{}
@ -123,6 +125,128 @@ func TestAuthPostWithCorrectCode(t *testing.T) {
} }
func TestAuthPostWithCorrectCodeAndPKCE(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
// Create Request and Response
code_verifier := "test_code_verifier"
// create code challenge
h := sha256.New()
h.Write([]byte(code_verifier))
code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
form := url.Values{}
form.Add("code", code)
form.Add("client_id", "http://example.com")
form.Add("redirect_uri", "http://example.com/response")
form.Add("grant_type", "authorization_code")
form.Add("code_verifier", code_verifier)
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
req.Header.Add("Accept", "application/json")
assertions.AssertNoError(t, err, "Error creating request")
rr := httptest.NewRecorder()
router := main.SingleUserRouter(&repo)
router.ServeHTTP(rr, req)
assertions.AssertStatus(t, rr, http.StatusOK)
// parse response as json
type responseType struct {
Me string `json:"me"`
}
var response responseType
json.Unmarshal(rr.Body.Bytes(), &response)
assertions.AssertEqual(t, response.Me, user.FullUrl())
}
func TestAuthPostWithCorrectCodeAndWrongPKCE(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
// Create Request and Response
code_verifier := "test_code_verifier"
// create code challenge
h := sha256.New()
h.Write([]byte(code_verifier + "wrong"))
code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
form := url.Values{}
form.Add("code", code)
form.Add("client_id", "http://example.com")
form.Add("redirect_uri", "http://example.com/response")
form.Add("grant_type", "authorization_code")
form.Add("code_verifier", code_verifier)
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
req.Header.Add("Accept", "application/json")
assertions.AssertNoError(t, err, "Error creating request")
rr := httptest.NewRecorder()
router := main.SingleUserRouter(&repo)
router.ServeHTTP(rr, req)
assertions.AssertStatus(t, rr, http.StatusUnauthorized)
}
func TestAuthPostWithCorrectCodePKCEPlain(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
// Create Request and Response
code_verifier := "test_code_verifier"
code_challenge := code_verifier
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
form := url.Values{}
form.Add("code", code)
form.Add("client_id", "http://example.com")
form.Add("redirect_uri", "http://example.com/response")
form.Add("grant_type", "authorization_code")
form.Add("code_verifier", code_verifier)
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
req.Header.Add("Accept", "application/json")
assertions.AssertNoError(t, err, "Error creating request")
rr := httptest.NewRecorder()
router := main.SingleUserRouter(&repo)
router.ServeHTTP(rr, req)
assertions.AssertStatus(t, rr, http.StatusOK)
}
func TestAuthPostWithCorrectCodePKCEPlainWrong(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
// Create Request and Response
code_verifier := "test_code_verifier"
code_challenge := code_verifier + "wrong"
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
form := url.Values{}
form.Add("code", code)
form.Add("client_id", "http://example.com")
form.Add("redirect_uri", "http://example.com/response")
form.Add("grant_type", "authorization_code")
form.Add("code_verifier", code_verifier)
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
req.Header.Add("Accept", "application/json")
assertions.AssertNoError(t, err, "Error creating request")
rr := httptest.NewRecorder()
router := main.SingleUserRouter(&repo)
router.ServeHTTP(rr, req)
assertions.AssertStatus(t, rr, http.StatusUnauthorized)
}
func TestAuthRedirectUriNotSet(t *testing.T) { func TestAuthRedirectUriNotSet(t *testing.T) {
repo, user := getSingleUserTestRepo() repo, user := getSingleUserTestRepo()
repo.HttpClient = &mocks.MockHttpClient{} repo.HttpClient = &mocks.MockHttpClient{}

View File

@ -74,6 +74,8 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
redirectUri := r.URL.Query().Get("redirect_uri") redirectUri := r.URL.Query().Get("redirect_uri")
state := r.URL.Query().Get("state") state := r.URL.Query().Get("state")
responseType := r.URL.Query().Get("response_type") responseType := r.URL.Query().Get("response_type")
codeChallenge := r.URL.Query().Get("code_challenge")
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")
// check if request is valid // check if request is valid
missing_params := []string{} missing_params := []string{}
@ -99,6 +101,11 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support).")) w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support)."))
return return
} }
if codeChallengeMethod != "" && (codeChallengeMethod != "S256" && codeChallengeMethod != "plain") {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Invalid code_challenge_method. Must be 'S256' or 'plain'."))
return
}
// check if redirect_uri is registered // check if redirect_uri is registered
resp, _ := repo.HttpClient.Get(clientId) resp, _ := repo.HttpClient.Get(clientId)
@ -127,13 +134,15 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
reqData := owl.AuthRequestData{ reqData := owl.AuthRequestData{
Me: me, Me: me,
ClientId: clientId, ClientId: clientId,
RedirectUri: redirectUri, RedirectUri: redirectUri,
State: state, State: state,
ResponseType: responseType, ResponseType: responseType,
User: user, CodeChallenge: codeChallenge,
CsrfToken: csrfToken, CodeChallengeMethod: codeChallengeMethod,
User: user,
CsrfToken: csrfToken,
} }
html, err := owl.RenderUserAuthPage(reqData) html, err := owl.RenderUserAuthPage(reqData)
@ -168,9 +177,10 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt
code := r.Form.Get("code") code := r.Form.Get("code")
client_id := r.Form.Get("client_id") client_id := r.Form.Get("client_id")
redirect_uri := r.Form.Get("redirect_uri") redirect_uri := r.Form.Get("redirect_uri")
code_verifier := r.Form.Get("code_verifier")
// check if request is valid // check if request is valid
valid := user.VerifyAuthCode(code, client_id, redirect_uri) valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier)
if !valid { if !valid {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Invalid code")) w.Write([]byte("Invalid code"))
@ -225,11 +235,12 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
return return
} }
password := r.FormValue("password") password := r.FormValue("password")
println("Password: ", password)
client_id := r.FormValue("client_id") client_id := r.FormValue("client_id")
redirect_uri := r.FormValue("redirect_uri") redirect_uri := r.FormValue("redirect_uri")
response_type := r.FormValue("response_type") response_type := r.FormValue("response_type")
state := r.FormValue("state") state := r.FormValue("state")
code_challenge := r.FormValue("code_challenge")
code_challenge_method := r.FormValue("code_challenge_method")
// CSRF check // CSRF check
formCsrfToken := r.FormValue("csrf_token") formCsrfToken := r.FormValue("csrf_token")
@ -250,17 +261,22 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
password_valid := user.VerifyPassword(password) password_valid := user.VerifyPassword(password)
if !password_valid { if !password_valid {
redirect := fmt.Sprintf(
"%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s",
user.AuthUrl(), client_id, redirect_uri, response_type, state,
)
if code_challenge != "" {
redirect += fmt.Sprintf("&code_challenge=%s&code_challenge_method=%s", code_challenge, code_challenge_method)
}
http.Redirect(w, r, http.Redirect(w, r,
fmt.Sprintf( redirect,
"%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s",
user.AuthUrl(), client_id, redirect_uri, response_type, state,
),
http.StatusFound, http.StatusFound,
) )
return return
} else { } else {
// password is valid, generate code // password is valid, generate code
code, err := user.GenerateAuthCode(client_id, redirect_uri) code, err := user.GenerateAuthCode(
client_id, redirect_uri, code_challenge, code_challenge_method)
if err != nil { if err != nil {
println("Error generating code: ", err.Error()) println("Error generating code: ", err.Error())
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)

View File

@ -8,5 +8,7 @@
<input type="hidden" name="response_type" value="{{.ResponseType}}"> <input type="hidden" name="response_type" value="{{.ResponseType}}">
<input type="hidden" name="state" value="{{.State}}"> <input type="hidden" name="state" value="{{.State}}">
<input type="hidden" name="csrf_token" value="{{.CsrfToken}}"> <input type="hidden" name="csrf_token" value="{{.CsrfToken}}">
<input type="hidden" name="code_challenge" value="{{.CodeChallenge}}">
<input type="hidden" name="code_challenge_method" value="{{.CodeChallengeMethod}}">
<input type="submit" value="Login"> <input type="submit" value="Login">
</form> </form>

View File

@ -21,13 +21,15 @@ type PostRenderData struct {
} }
type AuthRequestData struct { type AuthRequestData struct {
Me string Me string
ClientId string ClientId string
RedirectUri string RedirectUri string
State string State string
ResponseType string ResponseType string
User User CodeChallenge string
CsrfToken string CodeChallengeMethod string
User User
CsrfToken string
} }
func renderEmbedTemplate(templateFile string, data interface{}) (string, error) { func renderEmbedTemplate(templateFile string, data interface{}) (string, error) {

40
user.go
View File

@ -1,6 +1,8 @@
package owl package owl
import ( import (
"crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
@ -32,10 +34,12 @@ type UserMe struct {
} }
type AuthCode struct { type AuthCode struct {
Code string `yaml:"code"` Code string `yaml:"code"`
ClientId string `yaml:"client_id"` ClientId string `yaml:"client_id"`
RedirectUri string `yaml:"redirect_uri"` RedirectUri string `yaml:"redirect_uri"`
Created time.Time `yaml:"created"` CodeChallenge string `yaml:"code_challenge"`
CodeChallengeMethod string `yaml:"code_challenge_method"`
Created time.Time `yaml:"created"`
} }
func (user User) Dir() string { func (user User) Dir() string {
@ -284,21 +288,37 @@ func (user User) addAuthCode(code AuthCode) error {
return saveToYaml(user.AuthCodesFile(), codes) return saveToYaml(user.AuthCodesFile(), codes)
} }
func (user User) GenerateAuthCode(client_id string, redirect_uri string) (string, error) { func (user User) GenerateAuthCode(
client_id string, redirect_uri string,
code_challenge string, code_challenge_method string,
) (string, error) {
// generate code // generate code
code := GenerateRandomString(32) code := GenerateRandomString(32)
return code, user.addAuthCode(AuthCode{ return code, user.addAuthCode(AuthCode{
Code: code, Code: code,
ClientId: client_id, ClientId: client_id,
RedirectUri: redirect_uri, RedirectUri: redirect_uri,
CodeChallenge: code_challenge,
CodeChallengeMethod: code_challenge_method,
Created: time.Now(),
}) })
} }
func (user User) VerifyAuthCode(code string, client_id string, redirect_uri string) bool { func (user User) VerifyAuthCode(
code string, client_id string, redirect_uri string, code_verifier string,
) bool {
codes := user.getAuthCodes() codes := user.getAuthCodes()
for _, c := range codes { for _, c := range codes {
if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri { if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri {
return true if c.CodeChallengeMethod == "plain" {
return c.CodeChallenge == code_verifier
} else if c.CodeChallengeMethod == "S256" {
// hash code_verifier
hash := sha256.Sum256([]byte(code_verifier))
return c.CodeChallenge == base64.RawURLEncoding.EncodeToString(hash[:])
} else if c.CodeChallengeMethod == "" {
return true
}
} }
} }
return false return false