Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00bdffe0f3 | ||
|
|
0534769016 | ||
|
|
d70f7b4baa | ||
|
|
48f941fa99 | ||
|
|
64954673e9 | ||
|
|
4248f5cd87 | ||
|
|
212aa90d7c | ||
|
|
ed099d4238 | ||
|
|
c5c6c98bc2 | ||
|
|
15a353a636 | ||
|
|
8eaa9f1309 | ||
|
|
8559a4f775 | ||
|
|
a7962380ca | ||
|
|
797e653da6 | ||
|
|
08e7f807d3 | ||
|
|
f3ff42f93a | ||
|
|
ef912dd76e | ||
|
|
a31c1782bf | ||
|
|
6137e193cd | ||
|
|
d2b5d13b92 | ||
|
|
419fd9fe2a | ||
|
|
758eb64354 | ||
|
|
3d80bc801b | ||
|
|
521ea7b17d | ||
|
|
deb579d6e0 | ||
|
|
9e1f5955c0 | ||
|
|
cf6680bc62 | ||
|
|
8771f97498 | ||
|
|
962c5bed07 | ||
|
|
e48e440e4c | ||
|
|
815b8c6a26 | ||
|
|
cb4698366a | ||
|
|
e0b5abaaae | ||
|
|
c85619274f |
75
.circleci/config.yml
Normal file
75
.circleci/config.yml
Normal file
@@ -0,0 +1,75 @@
|
||||
version: 2.0
|
||||
|
||||
jobs:
|
||||
# Base test configuration for Go library tests Each distinct version should
|
||||
# inherit this base, and override (at least) the container image used.
|
||||
"test": &test
|
||||
docker:
|
||||
- image: circleci/golang:latest
|
||||
working_directory: /go/src/github.com/gorilla/mux
|
||||
steps: &steps
|
||||
- checkout
|
||||
- run: go version
|
||||
- run: go get -t -v ./...
|
||||
# Only run gofmt, vet & lint against the latest Go version
|
||||
- run: >
|
||||
if [[ "$LATEST" = true ]]; then
|
||||
go get -u golang.org/x/lint/golint
|
||||
golint ./...
|
||||
fi
|
||||
- run: >
|
||||
if [[ "$LATEST" = true ]]; then
|
||||
diff -u <(echo -n) <(gofmt -d .)
|
||||
fi
|
||||
- run: >
|
||||
if [[ "$LATEST" = true ]]; then
|
||||
go vet -v .
|
||||
fi
|
||||
- run: go test -v -race ./...
|
||||
|
||||
"latest":
|
||||
<<: *test
|
||||
environment:
|
||||
LATEST: true
|
||||
|
||||
"1.12":
|
||||
<<: *test
|
||||
docker:
|
||||
- image: circleci/golang:1.12
|
||||
|
||||
"1.11":
|
||||
<<: *test
|
||||
docker:
|
||||
- image: circleci/golang:1.11
|
||||
|
||||
"1.10":
|
||||
<<: *test
|
||||
docker:
|
||||
- image: circleci/golang:1.10
|
||||
|
||||
"1.9":
|
||||
<<: *test
|
||||
docker:
|
||||
- image: circleci/golang:1.9
|
||||
|
||||
"1.8":
|
||||
<<: *test
|
||||
docker:
|
||||
- image: circleci/golang:1.8
|
||||
|
||||
"1.7":
|
||||
<<: *test
|
||||
docker:
|
||||
- image: circleci/golang:1.7
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
build:
|
||||
jobs:
|
||||
- "latest"
|
||||
- "1.12"
|
||||
- "1.11"
|
||||
- "1.10"
|
||||
- "1.9"
|
||||
- "1.8"
|
||||
- "1.7"
|
||||
8
.github/release-drafter.yml
vendored
Normal file
8
.github/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Config for https://github.com/apps/release-drafter
|
||||
template: |
|
||||
|
||||
<summary of changes here>
|
||||
|
||||
## CHANGELOG
|
||||
|
||||
$CHANGES
|
||||
12
.github/stale.yml
vendored
Normal file
12
.github/stale.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
daysUntilStale: 75
|
||||
daysUntilClose: 14
|
||||
# Issues with these labels will never be considered stale
|
||||
exemptLabels:
|
||||
- proposal
|
||||
- needs review
|
||||
- build system
|
||||
staleLabel: stale
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it hasn't seen
|
||||
a recent update. It'll be automatically closed in a few days.
|
||||
closeComment: false
|
||||
23
.travis.yml
23
.travis.yml
@@ -1,23 +0,0 @@
|
||||
language: go
|
||||
sudo: false
|
||||
|
||||
matrix:
|
||||
include:
|
||||
- go: 1.5.x
|
||||
- go: 1.6.x
|
||||
- go: 1.7.x
|
||||
- go: 1.8.x
|
||||
- go: 1.9.x
|
||||
- go: 1.10.x
|
||||
- go: tip
|
||||
allow_failures:
|
||||
- go: tip
|
||||
|
||||
install:
|
||||
- # Skip
|
||||
|
||||
script:
|
||||
- go get -t -v ./...
|
||||
- diff -u <(echo -n) <(gofmt -d .)
|
||||
- go tool vet .
|
||||
- go test -v -race ./...
|
||||
8
AUTHORS
Normal file
8
AUTHORS
Normal file
@@ -0,0 +1,8 @@
|
||||
# This is the official list of gorilla/mux authors for copyright purposes.
|
||||
#
|
||||
# Please keep the list sorted.
|
||||
|
||||
Google LLC (https://opensource.google.com/)
|
||||
Kamil Kisielk <kamil@kamilkisiel.net>
|
||||
Matt Silverlock <matt@eatsleeprepeat.net>
|
||||
Rodrigo Moraes (https://github.com/moraes)
|
||||
@@ -1,11 +0,0 @@
|
||||
**What version of Go are you running?** (Paste the output of `go version`)
|
||||
|
||||
|
||||
**What version of gorilla/mux are you at?** (Paste the output of `git rev-parse HEAD` inside `$GOPATH/src/github.com/gorilla/mux`)
|
||||
|
||||
|
||||
**Describe your problem** (and what you have tried so far)
|
||||
|
||||
|
||||
**Paste a minimal, runnable, reproduction of your issue below** (use backticks to format it)
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||
Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
|
||||
85
README.md
85
README.md
@@ -2,11 +2,12 @@
|
||||
|
||||
[](https://godoc.org/github.com/gorilla/mux)
|
||||
[](https://travis-ci.org/gorilla/mux)
|
||||
[](https://circleci.com/gh/gorilla/mux)
|
||||
[](https://sourcegraph.com/github.com/gorilla/mux?badge)
|
||||
|
||||

|
||||
|
||||
http://www.gorillatoolkit.org/pkg/mux
|
||||
https://www.gorillatoolkit.org/pkg/mux
|
||||
|
||||
Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to
|
||||
their respective handler.
|
||||
@@ -29,6 +30,7 @@ The name mux stands for "HTTP request multiplexer". Like the standard `http.Serv
|
||||
* [Walking Routes](#walking-routes)
|
||||
* [Graceful Shutdown](#graceful-shutdown)
|
||||
* [Middleware](#middleware)
|
||||
* [Handling CORS Requests](#handling-cors-requests)
|
||||
* [Testing Handlers](#testing-handlers)
|
||||
* [Full Example](#full-example)
|
||||
|
||||
@@ -88,7 +90,7 @@ r := mux.NewRouter()
|
||||
// Only matches if domain is "www.example.com".
|
||||
r.Host("www.example.com")
|
||||
// Matches a dynamic subdomain.
|
||||
r.Host("{subdomain:[a-z]+}.domain.com")
|
||||
r.Host("{subdomain:[a-z]+}.example.com")
|
||||
```
|
||||
|
||||
There are several other matchers that can be added. To match path prefixes:
|
||||
@@ -238,13 +240,13 @@ This also works for host and query value variables:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
r.Host("{subdomain}.domain.com").
|
||||
r.Host("{subdomain}.example.com").
|
||||
Path("/articles/{category}/{id:[0-9]+}").
|
||||
Queries("filter", "{filter}").
|
||||
HandlerFunc(ArticleHandler).
|
||||
Name("article")
|
||||
|
||||
// url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla"
|
||||
// url.String() will be "http://news.example.com/articles/technology/42?filter=gorilla"
|
||||
url, err := r.Get("article").URL("subdomain", "news",
|
||||
"category", "technology",
|
||||
"id", "42",
|
||||
@@ -264,7 +266,7 @@ r.HeadersRegexp("Content-Type", "application/(text|json)")
|
||||
There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do:
|
||||
|
||||
```go
|
||||
// "http://news.domain.com/"
|
||||
// "http://news.example.com/"
|
||||
host, err := r.Get("article").URLHost("subdomain", "news")
|
||||
|
||||
// "/articles/technology/42"
|
||||
@@ -275,12 +277,12 @@ And if you use subrouters, host and path defined separately can be built as well
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
s := r.Host("{subdomain}.domain.com").Subrouter()
|
||||
s := r.Host("{subdomain}.example.com").Subrouter()
|
||||
s.Path("/articles/{category}/{id:[0-9]+}").
|
||||
HandlerFunc(ArticleHandler).
|
||||
Name("article")
|
||||
|
||||
// "http://news.domain.com/articles/technology/42"
|
||||
// "http://news.example.com/articles/technology/42"
|
||||
url, err := r.Get("article").URL("subdomain", "news",
|
||||
"category", "technology",
|
||||
"id", "42")
|
||||
@@ -491,6 +493,73 @@ r.Use(amw.Middleware)
|
||||
|
||||
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
|
||||
|
||||
### Handling CORS Requests
|
||||
|
||||
[CORSMethodMiddleware](https://godoc.org/github.com/gorilla/mux#CORSMethodMiddleware) intends to make it easier to strictly set the `Access-Control-Allow-Methods` response header.
|
||||
|
||||
* You will still need to use your own CORS handler to set the other CORS headers such as `Access-Control-Allow-Origin`
|
||||
* The middleware will set the `Access-Control-Allow-Methods` header to all the method matchers (e.g. `r.Methods(http.MethodGet, http.MethodPut, http.MethodOptions)` -> `Access-Control-Allow-Methods: GET,PUT,OPTIONS`) on a route
|
||||
* If you do not specify any methods, then:
|
||||
> _Important_: there must be an `OPTIONS` method matcher for the middleware to set the headers.
|
||||
|
||||
Here is an example of using `CORSMethodMiddleware` along with a custom `OPTIONS` handler to set all the required CORS headers:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
r := mux.NewRouter()
|
||||
|
||||
// IMPORTANT: you must specify an OPTIONS method matcher for the middleware to set CORS headers
|
||||
r.HandleFunc("/foo", fooHandler).Methods(http.MethodGet, http.MethodPut, http.MethodPatch, http.MethodOptions)
|
||||
r.Use(mux.CORSMethodMiddleware(r))
|
||||
|
||||
http.ListenAndServe(":8080", r)
|
||||
}
|
||||
|
||||
func fooHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
if r.Method == http.MethodOptions {
|
||||
return
|
||||
}
|
||||
|
||||
w.Write([]byte("foo"))
|
||||
}
|
||||
```
|
||||
|
||||
And an request to `/foo` using something like:
|
||||
|
||||
```bash
|
||||
curl localhost:8080/foo -v
|
||||
```
|
||||
|
||||
Would look like:
|
||||
|
||||
```bash
|
||||
* Trying ::1...
|
||||
* TCP_NODELAY set
|
||||
* Connected to localhost (::1) port 8080 (#0)
|
||||
> GET /foo HTTP/1.1
|
||||
> Host: localhost:8080
|
||||
> User-Agent: curl/7.59.0
|
||||
> Accept: */*
|
||||
>
|
||||
< HTTP/1.1 200 OK
|
||||
< Access-Control-Allow-Methods: GET,PUT,PATCH,OPTIONS
|
||||
< Access-Control-Allow-Origin: *
|
||||
< Date: Fri, 28 Jun 2019 20:13:30 GMT
|
||||
< Content-Length: 3
|
||||
< Content-Type: text/plain; charset=utf-8
|
||||
<
|
||||
* Connection #0 to host localhost left intact
|
||||
foo
|
||||
```
|
||||
|
||||
### Testing Handlers
|
||||
|
||||
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
|
||||
@@ -503,8 +572,8 @@ package main
|
||||
|
||||
func HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// A very simple health check.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// In the future we could report back on the status of our DB, or our cache
|
||||
// (e.g. Redis) by performing a simple PING, and include them in the response.
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
// +build go1.7
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
@@ -18,7 +16,3 @@ func contextSet(r *http.Request, key, val interface{}) *http.Request {
|
||||
|
||||
return r.WithContext(context.WithValue(r.Context(), key, val))
|
||||
}
|
||||
|
||||
func contextClear(r *http.Request) {
|
||||
return
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// +build !go1.7
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/context"
|
||||
)
|
||||
|
||||
func contextGet(r *http.Request, key interface{}) interface{} {
|
||||
return context.Get(r, key)
|
||||
}
|
||||
|
||||
func contextSet(r *http.Request, key, val interface{}) *http.Request {
|
||||
if val == nil {
|
||||
return r
|
||||
}
|
||||
|
||||
context.Set(r, key, val)
|
||||
return r
|
||||
}
|
||||
|
||||
func contextClear(r *http.Request) {
|
||||
context.Clear(r)
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
// +build !go1.7
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/context"
|
||||
)
|
||||
|
||||
// Tests that the context is cleared or not cleared properly depending on
|
||||
// the configuration of the router
|
||||
func TestKeepContext(t *testing.T) {
|
||||
func1 := func(w http.ResponseWriter, r *http.Request) {}
|
||||
|
||||
r := NewRouter()
|
||||
r.HandleFunc("/", func1).Name("func1")
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://localhost/", nil)
|
||||
context.Set(req, "t", 1)
|
||||
|
||||
res := new(http.ResponseWriter)
|
||||
r.ServeHTTP(*res, req)
|
||||
|
||||
if _, ok := context.GetOk(req, "t"); ok {
|
||||
t.Error("Context should have been cleared at end of request")
|
||||
}
|
||||
|
||||
r.KeepContext = true
|
||||
|
||||
req, _ = http.NewRequest("GET", "http://localhost/", nil)
|
||||
context.Set(req, "t", 1)
|
||||
|
||||
r.ServeHTTP(*res, req)
|
||||
if _, ok := context.GetOk(req, "t"); !ok {
|
||||
t.Error("Context should NOT have been cleared at end of request")
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
// +build go1.7
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
2
doc.go
2
doc.go
@@ -295,7 +295,7 @@ A more complex authentication middleware, which maps session token to users, cou
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/", handler)
|
||||
|
||||
amw := authenticationMiddleware{}
|
||||
amw := authenticationMiddleware{tokenUsers: make(map[string]string)}
|
||||
amw.Populate()
|
||||
|
||||
r.Use(amw.Middleware)
|
||||
|
||||
@@ -40,7 +40,7 @@ func Example_authenticationMiddleware() {
|
||||
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Do something here
|
||||
})
|
||||
amw := authenticationMiddleware{}
|
||||
amw := authenticationMiddleware{make(map[string]string)}
|
||||
amw.Populate()
|
||||
r.Use(amw.Middleware)
|
||||
}
|
||||
|
||||
37
example_cors_method_middleware_test.go
Normal file
37
example_cors_method_middleware_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package mux_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func ExampleCORSMethodMiddleware() {
|
||||
r := mux.NewRouter()
|
||||
|
||||
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle the request
|
||||
}).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "http://example.com")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
}).Methods(http.MethodOptions)
|
||||
|
||||
r.Use(mux.CORSMethodMiddleware(r))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("OPTIONS", "/foo", nil) // needs to be OPTIONS
|
||||
req.Header.Set("Access-Control-Request-Method", "POST") // needs to be non-empty
|
||||
req.Header.Set("Access-Control-Request-Headers", "Authorization") // needs to be non-empty
|
||||
req.Header.Set("Origin", "http://example.com") // needs to be non-empty
|
||||
|
||||
r.ServeHTTP(rw, req)
|
||||
|
||||
fmt.Println(rw.Header().Get("Access-Control-Allow-Methods"))
|
||||
fmt.Println(rw.Header().Get("Access-Control-Allow-Origin"))
|
||||
// Output:
|
||||
// GET,PUT,PATCH,OPTIONS
|
||||
// http://example.com
|
||||
}
|
||||
@@ -32,37 +32,19 @@ func (r *Router) useInterface(mw middleware) {
|
||||
r.middlewares = append(r.middlewares, mw)
|
||||
}
|
||||
|
||||
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
|
||||
// on a request, by matching routes based only on paths. It also handles
|
||||
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
|
||||
// returning without calling the next http handler.
|
||||
// CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header
|
||||
// on requests for routes that have an OPTIONS method matcher to all the method matchers on
|
||||
// the route. Routes that do not explicitly handle OPTIONS requests will not be processed
|
||||
// by the middleware. See examples for usage.
|
||||
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
var allMethods []string
|
||||
|
||||
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
|
||||
for _, m := range route.matchers {
|
||||
if _, ok := m.(*routeRegexp); ok {
|
||||
if m.Match(req, &RouteMatch{}) {
|
||||
methods, err := route.GetMethods()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allMethods = append(allMethods, methods...)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
allMethods, err := getAllMethodsForRoute(r, req)
|
||||
if err == nil {
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
|
||||
|
||||
if req.Method == "OPTIONS" {
|
||||
return
|
||||
for _, v := range allMethods {
|
||||
if v == http.MethodOptions {
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ","))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,3 +52,28 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// getAllMethodsForRoute returns all the methods from method matchers matching a given
|
||||
// request.
|
||||
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
|
||||
var allMethods []string
|
||||
|
||||
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
|
||||
for _, m := range route.matchers {
|
||||
if _, ok := m.(*routeRegexp); ok {
|
||||
if m.Match(req, &RouteMatch{}) {
|
||||
methods, err := route.GetMethods()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allMethods = append(allMethods, methods...)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return allMethods, err
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package mux
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -28,12 +27,12 @@ func TestMiddlewareAdd(t *testing.T) {
|
||||
|
||||
router.useInterface(mw)
|
||||
if len(router.middlewares) != 1 || router.middlewares[0] != mw {
|
||||
t.Fatal("Middleware was not added correctly")
|
||||
t.Fatal("Middleware interface was not added correctly")
|
||||
}
|
||||
|
||||
router.Use(mw.Middleware)
|
||||
if len(router.middlewares) != 2 {
|
||||
t.Fatal("MiddlewareFunc method was not added correctly")
|
||||
t.Fatal("Middleware method was not added correctly")
|
||||
}
|
||||
|
||||
banalMw := func(handler http.Handler) http.Handler {
|
||||
@@ -41,7 +40,7 @@ func TestMiddlewareAdd(t *testing.T) {
|
||||
}
|
||||
router.Use(banalMw)
|
||||
if len(router.middlewares) != 3 {
|
||||
t.Fatal("MiddlewareFunc method was not added correctly")
|
||||
t.Fatal("Middleware function was not added correctly")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,34 +54,37 @@ func TestMiddleware(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/")
|
||||
|
||||
// Test regular middleware call
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
t.Run("regular middleware call", func(t *testing.T) {
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
|
||||
// Middleware should not be called for 404
|
||||
req = newRequest("GET", "/not/found")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
t.Run("not called for 404", func(t *testing.T) {
|
||||
req = newRequest("GET", "/not/found")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
|
||||
// Middleware should not be called if there is a method mismatch
|
||||
req = newRequest("POST", "/")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
|
||||
// Add the middleware again as function
|
||||
router.Use(mw.Middleware)
|
||||
req = newRequest("GET", "/")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 3 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
|
||||
}
|
||||
t.Run("not called for method mismatch", func(t *testing.T) {
|
||||
req = newRequest("POST", "/")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("regular call using function middleware", func(t *testing.T) {
|
||||
router.Use(mw.Middleware)
|
||||
req = newRequest("GET", "/")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 3 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddlewareSubrouter(t *testing.T) {
|
||||
@@ -98,42 +100,56 @@ func TestMiddlewareSubrouter(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/")
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 0 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
|
||||
}
|
||||
t.Run("not called for route outside subrouter", func(t *testing.T) {
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 0 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
|
||||
req = newRequest("GET", "/sub/")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 0 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
|
||||
}
|
||||
t.Run("not called for subrouter root 404", func(t *testing.T) {
|
||||
req = newRequest("GET", "/sub/")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 0 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
|
||||
req = newRequest("GET", "/sub/x")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
t.Run("called once for route inside subrouter", func(t *testing.T) {
|
||||
req = newRequest("GET", "/sub/x")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
|
||||
req = newRequest("GET", "/sub/not/found")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
t.Run("not called for 404 inside subrouter", func(t *testing.T) {
|
||||
req = newRequest("GET", "/sub/not/found")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 1 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
|
||||
router.useInterface(mw)
|
||||
t.Run("middleware added to router", func(t *testing.T) {
|
||||
router.useInterface(mw)
|
||||
|
||||
req = newRequest("GET", "/")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 2 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
|
||||
}
|
||||
t.Run("called once for route outside subrouter", func(t *testing.T) {
|
||||
req = newRequest("GET", "/")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 2 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
|
||||
req = newRequest("GET", "/sub/x")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 4 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
|
||||
}
|
||||
t.Run("called twice for route inside subrouter", func(t *testing.T) {
|
||||
req = newRequest("GET", "/sub/x")
|
||||
router.ServeHTTP(rw, req)
|
||||
if mw.timesCalled != 4 {
|
||||
t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddlewareExecution(t *testing.T) {
|
||||
@@ -145,30 +161,33 @@ func TestMiddlewareExecution(t *testing.T) {
|
||||
w.Write(handlerStr)
|
||||
})
|
||||
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/")
|
||||
t.Run("responds normally without middleware", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/")
|
||||
|
||||
// Test handler-only call
|
||||
router.ServeHTTP(rw, req)
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Compare(rw.Body.Bytes(), handlerStr) != 0 {
|
||||
t.Fatal("Handler response is not what it should be")
|
||||
}
|
||||
|
||||
// Test middleware call
|
||||
rw = NewRecorder()
|
||||
|
||||
router.Use(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(mwStr)
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
if !bytes.Equal(rw.Body.Bytes(), handlerStr) {
|
||||
t.Fatal("Handler response is not what it should be")
|
||||
}
|
||||
})
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Compare(rw.Body.Bytes(), append(mwStr, handlerStr...)) != 0 {
|
||||
t.Fatal("Middleware + handler response is not what it should be")
|
||||
}
|
||||
t.Run("responds with handler and middleware response", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/")
|
||||
|
||||
router.Use(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(mwStr)
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) {
|
||||
t.Fatal("Middleware + handler response is not what it should be")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddlewareNotFound(t *testing.T) {
|
||||
@@ -187,26 +206,29 @@ func TestMiddlewareNotFound(t *testing.T) {
|
||||
})
|
||||
|
||||
// Test not found call with default handler
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/notfound")
|
||||
t.Run("not called", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/notfound")
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a 404")
|
||||
}
|
||||
|
||||
// Test not found call with custom handler
|
||||
rw = NewRecorder()
|
||||
req = newRequest("GET", "/notfound")
|
||||
|
||||
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Custom 404 handler"))
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a 404")
|
||||
}
|
||||
})
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a custom 404")
|
||||
}
|
||||
t.Run("not called with custom not found handler", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/notfound")
|
||||
|
||||
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Custom 404 handler"))
|
||||
})
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a custom 404")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddlewareMethodMismatch(t *testing.T) {
|
||||
@@ -225,27 +247,29 @@ func TestMiddlewareMethodMismatch(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
// Test method mismatch
|
||||
rw := NewRecorder()
|
||||
req := newRequest("POST", "/")
|
||||
t.Run("not called", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("POST", "/")
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
|
||||
// Test not found call
|
||||
rw = NewRecorder()
|
||||
req = newRequest("POST", "/")
|
||||
|
||||
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Method not allowed"))
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
})
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
t.Run("not called with custom method not allowed handler", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("POST", "/")
|
||||
|
||||
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Method not allowed"))
|
||||
})
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddlewareNotFoundSubrouter(t *testing.T) {
|
||||
@@ -269,27 +293,29 @@ func TestMiddlewareNotFoundSubrouter(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
// Test not found call for default handler
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/sub/notfound")
|
||||
t.Run("not called", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/sub/notfound")
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a 404")
|
||||
}
|
||||
|
||||
// Test not found call with custom handler
|
||||
rw = NewRecorder()
|
||||
req = newRequest("GET", "/sub/notfound")
|
||||
|
||||
subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Custom 404 handler"))
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a 404")
|
||||
}
|
||||
})
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a custom 404")
|
||||
}
|
||||
t.Run("not called with custom not found handler", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/sub/notfound")
|
||||
|
||||
subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Custom 404 handler"))
|
||||
})
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a custom 404")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
|
||||
@@ -313,65 +339,207 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
// Test method mismatch without custom handler
|
||||
rw := NewRecorder()
|
||||
req := newRequest("POST", "/sub/")
|
||||
t.Run("not called", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("POST", "/sub/")
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
|
||||
// Test method mismatch with custom handler
|
||||
rw = NewRecorder()
|
||||
req = newRequest("POST", "/sub/")
|
||||
|
||||
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Method not allowed"))
|
||||
router.ServeHTTP(rw, req)
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
})
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
t.Run("not called with custom method not allowed handler", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("POST", "/sub/")
|
||||
|
||||
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Method not allowed"))
|
||||
})
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
if bytes.Contains(rw.Body.Bytes(), mwStr) {
|
||||
t.Fatal("Middleware was called for a method mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCORSMethodMiddleware(t *testing.T) {
|
||||
router := NewRouter()
|
||||
|
||||
cases := []struct {
|
||||
path string
|
||||
response string
|
||||
method string
|
||||
testURL string
|
||||
expectedAllowedMethods string
|
||||
testCases := []struct {
|
||||
name string
|
||||
registerRoutes func(r *Router)
|
||||
requestHeader http.Header
|
||||
requestMethod string
|
||||
requestPath string
|
||||
expectedAccessControlAllowMethodsHeader string
|
||||
expectedResponse string
|
||||
}{
|
||||
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
|
||||
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
|
||||
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
|
||||
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
|
||||
{
|
||||
name: "does not set without OPTIONS matcher",
|
||||
registerRoutes: func(r *Router) {
|
||||
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||
},
|
||||
requestMethod: "GET",
|
||||
requestPath: "/foo",
|
||||
expectedAccessControlAllowMethodsHeader: "",
|
||||
expectedResponse: "a",
|
||||
},
|
||||
{
|
||||
name: "sets on non OPTIONS",
|
||||
registerRoutes: func(r *Router) {
|
||||
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
|
||||
},
|
||||
requestMethod: "GET",
|
||||
requestPath: "/foo",
|
||||
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
|
||||
expectedResponse: "a",
|
||||
},
|
||||
{
|
||||
name: "sets without preflight headers",
|
||||
registerRoutes: func(r *Router) {
|
||||
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
|
||||
},
|
||||
requestMethod: "OPTIONS",
|
||||
requestPath: "/foo",
|
||||
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
|
||||
expectedResponse: "b",
|
||||
},
|
||||
{
|
||||
name: "does not set on error",
|
||||
registerRoutes: func(r *Router) {
|
||||
r.HandleFunc("/foo", stringHandler("a"))
|
||||
},
|
||||
requestMethod: "OPTIONS",
|
||||
requestPath: "/foo",
|
||||
expectedAccessControlAllowMethodsHeader: "",
|
||||
expectedResponse: "a",
|
||||
},
|
||||
{
|
||||
name: "sets header on valid preflight",
|
||||
registerRoutes: func(r *Router) {
|
||||
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
|
||||
},
|
||||
requestMethod: "OPTIONS",
|
||||
requestPath: "/foo",
|
||||
requestHeader: http.Header{
|
||||
"Access-Control-Request-Method": []string{"GET"},
|
||||
"Access-Control-Request-Headers": []string{"Authorization"},
|
||||
"Origin": []string{"http://example.com"},
|
||||
},
|
||||
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
|
||||
expectedResponse: "b",
|
||||
},
|
||||
{
|
||||
name: "does not set methods from unmatching routes",
|
||||
registerRoutes: func(r *Router) {
|
||||
r.HandleFunc("/foo", stringHandler("c")).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/foo/bar", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
|
||||
r.HandleFunc("/foo/bar", stringHandler("b")).Methods(http.MethodOptions)
|
||||
},
|
||||
requestMethod: "OPTIONS",
|
||||
requestPath: "/foo/bar",
|
||||
requestHeader: http.Header{
|
||||
"Access-Control-Request-Method": []string{"GET"},
|
||||
"Access-Control-Request-Headers": []string{"Authorization"},
|
||||
"Origin": []string{"http://example.com"},
|
||||
},
|
||||
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
|
||||
expectedResponse: "b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := NewRouter()
|
||||
|
||||
router.Use(CORSMethodMiddleware(router))
|
||||
tt.registerRoutes(router)
|
||||
|
||||
for _, tt := range cases {
|
||||
rr := httptest.NewRecorder()
|
||||
req := newRequest(tt.method, tt.testURL)
|
||||
router.Use(CORSMethodMiddleware(router))
|
||||
|
||||
router.ServeHTTP(rr, req)
|
||||
rw := NewRecorder()
|
||||
req := newRequest(tt.requestMethod, tt.requestPath)
|
||||
req.Header = tt.requestHeader
|
||||
|
||||
if rr.Body.String() != tt.response {
|
||||
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
|
||||
}
|
||||
router.ServeHTTP(rw, req)
|
||||
|
||||
allowedMethods := rr.HeaderMap.Get("Access-Control-Allow-Methods")
|
||||
actualMethodsHeader := rw.Header().Get("Access-Control-Allow-Methods")
|
||||
if actualMethodsHeader != tt.expectedAccessControlAllowMethodsHeader {
|
||||
t.Fatalf("Expected Access-Control-Allow-Methods to equal %s but got %s", tt.expectedAccessControlAllowMethodsHeader, actualMethodsHeader)
|
||||
}
|
||||
|
||||
if allowedMethods != tt.expectedAllowedMethods {
|
||||
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
|
||||
}
|
||||
actualResponse := rw.Body.String()
|
||||
if actualResponse != tt.expectedResponse {
|
||||
t.Fatalf("Expected response to equal %s but got %s", tt.expectedResponse, actualResponse)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareOnMultiSubrouter(t *testing.T) {
|
||||
first := "first"
|
||||
second := "second"
|
||||
notFound := "404 not found"
|
||||
|
||||
router := NewRouter()
|
||||
firstSubRouter := router.PathPrefix("/").Subrouter()
|
||||
secondSubRouter := router.PathPrefix("/").Subrouter()
|
||||
|
||||
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte(notFound))
|
||||
})
|
||||
|
||||
firstSubRouter.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
})
|
||||
|
||||
secondSubRouter.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
})
|
||||
|
||||
firstSubRouter.Use(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(first))
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
|
||||
secondSubRouter.Use(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(second))
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("/first uses first middleware", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/first")
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if rw.Body.String() != first {
|
||||
t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("/second uses second middleware", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/second")
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if rw.Body.String() != second {
|
||||
t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses not found handler", func(t *testing.T) {
|
||||
rw := NewRecorder()
|
||||
req := newRequest("GET", "/second/not-exist")
|
||||
|
||||
router.ServeHTTP(rw, req)
|
||||
if rw.Body.String() != notFound {
|
||||
t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
129
mux.go
129
mux.go
@@ -22,7 +22,7 @@ var (
|
||||
|
||||
// NewRouter returns a new router instance.
|
||||
func NewRouter() *Router {
|
||||
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
|
||||
return &Router{namedRoutes: make(map[string]*Route)}
|
||||
}
|
||||
|
||||
// Router registers routes to be matched and dispatches a handler.
|
||||
@@ -50,24 +50,78 @@ type Router struct {
|
||||
// Configurable Handler to be used when the request method does not match the route.
|
||||
MethodNotAllowedHandler http.Handler
|
||||
|
||||
// Parent route, if this is a subrouter.
|
||||
parent parentRoute
|
||||
// Routes to be matched, in order.
|
||||
routes []*Route
|
||||
|
||||
// Routes by name for URL building.
|
||||
namedRoutes map[string]*Route
|
||||
// See Router.StrictSlash(). This defines the flag for new routes.
|
||||
strictSlash bool
|
||||
// See Router.SkipClean(). This defines the flag for new routes.
|
||||
skipClean bool
|
||||
|
||||
// If true, do not clear the request context after handling the request.
|
||||
// This has no effect when go1.7+ is used, since the context is stored
|
||||
//
|
||||
// Deprecated: No effect when go1.7+ is used, since the context is stored
|
||||
// on the request itself.
|
||||
KeepContext bool
|
||||
// see Router.UseEncodedPath(). This defines a flag for all routes.
|
||||
useEncodedPath bool
|
||||
|
||||
// Slice of middlewares to be called after a match is found
|
||||
middlewares []middleware
|
||||
|
||||
// configuration shared with `Route`
|
||||
routeConf
|
||||
}
|
||||
|
||||
// common route configuration shared between `Router` and `Route`
|
||||
type routeConf struct {
|
||||
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
|
||||
useEncodedPath bool
|
||||
|
||||
// If true, when the path pattern is "/path/", accessing "/path" will
|
||||
// redirect to the former and vice versa.
|
||||
strictSlash bool
|
||||
|
||||
// If true, when the path pattern is "/path//to", accessing "/path//to"
|
||||
// will not redirect
|
||||
skipClean bool
|
||||
|
||||
// Manager for the variables from host and path.
|
||||
regexp routeRegexpGroup
|
||||
|
||||
// List of matchers.
|
||||
matchers []matcher
|
||||
|
||||
// The scheme used when building URLs.
|
||||
buildScheme string
|
||||
|
||||
buildVarsFunc BuildVarsFunc
|
||||
}
|
||||
|
||||
// returns an effective deep copy of `routeConf`
|
||||
func copyRouteConf(r routeConf) routeConf {
|
||||
c := r
|
||||
|
||||
if r.regexp.path != nil {
|
||||
c.regexp.path = copyRouteRegexp(r.regexp.path)
|
||||
}
|
||||
|
||||
if r.regexp.host != nil {
|
||||
c.regexp.host = copyRouteRegexp(r.regexp.host)
|
||||
}
|
||||
|
||||
c.regexp.queries = make([]*routeRegexp, 0, len(r.regexp.queries))
|
||||
for _, q := range r.regexp.queries {
|
||||
c.regexp.queries = append(c.regexp.queries, copyRouteRegexp(q))
|
||||
}
|
||||
|
||||
c.matchers = make([]matcher, 0, len(r.matchers))
|
||||
for _, m := range r.matchers {
|
||||
c.matchers = append(c.matchers, m)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func copyRouteRegexp(r *routeRegexp) *routeRegexp {
|
||||
c := *r
|
||||
return &c
|
||||
}
|
||||
|
||||
// Match attempts to match the given request against the router's registered routes.
|
||||
@@ -155,22 +209,18 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
handler = http.NotFoundHandler()
|
||||
}
|
||||
|
||||
if !r.KeepContext {
|
||||
defer contextClear(req)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Get returns a route registered with the given name.
|
||||
func (r *Router) Get(name string) *Route {
|
||||
return r.getNamedRoutes()[name]
|
||||
return r.namedRoutes[name]
|
||||
}
|
||||
|
||||
// GetRoute returns a route registered with the given name. This method
|
||||
// was renamed to Get() and remains here for backwards compatibility.
|
||||
func (r *Router) GetRoute(name string) *Route {
|
||||
return r.getNamedRoutes()[name]
|
||||
return r.namedRoutes[name]
|
||||
}
|
||||
|
||||
// StrictSlash defines the trailing slash behavior for new routes. The initial
|
||||
@@ -221,55 +271,24 @@ func (r *Router) UseEncodedPath() *Router {
|
||||
return r
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// parentRoute
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (r *Router) getBuildScheme() string {
|
||||
if r.parent != nil {
|
||||
return r.parent.getBuildScheme()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getNamedRoutes returns the map where named routes are registered.
|
||||
func (r *Router) getNamedRoutes() map[string]*Route {
|
||||
if r.namedRoutes == nil {
|
||||
if r.parent != nil {
|
||||
r.namedRoutes = r.parent.getNamedRoutes()
|
||||
} else {
|
||||
r.namedRoutes = make(map[string]*Route)
|
||||
}
|
||||
}
|
||||
return r.namedRoutes
|
||||
}
|
||||
|
||||
// getRegexpGroup returns regexp definitions from the parent route, if any.
|
||||
func (r *Router) getRegexpGroup() *routeRegexpGroup {
|
||||
if r.parent != nil {
|
||||
return r.parent.getRegexpGroup()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) buildVars(m map[string]string) map[string]string {
|
||||
if r.parent != nil {
|
||||
m = r.parent.buildVars(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Route factories
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// NewRoute registers an empty route.
|
||||
func (r *Router) NewRoute() *Route {
|
||||
route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath}
|
||||
// initialize a route with a copy of the parent router's configuration
|
||||
route := &Route{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
|
||||
r.routes = append(r.routes, route)
|
||||
return route
|
||||
}
|
||||
|
||||
// Name registers a new route with a name.
|
||||
// See Route.Name().
|
||||
func (r *Router) Name(name string) *Route {
|
||||
return r.NewRoute().Name(name)
|
||||
}
|
||||
|
||||
// Handle registers a new route with a matcher for the URL path.
|
||||
// See Route.Path() and Route.Handler().
|
||||
func (r *Router) Handle(path string, handler http.Handler) *Route {
|
||||
|
||||
630
mux_test.go
630
mux_test.go
@@ -48,15 +48,6 @@ type routeTest struct {
|
||||
}
|
||||
|
||||
func TestHost(t *testing.T) {
|
||||
// newRequestHost a new request with a method, url, and host header
|
||||
newRequestHost := func(method, url, host string) *http.Request {
|
||||
req, err := http.NewRequest(method, url, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Host = host
|
||||
return req
|
||||
}
|
||||
|
||||
tests := []routeTest{
|
||||
{
|
||||
@@ -113,7 +104,15 @@ func TestHost(t *testing.T) {
|
||||
path: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
// BUG {new(Route).Host("aaa.bbb.ccc:1234"), newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:1234"), map[string]string{}, "aaa.bbb.ccc:1234", "", true},
|
||||
{
|
||||
title: "Host route with port, match with request header",
|
||||
route: new(Route).Host("aaa.bbb.ccc:1234"),
|
||||
request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:1234"),
|
||||
vars: map[string]string{},
|
||||
host: "aaa.bbb.ccc:1234",
|
||||
path: "",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
title: "Host route with port, wrong host in request header",
|
||||
route: new(Route).Host("aaa.bbb.ccc:1234"),
|
||||
@@ -123,6 +122,16 @@ func TestHost(t *testing.T) {
|
||||
path: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
title: "Host route with pattern, match with request header",
|
||||
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc:1{v2:(?:23|4)}"),
|
||||
request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:123"),
|
||||
vars: map[string]string{"v1": "bbb", "v2": "23"},
|
||||
host: "aaa.bbb.ccc:123",
|
||||
path: "",
|
||||
hostTemplate: `aaa.{v1:[a-z]{3}}.ccc:1{v2:(?:23|4)}`,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
title: "Host route with pattern, match",
|
||||
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"),
|
||||
@@ -205,8 +214,10 @@ func TestHost(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -437,10 +448,12 @@ func TestPath(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
testRegexp(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
testRegexp(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -516,9 +529,11 @@ func TestPathPrefix(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -623,9 +638,11 @@ func TestSchemeHostPath(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -682,8 +699,10 @@ func TestHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -732,9 +751,11 @@ func TestMethods(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testMethods(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testMethods(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1039,11 +1060,12 @@ func TestQueries(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testQueriesTemplates(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
testQueriesRegexp(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testTemplate(t, test)
|
||||
testQueriesTemplates(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
testQueriesRegexp(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1092,17 +1114,16 @@ func TestSchemes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatcherFunc(t *testing.T) {
|
||||
m := func(r *http.Request, m *RouteMatch) bool {
|
||||
if r.URL.Host == "aaa.bbb.ccc" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return r.URL.Host == "aaa.bbb.ccc"
|
||||
}
|
||||
|
||||
tests := []routeTest{
|
||||
@@ -1127,8 +1148,10 @@ func TestMatcherFunc(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1163,8 +1186,10 @@ func TestBuildVarsFunc(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1174,7 +1199,6 @@ func TestSubRouter(t *testing.T) {
|
||||
subrouter3 := new(Route).PathPrefix("/foo").Subrouter()
|
||||
subrouter4 := new(Route).PathPrefix("/foo/bar").Subrouter()
|
||||
subrouter5 := new(Route).PathPrefix("/{category}").Subrouter()
|
||||
|
||||
tests := []routeTest{
|
||||
{
|
||||
route: subrouter1.Path("/{v2:[a-z]+}"),
|
||||
@@ -1269,6 +1293,106 @@ func TestSubRouter(t *testing.T) {
|
||||
pathTemplate: `/{category}`,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
title: "Mismatch method specified on parent route",
|
||||
route: new(Route).Methods("POST").PathPrefix("/foo").Subrouter().Path("/"),
|
||||
request: newRequest("GET", "http://localhost/foo/"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/foo/",
|
||||
pathTemplate: `/foo/`,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
title: "Match method specified on parent route",
|
||||
route: new(Route).Methods("POST").PathPrefix("/foo").Subrouter().Path("/"),
|
||||
request: newRequest("POST", "http://localhost/foo/"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/foo/",
|
||||
pathTemplate: `/foo/`,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
title: "Mismatch scheme specified on parent route",
|
||||
route: new(Route).Schemes("https").Subrouter().PathPrefix("/"),
|
||||
request: newRequest("GET", "http://localhost/"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/",
|
||||
pathTemplate: `/`,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
title: "Match scheme specified on parent route",
|
||||
route: new(Route).Schemes("http").Subrouter().PathPrefix("/"),
|
||||
request: newRequest("GET", "http://localhost/"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/",
|
||||
pathTemplate: `/`,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
title: "No match header specified on parent route",
|
||||
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
|
||||
request: newRequest("GET", "http://localhost/"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/",
|
||||
pathTemplate: `/`,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
title: "Header mismatch value specified on parent route",
|
||||
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
|
||||
request: newRequestWithHeaders("GET", "http://localhost/", "X-Forwarded-Proto", "http"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/",
|
||||
pathTemplate: `/`,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
title: "Header match value specified on parent route",
|
||||
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
|
||||
request: newRequestWithHeaders("GET", "http://localhost/", "X-Forwarded-Proto", "https"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/",
|
||||
pathTemplate: `/`,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
title: "Query specified on parent route not present",
|
||||
route: new(Route).Headers("key", "foobar").Subrouter().PathPrefix("/"),
|
||||
request: newRequest("GET", "http://localhost/"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/",
|
||||
pathTemplate: `/`,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
title: "Query mismatch value specified on parent route",
|
||||
route: new(Route).Queries("key", "foobar").Subrouter().PathPrefix("/"),
|
||||
request: newRequest("GET", "http://localhost/?key=notfoobar"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/",
|
||||
pathTemplate: `/`,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
title: "Query match value specified on subroute",
|
||||
route: new(Route).Queries("key", "foobar").Subrouter().PathPrefix("/"),
|
||||
request: newRequest("GET", "http://localhost/?key=foobar"),
|
||||
vars: map[string]string{},
|
||||
host: "",
|
||||
path: "/",
|
||||
pathTemplate: `/`,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
title: "Build with scheme on parent router",
|
||||
route: new(Route).Schemes("ftp").Host("google.com").Subrouter().Path("/"),
|
||||
@@ -1294,9 +1418,11 @@ func TestSubRouter(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1315,14 +1441,24 @@ func TestNamedRoutes(t *testing.T) {
|
||||
r3.NewRoute().Name("g")
|
||||
r3.NewRoute().Name("h")
|
||||
r3.NewRoute().Name("i")
|
||||
r3.Name("j")
|
||||
|
||||
if r1.namedRoutes == nil || len(r1.namedRoutes) != 9 {
|
||||
t.Errorf("Expected 9 named routes, got %v", r1.namedRoutes)
|
||||
} else if r1.Get("i") == nil {
|
||||
if r1.namedRoutes == nil || len(r1.namedRoutes) != 10 {
|
||||
t.Errorf("Expected 10 named routes, got %v", r1.namedRoutes)
|
||||
} else if r1.Get("j") == nil {
|
||||
t.Errorf("Subroute name not registered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameMultipleCalls(t *testing.T) {
|
||||
r1 := NewRouter()
|
||||
rt := r1.NewRoute().Name("foo").Name("bar")
|
||||
err := rt.GetError()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrictSlash(t *testing.T) {
|
||||
r := NewRouter()
|
||||
r.StrictSlash(true)
|
||||
@@ -1391,9 +1527,11 @@ func TestStrictSlash(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
testUseEscapedRoute(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1425,8 +1563,10 @@ func TestUseEncodedPath(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testRoute(t, test)
|
||||
testTemplate(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1478,12 +1618,16 @@ func TestWalkSingleDepth(t *testing.T) {
|
||||
func TestWalkNested(t *testing.T) {
|
||||
router := NewRouter()
|
||||
|
||||
g := router.Path("/g").Subrouter()
|
||||
o := g.PathPrefix("/o").Subrouter()
|
||||
r := o.PathPrefix("/r").Subrouter()
|
||||
i := r.PathPrefix("/i").Subrouter()
|
||||
l1 := i.PathPrefix("/l").Subrouter()
|
||||
l2 := l1.PathPrefix("/l").Subrouter()
|
||||
routeSubrouter := func(r *Route) (*Route, *Router) {
|
||||
return r, r.Subrouter()
|
||||
}
|
||||
|
||||
gRoute, g := routeSubrouter(router.Path("/g"))
|
||||
oRoute, o := routeSubrouter(g.PathPrefix("/o"))
|
||||
rRoute, r := routeSubrouter(o.PathPrefix("/r"))
|
||||
iRoute, i := routeSubrouter(r.PathPrefix("/i"))
|
||||
l1Route, l1 := routeSubrouter(i.PathPrefix("/l"))
|
||||
l2Route, l2 := routeSubrouter(l1.PathPrefix("/l"))
|
||||
l2.Path("/a")
|
||||
|
||||
testCases := []struct {
|
||||
@@ -1491,12 +1635,12 @@ func TestWalkNested(t *testing.T) {
|
||||
ancestors []*Route
|
||||
}{
|
||||
{"/g", []*Route{}},
|
||||
{"/g/o", []*Route{g.parent.(*Route)}},
|
||||
{"/g/o/r", []*Route{g.parent.(*Route), o.parent.(*Route)}},
|
||||
{"/g/o/r/i", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route)}},
|
||||
{"/g/o/r/i/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route)}},
|
||||
{"/g/o/r/i/l/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route)}},
|
||||
{"/g/o/r/i/l/l/a", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route), l2.parent.(*Route)}},
|
||||
{"/g/o", []*Route{gRoute}},
|
||||
{"/g/o/r", []*Route{gRoute, oRoute}},
|
||||
{"/g/o/r/i", []*Route{gRoute, oRoute, rRoute}},
|
||||
{"/g/o/r/i/l", []*Route{gRoute, oRoute, rRoute, iRoute}},
|
||||
{"/g/o/r/i/l/l", []*Route{gRoute, oRoute, rRoute, iRoute, l1Route}},
|
||||
{"/g/o/r/i/l/l/a", []*Route{gRoute, oRoute, rRoute, iRoute, l1Route, l2Route}},
|
||||
}
|
||||
|
||||
idx := 0
|
||||
@@ -1529,8 +1673,8 @@ func TestWalkSubrouters(t *testing.T) {
|
||||
o.Methods("GET")
|
||||
o.Methods("PUT")
|
||||
|
||||
// all 4 routes should be matched, but final 2 routes do not have path templates
|
||||
paths := []string{"/g", "/g/o", "", ""}
|
||||
// all 4 routes should be matched
|
||||
paths := []string{"/g", "/g/o", "/g/o", "/g/o"}
|
||||
idx := 0
|
||||
err := router.Walk(func(route *Route, router *Router, ancestors []*Route) error {
|
||||
path := paths[idx]
|
||||
@@ -1711,7 +1855,11 @@ func testRoute(t *testing.T, test routeTest) {
|
||||
}
|
||||
}
|
||||
if query != "" {
|
||||
u, _ := route.URL(mapToPairs(match.Vars)...)
|
||||
u, err := route.URL(mapToPairs(match.Vars)...)
|
||||
if err != nil {
|
||||
t.Errorf("(%v) erred while creating url: %v", test.title, err)
|
||||
return
|
||||
}
|
||||
if query != u.RawQuery {
|
||||
t.Errorf("(%v) URL query not equal: expected %v, got %v", test.title, query, u.RawQuery)
|
||||
return
|
||||
@@ -2031,7 +2179,9 @@ func TestMethodsSubrouterCatchall(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testMethodsSubrouter(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testMethodsSubrouter(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2087,7 +2237,9 @@ func TestMethodsSubrouterStrictSlash(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testMethodsSubrouter(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testMethodsSubrouter(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2134,7 +2286,9 @@ func TestMethodsSubrouterPathPrefix(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testMethodsSubrouter(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testMethodsSubrouter(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2190,7 +2344,9 @@ func TestMethodsSubrouterSubrouter(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testMethodsSubrouter(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testMethodsSubrouter(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2244,7 +2400,9 @@ func TestMethodsSubrouterPathVariable(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testMethodsSubrouter(t, test)
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
testMethodsSubrouter(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2288,6 +2446,305 @@ func testMethodsSubrouter(t *testing.T, test methodsSubrouterTest) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubrouterMatching(t *testing.T) {
|
||||
const (
|
||||
none, stdOnly, subOnly uint8 = 0, 1 << 0, 1 << 1
|
||||
both = subOnly | stdOnly
|
||||
)
|
||||
|
||||
type request struct {
|
||||
Name string
|
||||
Request *http.Request
|
||||
Flags uint8
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
Standard, Subrouter func(*Router)
|
||||
Requests []request
|
||||
}{
|
||||
{
|
||||
"pathPrefix",
|
||||
func(r *Router) {
|
||||
r.PathPrefix("/before").PathPrefix("/after")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.PathPrefix("/before").Subrouter().PathPrefix("/after")
|
||||
},
|
||||
[]request{
|
||||
{"no match final path prefix", newRequest("GET", "/after"), none},
|
||||
{"no match parent path prefix", newRequest("GET", "/before"), none},
|
||||
{"matches append", newRequest("GET", "/before/after"), both},
|
||||
{"matches as prefix", newRequest("GET", "/before/after/1234"), both},
|
||||
},
|
||||
},
|
||||
{
|
||||
"path",
|
||||
func(r *Router) {
|
||||
r.Path("/before").Path("/after")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.Path("/before").Subrouter().Path("/after")
|
||||
},
|
||||
[]request{
|
||||
{"no match subroute path", newRequest("GET", "/after"), none},
|
||||
{"no match parent path", newRequest("GET", "/before"), none},
|
||||
{"no match as prefix", newRequest("GET", "/before/after/1234"), none},
|
||||
{"no match append", newRequest("GET", "/before/after"), none},
|
||||
},
|
||||
},
|
||||
{
|
||||
"host",
|
||||
func(r *Router) {
|
||||
r.Host("before.com").Host("after.com")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.Host("before.com").Subrouter().Host("after.com")
|
||||
},
|
||||
[]request{
|
||||
{"no match before", newRequestHost("GET", "/", "before.com"), none},
|
||||
{"no match other", newRequestHost("GET", "/", "other.com"), none},
|
||||
{"matches after", newRequestHost("GET", "/", "after.com"), none},
|
||||
},
|
||||
},
|
||||
{
|
||||
"queries variant keys",
|
||||
func(r *Router) {
|
||||
r.Queries("foo", "bar").Queries("cricket", "baseball")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.Queries("foo", "bar").Subrouter().Queries("cricket", "baseball")
|
||||
},
|
||||
[]request{
|
||||
{"matches with all", newRequest("GET", "/?foo=bar&cricket=baseball"), both},
|
||||
{"matches with more", newRequest("GET", "/?foo=bar&cricket=baseball&something=else"), both},
|
||||
{"no match with none", newRequest("GET", "/"), none},
|
||||
{"no match with some", newRequest("GET", "/?cricket=baseball"), none},
|
||||
},
|
||||
},
|
||||
{
|
||||
"queries overlapping keys",
|
||||
func(r *Router) {
|
||||
r.Queries("foo", "bar").Queries("foo", "baz")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.Queries("foo", "bar").Subrouter().Queries("foo", "baz")
|
||||
},
|
||||
[]request{
|
||||
{"no match old value", newRequest("GET", "/?foo=bar"), none},
|
||||
{"no match diff value", newRequest("GET", "/?foo=bak"), none},
|
||||
{"no match with none", newRequest("GET", "/"), none},
|
||||
{"matches override", newRequest("GET", "/?foo=baz"), none},
|
||||
},
|
||||
},
|
||||
{
|
||||
"header variant keys",
|
||||
func(r *Router) {
|
||||
r.Headers("foo", "bar").Headers("cricket", "baseball")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.Headers("foo", "bar").Subrouter().Headers("cricket", "baseball")
|
||||
},
|
||||
[]request{
|
||||
{
|
||||
"matches with all",
|
||||
newRequestWithHeaders("GET", "/", "foo", "bar", "cricket", "baseball"),
|
||||
both,
|
||||
},
|
||||
{
|
||||
"matches with more",
|
||||
newRequestWithHeaders("GET", "/", "foo", "bar", "cricket", "baseball", "something", "else"),
|
||||
both,
|
||||
},
|
||||
{"no match with none", newRequest("GET", "/"), none},
|
||||
{"no match with some", newRequestWithHeaders("GET", "/", "cricket", "baseball"), none},
|
||||
},
|
||||
},
|
||||
{
|
||||
"header overlapping keys",
|
||||
func(r *Router) {
|
||||
r.Headers("foo", "bar").Headers("foo", "baz")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.Headers("foo", "bar").Subrouter().Headers("foo", "baz")
|
||||
},
|
||||
[]request{
|
||||
{"no match old value", newRequestWithHeaders("GET", "/", "foo", "bar"), none},
|
||||
{"no match diff value", newRequestWithHeaders("GET", "/", "foo", "bak"), none},
|
||||
{"no match with none", newRequest("GET", "/"), none},
|
||||
{"matches override", newRequestWithHeaders("GET", "/", "foo", "baz"), none},
|
||||
},
|
||||
},
|
||||
{
|
||||
"method",
|
||||
func(r *Router) {
|
||||
r.Methods("POST").Methods("GET")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.Methods("POST").Subrouter().Methods("GET")
|
||||
},
|
||||
[]request{
|
||||
{"matches before", newRequest("POST", "/"), none},
|
||||
{"no match other", newRequest("HEAD", "/"), none},
|
||||
{"matches override", newRequest("GET", "/"), none},
|
||||
},
|
||||
},
|
||||
{
|
||||
"schemes",
|
||||
func(r *Router) {
|
||||
r.Schemes("http").Schemes("https")
|
||||
},
|
||||
func(r *Router) {
|
||||
r.Schemes("http").Subrouter().Schemes("https")
|
||||
},
|
||||
[]request{
|
||||
{"matches overrides", newRequest("GET", "https://www.example.com/"), none},
|
||||
{"matches original", newRequest("GET", "http://www.example.com/"), none},
|
||||
{"no match other", newRequest("GET", "ftp://www.example.com/"), none},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// case -> request -> router
|
||||
for _, c := range cases {
|
||||
t.Run(c.Name, func(t *testing.T) {
|
||||
for _, req := range c.Requests {
|
||||
t.Run(req.Name, func(t *testing.T) {
|
||||
for _, v := range []struct {
|
||||
Name string
|
||||
Config func(*Router)
|
||||
Expected bool
|
||||
}{
|
||||
{"subrouter", c.Subrouter, (req.Flags & subOnly) != 0},
|
||||
{"standard", c.Standard, (req.Flags & stdOnly) != 0},
|
||||
} {
|
||||
r := NewRouter()
|
||||
v.Config(r)
|
||||
if r.Match(req.Request, &RouteMatch{}) != v.Expected {
|
||||
if v.Expected {
|
||||
t.Errorf("expected %v match", v.Name)
|
||||
} else {
|
||||
t.Errorf("expected %v no match", v.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// verify that copyRouteConf copies fields as expected.
|
||||
func Test_copyRouteConf(t *testing.T) {
|
||||
var (
|
||||
m MatcherFunc = func(*http.Request, *RouteMatch) bool {
|
||||
return true
|
||||
}
|
||||
b BuildVarsFunc = func(i map[string]string) map[string]string {
|
||||
return i
|
||||
}
|
||||
r, _ = newRouteRegexp("hi", regexpTypeHost, routeRegexpOptions{})
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args routeConf
|
||||
want routeConf
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
routeConf{},
|
||||
routeConf{},
|
||||
},
|
||||
{
|
||||
"full",
|
||||
routeConf{
|
||||
useEncodedPath: true,
|
||||
strictSlash: true,
|
||||
skipClean: true,
|
||||
regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}},
|
||||
matchers: []matcher{m},
|
||||
buildScheme: "https",
|
||||
buildVarsFunc: b,
|
||||
},
|
||||
routeConf{
|
||||
useEncodedPath: true,
|
||||
strictSlash: true,
|
||||
skipClean: true,
|
||||
regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}},
|
||||
matchers: []matcher{m},
|
||||
buildScheme: "https",
|
||||
buildVarsFunc: b,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// special case some incomparable fields of routeConf before delegating to reflect.DeepEqual
|
||||
got := copyRouteConf(tt.args)
|
||||
|
||||
// funcs not comparable, just compare length of slices
|
||||
if len(got.matchers) != len(tt.want.matchers) {
|
||||
t.Errorf("matchers different lengths: %v %v", len(got.matchers), len(tt.want.matchers))
|
||||
}
|
||||
got.matchers, tt.want.matchers = nil, nil
|
||||
|
||||
// deep equal treats nil slice differently to empty slice so check for zero len first
|
||||
{
|
||||
bothZero := len(got.regexp.queries) == 0 && len(tt.want.regexp.queries) == 0
|
||||
if !bothZero && !reflect.DeepEqual(got.regexp.queries, tt.want.regexp.queries) {
|
||||
t.Errorf("queries unequal: %v %v", got.regexp.queries, tt.want.regexp.queries)
|
||||
}
|
||||
got.regexp.queries, tt.want.regexp.queries = nil, nil
|
||||
}
|
||||
|
||||
// funcs not comparable, just compare nullity
|
||||
if (got.buildVarsFunc == nil) != (tt.want.buildVarsFunc == nil) {
|
||||
t.Errorf("build vars funcs unequal: %v %v", got.buildVarsFunc == nil, tt.want.buildVarsFunc == nil)
|
||||
}
|
||||
got.buildVarsFunc, tt.want.buildVarsFunc = nil, nil
|
||||
|
||||
// finish the deal
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("route confs unequal: %v %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodNotAllowed(t *testing.T) {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }
|
||||
router := NewRouter()
|
||||
router.HandleFunc("/thing", handler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/something", handler).Methods(http.MethodGet)
|
||||
|
||||
w := NewRecorder()
|
||||
req := newRequest(http.MethodPut, "/thing")
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 405 {
|
||||
t.Fatalf("Expected status code 405 (got %d)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubrouterNotFound(t *testing.T) {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }
|
||||
router := NewRouter()
|
||||
router.Path("/a").Subrouter().HandleFunc("/thing", handler).Methods(http.MethodGet)
|
||||
router.Path("/b").Subrouter().HandleFunc("/something", handler).Methods(http.MethodGet)
|
||||
|
||||
w := NewRecorder()
|
||||
req := newRequest(http.MethodPut, "/not-present")
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 404 {
|
||||
t.Fatalf("Expected status code 404 (got %d)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// mapToPairs converts a string map to a slice of string pairs
|
||||
func mapToPairs(m map[string]string) []string {
|
||||
var i int
|
||||
@@ -2362,3 +2819,28 @@ func newRequest(method, url string) *http.Request {
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// create a new request with the provided headers
|
||||
func newRequestWithHeaders(method, url string, headers ...string) *http.Request {
|
||||
req := newRequest(method, url)
|
||||
|
||||
if len(headers)%2 != 0 {
|
||||
panic(fmt.Sprintf("Expected headers length divisible by 2 but got %v", len(headers)))
|
||||
}
|
||||
|
||||
for i := 0; i < len(headers); i += 2 {
|
||||
req.Header.Set(headers[i], headers[i+1])
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
// newRequestHost a new request with a method, url, and host header
|
||||
func newRequestHost(method, url, host string) *http.Request {
|
||||
req, err := http.NewRequest(method, url, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Host = host
|
||||
return req
|
||||
}
|
||||
|
||||
51
regexp.go
51
regexp.go
@@ -113,6 +113,13 @@ func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*ro
|
||||
if typ != regexpTypePrefix {
|
||||
pattern.WriteByte('$')
|
||||
}
|
||||
|
||||
var wildcardHostPort bool
|
||||
if typ == regexpTypeHost {
|
||||
if !strings.Contains(pattern.String(), ":") {
|
||||
wildcardHostPort = true
|
||||
}
|
||||
}
|
||||
reverse.WriteString(raw)
|
||||
if endSlash {
|
||||
reverse.WriteByte('/')
|
||||
@@ -131,13 +138,14 @@ func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*ro
|
||||
|
||||
// Done!
|
||||
return &routeRegexp{
|
||||
template: template,
|
||||
regexpType: typ,
|
||||
options: options,
|
||||
regexp: reg,
|
||||
reverse: reverse.String(),
|
||||
varsN: varsN,
|
||||
varsR: varsR,
|
||||
template: template,
|
||||
regexpType: typ,
|
||||
options: options,
|
||||
regexp: reg,
|
||||
reverse: reverse.String(),
|
||||
varsN: varsN,
|
||||
varsR: varsR,
|
||||
wildcardHostPort: wildcardHostPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -158,11 +166,22 @@ type routeRegexp struct {
|
||||
varsN []string
|
||||
// Variable regexps (validators).
|
||||
varsR []*regexp.Regexp
|
||||
// Wildcard host-port (no strict port match in hostname)
|
||||
wildcardHostPort bool
|
||||
}
|
||||
|
||||
// Match matches the regexp against the URL host or path.
|
||||
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
|
||||
if r.regexpType != regexpTypeHost {
|
||||
if r.regexpType == regexpTypeHost {
|
||||
host := getHost(req)
|
||||
if r.wildcardHostPort {
|
||||
// Don't be strict on the port match
|
||||
if i := strings.Index(host, ":"); i != -1 {
|
||||
host = host[:i]
|
||||
}
|
||||
}
|
||||
return r.regexp.MatchString(host)
|
||||
} else {
|
||||
if r.regexpType == regexpTypeQuery {
|
||||
return r.matchQueryString(req)
|
||||
}
|
||||
@@ -172,8 +191,6 @@ func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
|
||||
}
|
||||
return r.regexp.MatchString(path)
|
||||
}
|
||||
|
||||
return r.regexp.MatchString(getHost(req))
|
||||
}
|
||||
|
||||
// url builds a URL part using the given values.
|
||||
@@ -267,7 +284,7 @@ type routeRegexpGroup struct {
|
||||
}
|
||||
|
||||
// setMatch extracts the variables from the URL once a route matches.
|
||||
func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
|
||||
func (v routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
|
||||
// Store host variables.
|
||||
if v.host != nil {
|
||||
host := getHost(req)
|
||||
@@ -296,7 +313,7 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
|
||||
} else {
|
||||
u.Path += "/"
|
||||
}
|
||||
m.Handler = http.RedirectHandler(u.String(), 301)
|
||||
m.Handler = http.RedirectHandler(u.String(), http.StatusMovedPermanently)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -312,17 +329,13 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
|
||||
}
|
||||
|
||||
// getHost tries its best to return the request host.
|
||||
// According to section 14.23 of RFC 2616 the Host header
|
||||
// can include the port number if the default value of 80 is not used.
|
||||
func getHost(r *http.Request) string {
|
||||
if r.URL.IsAbs() {
|
||||
return r.URL.Host
|
||||
}
|
||||
host := r.Host
|
||||
// Slice off any port information.
|
||||
if i := strings.Index(host, ":"); i != -1 {
|
||||
host = host[:i]
|
||||
}
|
||||
return host
|
||||
|
||||
return r.Host
|
||||
}
|
||||
|
||||
func extractVars(input string, matches []int, names []string, output map[string]string) {
|
||||
|
||||
141
route.go
141
route.go
@@ -15,24 +15,8 @@ import (
|
||||
|
||||
// Route stores information to match a request and build URLs.
|
||||
type Route struct {
|
||||
// Parent where the route was registered (a Router).
|
||||
parent parentRoute
|
||||
// Request handler for the route.
|
||||
handler http.Handler
|
||||
// List of matchers.
|
||||
matchers []matcher
|
||||
// Manager for the variables from host and path.
|
||||
regexp *routeRegexpGroup
|
||||
// If true, when the path pattern is "/path/", accessing "/path" will
|
||||
// redirect to the former and vice versa.
|
||||
strictSlash bool
|
||||
// If true, when the path pattern is "/path//to", accessing "/path//to"
|
||||
// will not redirect
|
||||
skipClean bool
|
||||
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
|
||||
useEncodedPath bool
|
||||
// The scheme used when building URLs.
|
||||
buildScheme string
|
||||
// If true, this route never matches: it is only used to build URLs.
|
||||
buildOnly bool
|
||||
// The name used to build URLs.
|
||||
@@ -40,7 +24,11 @@ type Route struct {
|
||||
// Error resulted from building a route.
|
||||
err error
|
||||
|
||||
buildVarsFunc BuildVarsFunc
|
||||
// "global" reference to all named routes
|
||||
namedRoutes map[string]*Route
|
||||
|
||||
// config possibly passed in from `Router`
|
||||
routeConf
|
||||
}
|
||||
|
||||
// SkipClean reports whether path cleaning is enabled for this route via
|
||||
@@ -64,6 +52,18 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
|
||||
matchErr = ErrMethodMismatch
|
||||
continue
|
||||
}
|
||||
|
||||
// Ignore ErrNotFound errors. These errors arise from match call
|
||||
// to Subrouters.
|
||||
//
|
||||
// This prevents subsequent matching subrouters from failing to
|
||||
// run middleware. If not ignored, the middleware would see a
|
||||
// non-nil MatchErr and be skipped, even when there was a
|
||||
// matching route.
|
||||
if match.MatchErr == ErrNotFound {
|
||||
match.MatchErr = nil
|
||||
}
|
||||
|
||||
matchErr = nil
|
||||
return false
|
||||
}
|
||||
@@ -93,9 +93,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
|
||||
}
|
||||
|
||||
// Set variables.
|
||||
if r.regexp != nil {
|
||||
r.regexp.setMatch(req, match, r)
|
||||
}
|
||||
r.regexp.setMatch(req, match, r)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -137,7 +135,7 @@ func (r *Route) GetHandler() http.Handler {
|
||||
// Name -----------------------------------------------------------------------
|
||||
|
||||
// Name sets the name for the route, used to build URLs.
|
||||
// If the name was registered already it will be overwritten.
|
||||
// It is an error to call Name more than once on a route.
|
||||
func (r *Route) Name(name string) *Route {
|
||||
if r.name != "" {
|
||||
r.err = fmt.Errorf("mux: route already has name %q, can't set %q",
|
||||
@@ -145,7 +143,7 @@ func (r *Route) Name(name string) *Route {
|
||||
}
|
||||
if r.err == nil {
|
||||
r.name = name
|
||||
r.getNamedRoutes()[name] = r
|
||||
r.namedRoutes[name] = r
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -177,7 +175,6 @@ func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error {
|
||||
if r.err != nil {
|
||||
return r.err
|
||||
}
|
||||
r.regexp = r.getRegexpGroup()
|
||||
if typ == regexpTypePath || typ == regexpTypePrefix {
|
||||
if len(tpl) > 0 && tpl[0] != '/' {
|
||||
return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
|
||||
@@ -386,7 +383,7 @@ func (r *Route) PathPrefix(tpl string) *Route {
|
||||
// The above route will only match if the URL contains the defined queries
|
||||
// values, e.g.: ?foo=bar&id=42.
|
||||
//
|
||||
// It the value is an empty string, it will match any value if the key is set.
|
||||
// If the value is an empty string, it will match any value if the key is set.
|
||||
//
|
||||
// Variables can define an optional regexp pattern to be matched:
|
||||
//
|
||||
@@ -424,7 +421,7 @@ func (r *Route) Schemes(schemes ...string) *Route {
|
||||
for k, v := range schemes {
|
||||
schemes[k] = strings.ToLower(v)
|
||||
}
|
||||
if r.buildScheme == "" && len(schemes) > 0 {
|
||||
if len(schemes) > 0 {
|
||||
r.buildScheme = schemes[0]
|
||||
}
|
||||
return r.addMatcher(schemeMatcher(schemes))
|
||||
@@ -439,7 +436,15 @@ type BuildVarsFunc func(map[string]string) map[string]string
|
||||
// BuildVarsFunc adds a custom function to be used to modify build variables
|
||||
// before a route's URL is built.
|
||||
func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
|
||||
r.buildVarsFunc = f
|
||||
if r.buildVarsFunc != nil {
|
||||
// compose the old and new functions
|
||||
old := r.buildVarsFunc
|
||||
r.buildVarsFunc = func(m map[string]string) map[string]string {
|
||||
return f(old(m))
|
||||
}
|
||||
} else {
|
||||
r.buildVarsFunc = f
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -458,7 +463,8 @@ func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
|
||||
// Here, the routes registered in the subrouter won't be tested if the host
|
||||
// doesn't match.
|
||||
func (r *Route) Subrouter() *Router {
|
||||
router := &Router{parent: r, strictSlash: r.strictSlash}
|
||||
// initialize a subrouter with a copy of the parent route's configuration
|
||||
router := &Router{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
|
||||
r.addMatcher(router)
|
||||
return router
|
||||
}
|
||||
@@ -502,9 +508,6 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil {
|
||||
return nil, errors.New("mux: route doesn't have a host or path")
|
||||
}
|
||||
values, err := r.prepareVars(pairs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -516,8 +519,8 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) {
|
||||
return nil, err
|
||||
}
|
||||
scheme = "http"
|
||||
if s := r.getBuildScheme(); s != "" {
|
||||
scheme = s
|
||||
if r.buildScheme != "" {
|
||||
scheme = r.buildScheme
|
||||
}
|
||||
}
|
||||
if r.regexp.path != nil {
|
||||
@@ -547,7 +550,7 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.host == nil {
|
||||
if r.regexp.host == nil {
|
||||
return nil, errors.New("mux: route doesn't have a host")
|
||||
}
|
||||
values, err := r.prepareVars(pairs...)
|
||||
@@ -562,8 +565,8 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
|
||||
Scheme: "http",
|
||||
Host: host,
|
||||
}
|
||||
if s := r.getBuildScheme(); s != "" {
|
||||
u.Scheme = s
|
||||
if r.buildScheme != "" {
|
||||
u.Scheme = r.buildScheme
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
@@ -575,7 +578,7 @@ func (r *Route) URLPath(pairs ...string) (*url.URL, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.path == nil {
|
||||
if r.regexp.path == nil {
|
||||
return nil, errors.New("mux: route doesn't have a path")
|
||||
}
|
||||
values, err := r.prepareVars(pairs...)
|
||||
@@ -600,7 +603,7 @@ func (r *Route) GetPathTemplate() (string, error) {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.path == nil {
|
||||
if r.regexp.path == nil {
|
||||
return "", errors.New("mux: route doesn't have a path")
|
||||
}
|
||||
return r.regexp.path.template, nil
|
||||
@@ -614,7 +617,7 @@ func (r *Route) GetPathRegexp() (string, error) {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.path == nil {
|
||||
if r.regexp.path == nil {
|
||||
return "", errors.New("mux: route does not have a path")
|
||||
}
|
||||
return r.regexp.path.regexp.String(), nil
|
||||
@@ -629,7 +632,7 @@ func (r *Route) GetQueriesRegexp() ([]string, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.queries == nil {
|
||||
if r.regexp.queries == nil {
|
||||
return nil, errors.New("mux: route doesn't have queries")
|
||||
}
|
||||
var queries []string
|
||||
@@ -648,7 +651,7 @@ func (r *Route) GetQueriesTemplates() ([]string, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.queries == nil {
|
||||
if r.regexp.queries == nil {
|
||||
return nil, errors.New("mux: route doesn't have queries")
|
||||
}
|
||||
var queries []string
|
||||
@@ -683,7 +686,7 @@ func (r *Route) GetHostTemplate() (string, error) {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.host == nil {
|
||||
if r.regexp.host == nil {
|
||||
return "", errors.New("mux: route doesn't have a host")
|
||||
}
|
||||
return r.regexp.host.template, nil
|
||||
@@ -700,64 +703,8 @@ func (r *Route) prepareVars(pairs ...string) (map[string]string, error) {
|
||||
}
|
||||
|
||||
func (r *Route) buildVars(m map[string]string) map[string]string {
|
||||
if r.parent != nil {
|
||||
m = r.parent.buildVars(m)
|
||||
}
|
||||
if r.buildVarsFunc != nil {
|
||||
m = r.buildVarsFunc(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// parentRoute
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// parentRoute allows routes to know about parent host and path definitions.
|
||||
type parentRoute interface {
|
||||
getBuildScheme() string
|
||||
getNamedRoutes() map[string]*Route
|
||||
getRegexpGroup() *routeRegexpGroup
|
||||
buildVars(map[string]string) map[string]string
|
||||
}
|
||||
|
||||
func (r *Route) getBuildScheme() string {
|
||||
if r.buildScheme != "" {
|
||||
return r.buildScheme
|
||||
}
|
||||
if r.parent != nil {
|
||||
return r.parent.getBuildScheme()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getNamedRoutes returns the map where named routes are registered.
|
||||
func (r *Route) getNamedRoutes() map[string]*Route {
|
||||
if r.parent == nil {
|
||||
// During tests router is not always set.
|
||||
r.parent = NewRouter()
|
||||
}
|
||||
return r.parent.getNamedRoutes()
|
||||
}
|
||||
|
||||
// getRegexpGroup returns regexp definitions from this route.
|
||||
func (r *Route) getRegexpGroup() *routeRegexpGroup {
|
||||
if r.regexp == nil {
|
||||
if r.parent == nil {
|
||||
// During tests router is not always set.
|
||||
r.parent = NewRouter()
|
||||
}
|
||||
regexp := r.parent.getRegexpGroup()
|
||||
if regexp == nil {
|
||||
r.regexp = new(routeRegexpGroup)
|
||||
} else {
|
||||
// Copy.
|
||||
r.regexp = &routeRegexpGroup{
|
||||
host: regexp.host,
|
||||
path: regexp.path,
|
||||
queries: regexp.queries,
|
||||
}
|
||||
}
|
||||
}
|
||||
return r.regexp
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user