diff --git a/cmd/owl/web/auth_test.go b/cmd/owl/web/auth_test.go index ba1fe95..9ce37f9 100644 --- a/cmd/owl/web/auth_test.go +++ b/cmd/owl/web/auth_test.go @@ -75,7 +75,7 @@ func TestAuthPostCorrectPassword(t *testing.T) { func TestAuthPostWithIncorrectCode(t *testing.T) { repo, user := getSingleUserTestRepo() user.ResetPassword("testpassword") - user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "") + user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "", "profile") // Create Request and Response form := url.Values{} @@ -97,7 +97,7 @@ func TestAuthPostWithIncorrectCode(t *testing.T) { func TestAuthPostWithCorrectCode(t *testing.T) { repo, user := getSingleUserTestRepo() user.ResetPassword("testpassword") - code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "", "profile") // Create Request and Response form := url.Values{} @@ -135,7 +135,7 @@ func TestAuthPostWithCorrectCodeAndPKCE(t *testing.T) { h := sha256.New() h.Write([]byte(code_verifier)) code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256", "profile") form := url.Values{} form.Add("code", code) @@ -173,7 +173,7 @@ func TestAuthPostWithCorrectCodeAndWrongPKCE(t *testing.T) { h := sha256.New() h.Write([]byte(code_verifier + "wrong")) code_challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "S256", "profile") form := url.Values{} form.Add("code", code) @@ -200,7 +200,7 @@ func TestAuthPostWithCorrectCodePKCEPlain(t *testing.T) { // Create Request and Response code_verifier := "test_code_verifier" code_challenge := code_verifier - code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain", "profile") form := url.Values{} form.Add("code", code) @@ -227,7 +227,7 @@ func TestAuthPostWithCorrectCodePKCEPlainWrong(t *testing.T) { // Create Request and Response code_verifier := "test_code_verifier" code_challenge := code_verifier + "wrong" - code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", code_challenge, "plain", "profile") form := url.Values{} form.Add("code", code) @@ -343,7 +343,7 @@ func TestAuthRedirectUriSameHost(t *testing.T) { func TestAccessTokenCorrectPassword(t *testing.T) { repo, user := getSingleUserTestRepo() user.ResetPassword("testpassword") - code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "", "profile create") // Create Request and Response form := url.Values{} @@ -367,11 +367,13 @@ func TestAccessTokenCorrectPassword(t *testing.T) { AccessToken string `json:"access_token"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` } var response responseType json.Unmarshal(rr.Body.Bytes(), &response) assertions.AssertEqual(t, response.Me, user.FullUrl()) assertions.AssertEqual(t, response.TokenType, "Bearer") + assertions.AssertEqual(t, response.Scope, "profile create") assertions.Assert(t, response.ExpiresIn > 0, "ExpiresIn should be greater than 0") assertions.Assert(t, len(response.AccessToken) > 0, "AccessToken should be greater than 0") } @@ -379,7 +381,7 @@ func TestAccessTokenCorrectPassword(t *testing.T) { func TestAccessTokenWithIncorrectCode(t *testing.T) { repo, user := getSingleUserTestRepo() user.ResetPassword("testpassword") - user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "") + user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "", "profile") // Create Request and Response form := url.Values{} diff --git a/cmd/owl/web/handler.go b/cmd/owl/web/handler.go index b5ca022..4cbf887 100644 --- a/cmd/owl/web/handler.go +++ b/cmd/owl/web/handler.go @@ -76,6 +76,7 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque responseType := r.URL.Query().Get("response_type") codeChallenge := r.URL.Query().Get("code_challenge") codeChallengeMethod := r.URL.Query().Get("code_challenge_method") + scope := r.URL.Query().Get("scope") // check if request is valid missing_params := []string{} @@ -152,6 +153,7 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque ClientId: clientId, RedirectUri: redirectUri, State: state, + Scope: scope, ResponseType: responseType, CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, @@ -171,14 +173,14 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque } } -func verifyAuthCodeRequest(user owl.User, w http.ResponseWriter, r *http.Request) bool { +func verifyAuthCodeRequest(user owl.User, w http.ResponseWriter, r *http.Request) (bool, owl.AuthCode) { // get form data from post request err := r.ParseForm() if err != nil { println("Error parsing form: ", err.Error()) w.WriteHeader(http.StatusBadRequest) w.Write([]byte("Error parsing form")) - return false + return false, owl.AuthCode{} } code := r.Form.Get("code") client_id := r.Form.Get("client_id") @@ -186,13 +188,12 @@ func verifyAuthCodeRequest(user owl.User, w http.ResponseWriter, r *http.Request code_verifier := r.Form.Get("code_verifier") // check if request is valid - valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier) + valid, authCode := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier) if !valid { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("Invalid code")) - return false } - return true + return valid, authCode } func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Request, httprouter.Params) { @@ -204,7 +205,8 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt return } - if verifyAuthCodeRequest(user, w, r) { + valid, _ := verifyAuthCodeRequest(user, w, r) + if valid { w.WriteHeader(http.StatusOK) type ResponseProfile struct { Name string `json:"name"` @@ -244,15 +246,23 @@ func userAuthTokenHandler(repo *owl.Repository) func(http.ResponseWriter, *http. return } - if verifyAuthCodeRequest(user, w, r) { + valid, authCode := verifyAuthCodeRequest(user, w, r) + if valid { + if authCode.Scope == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Empty scope, no token issued")) + return + } + type Response struct { Me string `json:"me"` TokenType string `json:"token_type"` AccessToken string `json:"access_token"` + Scope string `json:"scope"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token"` } - accessToken, duration, err := user.GenerateAccessToken() + accessToken, duration, err := user.GenerateAccessToken(authCode) if err != nil { println("Error generating access token: ", err.Error()) w.WriteHeader(http.StatusInternalServerError) @@ -263,6 +273,7 @@ func userAuthTokenHandler(repo *owl.Repository) func(http.ResponseWriter, *http. Me: user.FullUrl(), TokenType: "Bearer", AccessToken: accessToken, + Scope: authCode.Scope, ExpiresIn: duration, } jsonData, err := json.Marshal(response) @@ -301,6 +312,7 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http state := r.FormValue("state") code_challenge := r.FormValue("code_challenge") code_challenge_method := r.FormValue("code_challenge_method") + scope := r.FormValue("scope") // CSRF check formCsrfToken := r.FormValue("csrf_token") @@ -336,7 +348,7 @@ func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http } else { // password is valid, generate code code, err := user.GenerateAuthCode( - client_id, redirect_uri, code_challenge, code_challenge_method) + client_id, redirect_uri, code_challenge, code_challenge_method, scope) if err != nil { println("Error generating code: ", err.Error()) w.WriteHeader(http.StatusInternalServerError) diff --git a/embed/auth.html b/embed/auth.html index d2f322a..0c3a4b4 100644 --- a/embed/auth.html +++ b/embed/auth.html @@ -10,5 +10,6 @@ + \ No newline at end of file diff --git a/renderer.go b/renderer.go index b72bf7d..3bdfee3 100644 --- a/renderer.go +++ b/renderer.go @@ -25,6 +25,7 @@ type AuthRequestData struct { ClientId string RedirectUri string State string + Scope string ResponseType string CodeChallenge string CodeChallengeMethod string diff --git a/user.go b/user.go index 2a2373d..ec0b2c3 100644 --- a/user.go +++ b/user.go @@ -39,13 +39,17 @@ type AuthCode struct { RedirectUri string `yaml:"redirect_uri"` CodeChallenge string `yaml:"code_challenge"` CodeChallengeMethod string `yaml:"code_challenge_method"` + Scope string `yaml:"scope"` Created time.Time `yaml:"created"` } type AccessToken struct { - Token string `yaml:"token"` - Created time.Time `yaml:"created"` - ExpiresIn int `yaml:"expires_in"` + Token string `yaml:"token"` + Scope string `yaml:"scope"` + ClientId string `yaml:"client_id"` + RedirectUri string `yaml:"redirect_uri"` + Created time.Time `yaml:"created"` + ExpiresIn int `yaml:"expires_in"` } func (user User) Dir() string { @@ -306,6 +310,7 @@ func (user User) addAuthCode(code AuthCode) error { func (user User) GenerateAuthCode( client_id string, redirect_uri string, code_challenge string, code_challenge_method string, + scope string, ) (string, error) { // generate code code := GenerateRandomString(32) @@ -315,28 +320,29 @@ func (user User) GenerateAuthCode( RedirectUri: redirect_uri, CodeChallenge: code_challenge, CodeChallengeMethod: code_challenge_method, + Scope: scope, Created: time.Now(), }) } func (user User) VerifyAuthCode( code string, client_id string, redirect_uri string, code_verifier string, -) bool { +) (bool, AuthCode) { codes := user.getAuthCodes() for _, c := range codes { if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri { if c.CodeChallengeMethod == "plain" { - return c.CodeChallenge == code_verifier + return c.CodeChallenge == code_verifier, c } else if c.CodeChallengeMethod == "S256" { // hash code_verifier hash := sha256.Sum256([]byte(code_verifier)) - return c.CodeChallenge == base64.RawURLEncoding.EncodeToString(hash[:]) + return c.CodeChallenge == base64.RawURLEncoding.EncodeToString(hash[:]), c } else if c.CodeChallengeMethod == "" { - return true + return true, c } } } - return false + return false, AuthCode{} } func (user User) getAccessTokens() []AccessToken { @@ -351,13 +357,16 @@ func (user User) addAccessToken(code AccessToken) error { return saveToYaml(user.AccessTokensFile(), codes) } -func (user User) GenerateAccessToken() (string, int, error) { +func (user User) GenerateAccessToken(authCode AuthCode) (string, int, error) { // generate code token := GenerateRandomString(32) duration := 24 * 60 * 60 return token, duration, user.addAccessToken(AccessToken{ - Token: token, - ExpiresIn: duration, - Created: time.Now(), + Token: token, + ClientId: authCode.ClientId, + RedirectUri: authCode.RedirectUri, + Scope: authCode.Scope, + ExpiresIn: duration, + Created: time.Now(), }) }