diff --git a/cmd/owl/web/auth_test.go b/cmd/owl/web/auth_test.go index f23019f..ba1fe95 100644 --- a/cmd/owl/web/auth_test.go +++ b/cmd/owl/web/auth_test.go @@ -339,3 +339,61 @@ func TestAuthRedirectUriSameHost(t *testing.T) { assertions.AssertStatus(t, rr, http.StatusOK) } + +func TestAccessTokenCorrectPassword(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "") + + // Create Request and Response + form := url.Values{} + form.Add("code", code) + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("grant_type", "authorization_code") + req, err := http.NewRequest("POST", user.AuthUrl()+"token/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusOK) + // parse response as json + type responseType struct { + Me string `json:"me"` + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + var response responseType + json.Unmarshal(rr.Body.Bytes(), &response) + assertions.AssertEqual(t, response.Me, user.FullUrl()) + assertions.AssertEqual(t, response.TokenType, "Bearer") + assertions.Assert(t, response.ExpiresIn > 0, "ExpiresIn should be greater than 0") + assertions.Assert(t, len(response.AccessToken) > 0, "AccessToken should be greater than 0") +} + +func TestAccessTokenWithIncorrectCode(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + user.GenerateAuthCode("http://example.com", "http://example.com/response", "", "") + + // Create Request and Response + form := url.Values{} + form.Add("code", "wrongcode") + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("grant_type", "authorization_code") + req, err := http.NewRequest("POST", user.AuthUrl()+"token/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusUnauthorized) +} diff --git a/cmd/owl/web/handler.go b/cmd/owl/web/handler.go index 0a07aba..b5ca022 100644 --- a/cmd/owl/web/handler.go +++ b/cmd/owl/web/handler.go @@ -171,6 +171,30 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque } } +func verifyAuthCodeRequest(user owl.User, w http.ResponseWriter, r *http.Request) bool { + // get form data from post request + err := r.ParseForm() + if err != nil { + println("Error parsing form: ", err.Error()) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Error parsing form")) + return false + } + code := r.Form.Get("code") + client_id := r.Form.Get("client_id") + redirect_uri := r.Form.Get("redirect_uri") + code_verifier := r.Form.Get("code_verifier") + + // check if request is valid + valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier) + if !valid { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Invalid code")) + return false + } + return true +} + func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Request, httprouter.Params) { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { user, err := getUserFromRepo(repo, ps) @@ -180,26 +204,7 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt return } - // get form data from post request - err = r.ParseForm() - if err != nil { - println("Error parsing form: ", err.Error()) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("Error parsing form")) - return - } - code := r.Form.Get("code") - client_id := r.Form.Get("client_id") - redirect_uri := r.Form.Get("redirect_uri") - code_verifier := r.Form.Get("code_verifier") - - // check if request is valid - valid := user.VerifyAuthCode(code, client_id, redirect_uri, code_verifier) - if !valid { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("Invalid code")) - return - } else { + if verifyAuthCodeRequest(user, w, r) { w.WriteHeader(http.StatusOK) type ResponseProfile struct { Name string `json:"name"` @@ -227,7 +232,48 @@ func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *htt w.Write(jsonData) return } + } +} +func userAuthTokenHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Request, httprouter.Params) { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + user, err := getUserFromRepo(repo, ps) + if err != nil { + println("Error getting user: ", err.Error()) + notFoundHandler(repo)(w, r) + return + } + + if verifyAuthCodeRequest(user, w, r) { + type Response struct { + Me string `json:"me"` + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + accessToken, duration, err := user.GenerateAccessToken() + if err != nil { + println("Error generating access token: ", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal server error")) + return + } + response := Response{ + Me: user.FullUrl(), + TokenType: "Bearer", + AccessToken: accessToken, + ExpiresIn: duration, + } + jsonData, err := json.Marshal(response) + if err != nil { + println("Error marshalling json: ", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal server error")) + } + w.Write(jsonData) + return + } } } diff --git a/cmd/owl/web/server.go b/cmd/owl/web/server.go index 628df4e..02937c1 100644 --- a/cmd/owl/web/server.go +++ b/cmd/owl/web/server.go @@ -17,6 +17,7 @@ func Router(repo *owl.Repository) http.Handler { router.GET("/user/:user/auth/", userAuthHandler(repo)) router.POST("/user/:user/auth/", userAuthProfileHandler(repo)) router.POST("/user/:user/auth/verify/", userAuthVerifyHandler(repo)) + router.POST("/user/:user/auth/token/", userAuthTokenHandler(repo)) router.GET("/user/:user/media/*filepath", userMediaHandler(repo)) router.GET("/user/:user/index.xml", userRSSHandler(repo)) router.GET("/user/:user/posts/:post/", postHandler(repo)) @@ -33,6 +34,7 @@ func SingleUserRouter(repo *owl.Repository) http.Handler { router.GET("/auth/", userAuthHandler(repo)) router.POST("/auth/", userAuthProfileHandler(repo)) router.POST("/auth/verify/", userAuthVerifyHandler(repo)) + router.POST("/auth/token/", userAuthTokenHandler(repo)) router.GET("/media/*filepath", userMediaHandler(repo)) router.GET("/index.xml", userRSSHandler(repo)) router.GET("/posts/:post/", postHandler(repo)) diff --git a/embed/initial/base.html b/embed/initial/base.html index 003f0bb..ff99bb4 100644 --- a/embed/initial/base.html +++ b/embed/initial/base.html @@ -29,6 +29,7 @@ {{ if .User.AuthUrl }} + {{ end }}