diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..5efd35f --- /dev/null +++ b/auth_test.go @@ -0,0 +1,48 @@ +package owl_test + +import ( + "h4kor/owl-blogs" + "h4kor/owl-blogs/test/assertions" + "net/http" + "testing" +) + +func TestGetRedirctUrisLink(t *testing.T) { + html := []byte("") + parser := &owl.OwlHtmlParser{} + uris, err := parser.GetRedirctUris(constructResponse(html)) + + assertions.AssertNoError(t, err, "Unable to parse feed") + + assertions.AssertArrayContains(t, uris, "http://example.com/redirect") +} + +func TestGetRedirctUrisLinkMultiple(t *testing.T) { + html := []byte(` + + + + + + `) + parser := &owl.OwlHtmlParser{} + uris, err := parser.GetRedirctUris(constructResponse(html)) + + assertions.AssertNoError(t, err, "Unable to parse feed") + + assertions.AssertArrayContains(t, uris, "http://example.com/redirect1") + assertions.AssertArrayContains(t, uris, "http://example.com/redirect2") + assertions.AssertArrayContains(t, uris, "http://example.com/redirect3") + assertions.AssertLen(t, uris, 3) +} + +func TestGetRedirectUrisLinkHeader(t *testing.T) { + html := []byte("") + parser := &owl.OwlHtmlParser{} + resp := constructResponse(html) + resp.Header = http.Header{"Link": []string{"; rel=\"redirect_uri\""}} + uris, err := parser.GetRedirctUris(resp) + + assertions.AssertNoError(t, err, "Unable to parse feed") + assertions.AssertArrayContains(t, uris, "http://example.com/redirect") +} diff --git a/cmd/owl/web/auth_test.go b/cmd/owl/web/auth_test.go index 65ca9d6..0bf6e76 100644 --- a/cmd/owl/web/auth_test.go +++ b/cmd/owl/web/auth_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" main "h4kor/owl-blogs/cmd/owl/web" "h4kor/owl-blogs/test/assertions" + "h4kor/owl-blogs/test/mocks" "net/http" "net/http/httptest" "net/url" @@ -121,3 +122,65 @@ func TestAuthPostWithCorrectCode(t *testing.T) { assertions.AssertEqual(t, response.Me, user.FullUrl()) } + +func TestAuthRedirectUriNotSet(t *testing.T) { + repo, user := getSingleUserTestRepo() + repo.HttpClient = &mocks.MockHttpClient{} + repo.Parser = &mocks.MockParseLinksHtmlParser{ + Links: []string{"http://example.com/response"}, + } + user.ResetPassword("testpassword") + + csrfToken := "test_csrf_token" + + // Create Request and Response + form := url.Values{} + form.Add("password", "wrongpassword") + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response_not_set") + form.Add("response_type", "code") + form.Add("state", "test_state") + form.Add("csrf_token", csrfToken) + + req, err := http.NewRequest("GET", user.AuthUrl()+"?"+form.Encode(), nil) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusBadRequest) +} + +func TestAuthRedirectUriSet(t *testing.T) { + repo, user := getSingleUserTestRepo() + repo.HttpClient = &mocks.MockHttpClient{} + repo.Parser = &mocks.MockParseLinksHtmlParser{ + Links: []string{"http://example.com/response"}, + } + user.ResetPassword("testpassword") + + csrfToken := "test_csrf_token" + + // Create Request and Response + form := url.Values{} + form.Add("password", "wrongpassword") + form.Add("client_id", "http://example.com") + form.Add("redirect_uri", "http://example.com/response") + form.Add("response_type", "code") + form.Add("state", "test_state") + form.Add("csrf_token", csrfToken) + + req, err := http.NewRequest("GET", user.AuthUrl()+"?"+form.Encode(), nil) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode()))) + req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) + assertions.AssertNoError(t, err, "Error creating request") + rr := httptest.NewRecorder() + router := main.SingleUserRouter(&repo) + router.ServeHTTP(rr, req) + + assertions.AssertStatus(t, rr, http.StatusOK) +} diff --git a/cmd/owl/web/handler.go b/cmd/owl/web/handler.go index 8a7ead3..62ad48f 100644 --- a/cmd/owl/web/handler.go +++ b/cmd/owl/web/handler.go @@ -100,6 +100,23 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque return } + // check if redirect_uri is registered + resp, _ := repo.HttpClient.Get(clientId) + registered_redirects, _ := repo.Parser.GetRedirctUris(resp) + is_registered := false + for _, registered_redirect := range registered_redirects { + if registered_redirect == redirectUri { + // redirect_uri is registered + is_registered = true + break + } + } + if !is_registered { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Invalid redirect_uri. Must be registered with client_id.")) + return + } + // Double Submit Cookie Pattern // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#double-submit-cookie csrfToken := owl.GenerateRandomString(32) diff --git a/test/assertions/asserts.go b/test/assertions/asserts.go index d30b69c..613a6bc 100644 --- a/test/assertions/asserts.go +++ b/test/assertions/asserts.go @@ -27,6 +27,16 @@ func AssertContains(t *testing.T, containing string, search string) { } } +func AssertArrayContains[T comparable](t *testing.T, list []T, search T) { + t.Helper() + for _, item := range list { + if item == search { + return + } + } + t.Errorf("Expected '%v' to be in '%v'", search, list) +} + func AssertNotContains(t *testing.T, containing string, search string) { t.Helper() if strings.Contains(containing, search) { diff --git a/webmention.go b/webmention.go index 67ac2d6..046af1c 100644 --- a/webmention.go +++ b/webmention.go @@ -259,6 +259,30 @@ func (OwlHtmlParser) GetRedirctUris(resp *http.Response) ([]string, error) { } var findLinks func(*html.Node) ([]string, error) + // Check link headers + header_links := make([]string, 0) + for _, linkHeader := range resp.Header["Link"] { + linkHeaderParts := strings.Split(linkHeader, ",") + for _, linkHeaderPart := range linkHeaderParts { + linkHeaderPart = strings.TrimSpace(linkHeaderPart) + params := strings.Split(linkHeaderPart, ";") + if len(params) != 2 { + continue + } + for _, param := range params[1:] { + param = strings.TrimSpace(param) + if strings.Contains(param, "redirect_uri") { + link := strings.Split(params[0], ";")[0] + link = strings.Trim(link, "<>") + linkUrl, err := url.Parse(link) + if err == nil { + header_links = append(header_links, requestUrl.ResolveReference(linkUrl).String()) + } + } + } + } + } + findLinks = func(n *html.Node) ([]string, error) { links := make([]string, 0) if n.Type == html.ElementNode && n.Data == "link" { @@ -287,5 +311,6 @@ func (OwlHtmlParser) GetRedirctUris(resp *http.Response) ([]string, error) { } return links, nil } - return findLinks(doc) + body_links, err := findLinks(doc) + return append(body_links, header_links...), err }