diff --git a/cmd/owl/web/auth_test.go b/cmd/owl/web/auth_test.go new file mode 100644 index 0000000..a7485fe --- /dev/null +++ b/cmd/owl/web/auth_test.go @@ -0,0 +1,114 @@ +package web_test + +import ( + "encoding/json" + main "h4kor/owl-blogs/cmd/owl/web" + "h4kor/owl-blogs/priv/assertions" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "testing" +) + +func TestAuthPostWrongPassword(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + + // Create Request and Response + form := url.Values{} + form.Add("password", "wrongpassword") + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("response_type", "code") + form.Add("state", "test_state") + req, err := http.NewRequest("POST", user.AuthUrl()+"verify/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusFound) + assertions.AssertContains(t, rr.Header().Get("Location"), "error=invalid_password") +} + +func TestAuthPostCorrectPassword(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + + // Create Request and Response + form := url.Values{} + form.Add("password", "testpassword") + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("response_type", "code") + form.Add("state", "test_state") + req, err := http.NewRequest("POST", user.AuthUrl()+"verify/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusFound) + assertions.AssertContains(t, rr.Header().Get("Location"), "code=") + assertions.AssertContains(t, rr.Header().Get("Location"), "state=test_state") + assertions.AssertContains(t, rr.Header().Get("Location"), "http://example.com/response") +} + +func TestAuthPostWithIncorrectCode(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + user.GenerateAuthCode("http://example.com", "http://example.com/response") + + // Create Request and Response + form := url.Values{} + form.Add("code", "wrongcode") + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("grant_type", "authorization_code") + req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusUnauthorized) +} + +func TestAuthPostWithCorrectCode(t *testing.T) { + repo, user := getSingleUserTestRepo() + user.ResetPassword("testpassword") + code, _ := user.GenerateAuthCode("http://example.com", "http://example.com/response") + + // Create Request and Response + form := url.Values{} + form.Add("code", code) + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("grant_type", "authorization_code") + req, err := http.NewRequest("POST", user.AuthUrl(), strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + req.Header.Add("Accept", "application/json") + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusOK) + // parse response as json + type responseType struct { + Me string `json:"me"` + } + var response responseType + json.Unmarshal(rr.Body.Bytes(), &response) + assertions.AssertEqual(t, response.Me, user.FullUrl()) + +} diff --git a/cmd/owl/web/handler.go b/cmd/owl/web/handler.go index 3228371..0d100aa 100644 --- a/cmd/owl/web/handler.go +++ b/cmd/owl/web/handler.go @@ -1,6 +1,7 @@ package web import ( + "encoding/json" "fmt" "h4kor/owl-blogs" "net/http" @@ -120,6 +121,121 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque } } +func userAuthProfileHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Request, httprouter.Params) { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + user, err := getUserFromRepo(repo, ps) + if err != nil { + println("Error getting user: ", err.Error()) + notFoundHandler(repo)(w, r) + return + } + + // get form data from post request + err = r.ParseForm() + if err != nil { + println("Error parsing form: ", err.Error()) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Error parsing form")) + return + } + code := r.Form.Get("code") + client_id := r.Form.Get("client_id") + redirect_uri := r.Form.Get("redirect_uri") + + // check if request is valid + valid := user.VerifyAuthCode(code, client_id, redirect_uri) + if !valid { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Invalid code")) + return + } else { + w.WriteHeader(http.StatusOK) + type ResponseProfile struct { + Name string `json:"name"` + Url string `json:"url"` + Photo string `json:"photo"` + } + type Response struct { + Me string `json:"me"` + Profile ResponseProfile `json:"profile"` + } + response := Response{ + Me: user.FullUrl(), + Profile: ResponseProfile{ + Name: user.Name(), + Url: user.FullUrl(), + Photo: user.AvatarUrl(), + }, + } + jsonData, err := json.Marshal(response) + if err != nil { + println("Error marshalling json: ", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal server error")) + } + w.Write(jsonData) + return + } + + } +} + +func userAuthVerifyHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Request, httprouter.Params) { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + user, err := getUserFromRepo(repo, ps) + if err != nil { + println("Error getting user: ", err.Error()) + notFoundHandler(repo)(w, r) + return + } + + // get form data from post request + err = r.ParseForm() + if err != nil { + println("Error parsing form: ", err.Error()) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Error parsing form")) + return + } + password := r.FormValue("password") + println("Password: ", password) + client_id := r.FormValue("client_id") + redirect_uri := r.FormValue("redirect_uri") + response_type := r.FormValue("response_type") + state := r.FormValue("state") + + password_valid := user.VerifyPassword(password) + if !password_valid { + http.Redirect(w, r, + fmt.Sprintf( + "%s?error=invalid_password&client_id=%s&redirect_uri=%s&response_type=%s&state=%s", + user.AuthUrl(), client_id, redirect_uri, response_type, state, + ), + http.StatusFound, + ) + return + } else { + // password is valid, generate code + code, err := user.GenerateAuthCode(client_id, redirect_uri) + if err != nil { + println("Error generating code: ", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal server error")) + return + } + http.Redirect(w, r, + fmt.Sprintf( + "%s?code=%s&state=%s", + redirect_uri, code, state, + ), + http.StatusFound, + ) + return + } + + } +} + func userWebmentionHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Request, httprouter.Params) { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { user, err := getUserFromRepo(repo, ps) diff --git a/cmd/owl/web/server.go b/cmd/owl/web/server.go index b5b542e..628df4e 100644 --- a/cmd/owl/web/server.go +++ b/cmd/owl/web/server.go @@ -15,6 +15,8 @@ func Router(repo *owl.Repository) http.Handler { router.GET("/", repoIndexHandler(repo)) router.GET("/user/:user/", userIndexHandler(repo)) router.GET("/user/:user/auth/", userAuthHandler(repo)) + router.POST("/user/:user/auth/", userAuthProfileHandler(repo)) + router.POST("/user/:user/auth/verify/", userAuthVerifyHandler(repo)) router.GET("/user/:user/media/*filepath", userMediaHandler(repo)) router.GET("/user/:user/index.xml", userRSSHandler(repo)) router.GET("/user/:user/posts/:post/", postHandler(repo)) @@ -29,6 +31,8 @@ func SingleUserRouter(repo *owl.Repository) http.Handler { router.ServeFiles("/static/*filepath", http.Dir(repo.StaticDir())) router.GET("/", userIndexHandler(repo)) router.GET("/auth/", userAuthHandler(repo)) + router.POST("/auth/", userAuthProfileHandler(repo)) + router.POST("/auth/verify/", userAuthVerifyHandler(repo)) router.GET("/media/*filepath", userMediaHandler(repo)) router.GET("/index.xml", userRSSHandler(repo)) router.GET("/posts/:post/", postHandler(repo)) diff --git a/user.go b/user.go index 7c3bc5b..2c58b84 100644 --- a/user.go +++ b/user.go @@ -2,6 +2,7 @@ package owl import ( "fmt" + "math/rand" "net/url" "os" "path" @@ -31,6 +32,13 @@ type UserMe struct { Url string `yaml:"url"` } +type AuthCode struct { + Code string `yaml:"code"` + ClientId string `yaml:"client_id"` + RedirectUri string `yaml:"redirect_uri"` + Created time.Time `yaml:"created"` +} + func (user User) Dir() string { return path.Join(user.repo.UsersDir(), user.name) } @@ -78,6 +86,10 @@ func (user User) ConfigFile() string { return path.Join(user.MetaDir(), "config.yml") } +func (user User) AuthCodesFile() string { + return path.Join(user.MetaDir(), "access_tokens.yml") +} + func (user User) Name() string { return user.name } @@ -260,3 +272,40 @@ func (user User) VerifyPassword(password string) bool { ) return err == nil } + +func (user User) getAuthCodes() []AuthCode { + codes := make([]AuthCode, 0) + loadFromYaml(user.AuthCodesFile(), &codes) + return codes +} + +func (user User) addAuthCode(code AuthCode) error { + codes := user.getAuthCodes() + codes = append(codes, code) + return saveToYaml(user.AuthCodesFile(), codes) +} + +func (user User) GenerateAuthCode(client_id string, redirect_uri string) (string, error) { + // generate code + const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, 32) + for i := range b { + b[i] = chars[rand.Intn(len(chars))] + } + code := string(b) + return code, user.addAuthCode(AuthCode{ + Code: code, + ClientId: client_id, + RedirectUri: redirect_uri, + }) +} + +func (user User) VerifyAuthCode(code string, client_id string, redirect_uri string) bool { + codes := user.getAuthCodes() + for _, c := range codes { + if c.Code == code && c.ClientId == client_id && c.RedirectUri == redirect_uri { + return true + } + } + return false +}