PKCE implementation
This commit is contained in:
parent
ae387f3d7d
commit
4d5af131c2
|
@ -1,6 +1,8 @@
|
||||||
package web_test
|
package web_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
main "h4kor/owl-blogs/cmd/owl/web"
|
main "h4kor/owl-blogs/cmd/owl/web"
|
||||||
"h4kor/owl-blogs/test/assertions"
|
"h4kor/owl-blogs/test/assertions"
|
||||||
|
@ -73,7 +75,7 @@ func TestAuthPostCorrectPassword(t *testing.T) {
|
||||||
func TestAuthPostWithIncorrectCode(t *testing.T) {
|
func TestAuthPostWithIncorrectCode(t *testing.T) {
|
||||||
repo, user := getSingleUserTestRepo()
|
repo, user := getSingleUserTestRepo()
|
||||||
user.ResetPassword("testpassword")
|
user.ResetPassword("testpassword")
|
||||||
user.GenerateAuthCode("http://example.com", "http://example.com/response")
|
user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
|
||||||
|
|
||||||
// Create Request and Response
|
// Create Request and Response
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
|
@ -95,7 +97,7 @@ func TestAuthPostWithIncorrectCode(t *testing.T) {
|
||||||
func TestAuthPostWithCorrectCode(t *testing.T) {
|
func TestAuthPostWithCorrectCode(t *testing.T) {
|
||||||
repo, user := getSingleUserTestRepo()
|
repo, user := getSingleUserTestRepo()
|
||||||
user.ResetPassword("testpassword")
|
user.ResetPassword("testpassword")
|
||||||
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response")
|
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "")
|
||||||
|
|
||||||
// Create Request and Response
|
// Create Request and Response
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
|
@ -123,6 +125,128 @@ func TestAuthPostWithCorrectCode(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthPostWithCorrectCodeAndPKCE(t *testing.T) {
|
||||||
|
repo, user := getSingleUserTestRepo()
|
||||||
|
user.ResetPassword("testpassword")
|
||||||
|
|
||||||
|
// Create Request and Response
|
||||||
|
code_verifier := "test_code_verifier"
|
||||||
|
// create code challenge
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(code_verifier))
|
||||||
|
code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Add("code", code)
|
||||||
|
form.Add("client_id", "http://example.com")
|
||||||
|
form.Add("redirect_uri", "http://example.com/response")
|
||||||
|
form.Add("grant_type", "authorization_code")
|
||||||
|
form.Add("code_verifier", code_verifier)
|
||||||
|
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
|
||||||
|
req.Header.Add("Accept", "application/json")
|
||||||
|
assertions.AssertNoError(t, err, "Error creating request")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router := main.SingleUserRouter(&repo)
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
assertions.AssertStatus(t, rr, http.StatusOK)
|
||||||
|
// parse response as json
|
||||||
|
type responseType struct {
|
||||||
|
Me string `json:"me"`
|
||||||
|
}
|
||||||
|
var response responseType
|
||||||
|
json.Unmarshal(rr.Body.Bytes(), &response)
|
||||||
|
assertions.AssertEqual(t, response.Me, user.FullUrl())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthPostWithCorrectCodeAndWrongPKCE(t *testing.T) {
|
||||||
|
repo, user := getSingleUserTestRepo()
|
||||||
|
user.ResetPassword("testpassword")
|
||||||
|
|
||||||
|
// Create Request and Response
|
||||||
|
code_verifier := "test_code_verifier"
|
||||||
|
// create code challenge
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(code_verifier + "wrong"))
|
||||||
|
code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256")
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Add("code", code)
|
||||||
|
form.Add("client_id", "http://example.com")
|
||||||
|
form.Add("redirect_uri", "http://example.com/response")
|
||||||
|
form.Add("grant_type", "authorization_code")
|
||||||
|
form.Add("code_verifier", code_verifier)
|
||||||
|
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
|
||||||
|
req.Header.Add("Accept", "application/json")
|
||||||
|
assertions.AssertNoError(t, err, "Error creating request")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router := main.SingleUserRouter(&repo)
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
assertions.AssertStatus(t, rr, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthPostWithCorrectCodePKCEPlain(t *testing.T) {
|
||||||
|
repo, user := getSingleUserTestRepo()
|
||||||
|
user.ResetPassword("testpassword")
|
||||||
|
|
||||||
|
// Create Request and Response
|
||||||
|
code_verifier := "test_code_verifier"
|
||||||
|
code_challenge := code_verifier
|
||||||
|
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Add("code", code)
|
||||||
|
form.Add("client_id", "http://example.com")
|
||||||
|
form.Add("redirect_uri", "http://example.com/response")
|
||||||
|
form.Add("grant_type", "authorization_code")
|
||||||
|
form.Add("code_verifier", code_verifier)
|
||||||
|
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
|
||||||
|
req.Header.Add("Accept", "application/json")
|
||||||
|
assertions.AssertNoError(t, err, "Error creating request")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router := main.SingleUserRouter(&repo)
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
assertions.AssertStatus(t, rr, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthPostWithCorrectCodePKCEPlainWrong(t *testing.T) {
|
||||||
|
repo, user := getSingleUserTestRepo()
|
||||||
|
user.ResetPassword("testpassword")
|
||||||
|
|
||||||
|
// Create Request and Response
|
||||||
|
code_verifier := "test_code_verifier"
|
||||||
|
code_challenge := code_verifier + "wrong"
|
||||||
|
code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain")
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Add("code", code)
|
||||||
|
form.Add("client_id", "http://example.com")
|
||||||
|
form.Add("redirect_uri", "http://example.com/response")
|
||||||
|
form.Add("grant_type", "authorization_code")
|
||||||
|
form.Add("code_verifier", code_verifier)
|
||||||
|
req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
|
||||||
|
req.Header.Add("Accept", "application/json")
|
||||||
|
assertions.AssertNoError(t, err, "Error creating request")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router := main.SingleUserRouter(&repo)
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
assertions.AssertStatus(t, rr, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthRedirectUriNotSet(t *testing.T) {
|
func TestAuthRedirectUriNotSet(t *testing.T) {
|
||||||
repo, user := getSingleUserTestRepo()
|
repo, user := getSingleUserTestRepo()
|
||||||
repo.HttpClient = &mocks.MockHttpClient{}
|
repo.HttpClient = &mocks.MockHttpClient{}
|
||||||
|
|
|
@ -74,6 +74,8 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
|
||||||
redirectUri := r.URL.Query().Get("redirect_uri")
|
redirectUri := r.URL.Query().Get("redirect_uri")
|
||||||
state := r.URL.Query().Get("state")
|
state := r.URL.Query().Get("state")
|
||||||
responseType := r.URL.Query().Get("response_type")
|
responseType := r.URL.Query().Get("response_type")
|
||||||
|
codeChallenge := r.URL.Query().Get("code_challenge")
|
||||||
|
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")
|
||||||
|
|
||||||
// check if request is valid
|
// check if request is valid
|
||||||
missing_params := []string{}
|
missing_params := []string{}
|
||||||
|
@ -99,6 +101,11 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
|
||||||
w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support)."))
|
w.Write([]byte("Invalid response_type. Must be 'code' ('id' converted to 'code' for legacy support)."))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if codeChallengeMethod != "" && (codeChallengeMethod != "S256" && codeChallengeMethod != "plain") {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte("Invalid code_challenge_method. Must be 'S256' or 'plain'."))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// check if redirect_uri is registered
|
// check if redirect_uri is registered
|
||||||
resp, _ := repo.HttpClient.Get(clientId)
|
resp, _ := repo.HttpClient.Get(clientId)
|
||||||
|
@ -127,13 +134,15 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
|
||||||
http.SetCookie(w, &cookie)
|
http.SetCookie(w, &cookie)
|
||||||
|
|
||||||
reqData := owl.AuthRequestData{
|
reqData := owl.AuthRequestData{
|
||||||
Me: me,
|
Me: me,
|
||||||
ClientId: clientId,
|
ClientId: clientId,
|
||||||
RedirectUri: redirectUri,
|
RedirectUri: redirectUri,
|
||||||
State: state,
|
State: state,
|
||||||
ResponseType: responseType,
|
ResponseType: responseType,
|
||||||
User: user,
|
CodeChallenge: codeChallenge,
|
||||||
CsrfToken: csrfToken,
|
CodeChallengeMethod: codeChallengeMethod,
|
||||||
|
User: user,
|
||||||
|
CsrfToken: csrfToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
html, err := owl.RenderUserAuthPage(reqData)
|
html, err := owl.RenderUserAuthPage(reqData)
|
||||||
|
@ -168,9 +177,10 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt
|
||||||
code := r.Form.Get("code")
|
code := r.Form.Get("code")
|
||||||
client_id := r.Form.Get("client_id")
|
client_id := r.Form.Get("client_id")
|
||||||
redirect_uri := r.Form.Get("redirect_uri")
|
redirect_uri := r.Form.Get("redirect_uri")
|
||||||
|
code_verifier := r.Form.Get("code_verifier")
|
||||||
|
|
||||||
// check if request is valid
|
// check if request is valid
|
||||||
valid := user.VerifyAuthCode(code, client_id, redirect_uri)
|
valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier)
|
||||||
if !valid {
|
if !valid {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
w.Write([]byte("Invalid code"))
|
w.Write([]byte("Invalid code"))
|
||||||
|
@ -225,11 +235,12 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
password := r.FormValue("password")
|
password := r.FormValue("password")
|
||||||
println("Password: ", password)
|
|
||||||
client_id := r.FormValue("client_id")
|
client_id := r.FormValue("client_id")
|
||||||
redirect_uri := r.FormValue("redirect_uri")
|
redirect_uri := r.FormValue("redirect_uri")
|
||||||
response_type := r.FormValue("response_type")
|
response_type := r.FormValue("response_type")
|
||||||
state := r.FormValue("state")
|
state := r.FormValue("state")
|
||||||
|
code_challenge := r.FormValue("code_challenge")
|
||||||
|
code_challenge_method := r.FormValue("code_challenge_method")
|
||||||
|
|
||||||
// CSRF check
|
// CSRF check
|
||||||
formCsrfToken := r.FormValue("csrf_token")
|
formCsrfToken := r.FormValue("csrf_token")
|
||||||
|
@ -250,17 +261,22 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http
|
||||||
|
|
||||||
password_valid := user.VerifyPassword(password)
|
password_valid := user.VerifyPassword(password)
|
||||||
if !password_valid {
|
if !password_valid {
|
||||||
|
redirect := fmt.Sprintf(
|
||||||
|
"%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s",
|
||||||
|
user.AuthUrl(), client_id, redirect_uri, response_type, state,
|
||||||
|
)
|
||||||
|
if code_challenge != "" {
|
||||||
|
redirect += fmt.Sprintf("&code_challenge=%s&code_challenge_method=%s", code_challenge, code_challenge_method)
|
||||||
|
}
|
||||||
http.Redirect(w, r,
|
http.Redirect(w, r,
|
||||||
fmt.Sprintf(
|
redirect,
|
||||||
"%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s",
|
|
||||||
user.AuthUrl(), client_id, redirect_uri, response_type, state,
|
|
||||||
),
|
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
// password is valid, generate code
|
// password is valid, generate code
|
||||||
code, err := user.GenerateAuthCode(client_id, redirect_uri)
|
code, err := user.GenerateAuthCode(
|
||||||
|
client_id, redirect_uri, code_challenge, code_challenge_method)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
println("Error generating code: ", err.Error())
|
println("Error generating code: ", err.Error())
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
|
@ -8,5 +8,7 @@
|
||||||
<input type="hidden" name="response_type" value="{{.ResponseType}}">
|
<input type="hidden" name="response_type" value="{{.ResponseType}}">
|
||||||
<input type="hidden" name="state" value="{{.State}}">
|
<input type="hidden" name="state" value="{{.State}}">
|
||||||
<input type="hidden" name="csrf_token" value="{{.CsrfToken}}">
|
<input type="hidden" name="csrf_token" value="{{.CsrfToken}}">
|
||||||
|
<input type="hidden" name="code_challenge" value="{{.CodeChallenge}}">
|
||||||
|
<input type="hidden" name="code_challenge_method" value="{{.CodeChallengeMethod}}">
|
||||||
<input type="submit" value="Login">
|
<input type="submit" value="Login">
|
||||||
</form>
|
</form>
|
16
renderer.go
16
renderer.go
|
@ -21,13 +21,15 @@ type PostRenderData struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthRequestData struct {
|
type AuthRequestData struct {
|
||||||
Me string
|
Me string
|
||||||
ClientId string
|
ClientId string
|
||||||
RedirectUri string
|
RedirectUri string
|
||||||
State string
|
State string
|
||||||
ResponseType string
|
ResponseType string
|
||||||
User User
|
CodeChallenge string
|
||||||
CsrfToken string
|
CodeChallengeMethod string
|
||||||
|
User User
|
||||||
|
CsrfToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
func renderEmbedTemplate(templateFile string, data interface{}) (string, error) {
|
func renderEmbedTemplate(templateFile string, data interface{}) (string, error) {
|
||||||
|
|
40
user.go
40
user.go
|
@ -1,6 +1,8 @@
|
||||||
package owl
|
package owl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
@ -32,10 +34,12 @@ type UserMe struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthCode struct {
|
type AuthCode struct {
|
||||||
Code string `yaml:"code"`
|
Code string `yaml:"code"`
|
||||||
ClientId string `yaml:"client_id"`
|
ClientId string `yaml:"client_id"`
|
||||||
RedirectUri string `yaml:"redirect_uri"`
|
RedirectUri string `yaml:"redirect_uri"`
|
||||||
Created time.Time `yaml:"created"`
|
CodeChallenge string `yaml:"code_challenge"`
|
||||||
|
CodeChallengeMethod string `yaml:"code_challenge_method"`
|
||||||
|
Created time.Time `yaml:"created"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user User) Dir() string {
|
func (user User) Dir() string {
|
||||||
|
@ -284,21 +288,37 @@ func (user User) addAuthCode(code AuthCode) error {
|
||||||
return saveToYaml(user.AuthCodesFile(), codes)
|
return saveToYaml(user.AuthCodesFile(), codes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user User) GenerateAuthCode(client_id string, redirect_uri string) (string, error) {
|
func (user User) GenerateAuthCode(
|
||||||
|
client_id string, redirect_uri string,
|
||||||
|
code_challenge string, code_challenge_method string,
|
||||||
|
) (string, error) {
|
||||||
// generate code
|
// generate code
|
||||||
code := GenerateRandomString(32)
|
code := GenerateRandomString(32)
|
||||||
return code, user.addAuthCode(AuthCode{
|
return code, user.addAuthCode(AuthCode{
|
||||||
Code: code,
|
Code: code,
|
||||||
ClientId: client_id,
|
ClientId: client_id,
|
||||||
RedirectUri: redirect_uri,
|
RedirectUri: redirect_uri,
|
||||||
|
CodeChallenge: code_challenge,
|
||||||
|
CodeChallengeMethod: code_challenge_method,
|
||||||
|
Created: time.Now(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user User) VerifyAuthCode(code string, client_id string, redirect_uri string) bool {
|
func (user User) VerifyAuthCode(
|
||||||
|
code string, client_id string, redirect_uri string, code_verifier string,
|
||||||
|
) bool {
|
||||||
codes := user.getAuthCodes()
|
codes := user.getAuthCodes()
|
||||||
for _, c := range codes {
|
for _, c := range codes {
|
||||||
if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri {
|
if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri {
|
||||||
return true
|
if c.CodeChallengeMethod == "plain" {
|
||||||
|
return c.CodeChallenge == code_verifier
|
||||||
|
} else if c.CodeChallengeMethod == "S256" {
|
||||||
|
// hash code_verifier
|
||||||
|
hash := sha256.Sum256([]byte(code_verifier))
|
||||||
|
return c.CodeChallenge == base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
} else if c.CodeChallengeMethod == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|
Loading…
Reference in New Issue