Add CORSMethodMiddleware (#366)
CORSMethodMiddleware sets the Access-Control-Allow-Methods response header on a request, by matching routes based only on paths. It also handles OPTIONS requests, by settings Access-Control-Allow-Methods, and then returning without calling the next HTTP handler.
This commit is contained in:
committed by
Matt Silverlock
parent
ded0c29b24
commit
5e55a4adb8
@@ -1,6 +1,9 @@
|
||||
package mux
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
|
||||
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
|
||||
@@ -28,3 +31,42 @@ func (r *Router) Use(mwf ...MiddlewareFunc) {
|
||||
func (r *Router) useInterface(mw middleware) {
|
||||
r.middlewares = append(r.middlewares, mw)
|
||||
}
|
||||
|
||||
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
|
||||
// on a request, by matching routes based only on paths. It also handles
|
||||
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
|
||||
// returning without calling the next http handler.
|
||||
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
var allMethods []string
|
||||
|
||||
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
|
||||
for _, m := range route.matchers {
|
||||
if _, ok := m.(*routeRegexp); ok {
|
||||
if m.Match(req, &RouteMatch{}) {
|
||||
methods, err := route.GetMethods()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allMethods = append(allMethods, methods...)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
|
||||
|
||||
if req.Method == "OPTIONS" {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package mux
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -334,3 +335,43 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSMethodMiddleware(t *testing.T) {
|
||||
router := NewRouter()
|
||||
|
||||
cases := []struct {
|
||||
path string
|
||||
response string
|
||||
method string
|
||||
testURL string
|
||||
expectedAllowedMethods string
|
||||
}{
|
||||
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
|
||||
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
|
||||
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
|
||||
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
|
||||
}
|
||||
|
||||
router.Use(CORSMethodMiddleware(router))
|
||||
|
||||
for _, tt := range cases {
|
||||
rr := httptest.NewRecorder()
|
||||
req := newRequest(tt.method, tt.testURL)
|
||||
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Body.String() != tt.response {
|
||||
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
|
||||
}
|
||||
|
||||
allowedMethods := rr.HeaderMap.Get("Access-Control-Allow-Methods")
|
||||
|
||||
if allowedMethods != tt.expectedAllowedMethods {
|
||||
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2315,6 +2315,14 @@ func stringMapEqual(m1, m2 map[string]string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// stringHandler returns a handler func that writes a message 's' to the
|
||||
// http.ResponseWriter.
|
||||
func stringHandler(s string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(s))
|
||||
}
|
||||
}
|
||||
|
||||
// newRequest is a helper function to create a new request with a method and url.
|
||||
// The request returned is a 'server' request as opposed to a 'client' one through
|
||||
// simulated write onto the wire and read off of the wire.
|
||||
|
||||
Reference in New Issue
Block a user