PKCE implementation
This commit is contained in:
parent
ae387f3d7d
commit
4d5af131c2
|
@ -1,6 +1,8 @@
|
|||
package web_test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
main "h4kor/owl-blogs/cmd/owl/web"
|
||||
"h4kor/owl-blogs/test/assertions"
|
||||
|
@ -73,7 +75,7 @@ func TestAuthPostCorrectPassword(t *testing.T) {
|
|||
func TestAuthPostWithIncorrectCode(t *testing.T) {
|
||||
repo, user := getSingleUserTestRepo()
|
||||
user.ResetPassword("testpassword")
|
||||
user.GenerateAuthCode("http://example.com", "http://example.com/response")
|
||||
user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
|
||||
|
||||
// Create Request and Response
|
||||
form := url.Values{}
|
||||
|
@ -95,7 +97,7 @@ func TestAuthPostWithIncorrectCode(t *testing.T) {
|
|||
func TestAuthPostWithCorrectCode(t *testing.T) {
|
||||
repo, user := getSingleUserTestRepo()
|
||||
user.ResetPassword("testpassword")
|
||||
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response")
|
||||
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
|
||||
|
||||
// Create Request and Response
|
||||
form := url.Values{}
|
||||
|
@ -123,6 +125,128 @@ func TestAuthPostWithCorrectCode(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
func TestAuthPostWithCorrectCodeAndPKCE(t *testing.T) {
|
||||
repo, user := getSingleUserTestRepo()
|
||||
user.ResetPassword("testpassword")
|
||||
|
||||
// Create Request and Response
|
||||
code_verifier := "test_code_verifier"
|
||||
// create code challenge
|
||||
h := sha256.New()
|
||||
h.Write([]byte(code_verifier))
|
||||
code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
|
||||
|
||||
form := url.Values{}
|
||||
form.Add("code", code)
|
||||
form.Add("client_id", "http://example.com")
|
||||
form.Add("redirect_uri", "http://example.com/response")
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("code_verifier", code_verifier)
|
||||
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
|
||||
req.Header.Add("Accept", "application/json")
|
||||
assertions.AssertNoError(t, err, "Error creating request")
|
||||
rr := httptest.NewRecorder()
|
||||
router := main.SingleUserRouter(&repo)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assertions.AssertStatus(t, rr, http.StatusOK)
|
||||
// parse response as json
|
||||
type responseType struct {
|
||||
Me string `json:"me"`
|
||||
}
|
||||
var response responseType
|
||||
json.Unmarshal(rr.Body.Bytes(), &response)
|
||||
assertions.AssertEqual(t, response.Me, user.FullUrl())
|
||||
|
||||
}
|
||||
|
||||
func TestAuthPostWithCorrectCodeAndWrongPKCE(t *testing.T) {
|
||||
repo, user := getSingleUserTestRepo()
|
||||
user.ResetPassword("testpassword")
|
||||
|
||||
// Create Request and Response
|
||||
code_verifier := "test_code_verifier"
|
||||
// create code challenge
|
||||
h := sha256.New()
|
||||
h.Write([]byte(code_verifier + "wrong"))
|
||||
code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
|
||||
|
||||
form := url.Values{}
|
||||
form.Add("code", code)
|
||||
form.Add("client_id", "http://example.com")
|
||||
form.Add("redirect_uri", "http://example.com/response")
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("code_verifier", code_verifier)
|
||||
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
|
||||
req.Header.Add("Accept", "application/json")
|
||||
assertions.AssertNoError(t, err, "Error creating request")
|
||||
rr := httptest.NewRecorder()
|
||||
router := main.SingleUserRouter(&repo)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assertions.AssertStatus(t, rr, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func TestAuthPostWithCorrectCodePKCEPlain(t *testing.T) {
|
||||
repo, user := getSingleUserTestRepo()
|
||||
user.ResetPassword("testpassword")
|
||||
|
||||
// Create Request and Response
|
||||
code_verifier := "test_code_verifier"
|
||||
code_challenge := code_verifier
|
||||
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
|
||||
|
||||
form := url.Values{}
|
||||
form.Add("code", code)
|
||||
form.Add("client_id", "http://example.com")
|
||||
form.Add("redirect_uri", "http://example.com/response")
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("code_verifier", code_verifier)
|
||||
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
|
||||
req.Header.Add("Accept", "application/json")
|
||||
assertions.AssertNoError(t, err, "Error creating request")
|
||||
rr := httptest.NewRecorder()
|
||||
router := main.SingleUserRouter(&repo)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assertions.AssertStatus(t, rr, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestAuthPostWithCorrectCodePKCEPlainWrong(t *testing.T) {
|
||||
repo, user := getSingleUserTestRepo()
|
||||
user.ResetPassword("testpassword")
|
||||
|
||||
// Create Request and Response
|
||||
code_verifier := "test_code_verifier"
|
||||
code_challenge := code_verifier + "wrong"
|
||||
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
|
||||
|
||||
form := url.Values{}
|
||||
form.Add("code", code)
|
||||
form.Add("client_id", "http://example.com")
|
||||
form.Add("redirect_uri", "http://example.com/response")
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("code_verifier", code_verifier)
|
||||
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
|
||||
req.Header.Add("Accept", "application/json")
|
||||
assertions.AssertNoError(t, err, "Error creating request")
|
||||
rr := httptest.NewRecorder()
|
||||
router := main.SingleUserRouter(&repo)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assertions.AssertStatus(t, rr, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func TestAuthRedirectUriNotSet(t *testing.T) {
|
||||
repo, user := getSingleUserTestRepo()
|
||||
repo.HttpClient = &mocks.MockHttpClient{}
|
||||
|
|
|
@ -74,6 +74,8 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
|
|||
redirectUri := r.URL.Query().Get("redirect_uri")
|
||||
state := r.URL.Query().Get("state")
|
||||
responseType := r.URL.Query().Get("response_type")
|
||||
codeChallenge := r.URL.Query().Get("code_challenge")
|
||||
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")
|
||||
|
||||
// check if request is valid
|
||||
missing_params := []string{}
|
||||
|
@ -99,6 +101,11 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
|
|||
w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support)."))
|
||||
return
|
||||
}
|
||||
if codeChallengeMethod != "" && (codeChallengeMethod != "S256" && codeChallengeMethod != "plain") {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("Invalid code_challenge_method. Must be 'S256' or 'plain'."))
|
||||
return
|
||||
}
|
||||
|
||||
// check if redirect_uri is registered
|
||||
resp, _ := repo.HttpClient.Get(clientId)
|
||||
|
@ -132,6 +139,8 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
|
|||
RedirectUri: redirectUri,
|
||||
State: state,
|
||||
ResponseType: responseType,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
User: user,
|
||||
CsrfToken: csrfToken,
|
||||
}
|
||||
|
@ -168,9 +177,10 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt
|
|||
code := r.Form.Get("code")
|
||||
client_id := r.Form.Get("client_id")
|
||||
redirect_uri := r.Form.Get("redirect_uri")
|
||||
code_verifier := r.Form.Get("code_verifier")
|
||||
|
||||
// check if request is valid
|
||||
valid := user.VerifyAuthCode(code, client_id, redirect_uri)
|
||||
valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier)
|
||||
if !valid {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Invalid code"))
|
||||
|
@ -225,11 +235,12 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
|
|||
return
|
||||
}
|
||||
password := r.FormValue("password")
|
||||
println("Password: ", password)
|
||||
client_id := r.FormValue("client_id")
|
||||
redirect_uri := r.FormValue("redirect_uri")
|
||||
response_type := r.FormValue("response_type")
|
||||
state := r.FormValue("state")
|
||||
code_challenge := r.FormValue("code_challenge")
|
||||
code_challenge_method := r.FormValue("code_challenge_method")
|
||||
|
||||
// CSRF check
|
||||
formCsrfToken := r.FormValue("csrf_token")
|
||||
|
@ -250,17 +261,22 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
|
|||
|
||||
password_valid := user.VerifyPassword(password)
|
||||
if !password_valid {
|
||||
http.Redirect(w, r,
|
||||
fmt.Sprintf(
|
||||
redirect := fmt.Sprintf(
|
||||
"%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s",
|
||||
user.AuthUrl(), client_id, redirect_uri, response_type, state,
|
||||
),
|
||||
)
|
||||
if code_challenge != "" {
|
||||
redirect += fmt.Sprintf("&code_challenge=%s&code_challenge_method=%s", code_challenge, code_challenge_method)
|
||||
}
|
||||
http.Redirect(w, r,
|
||||
redirect,
|
||||
http.StatusFound,
|
||||
)
|
||||
return
|
||||
} else {
|
||||
// password is valid, generate code
|
||||
code, err := user.GenerateAuthCode(client_id, redirect_uri)
|
||||
code, err := user.GenerateAuthCode(
|
||||
client_id, redirect_uri, code_challenge, code_challenge_method)
|
||||
if err != nil {
|
||||
println("Error generating code: ", err.Error())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
|
|
@ -8,5 +8,7 @@
|
|||
<input type="hidden" name="response_type" value="{{.ResponseType}}">
|
||||
<input type="hidden" name="state" value="{{.State}}">
|
||||
<input type="hidden" name="csrf_token" value="{{.CsrfToken}}">
|
||||
<input type="hidden" name="code_challenge" value="{{.CodeChallenge}}">
|
||||
<input type="hidden" name="code_challenge_method" value="{{.CodeChallengeMethod}}">
|
||||
<input type="submit" value="Login">
|
||||
</form>
|
|
@ -26,6 +26,8 @@ type AuthRequestData struct {
|
|||
RedirectUri string
|
||||
State string
|
||||
ResponseType string
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod string
|
||||
User User
|
||||
CsrfToken string
|
||||
}
|
||||
|
|
24
user.go
24
user.go
|
@ -1,6 +1,8 @@
|
|||
package owl
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -35,6 +37,8 @@ type AuthCode struct {
|
|||
Code string `yaml:"code"`
|
||||
ClientId string `yaml:"client_id"`
|
||||
RedirectUri string `yaml:"redirect_uri"`
|
||||
CodeChallenge string `yaml:"code_challenge"`
|
||||
CodeChallengeMethod string `yaml:"code_challenge_method"`
|
||||
Created time.Time `yaml:"created"`
|
||||
}
|
||||
|
||||
|
@ -284,22 +288,38 @@ func (user User) addAuthCode(code AuthCode) error {
|
|||
return saveToYaml(user.AuthCodesFile(), codes)
|
||||
}
|
||||
|
||||
func (user User) GenerateAuthCode(client_id string, redirect_uri string) (string, error) {
|
||||
func (user User) GenerateAuthCode(
|
||||
client_id string, redirect_uri string,
|
||||
code_challenge string, code_challenge_method string,
|
||||
) (string, error) {
|
||||
// generate code
|
||||
code := GenerateRandomString(32)
|
||||
return code, user.addAuthCode(AuthCode{
|
||||
Code: code,
|
||||
ClientId: client_id,
|
||||
RedirectUri: redirect_uri,
|
||||
CodeChallenge: code_challenge,
|
||||
CodeChallengeMethod: code_challenge_method,
|
||||
Created: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
func (user User) VerifyAuthCode(code string, client_id string, redirect_uri string) bool {
|
||||
func (user User) VerifyAuthCode(
|
||||
code string, client_id string, redirect_uri string, code_verifier string,
|
||||
) bool {
|
||||
codes := user.getAuthCodes()
|
||||
for _, c := range codes {
|
||||
if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri {
|
||||
if c.CodeChallengeMethod == "plain" {
|
||||
return c.CodeChallenge == code_verifier
|
||||
} else if c.CodeChallengeMethod == "S256" {
|
||||
// hash code_verifier
|
||||
hash := sha256.Sum256([]byte(code_verifier))
|
||||
return c.CodeChallenge == base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
} else if c.CodeChallengeMethod == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue