diff --git a/auth_test.go b/auth_test.go
new file mode 100644
index 0000000..5efd35f
--- /dev/null
+++ b/auth_test.go
@@ -0,0 +1,48 @@
+package owl_test
+
+import (
+ "h4kor/owl-blogs"
+ "h4kor/owl-blogs/test/assertions"
+ "net/http"
+ "testing"
+)
+
+func TestGetRedirctUrisLink(t *testing.T) {
+ html := []byte("")
+ parser := &owl.OwlHtmlParser{}
+ uris, err := parser.GetRedirctUris(constructResponse(html))
+
+ assertions.AssertNoError(t, err, "Unable to parse feed")
+
+ assertions.AssertArrayContains(t, uris, "http://example.com/redirect")
+}
+
+func TestGetRedirctUrisLinkMultiple(t *testing.T) {
+ html := []byte(`
+
+
+
+
+
+ `)
+ parser := &owl.OwlHtmlParser{}
+ uris, err := parser.GetRedirctUris(constructResponse(html))
+
+ assertions.AssertNoError(t, err, "Unable to parse feed")
+
+ assertions.AssertArrayContains(t, uris, "http://example.com/redirect1")
+ assertions.AssertArrayContains(t, uris, "http://example.com/redirect2")
+ assertions.AssertArrayContains(t, uris, "http://example.com/redirect3")
+ assertions.AssertLen(t, uris, 3)
+}
+
+func TestGetRedirectUrisLinkHeader(t *testing.T) {
+ html := []byte("")
+ parser := &owl.OwlHtmlParser{}
+ resp := constructResponse(html)
+ resp.Header = http.Header{"Link": []string{"; rel=\"redirect_uri\""}}
+ uris, err := parser.GetRedirctUris(resp)
+
+ assertions.AssertNoError(t, err, "Unable to parse feed")
+ assertions.AssertArrayContains(t, uris, "http://example.com/redirect")
+}
diff --git a/cmd/owl/web/auth_test.go b/cmd/owl/web/auth_test.go
index 65ca9d6..0bf6e76 100644
--- a/cmd/owl/web/auth_test.go
+++ b/cmd/owl/web/auth_test.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
main "h4kor/owl-blogs/cmd/owl/web"
"h4kor/owl-blogs/test/assertions"
+ "h4kor/owl-blogs/test/mocks"
"net/http"
"net/http/httptest"
"net/url"
@@ -121,3 +122,65 @@ func TestAuthPostWithCorrectCode(t *testing.T) {
assertions.AssertEqual(t, response.Me, user.FullUrl())
}
+
+func TestAuthRedirectUriNotSet(t *testing.T) {
+ repo, user := getSingleUserTestRepo()
+ repo.HttpClient = &mocks.MockHttpClient{}
+ repo.Parser = &mocks.MockParseLinksHtmlParser{
+ Links: []string{"http://example.com/response"},
+ }
+ user.ResetPassword("testpassword")
+
+ csrfToken := "test_csrf_token"
+
+ // Create Request and Response
+ form := url.Values{}
+ form.Add("password", "wrongpassword")
+ form.Add("client_id", "http://example.com")
+ form.Add("redirect_uri", "http://example.com/response_not_set")
+ form.Add("response_type", "code")
+ form.Add("state", "test_state")
+ form.Add("csrf_token", csrfToken)
+
+ req, err := http.NewRequest("GET", user.AuthUrl()+"?"+form.Encode(), nil)
+ req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
+ req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
+ assertions.AssertNoError(t, err, "Error creating request")
+ rr := httptest.NewRecorder()
+ router := main.SingleUserRouter(&repo)
+ router.ServeHTTP(rr, req)
+
+ assertions.AssertStatus(t, rr, http.StatusBadRequest)
+}
+
+func TestAuthRedirectUriSet(t *testing.T) {
+ repo, user := getSingleUserTestRepo()
+ repo.HttpClient = &mocks.MockHttpClient{}
+ repo.Parser = &mocks.MockParseLinksHtmlParser{
+ Links: []string{"http://example.com/response"},
+ }
+ user.ResetPassword("testpassword")
+
+ csrfToken := "test_csrf_token"
+
+ // Create Request and Response
+ form := url.Values{}
+ form.Add("password", "wrongpassword")
+ form.Add("client_id", "http://example.com")
+ form.Add("redirect_uri", "http://example.com/response")
+ form.Add("response_type", "code")
+ form.Add("state", "test_state")
+ form.Add("csrf_token", csrfToken)
+
+ req, err := http.NewRequest("GET", user.AuthUrl()+"?"+form.Encode(), nil)
+ req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Add("Content-Length", strconv.Itoa(len(form.Encode())))
+ req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
+ assertions.AssertNoError(t, err, "Error creating request")
+ rr := httptest.NewRecorder()
+ router := main.SingleUserRouter(&repo)
+ router.ServeHTTP(rr, req)
+
+ assertions.AssertStatus(t, rr, http.StatusOK)
+}
diff --git a/cmd/owl/web/handler.go b/cmd/owl/web/handler.go
index 8a7ead3..62ad48f 100644
--- a/cmd/owl/web/handler.go
+++ b/cmd/owl/web/handler.go
@@ -100,6 +100,23 @@ func userAuthHandler(repo *owl.Repository) func(http.ResponseWriter, *http.Reque
return
}
+ // check if redirect_uri is registered
+ resp, _ := repo.HttpClient.Get(clientId)
+ registered_redirects, _ := repo.Parser.GetRedirctUris(resp)
+ is_registered := false
+ for _, registered_redirect := range registered_redirects {
+ if registered_redirect == redirectUri {
+ // redirect_uri is registered
+ is_registered = true
+ break
+ }
+ }
+ if !is_registered {
+ w.WriteHeader(http.StatusBadRequest)
+ w.Write([]byte("Invalid redirect_uri. Must be registered with client_id."))
+ return
+ }
+
// Double Submit Cookie Pattern
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#double-submit-cookie
csrfToken := owl.GenerateRandomString(32)
diff --git a/test/assertions/asserts.go b/test/assertions/asserts.go
index d30b69c..613a6bc 100644
--- a/test/assertions/asserts.go
+++ b/test/assertions/asserts.go
@@ -27,6 +27,16 @@ func AssertContains(t *testing.T, containing string, search string) {
}
}
+func AssertArrayContains[T comparable](t *testing.T, list []T, search T) {
+ t.Helper()
+ for _, item := range list {
+ if item == search {
+ return
+ }
+ }
+ t.Errorf("Expected '%v' to be in '%v'", search, list)
+}
+
func AssertNotContains(t *testing.T, containing string, search string) {
t.Helper()
if strings.Contains(containing, search) {
diff --git a/webmention.go b/webmention.go
index 67ac2d6..046af1c 100644
--- a/webmention.go
+++ b/webmention.go
@@ -259,6 +259,30 @@ func (OwlHtmlParser) GetRedirctUris(resp *http.Response) ([]string, error) {
}
var findLinks func(*html.Node) ([]string, error)
+ // Check link headers
+ header_links := make([]string, 0)
+ for _, linkHeader := range resp.Header["Link"] {
+ linkHeaderParts := strings.Split(linkHeader, ",")
+ for _, linkHeaderPart := range linkHeaderParts {
+ linkHeaderPart = strings.TrimSpace(linkHeaderPart)
+ params := strings.Split(linkHeaderPart, ";")
+ if len(params) != 2 {
+ continue
+ }
+ for _, param := range params[1:] {
+ param = strings.TrimSpace(param)
+ if strings.Contains(param, "redirect_uri") {
+ link := strings.Split(params[0], ";")[0]
+ link = strings.Trim(link, "<>")
+ linkUrl, err := url.Parse(link)
+ if err == nil {
+ header_links = append(header_links, requestUrl.ResolveReference(linkUrl).String())
+ }
+ }
+ }
+ }
+ }
+
findLinks = func(n *html.Node) ([]string, error) {
links := make([]string, 0)
if n.Type == html.ElementNode && n.Data == "link" {
@@ -287,5 +311,6 @@ func (OwlHtmlParser) GetRedirctUris(resp *http.Response) ([]string, error) {
}
return links, nil
}
- return findLinks(doc)
+ body_links, err := findLinks(doc)
+ return append(body_links, header_links...), err
}