Improve CORS Method Middleware (#477)
* More sensical CORSMethodMiddleware * Only sets Access-Control-Allow-Methods on valid preflight requests * Does not return after setting the Access-Control-Allow-Methods header * Does not append OPTIONS header to Access-Control-Allow-Methods regardless of whether there is an OPTIONS method matcher * Adds tests for the listed behavior * Add example for CORSMethodMiddleware * Do not check for preflight and add documentation to the README * Use http.MethodOptions instead of "OPTIONS" * Add link to CORSMethodMiddleware section to readme * Add test for unmatching route methods * Rename CORS Method Middleware to Handling CORS Requests in README * Link CORSMethodMiddleware in README to godoc * Break CORSMethodMiddleware doc into bullets for readability * Add comment about specifying OPTIONS to example in README for CORSMethodMiddleware * Document cURL command used for testing CORS Method Middleware * Update comment in example to "Handle the request" * Add explicit comment about OPTIONS matchers to CORSMethodMiddleware doc * Update circleci config to only check gofmt diff on latest go version * Break up gofmt and go vet checks into separate steps. * Use canonical circleci config
This commit is contained in:
committed by
Matt Silverlock
parent
d70f7b4baa
commit
0534769016
@@ -11,8 +11,20 @@ jobs:
|
|||||||
- checkout
|
- checkout
|
||||||
- run: go version
|
- run: go version
|
||||||
- run: go get -t -v ./...
|
- run: go get -t -v ./...
|
||||||
- run: diff -u <(echo -n) <(gofmt -d .)
|
# Only run gofmt, vet & lint against the latest Go version
|
||||||
- run: if [[ "$LATEST" = true ]]; then go vet -v .; fi
|
- run: >
|
||||||
|
if [[ "$LATEST" = true ]]; then
|
||||||
|
go get -u golang.org/x/lint/golint
|
||||||
|
golint ./...
|
||||||
|
fi
|
||||||
|
- run: >
|
||||||
|
if [[ "$LATEST" = true ]]; then
|
||||||
|
diff -u <(echo -n) <(gofmt -d .)
|
||||||
|
fi
|
||||||
|
- run: >
|
||||||
|
if [[ "$LATEST" = true ]]; then
|
||||||
|
go vet -v .
|
||||||
|
fi
|
||||||
- run: go test -v -race ./...
|
- run: go test -v -race ./...
|
||||||
|
|
||||||
"latest":
|
"latest":
|
||||||
|
|||||||
68
README.md
68
README.md
@@ -30,6 +30,7 @@ The name mux stands for "HTTP request multiplexer". Like the standard `http.Serv
|
|||||||
* [Walking Routes](#walking-routes)
|
* [Walking Routes](#walking-routes)
|
||||||
* [Graceful Shutdown](#graceful-shutdown)
|
* [Graceful Shutdown](#graceful-shutdown)
|
||||||
* [Middleware](#middleware)
|
* [Middleware](#middleware)
|
||||||
|
* [Handling CORS Requests](#handling-cors-requests)
|
||||||
* [Testing Handlers](#testing-handlers)
|
* [Testing Handlers](#testing-handlers)
|
||||||
* [Full Example](#full-example)
|
* [Full Example](#full-example)
|
||||||
|
|
||||||
@@ -492,6 +493,73 @@ r.Use(amw.Middleware)
|
|||||||
|
|
||||||
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
|
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
|
||||||
|
|
||||||
|
### Handling CORS Requests
|
||||||
|
|
||||||
|
[CORSMethodMiddleware](https://godoc.org/github.com/gorilla/mux#CORSMethodMiddleware) intends to make it easier to strictly set the `Access-Control-Allow-Methods` response header.
|
||||||
|
|
||||||
|
* You will still need to use your own CORS handler to set the other CORS headers such as `Access-Control-Allow-Origin`
|
||||||
|
* The middleware will set the `Access-Control-Allow-Methods` header to all the method matchers (e.g. `r.Methods(http.MethodGet, http.MethodPut, http.MethodOptions)` -> `Access-Control-Allow-Methods: GET,PUT,OPTIONS`) on a route
|
||||||
|
* If you do not specify any methods, then:
|
||||||
|
> _Important_: there must be an `OPTIONS` method matcher for the middleware to set the headers.
|
||||||
|
|
||||||
|
Here is an example of using `CORSMethodMiddleware` along with a custom `OPTIONS` handler to set all the required CORS headers:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
|
||||||
|
// IMPORTANT: you must specify an OPTIONS method matcher for the middleware to set CORS headers
|
||||||
|
r.HandleFunc("/foo", fooHandler).Methods(http.MethodGet, http.MethodPut, http.MethodPatch, http.MethodOptions)
|
||||||
|
r.Use(mux.CORSMethodMiddleware(r))
|
||||||
|
|
||||||
|
http.ListenAndServe(":8080", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fooHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write([]byte("foo"))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
And an request to `/foo` using something like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl localhost:8080/foo -v
|
||||||
|
```
|
||||||
|
|
||||||
|
Would look like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
* Trying ::1...
|
||||||
|
* TCP_NODELAY set
|
||||||
|
* Connected to localhost (::1) port 8080 (#0)
|
||||||
|
> GET /foo HTTP/1.1
|
||||||
|
> Host: localhost:8080
|
||||||
|
> User-Agent: curl/7.59.0
|
||||||
|
> Accept: */*
|
||||||
|
>
|
||||||
|
< HTTP/1.1 200 OK
|
||||||
|
< Access-Control-Allow-Methods: GET,PUT,PATCH,OPTIONS
|
||||||
|
< Access-Control-Allow-Origin: *
|
||||||
|
< Date: Fri, 28 Jun 2019 20:13:30 GMT
|
||||||
|
< Content-Length: 3
|
||||||
|
< Content-Type: text/plain; charset=utf-8
|
||||||
|
<
|
||||||
|
* Connection #0 to host localhost left intact
|
||||||
|
foo
|
||||||
|
```
|
||||||
|
|
||||||
### Testing Handlers
|
### Testing Handlers
|
||||||
|
|
||||||
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
|
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
|
||||||
|
|||||||
37
example_cors_method_middleware_test.go
Normal file
37
example_cors_method_middleware_test.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package mux_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleCORSMethodMiddleware() {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
|
||||||
|
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Handle the request
|
||||||
|
}).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||||
|
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "http://example.com")
|
||||||
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||||
|
}).Methods(http.MethodOptions)
|
||||||
|
|
||||||
|
r.Use(mux.CORSMethodMiddleware(r))
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("OPTIONS", "/foo", nil) // needs to be OPTIONS
|
||||||
|
req.Header.Set("Access-Control-Request-Method", "POST") // needs to be non-empty
|
||||||
|
req.Header.Set("Access-Control-Request-Headers", "Authorization") // needs to be non-empty
|
||||||
|
req.Header.Set("Origin", "http://example.com") // needs to be non-empty
|
||||||
|
|
||||||
|
r.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
fmt.Println(rw.Header().Get("Access-Control-Allow-Methods"))
|
||||||
|
fmt.Println(rw.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
// Output:
|
||||||
|
// GET,PUT,PATCH,OPTIONS
|
||||||
|
// http://example.com
|
||||||
|
}
|
||||||
@@ -32,37 +32,19 @@ func (r *Router) useInterface(mw middleware) {
|
|||||||
r.middlewares = append(r.middlewares, mw)
|
r.middlewares = append(r.middlewares, mw)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
|
// CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header
|
||||||
// on a request, by matching routes based only on paths. It also handles
|
// on requests for routes that have an OPTIONS method matcher to all the method matchers on
|
||||||
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
|
// the route. Routes that do not explicitly handle OPTIONS requests will not be processed
|
||||||
// returning without calling the next http handler.
|
// by the middleware. See examples for usage.
|
||||||
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
|
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
var allMethods []string
|
allMethods, err := getAllMethodsForRoute(r, req)
|
||||||
|
|
||||||
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
|
|
||||||
for _, m := range route.matchers {
|
|
||||||
if _, ok := m.(*routeRegexp); ok {
|
|
||||||
if m.Match(req, &RouteMatch{}) {
|
|
||||||
methods, err := route.GetMethods()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
allMethods = append(allMethods, methods...)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
|
for _, v := range allMethods {
|
||||||
|
if v == http.MethodOptions {
|
||||||
if req.Method == "OPTIONS" {
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ","))
|
||||||
return
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,3 +52,28 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getAllMethodsForRoute returns all the methods from method matchers matching a given
|
||||||
|
// request.
|
||||||
|
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
|
||||||
|
var allMethods []string
|
||||||
|
|
||||||
|
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
|
||||||
|
for _, m := range route.matchers {
|
||||||
|
if _, ok := m.(*routeRegexp); ok {
|
||||||
|
if m.Match(req, &RouteMatch{}) {
|
||||||
|
methods, err := route.GetMethods()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
allMethods = append(allMethods, methods...)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return allMethods, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ package mux
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -367,42 +365,114 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCORSMethodMiddleware(t *testing.T) {
|
func TestCORSMethodMiddleware(t *testing.T) {
|
||||||
router := NewRouter()
|
testCases := []struct {
|
||||||
|
name string
|
||||||
cases := []struct {
|
registerRoutes func(r *Router)
|
||||||
path string
|
requestHeader http.Header
|
||||||
response string
|
requestMethod string
|
||||||
method string
|
requestPath string
|
||||||
testURL string
|
expectedAccessControlAllowMethodsHeader string
|
||||||
expectedAllowedMethods string
|
expectedResponse string
|
||||||
}{
|
}{
|
||||||
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
|
{
|
||||||
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
|
name: "does not set without OPTIONS matcher",
|
||||||
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
|
registerRoutes: func(r *Router) {
|
||||||
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
|
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||||
|
},
|
||||||
|
requestMethod: "GET",
|
||||||
|
requestPath: "/foo",
|
||||||
|
expectedAccessControlAllowMethodsHeader: "",
|
||||||
|
expectedResponse: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sets on non OPTIONS",
|
||||||
|
registerRoutes: func(r *Router) {
|
||||||
|
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||||
|
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
|
||||||
|
},
|
||||||
|
requestMethod: "GET",
|
||||||
|
requestPath: "/foo",
|
||||||
|
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
|
||||||
|
expectedResponse: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sets without preflight headers",
|
||||||
|
registerRoutes: func(r *Router) {
|
||||||
|
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||||
|
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
|
||||||
|
},
|
||||||
|
requestMethod: "OPTIONS",
|
||||||
|
requestPath: "/foo",
|
||||||
|
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
|
||||||
|
expectedResponse: "b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "does not set on error",
|
||||||
|
registerRoutes: func(r *Router) {
|
||||||
|
r.HandleFunc("/foo", stringHandler("a"))
|
||||||
|
},
|
||||||
|
requestMethod: "OPTIONS",
|
||||||
|
requestPath: "/foo",
|
||||||
|
expectedAccessControlAllowMethodsHeader: "",
|
||||||
|
expectedResponse: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sets header on valid preflight",
|
||||||
|
registerRoutes: func(r *Router) {
|
||||||
|
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||||
|
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
|
||||||
|
},
|
||||||
|
requestMethod: "OPTIONS",
|
||||||
|
requestPath: "/foo",
|
||||||
|
requestHeader: http.Header{
|
||||||
|
"Access-Control-Request-Method": []string{"GET"},
|
||||||
|
"Access-Control-Request-Headers": []string{"Authorization"},
|
||||||
|
"Origin": []string{"http://example.com"},
|
||||||
|
},
|
||||||
|
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
|
||||||
|
expectedResponse: "b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "does not set methods from unmatching routes",
|
||||||
|
registerRoutes: func(r *Router) {
|
||||||
|
r.HandleFunc("/foo", stringHandler("c")).Methods(http.MethodDelete)
|
||||||
|
r.HandleFunc("/foo/bar", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||||
|
r.HandleFunc("/foo/bar", stringHandler("b")).Methods(http.MethodOptions)
|
||||||
|
},
|
||||||
|
requestMethod: "OPTIONS",
|
||||||
|
requestPath: "/foo/bar",
|
||||||
|
requestHeader: http.Header{
|
||||||
|
"Access-Control-Request-Method": []string{"GET"},
|
||||||
|
"Access-Control-Request-Headers": []string{"Authorization"},
|
||||||
|
"Origin": []string{"http://example.com"},
|
||||||
|
},
|
||||||
|
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
|
||||||
|
expectedResponse: "b",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range testCases {
|
||||||
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
}
|
router := NewRouter()
|
||||||
|
|
||||||
router.Use(CORSMethodMiddleware(router))
|
tt.registerRoutes(router)
|
||||||
|
|
||||||
for i, tt := range cases {
|
router.Use(CORSMethodMiddleware(router))
|
||||||
t.Run(fmt.Sprintf("cases[%d]", i), func(t *testing.T) {
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
req := newRequest(tt.method, tt.testURL)
|
|
||||||
|
|
||||||
router.ServeHTTP(rr, req)
|
rw := NewRecorder()
|
||||||
|
req := newRequest(tt.requestMethod, tt.requestPath)
|
||||||
|
req.Header = tt.requestHeader
|
||||||
|
|
||||||
if rr.Body.String() != tt.response {
|
router.ServeHTTP(rw, req)
|
||||||
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
|
|
||||||
|
actualMethodsHeader := rw.Header().Get("Access-Control-Allow-Methods")
|
||||||
|
if actualMethodsHeader != tt.expectedAccessControlAllowMethodsHeader {
|
||||||
|
t.Fatalf("Expected Access-Control-Allow-Methods to equal %s but got %s", tt.expectedAccessControlAllowMethodsHeader, actualMethodsHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
allowedMethods := rr.Header().Get("Access-Control-Allow-Methods")
|
actualResponse := rw.Body.String()
|
||||||
|
if actualResponse != tt.expectedResponse {
|
||||||
if allowedMethods != tt.expectedAllowedMethods {
|
t.Fatalf("Expected response to equal %s but got %s", tt.expectedResponse, actualResponse)
|
||||||
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user