Skip to content

Commit

Permalink
Added src rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
yyewolf committed Jul 12, 2023
1 parent 965798e commit fb590ac
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
14 changes: 14 additions & 0 deletions policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ type Policy struct {
// if one would want to allow all URL schemes, they would add `.+`
allowURLSchemeRegexps []*regexp.Regexp

// If srcRewriter is not nil, it is used to rewrite the src attribute
// of tags that download resources, such as <img> and <script>.
// It requires that the URL is parsable by "net/url" url.Parse().
srcRewriter urlRewriter

// If an element has had all attributes removed as a result of a policy
// being applied, then the element would be removed from the output.
//
Expand Down Expand Up @@ -196,6 +201,8 @@ type stylePolicyBuilder struct {

type urlPolicy func(url *url.URL) (allowUrl bool)

type urlRewriter func(*url.URL)

type SandboxValue int64

const (
Expand Down Expand Up @@ -575,6 +582,13 @@ func (p *Policy) AllowURLSchemesMatching(r *regexp.Regexp) *Policy {
return p
}

// RewriteSrc will rewrite the src attribute of a resource downloading tag
// (e.g. <img>, <script>, <iframe>) using the provided function.
func (p *Policy) RewriteSrc(fn urlRewriter) *Policy {
p.srcRewriter = fn
return p
}

// RequireNoFollowOnLinks will result in all a, area, link tags having a
// rel="nofollow"added to them if one does not already exist
//
Expand Down
8 changes: 8 additions & 0 deletions sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,14 @@ attrsLoop:
case "audio", "embed", "iframe", "img", "script", "source", "track", "video":
if htmlAttr.Key == "src" {
if u, ok := p.validURL(htmlAttr.Val); ok {
if p.srcRewriter != nil {
parsedURL, err := url.Parse(u)
if err != nil {
fmt.Println(err)
}
p.srcRewriter(parsedURL)
u = parsedURL.String()
}
htmlAttr.Val = u
tmpAttrs = append(tmpAttrs, htmlAttr)
}
Expand Down
15 changes: 13 additions & 2 deletions sanitize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,27 @@ func TestLinks(t *testing.T) {
},
{
in: `<img src="giraffe.gif" />`,
expected: `<img src="giraffe.gif"/>`,
expected: `<img src="https://proxy.example.com/?u=giraffe.gif"/>`,
},
{
in: `<img src="giraffe.gif?height=500&amp;width=500&amp;flag" />`,
expected: `<img src="giraffe.gif?height=500&amp;width=500&amp;flag"/>`,
expected: `<img src="https://proxy.example.com/?u=giraffe.gif?height=500&amp;width=500&amp;flag"/>`,
},
}

p := UGCPolicy()
p.RequireParseableURLs(true)
p.RewriteSrc(func(u *url.URL) {
// Proxify all requests to "https://proxy.example.com/?u=http://example.com/"
// This is a contrived example, but it shows how to rewrite URLs
// to proxy all requests through a single URL.

url := u.String()
u.Scheme = "https"
u.Host = "proxy.example.com"
u.Path = "/"
u.RawQuery = "u=" + url
})

// These tests are run concurrently to enable the race detector to pick up
// potential issues
Expand Down

0 comments on commit fb590ac

Please sign in to comment.