PKCE implementation

This commit is contained in:
Niko Abeler 2022-11-06 16:27:35 +01:00
parent ae387f3d7d
commit 4d5af131c2
5 changed files with 197 additions and 33 deletions

View File

@ -1,6 +1,8 @@
package web_test package web_test
import ( import (
"crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
main "h4kor/owl-blogs/cmd/owl/web" main "h4kor/owl-blogs/cmd/owl/web"
"h4kor/owl-blogs/test/assertions" "h4kor/owl-blogs/test/assertions"
@ -73,7 +75,7 @@ func TestAuthPostCorrectPassword(t *testing.T) {
func TestAuthPostWithIncorrectCode(t *testing.T) { func TestAuthPostWithIncorrectCode(t *testing.T) {
repo, user := getSingleUserTestRepo() repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword") user.ResetPassword("testpassword")
user.GenerateAuthCode("http://example.com", "http://example.com/response") user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
// Create Request and Response // Create Request and Response
form := url.Values{} form := url.Values{}
@ -95,7 +97,7 @@ func TestAuthPostWithIncorrectCode(t *testing.T) {
func TestAuthPostWithCorrectCode(t *testing.T) { func TestAuthPostWithCorrectCode(t *testing.T) {
repo, user := getSingleUserTestRepo() repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword") user.ResetPassword("testpassword")
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response") code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
// Create Request and Response // Create Request and Response
form := url.Values{} form := url.Values{}
@ -123,6 +125,128 @@ func TestAuthPostWithCorrectCode(t *testing.T) {
} }
func TestAuthPostWithCorrectCodeAndPKCE(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
// Create Request and Response
code_verifier := "test_code_verifier"
// create code challenge
h := sha256.New()
h.Write([]byte(code_verifier))
code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
form := url.Values{}
form.Add("code", code)
form.Add("client_id", "http://example.com")
form.Add("redirect_uri", "http://example.com/response")
form.Add("grant_type", "authorization_code")
form.Add("code_verifier", code_verifier)
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
req.Header.Add("Accept", "application/json")
assertions.AssertNoError(t, err, "Error creating request")
rr := httptest.NewRecorder()
router := main.SingleUserRouter(&repo)
router.ServeHTTP(rr, req)
assertions.AssertStatus(t, rr, http.StatusOK)
// parse response as json
type responseType struct {
Me string `json:"me"`
}
var response responseType
json.Unmarshal(rr.Body.Bytes(), &response)
assertions.AssertEqual(t, response.Me, user.FullUrl())
}
func TestAuthPostWithCorrectCodeAndWrongPKCE(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
// Create Request and Response
code_verifier := "test_code_verifier"
// create code challenge
h := sha256.New()
h.Write([]byte(code_verifier + "wrong"))
code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
form := url.Values{}
form.Add("code", code)
form.Add("client_id", "http://example.com")
form.Add("redirect_uri", "http://example.com/response")
form.Add("grant_type", "authorization_code")
form.Add("code_verifier", code_verifier)
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
req.Header.Add("Accept", "application/json")
assertions.AssertNoError(t, err, "Error creating request")
rr := httptest.NewRecorder()
router := main.SingleUserRouter(&repo)
router.ServeHTTP(rr, req)
assertions.AssertStatus(t, rr, http.StatusUnauthorized)
}
func TestAuthPostWithCorrectCodePKCEPlain(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
// Create Request and Response
code_verifier := "test_code_verifier"
code_challenge := code_verifier
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
form := url.Values{}
form.Add("code", code)
form.Add("client_id", "http://example.com")
form.Add("redirect_uri", "http://example.com/response")
form.Add("grant_type", "authorization_code")
form.Add("code_verifier", code_verifier)
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
req.Header.Add("Accept", "application/json")
assertions.AssertNoError(t, err, "Error creating request")
rr := httptest.NewRecorder()
router := main.SingleUserRouter(&repo)
router.ServeHTTP(rr, req)
assertions.AssertStatus(t, rr, http.StatusOK)
}
func TestAuthPostWithCorrectCodePKCEPlainWrong(t *testing.T) {
repo, user := getSingleUserTestRepo()
user.ResetPassword("testpassword")
// Create Request and Response
code_verifier := "test_code_verifier"
code_challenge := code_verifier + "wrong"
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
form := url.Values{}
form.Add("code", code)
form.Add("client_id", "http://example.com")
form.Add("redirect_uri", "http://example.com/response")
form.Add("grant_type", "authorization_code")
form.Add("code_verifier", code_verifier)
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
req.Header.Add("Accept", "application/json")
assertions.AssertNoError(t, err, "Error creating request")
rr := httptest.NewRecorder()
router := main.SingleUserRouter(&repo)
router.ServeHTTP(rr, req)
assertions.AssertStatus(t, rr, http.StatusUnauthorized)
}
func TestAuthRedirectUriNotSet(t *testing.T) { func TestAuthRedirectUriNotSet(t *testing.T) {
repo, user := getSingleUserTestRepo() repo, user := getSingleUserTestRepo()
repo.HttpClient = &mocks.MockHttpClient{} repo.HttpClient = &mocks.MockHttpClient{}

View File

@ -74,6 +74,8 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
redirectUri := r.URL.Query().Get("redirect_uri") redirectUri := r.URL.Query().Get("redirect_uri")
state := r.URL.Query().Get("state") state := r.URL.Query().Get("state")
responseType := r.URL.Query().Get("response_type") responseType := r.URL.Query().Get("response_type")
codeChallenge := r.URL.Query().Get("code_challenge")
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")
// check if request is valid // check if request is valid
missing_params := []string{} missing_params := []string{}
@ -99,6 +101,11 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support).")) w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support)."))
return return
} }
if codeChallengeMethod != "" && (codeChallengeMethod != "S256" && codeChallengeMethod != "plain") {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Invalid code_challenge_method. Must be 'S256' or 'plain'."))
return
}
// check if redirect_uri is registered // check if redirect_uri is registered
resp, _ := repo.HttpClient.Get(clientId) resp, _ := repo.HttpClient.Get(clientId)
@ -132,6 +139,8 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
RedirectUri: redirectUri, RedirectUri: redirectUri,
State: state, State: state,
ResponseType: responseType, ResponseType: responseType,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
User: user, User: user,
CsrfToken: csrfToken, CsrfToken: csrfToken,
} }
@ -168,9 +177,10 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt
code := r.Form.Get("code") code := r.Form.Get("code")
client_id := r.Form.Get("client_id") client_id := r.Form.Get("client_id")
redirect_uri := r.Form.Get("redirect_uri") redirect_uri := r.Form.Get("redirect_uri")
code_verifier := r.Form.Get("code_verifier")
// check if request is valid // check if request is valid
valid := user.VerifyAuthCode(code, client_id, redirect_uri) valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier)
if !valid { if !valid {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Invalid code")) w.Write([]byte("Invalid code"))
@ -225,11 +235,12 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
return return
} }
password := r.FormValue("password") password := r.FormValue("password")
println("Password: ", password)
client_id := r.FormValue("client_id") client_id := r.FormValue("client_id")
redirect_uri := r.FormValue("redirect_uri") redirect_uri := r.FormValue("redirect_uri")
response_type := r.FormValue("response_type") response_type := r.FormValue("response_type")
state := r.FormValue("state") state := r.FormValue("state")
code_challenge := r.FormValue("code_challenge")
code_challenge_method := r.FormValue("code_challenge_method")
// CSRF check // CSRF check
formCsrfToken := r.FormValue("csrf_token") formCsrfToken := r.FormValue("csrf_token")
@ -250,17 +261,22 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
password_valid := user.VerifyPassword(password) password_valid := user.VerifyPassword(password)
if !password_valid { if !password_valid {
http.Redirect(w, r, redirect := fmt.Sprintf(
fmt.Sprintf(
"%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s", "%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s",
user.AuthUrl(), client_id, redirect_uri, response_type, state, user.AuthUrl(), client_id, redirect_uri, response_type, state,
), )
if code_challenge != "" {
redirect += fmt.Sprintf("&code_challenge=%s&code_challenge_method=%s", code_challenge, code_challenge_method)
}
http.Redirect(w, r,
redirect,
http.StatusFound, http.StatusFound,
) )
return return
} else { } else {
// password is valid, generate code // password is valid, generate code
code, err := user.GenerateAuthCode(client_id, redirect_uri) code, err := user.GenerateAuthCode(
client_id, redirect_uri, code_challenge, code_challenge_method)
if err != nil { if err != nil {
println("Error generating code: ", err.Error()) println("Error generating code: ", err.Error())
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)

View File

@ -8,5 +8,7 @@
<input type="hidden" name="response_type" value="{{.ResponseType}}"> <input type="hidden" name="response_type" value="{{.ResponseType}}">
<input type="hidden" name="state" value="{{.State}}"> <input type="hidden" name="state" value="{{.State}}">
<input type="hidden" name="csrf_token" value="{{.CsrfToken}}"> <input type="hidden" name="csrf_token" value="{{.CsrfToken}}">
<input type="hidden" name="code_challenge" value="{{.CodeChallenge}}">
<input type="hidden" name="code_challenge_method" value="{{.CodeChallengeMethod}}">
<input type="submit" value="Login"> <input type="submit" value="Login">
</form> </form>

View File

@ -26,6 +26,8 @@ type AuthRequestData struct {
RedirectUri string RedirectUri string
State string State string
ResponseType string ResponseType string
CodeChallenge string
CodeChallengeMethod string
User User User User
CsrfToken string CsrfToken string
} }

24
user.go
View File

@ -1,6 +1,8 @@
package owl package owl
import ( import (
"crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
@ -35,6 +37,8 @@ type AuthCode struct {
Code string `yaml:"code"` Code string `yaml:"code"`
ClientId string `yaml:"client_id"` ClientId string `yaml:"client_id"`
RedirectUri string `yaml:"redirect_uri"` RedirectUri string `yaml:"redirect_uri"`
CodeChallenge string `yaml:"code_challenge"`
CodeChallengeMethod string `yaml:"code_challenge_method"`
Created time.Time `yaml:"created"` Created time.Time `yaml:"created"`
} }
@ -284,22 +288,38 @@ func (user User) addAuthCode(code AuthCode) error {
return saveToYaml(user.AuthCodesFile(), codes) return saveToYaml(user.AuthCodesFile(), codes)
} }
func (user User) GenerateAuthCode(client_id string, redirect_uri string) (string, error) { func (user User) GenerateAuthCode(
client_id string, redirect_uri string,
code_challenge string, code_challenge_method string,
) (string, error) {
// generate code // generate code
code := GenerateRandomString(32) code := GenerateRandomString(32)
return code, user.addAuthCode(AuthCode{ return code, user.addAuthCode(AuthCode{
Code: code, Code: code,
ClientId: client_id, ClientId: client_id,
RedirectUri: redirect_uri, RedirectUri: redirect_uri,
CodeChallenge: code_challenge,
CodeChallengeMethod: code_challenge_method,
Created: time.Now(),
}) })
} }
func (user User) VerifyAuthCode(code string, client_id string, redirect_uri string) bool { func (user User) VerifyAuthCode(
code string, client_id string, redirect_uri string, code_verifier string,
) bool {
codes := user.getAuthCodes() codes := user.getAuthCodes()
for _, c := range codes { for _, c := range codes {
if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri { if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri {
if c.CodeChallengeMethod == "plain" {
return c.CodeChallenge == code_verifier
} else if c.CodeChallengeMethod == "S256" {
// hash code_verifier
hash := sha256.Sum256([]byte(code_verifier))
return c.CodeChallenge == base64.RawURLEncoding.EncodeToString(hash[:])
} else if c.CodeChallengeMethod == "" {
return true return true
} }
} }
}
return false return false
} }