53 Commits

Author SHA1 Message Date
Matt Silverlock
00bdffe0f3 Update stale.yml (#494) 2019-06-29 21:17:52 -07:00
Franklin Harding
0534769016 Improve CORS Method Middleware (#477)
* More sensical CORSMethodMiddleware

* Only sets Access-Control-Allow-Methods on valid preflight requests
* Does not return after setting the Access-Control-Allow-Methods header
* Does not append OPTIONS header to Access-Control-Allow-Methods
regardless of whether there is an OPTIONS method matcher
* Adds tests for the listed behavior

* Add example for CORSMethodMiddleware

* Do not check for preflight and add documentation to the README

* Use http.MethodOptions instead of "OPTIONS"

* Add link to CORSMethodMiddleware section to readme

* Add test for unmatching route methods

* Rename CORS Method Middleware to Handling CORS Requests in README

* Link CORSMethodMiddleware in README to godoc

* Break CORSMethodMiddleware doc into bullets for readability

* Add comment about specifying OPTIONS to example in README for CORSMethodMiddleware

* Document cURL command used for testing CORS Method Middleware

* Update comment in example to "Handle the request"

* Add explicit comment about OPTIONS matchers to CORSMethodMiddleware doc

* Update circleci config to only check gofmt diff on latest go version

* Break up gofmt and go vet checks into separate steps.

* Use canonical circleci config
2019-06-29 13:52:29 -07:00
Matt Silverlock
d70f7b4baa Delete ISSUE_TEMPLATE.md (#492) 2019-06-29 11:01:59 -07:00
Franklin Harding
48f941fa99 Use subtests for middleware tests (#478)
* Use subtests for middleware tests
* Don't use subtests for MiddlewareAdd
2019-06-29 10:24:12 -07:00
Matt Silverlock
64954673e9 Delete .travis.yml (#490) 2019-06-28 16:07:30 -07:00
Franklin Harding
4248f5cd87 Fix nil panic in authentication middleware example (#489) 2019-06-28 08:33:07 -07:00
Matt Silverlock
212aa90d7c [WIP] Create CircleCI config (#484)
* [ci] Create CircleCI config
* Fix typos in container versions
* Add CircleCI badge
2019-06-24 09:05:39 -07:00
M@
ed099d4238 host:port matching does not require a :port to be specified.
In lieu of checking the template pattern on every Match request, a bool is added to the routeRegexp, and set
if the routeRegexp is a host AND there is no ":" in the template. I dislike extending the type, but I'd dislike
doing a string match on every single Match, even more.
2019-05-16 17:20:44 -07:00
sekky0905
c5c6c98bc2 [build] Remove sudo setting from travis.yml (#462) 2019-03-16 06:32:43 -07:00
Benjamin Boudreau
15a353a636 adding Router.Name to create new Route (#457) 2019-02-28 10:12:03 -08:00
Benjamin Boudreau
8eaa9f1309 fix go1.12 go vet usage (#458) 2019-02-28 09:36:07 -08:00
Souvik Haldar
8559a4f775 [docs] typo (#454) 2019-02-17 07:38:49 -08:00
moeryomenko
a7962380ca replace rr.HeaderMap by rr.Header() (#443) 2019-01-25 10:05:53 -06:00
Tim
797e653da6 Call WriteHeader after setting other header(s) in the example (#442)
From the docs: Changing the header map after a call to WriteHeader (or
Write) has no effect unless the modified headers are
trailers.
2019-01-25 05:41:49 -06:00
Gregor Weckbecker
08e7f807d3 Ignore ErrNotFound while matching Subrouters (#438)
MatchErr is set by the router to ErrNotFound if no route matches. If
no route of a Subrouter matches the error can by safely ignored. This
implementation only ignores these errors and does not ignore other
errors like ErrMethodMismatch.
2019-01-08 08:29:30 -06:00
santsai
f3ff42f93a getHost() now returns full host & port information (#383)
Previously, getHost only returned the host. As it now returns the
port as well, any .Host matches on a route will need to be updated
to also support matching on the port for cases where the port is
non default, eg: 80 for http or 443 for https.
2019-01-04 07:08:45 -08:00
tomare
ef912dd76e [bugfix] Clear matchErr when traversing subrouters.
Previously, when searching for a match, matchErr would be erroneously set, and prevent middleware from running (no match == no middleware runs).

This fix clears matchErr before traversing the next subrouter in a multi-subrouter router.
2018-12-27 16:42:16 -08:00
Raees
a31c1782bf Replace domain.com with example.com (#434)
Because domain.com is an actual business, example.com should be used for example purposes.
2018-12-25 08:41:17 -08:00
Michael Li
6137e193cd remove redundant code that remove support gorilla/context (#427)
* remove redundant code that remove support gorilla/context

* backward compatible for remove redundant code
2018-12-17 09:42:43 -05:00
Matt Silverlock
d2b5d13b92 Update and rename stale to stale.yml (#425) 2018-12-08 12:40:53 -08:00
Matt Silverlock
419fd9fe2a Add stalebot config (#424) 2018-12-07 08:41:48 -08:00
Joe Wilner
758eb64354 Improve subroute configuration propagation #422
* Pull out common shared `routeConf` so that config is pushed on to child
routers and routes.
* Removes obsolete usages of `parentRoute`
* Add tests defining compositional behavior
* Exercise `copyRouteConf` for posterity
2018-12-07 09:48:26 -06:00
kanozec
3d80bc801b Use subtests in mux_test.go (#415) 2018-10-30 08:25:28 -07:00
Nguyen Ngoc Trung (Steven)
521ea7b17d Use constant for 301 status code in regexp.go (#412) 2018-10-23 19:08:00 -07:00
Kamil Kisiel
deb579d6e0 README.md: Update site URL 2018-10-12 08:31:51 -07:00
Matt Silverlock
9e1f5955c0 Always run on the latest stable Go version. (#402)
Only run vet on the latest Go version.
2018-09-03 08:43:05 -07:00
Matt Silverlock
cf6680bc62 Create release-drafter.yml (#399) 2018-09-02 15:36:45 -07:00
Franklin Harding
8771f97498 Drop support for Go < 1.7: remove gorilla/context (#391)
* Drop support for Go < 1.7: remove gorilla/context
* Remove Go < 1.7 from Travis CI config
* Remove unneeded _native from context files
2018-09-02 15:22:40 -07:00
Shalom Yerushalmy
962c5bed07 Add 1.11 to build in travis (#398) 2018-08-30 07:23:24 -07:00
Kamil Kisiel
e48e440e4c Add test for multiple calls to Name().
Fixes #394
2018-08-07 00:52:56 -07:00
Kamil Kisiel
815b8c6a26 Clarify behaviour of Name method if called multiple times. 2018-08-07 00:50:18 -07:00
Matt Silverlock
cb4698366a Update LICENSE & AUTHORS files. (#386) 2018-06-05 14:15:56 -07:00
Jim Kalafut
e0b5abaaae Initialize user map (#371) 2018-05-26 15:17:21 -07:00
Matt Silverlock
c85619274f [deps] Add go.mod for versioned Go (#376) 2018-05-17 10:36:23 -07:00
Matt Silverlock
e3702bed27 [docs] Improve docstrings for middleware, skipclean (#375) 2018-05-12 20:22:33 -07:00
Sean Walberg
fdeb7bc314 [docs] Doc fix for testing variables in path (#374)
The example in the README does not pass the request through a mux therefore the request variables from the path are never populated. Update the sample to create a minimum viable router to use.

Fixes #373
2018-05-12 20:09:30 -07:00
Franklin Harding
5e55a4adb8 Add CORSMethodMiddleware (#366)
CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
on a request, by matching routes based only on paths. It also handles
OPTIONS requests, by settings Access-Control-Allow-Methods, and then
returning without calling the next HTTP handler.
2018-05-11 18:30:14 -07:00
Matt Silverlock
ded0c29b24 Fix linter issues (docs) (#370) 2018-04-30 20:11:36 -07:00
Matt Silverlock
b57cb1605f [build] Update Go versions; add 1.10.x (#364) 2018-04-16 13:45:19 -07:00
brandon-height
94231ffd98 Fix table-driven example documentation (#363)
Prior to this change, the example documentation
found in the README.md has an errant code which
won't work in the table-driven code example.

This change modifies the variable name from `t` to `tc`
so it does not conflict with the `t *testing.T` struct
definition.

* Adds a range clause to the `for` statement
* Modifies `for` statement scope to use `tc.shouldPass`, and `tc.routeVariable`

Doc: https://github.com/gorilla/mux#testing-handlers
2018-04-03 11:23:30 -07:00
Johan Svensson
4dbd923b0c Make Use() variadic (#355)
Enables neater syntax when chaining several middleware functions.
2018-03-14 09:31:26 -07:00
Geon Kim
07ba1fd60e Modify http status code to variable in README (#350)
* Modify http status code to variable

* Modify doc

* Modify README
2018-02-25 21:11:51 -08:00
Geon Kim
d284fd8421 Modify 403 status code to const variable (#349)
* Modify http status code to variable

* Modify doc
2018-02-25 08:08:54 -08:00
Kamil Kisiel
c0091a0299 Create authentication middleware example. (#340)
* Create authentication middleware example.

For #339

* Fix example test filename.
2018-01-19 23:58:19 -08:00
Franklin Harding
0fdf828bb2 [docs] Clarify SetURLVars (#335)
* [docs] Clarify SetURLVars

Clarify in documentation that SetURLVars does not modify the given
*htttp.Request, provide an example of usage.

* Short and sweet function doc, example test.
2018-01-19 22:28:49 -08:00
Kamil Kisiel
077b44c2cf [docs] Document route.Get* methods consistently (#338)
They actually return an error instead of an empty list. `GetMethods` happened to not return an error, but it should for consistency, so I added that as well.
2018-01-19 20:51:41 -08:00
Kamil Kisiel
dc83507598 [docs] README.md: Improve "walking routes" example. (#337) (#323)
Fixes #323.

Also removed the duplicate "listing routes" example.
2018-01-19 20:47:48 -08:00
safeoy
3dbb9ed96e README.md: add miss "time" (#336) 2018-01-19 20:20:16 -08:00
Matt Silverlock
ad8790881f [docs] Fix doc.go (#333)
Addresses https://github.com/gorilla/mux/pull/294#discussion_r162309666
2018-01-18 09:53:57 -08:00
Matt Silverlock
69dae3b874 [docs] Add testing example (#331) 2018-01-16 23:16:36 -08:00
Matt Silverlock
63c5c2f1f0 [docs] Fix Middleware docs typos (#332) 2018-01-16 23:16:06 -08:00
Kamil Kisiel
85e6bfff1a Update doc.go: r.AddMiddleware(...) -> r.Use(...) 2018-01-16 17:18:53 -08:00
Kush Mansingh
0b74e3d0fe Make shutdown docs compilable (#330) 2018-01-16 14:43:47 -08:00
23 changed files with 1609 additions and 628 deletions

75
.circleci/config.yml Normal file
View File

@@ -0,0 +1,75 @@
version: 2.0
jobs:
# Base test configuration for Go library tests Each distinct version should
# inherit this base, and override (at least) the container image used.
"test": &test
docker:
- image: circleci/golang:latest
working_directory: /go/src/github.com/gorilla/mux
steps: &steps
- checkout
- run: go version
- run: go get -t -v ./...
# Only run gofmt, vet & lint against the latest Go version
- run: >
if [[ "$LATEST" = true ]]; then
go get -u golang.org/x/lint/golint
golint ./...
fi
- run: >
if [[ "$LATEST" = true ]]; then
diff -u <(echo -n) <(gofmt -d .)
fi
- run: >
if [[ "$LATEST" = true ]]; then
go vet -v .
fi
- run: go test -v -race ./...
"latest":
<<: *test
environment:
LATEST: true
"1.12":
<<: *test
docker:
- image: circleci/golang:1.12
"1.11":
<<: *test
docker:
- image: circleci/golang:1.11
"1.10":
<<: *test
docker:
- image: circleci/golang:1.10
"1.9":
<<: *test
docker:
- image: circleci/golang:1.9
"1.8":
<<: *test
docker:
- image: circleci/golang:1.8
"1.7":
<<: *test
docker:
- image: circleci/golang:1.7
workflows:
version: 2
build:
jobs:
- "latest"
- "1.12"
- "1.11"
- "1.10"
- "1.9"
- "1.8"
- "1.7"

8
.github/release-drafter.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
# Config for https://github.com/apps/release-drafter
template: |
<summary of changes here>
## CHANGELOG
$CHANGES

12
.github/stale.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
daysUntilStale: 75
daysUntilClose: 14
# Issues with these labels will never be considered stale
exemptLabels:
- proposal
- needs review
- build system
staleLabel: stale
markComment: >
This issue has been automatically marked as stale because it hasn't seen
a recent update. It'll be automatically closed in a few days.
closeComment: false

View File

@@ -1,22 +0,0 @@
language: go
sudo: false
matrix:
include:
- go: 1.5
- go: 1.6
- go: 1.7
- go: 1.8
- go: 1.9
- go: tip
allow_failures:
- go: tip
install:
- # Skip
script:
- go get -t -v ./...
- diff -u <(echo -n) <(gofmt -d .)
- go tool vet .
- go test -v -race ./...

8
AUTHORS Normal file
View File

@@ -0,0 +1,8 @@
# This is the official list of gorilla/mux authors for copyright purposes.
#
# Please keep the list sorted.
Google LLC (https://opensource.google.com/)
Kamil Kisielk <kamil@kamilkisiel.net>
Matt Silverlock <matt@eatsleeprepeat.net>
Rodrigo Moraes (https://github.com/moraes)

View File

@@ -1,11 +0,0 @@
**What version of Go are you running?** (Paste the output of `go version`)
**What version of gorilla/mux are you at?** (Paste the output of `git rev-parse HEAD` inside `$GOPATH/src/github.com/gorilla/mux`)
**Describe your problem** (and what you have tried so far)
**Paste a minimal, runnable, reproduction of your issue below** (use backticks to format it)

View File

@@ -1,4 +1,4 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are

382
README.md
View File

@@ -1,12 +1,13 @@
gorilla/mux
===
# gorilla/mux
[![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux)
[![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux)
[![CircleCI](https://circleci.com/gh/gorilla/mux.svg?style=svg)](https://circleci.com/gh/gorilla/mux)
[![Sourcegraph](https://sourcegraph.com/github.com/gorilla/mux/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/mux?badge)
![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png)
http://www.gorillatoolkit.org/pkg/mux
https://www.gorillatoolkit.org/pkg/mux
Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to
their respective handler.
@@ -29,6 +30,8 @@ The name mux stands for "HTTP request multiplexer". Like the standard `http.Serv
* [Walking Routes](#walking-routes)
* [Graceful Shutdown](#graceful-shutdown)
* [Middleware](#middleware)
* [Handling CORS Requests](#handling-cors-requests)
* [Testing Handlers](#testing-handlers)
* [Full Example](#full-example)
---
@@ -87,7 +90,7 @@ r := mux.NewRouter()
// Only matches if domain is "www.example.com".
r.Host("www.example.com")
// Matches a dynamic subdomain.
r.Host("{subdomain:[a-z]+}.domain.com")
r.Host("{subdomain:[a-z]+}.example.com")
```
There are several other matchers that can be added. To match path prefixes:
@@ -178,70 +181,13 @@ s.HandleFunc("/{key}/", ProductHandler)
// "/products/{key}/details"
s.HandleFunc("/{key}/details", ProductDetailsHandler)
```
### Listing Routes
Routes on a mux can be listed using the Router.Walk method—useful for generating documentation:
```go
package main
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/mux"
)
func handler(w http.ResponseWriter, r *http.Request) {
return
}
func main() {
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.HandleFunc("/products", handler).Methods("POST")
r.HandleFunc("/articles", handler).Methods("GET")
r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT")
r.HandleFunc("/authors", handler).Queries("surname", "{surname}")
r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
t, err := route.GetPathTemplate()
if err != nil {
return err
}
qt, err := route.GetQueriesTemplates()
if err != nil {
return err
}
// p will contain regular expression is compatible with regular expression in Perl, Python, and other languages.
// for instance the regular expression for path '/articles/{id}' will be '^/articles/(?P<v0>[^/]+)$'
p, err := route.GetPathRegexp()
if err != nil {
return err
}
// qr will contain a list of regular expressions with the same semantics as GetPathRegexp,
// just applied to the Queries pairs instead, e.g., 'Queries("surname", "{surname}") will return
// {"^surname=(?P<v0>.*)$}. Where each combined query pair will have an entry in the list.
qr, err := route.GetQueriesRegexp()
if err != nil {
return err
}
m, err := route.GetMethods()
if err != nil {
return err
}
fmt.Println(strings.Join(m, ","), strings.Join(qt, ","), strings.Join(qr, ","), t, p)
return nil
})
http.Handle("/", r)
}
```
### Static Files
Note that the path provided to `PathPrefix()` represents a "wildcard": calling
`PathPrefix("/static/").Handler(...)` means that the handler will be passed any
request that matches "/static/*". This makes it easy to serve static files with mux:
request that matches "/static/\*". This makes it easy to serve static files with mux:
```go
func main() {
@@ -294,13 +240,13 @@ This also works for host and query value variables:
```go
r := mux.NewRouter()
r.Host("{subdomain}.domain.com").
r.Host("{subdomain}.example.com").
Path("/articles/{category}/{id:[0-9]+}").
Queries("filter", "{filter}").
HandlerFunc(ArticleHandler).
Name("article")
// url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla"
// url.String() will be "http://news.example.com/articles/technology/42?filter=gorilla"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42",
@@ -320,7 +266,7 @@ r.HeadersRegexp("Content-Type", "application/(text|json)")
There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do:
```go
// "http://news.domain.com/"
// "http://news.example.com/"
host, err := r.Get("article").URLHost("subdomain", "news")
// "/articles/technology/42"
@@ -331,12 +277,12 @@ And if you use subrouters, host and path defined separately can be built as well
```go
r := mux.NewRouter()
s := r.Host("{subdomain}.domain.com").Subrouter()
s := r.Host("{subdomain}.example.com").Subrouter()
s.Path("/articles/{category}/{id:[0-9]+}").
HandlerFunc(ArticleHandler).
Name("article")
// "http://news.domain.com/articles/technology/42"
// "http://news.example.com/articles/technology/42"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42")
@@ -348,41 +294,58 @@ The `Walk` function on `mux.Router` can be used to visit all of the routes that
the following prints all of the registered routes:
```go
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.HandleFunc("/products", handler).Methods("POST")
r.HandleFunc("/articles", handler).Methods("GET")
r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT")
r.HandleFunc("/authors", handler).Queries("surname", "{surname}")
r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
t, err := route.GetPathTemplate()
if err != nil {
return err
}
qt, err := route.GetQueriesTemplates()
if err != nil {
return err
}
// p will contain a regular expression that is compatible with regular expressions in Perl, Python, and other languages.
// For example, the regular expression for path '/articles/{id}' will be '^/articles/(?P<v0>[^/]+)$'.
p, err := route.GetPathRegexp()
if err != nil {
return err
}
// qr will contain a list of regular expressions with the same semantics as GetPathRegexp,
// just applied to the Queries pairs instead, e.g., 'Queries("surname", "{surname}") will return
// {"^surname=(?P<v0>.*)$}. Where each combined query pair will have an entry in the list.
qr, err := route.GetQueriesRegexp()
if err != nil {
return err
}
m, err := route.GetMethods()
if err != nil {
return err
}
fmt.Println(strings.Join(m, ","), strings.Join(qt, ","), strings.Join(qr, ","), t, p)
return nil
})
package main
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/mux"
)
func handler(w http.ResponseWriter, r *http.Request) {
return
}
func main() {
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.HandleFunc("/products", handler).Methods("POST")
r.HandleFunc("/articles", handler).Methods("GET")
r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT")
r.HandleFunc("/authors", handler).Queries("surname", "{surname}")
err := r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
pathTemplate, err := route.GetPathTemplate()
if err == nil {
fmt.Println("ROUTE:", pathTemplate)
}
pathRegexp, err := route.GetPathRegexp()
if err == nil {
fmt.Println("Path regexp:", pathRegexp)
}
queriesTemplates, err := route.GetQueriesTemplates()
if err == nil {
fmt.Println("Queries templates:", strings.Join(queriesTemplates, ","))
}
queriesRegexps, err := route.GetQueriesRegexp()
if err == nil {
fmt.Println("Queries regexps:", strings.Join(queriesRegexps, ","))
}
methods, err := route.GetMethods()
if err == nil {
fmt.Println("Methods:", strings.Join(methods, ","))
}
fmt.Println()
return nil
})
if err != nil {
fmt.Println(err)
}
http.Handle("/", r)
}
```
### Graceful Shutdown
@@ -399,6 +362,7 @@ import (
"net/http"
"os"
"os/signal"
"time"
"github.com/gorilla/mux"
)
@@ -410,7 +374,7 @@ func main() {
r := mux.NewRouter()
// Add your routes as needed
srv := &http.Server{
Addr: "0.0.0.0:8080",
// Good practice to set timeouts to avoid Slowloris attacks.
@@ -426,7 +390,7 @@ func main() {
log.Println(err)
}
}()
c := make(chan os.Signal, 1)
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C)
// SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught.
@@ -436,7 +400,8 @@ func main() {
<-c
// Create a deadline to wait for.
ctx, cancel := context.WithTimeout(ctx, wait)
ctx, cancel := context.WithTimeout(context.Background(), wait)
defer cancel()
// Doesn't block if no connections, but will otherwise wait
// until the timeout deadline.
srv.Shutdown(ctx)
@@ -464,7 +429,7 @@ Typically, the returned handler is a closure which does something with the http.
A very basic middleware which logs the URI of the request being handled could be written as:
```go
func simpleMw(next http.Handler) http.Handler {
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
log.Println(r.RequestURI)
@@ -474,12 +439,12 @@ func simpleMw(next http.Handler) http.Handler {
}
```
Middlewares can be added to a router using `Router.AddMiddlewareFunc()`:
Middlewares can be added to a router using `Router.Use()`:
```go
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.AddMiddleware(simpleMw)
r.Use(loggingMiddleware)
```
A more complex authentication middleware, which maps session token to users, could be written as:
@@ -502,7 +467,7 @@ func (amw *authenticationMiddleware) Populate() {
func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("X-Session-Token")
if user, found := amw.tokenUsers[token]; found {
// We found the token in our map
log.Printf("Authenticated user %s\n", user)
@@ -510,7 +475,7 @@ func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler
next.ServeHTTP(w, r)
} else {
// Write an error and stop the handler chain
http.Error(w, "Forbidden", 403)
http.Error(w, "Forbidden", http.StatusForbidden)
}
})
}
@@ -523,10 +488,203 @@ r.HandleFunc("/", handler)
amw := authenticationMiddleware{}
amw.Populate()
r.AddMiddlewareFunc(amw.Middleware)
r.Use(amw.Middleware)
```
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares *should* write to `ResponseWriter` if they *are* going to terminate the request, and they *should not* write to `ResponseWriter` if they *are not* going to terminate it.
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
### Handling CORS Requests
[CORSMethodMiddleware](https://godoc.org/github.com/gorilla/mux#CORSMethodMiddleware) intends to make it easier to strictly set the `Access-Control-Allow-Methods` response header.
* You will still need to use your own CORS handler to set the other CORS headers such as `Access-Control-Allow-Origin`
* The middleware will set the `Access-Control-Allow-Methods` header to all the method matchers (e.g. `r.Methods(http.MethodGet, http.MethodPut, http.MethodOptions)` -> `Access-Control-Allow-Methods: GET,PUT,OPTIONS`) on a route
* If you do not specify any methods, then:
> _Important_: there must be an `OPTIONS` method matcher for the middleware to set the headers.
Here is an example of using `CORSMethodMiddleware` along with a custom `OPTIONS` handler to set all the required CORS headers:
```go
package main
import (
"net/http"
"github.com/gorilla/mux"
)
func main() {
r := mux.NewRouter()
// IMPORTANT: you must specify an OPTIONS method matcher for the middleware to set CORS headers
r.HandleFunc("/foo", fooHandler).Methods(http.MethodGet, http.MethodPut, http.MethodPatch, http.MethodOptions)
r.Use(mux.CORSMethodMiddleware(r))
http.ListenAndServe(":8080", r)
}
func fooHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == http.MethodOptions {
return
}
w.Write([]byte("foo"))
}
```
And an request to `/foo` using something like:
```bash
curl localhost:8080/foo -v
```
Would look like:
```bash
* Trying ::1...
* TCP_NODELAY set
* Connected to localhost (::1) port 8080 (#0)
> GET /foo HTTP/1.1
> Host: localhost:8080
> User-Agent: curl/7.59.0
> Accept: */*
>
< HTTP/1.1 200 OK
< Access-Control-Allow-Methods: GET,PUT,PATCH,OPTIONS
< Access-Control-Allow-Origin: *
< Date: Fri, 28 Jun 2019 20:13:30 GMT
< Content-Length: 3
< Content-Type: text/plain; charset=utf-8
<
* Connection #0 to host localhost left intact
foo
```
### Testing Handlers
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
First, our simple HTTP handler:
```go
// endpoints.go
package main
func HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
// A very simple health check.
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
// In the future we could report back on the status of our DB, or our cache
// (e.g. Redis) by performing a simple PING, and include them in the response.
io.WriteString(w, `{"alive": true}`)
}
func main() {
r := mux.NewRouter()
r.HandleFunc("/health", HealthCheckHandler)
log.Fatal(http.ListenAndServe("localhost:8080", r))
}
```
Our test code:
```go
// endpoints_test.go
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestHealthCheckHandler(t *testing.T) {
// Create a request to pass to our handler. We don't have any query parameters for now, so we'll
// pass 'nil' as the third parameter.
req, err := http.NewRequest("GET", "/health", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(HealthCheckHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response body is what we expect.
expected := `{"alive": true}`
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v",
rr.Body.String(), expected)
}
}
```
In the case that our routes have [variables](#examples), we can pass those in the request. We could write
[table-driven tests](https://dave.cheney.net/2013/06/09/writing-table-driven-tests-in-go) to test multiple
possible route variables as needed.
```go
// endpoints.go
func main() {
r := mux.NewRouter()
// A route with a route variable:
r.HandleFunc("/metrics/{type}", MetricsHandler)
log.Fatal(http.ListenAndServe("localhost:8080", r))
}
```
Our test file, with a table-driven test of `routeVariables`:
```go
// endpoints_test.go
func TestMetricsHandler(t *testing.T) {
tt := []struct{
routeVariable string
shouldPass bool
}{
{"goroutines", true},
{"heap", true},
{"counters", true},
{"queries", true},
{"adhadaeqm3k", false},
}
for _, tc := range tt {
path := fmt.Sprintf("/metrics/%s", tc.routeVariable)
req, err := http.NewRequest("GET", path, nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
// Need to create a router that we can pass the request through so that the vars will be added to the context
router := mux.NewRouter()
router.HandleFunc("/metrics/{type}", MetricsHandler)
router.ServeHTTP(rr, req)
// In this case, our MetricsHandler returns a non-200 response
// for a route variable it doesn't know about.
if rr.Code == http.StatusOK && !tc.shouldPass {
t.Errorf("handler should have failed on routeVariable %s: got %v want %v",
tc.routeVariable, rr.Code, http.StatusOK)
}
}
}
```
## Full Example

View File

@@ -1,5 +1,3 @@
// +build go1.7
package mux
import (
@@ -18,7 +16,3 @@ func contextSet(r *http.Request, key, val interface{}) *http.Request {
return r.WithContext(context.WithValue(r.Context(), key, val))
}
func contextClear(r *http.Request) {
return
}

View File

@@ -1,26 +0,0 @@
// +build !go1.7
package mux
import (
"net/http"
"github.com/gorilla/context"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return context.Get(r, key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
context.Set(r, key, val)
return r
}
func contextClear(r *http.Request) {
context.Clear(r)
}

View File

@@ -1,40 +0,0 @@
// +build !go1.7
package mux
import (
"net/http"
"testing"
"github.com/gorilla/context"
)
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/", func1).Name("func1")
req, _ := http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
res := new(http.ResponseWriter)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); ok {
t.Error("Context should have been cleared at end of request")
}
r.KeepContext = true
req, _ = http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); !ok {
t.Error("Context should NOT have been cleared at end of request")
}
}

View File

@@ -1,5 +1,3 @@
// +build go1.7
package mux
import (

9
doc.go
View File

@@ -239,8 +239,7 @@ as well:
"category", "technology",
"id", "42")
Since **vX.Y.Z**, mux supports the addition of middlewares to a [Router](https://godoc.org/github.com/gorilla/mux#Router), which are executed if a
match is found (including subrouters). Middlewares are defined using the de facto standard type:
Mux supports the addition of middlewares to a Router, which are executed in the order they are added if a match is found, including its subrouters. Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or ResponseWriter hijacking.
type MiddlewareFunc func(http.Handler) http.Handler
@@ -261,7 +260,7 @@ Middlewares can be added to a router using `Router.Use()`:
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.AddMiddleware(simpleMw)
r.Use(simpleMw)
A more complex authentication middleware, which maps session token to users, could be written as:
@@ -288,7 +287,7 @@ A more complex authentication middleware, which maps session token to users, cou
log.Printf("Authenticated user %s\n", user)
next.ServeHTTP(w, r)
} else {
http.Error(w, "Forbidden", 403)
http.Error(w, "Forbidden", http.StatusForbidden)
}
})
}
@@ -296,7 +295,7 @@ A more complex authentication middleware, which maps session token to users, cou
r := mux.NewRouter()
r.HandleFunc("/", handler)
amw := authenticationMiddleware{}
amw := authenticationMiddleware{tokenUsers: make(map[string]string)}
amw.Populate()
r.Use(amw.Middleware)

View File

@@ -0,0 +1,46 @@
package mux_test
import (
"log"
"net/http"
"github.com/gorilla/mux"
)
// Define our struct
type authenticationMiddleware struct {
tokenUsers map[string]string
}
// Initialize it somewhere
func (amw *authenticationMiddleware) Populate() {
amw.tokenUsers["00000000"] = "user0"
amw.tokenUsers["aaaaaaaa"] = "userA"
amw.tokenUsers["05f717e5"] = "randomUser"
amw.tokenUsers["deadbeef"] = "user0"
}
// Middleware function, which will be called for each request
func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("X-Session-Token")
if user, found := amw.tokenUsers[token]; found {
// We found the token in our map
log.Printf("Authenticated user %s\n", user)
next.ServeHTTP(w, r)
} else {
http.Error(w, "Forbidden", http.StatusForbidden)
}
})
}
func Example_authenticationMiddleware() {
r := mux.NewRouter()
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Do something here
})
amw := authenticationMiddleware{make(map[string]string)}
amw.Populate()
r.Use(amw.Middleware)
}

View File

@@ -0,0 +1,37 @@
package mux_test
import (
"fmt"
"net/http"
"net/http/httptest"
"github.com/gorilla/mux"
)
func ExampleCORSMethodMiddleware() {
r := mux.NewRouter()
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
// Handle the request
}).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "http://example.com")
w.Header().Set("Access-Control-Max-Age", "86400")
}).Methods(http.MethodOptions)
r.Use(mux.CORSMethodMiddleware(r))
rw := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "/foo", nil) // needs to be OPTIONS
req.Header.Set("Access-Control-Request-Method", "POST") // needs to be non-empty
req.Header.Set("Access-Control-Request-Headers", "Authorization") // needs to be non-empty
req.Header.Set("Origin", "http://example.com") // needs to be non-empty
r.ServeHTTP(rw, req)
fmt.Println(rw.Header().Get("Access-Control-Allow-Methods"))
fmt.Println(rw.Header().Get("Access-Control-Allow-Origin"))
// Output:
// GET,PUT,PATCH,OPTIONS
// http://example.com
}

1
go.mod Normal file
View File

@@ -0,0 +1 @@
module github.com/gorilla/mux

View File

@@ -1,6 +1,9 @@
package mux
import "net/http"
import (
"net/http"
"strings"
)
// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
@@ -12,17 +15,65 @@ type middleware interface {
Middleware(handler http.Handler) http.Handler
}
// MiddlewareFunc also implements the middleware interface.
// Middleware allows MiddlewareFunc to implement the middleware interface.
func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler {
return mw(handler)
}
// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (r *Router) Use(mwf MiddlewareFunc) {
r.middlewares = append(r.middlewares, mwf)
func (r *Router) Use(mwf ...MiddlewareFunc) {
for _, fn := range mwf {
r.middlewares = append(r.middlewares, fn)
}
}
// useInterface appends a middleware to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (r *Router) useInterface(mw middleware) {
r.middlewares = append(r.middlewares, mw)
}
// CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header
// on requests for routes that have an OPTIONS method matcher to all the method matchers on
// the route. Routes that do not explicitly handle OPTIONS requests will not be processed
// by the middleware. See examples for usage.
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
allMethods, err := getAllMethodsForRoute(r, req)
if err == nil {
for _, v := range allMethods {
if v == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ","))
}
}
}
next.ServeHTTP(w, req)
})
}
}
// getAllMethodsForRoute returns all the methods from method matchers matching a given
// request.
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
var allMethods []string
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}
allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})
return allMethods, err
}

View File

@@ -27,12 +27,12 @@ func TestMiddlewareAdd(t *testing.T) {
router.useInterface(mw)
if len(router.middlewares) != 1 || router.middlewares[0] != mw {
t.Fatal("Middleware was not added correctly")
t.Fatal("Middleware interface was not added correctly")
}
router.Use(mw.Middleware)
if len(router.middlewares) != 2 {
t.Fatal("MiddlewareFunc method was not added correctly")
t.Fatal("Middleware method was not added correctly")
}
banalMw := func(handler http.Handler) http.Handler {
@@ -40,7 +40,7 @@ func TestMiddlewareAdd(t *testing.T) {
}
router.Use(banalMw)
if len(router.middlewares) != 3 {
t.Fatal("MiddlewareFunc method was not added correctly")
t.Fatal("Middleware function was not added correctly")
}
}
@@ -54,34 +54,37 @@ func TestMiddleware(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/")
// Test regular middleware call
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
t.Run("regular middleware call", func(t *testing.T) {
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
// Middleware should not be called for 404
req = newRequest("GET", "/not/found")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
t.Run("not called for 404", func(t *testing.T) {
req = newRequest("GET", "/not/found")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
// Middleware should not be called if there is a method mismatch
req = newRequest("POST", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
// Add the middleware again as function
router.Use(mw.Middleware)
req = newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 3 {
t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
}
t.Run("not called for method mismatch", func(t *testing.T) {
req = newRequest("POST", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
t.Run("regular call using function middleware", func(t *testing.T) {
router.Use(mw.Middleware)
req = newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 3 {
t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
}
})
}
func TestMiddlewareSubrouter(t *testing.T) {
@@ -97,42 +100,56 @@ func TestMiddlewareSubrouter(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 0 {
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
}
t.Run("not called for route outside subrouter", func(t *testing.T) {
router.ServeHTTP(rw, req)
if mw.timesCalled != 0 {
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
}
})
req = newRequest("GET", "/sub/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 0 {
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
}
t.Run("not called for subrouter root 404", func(t *testing.T) {
req = newRequest("GET", "/sub/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 0 {
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
}
})
req = newRequest("GET", "/sub/x")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
t.Run("called once for route inside subrouter", func(t *testing.T) {
req = newRequest("GET", "/sub/x")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
req = newRequest("GET", "/sub/not/found")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
t.Run("not called for 404 inside subrouter", func(t *testing.T) {
req = newRequest("GET", "/sub/not/found")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
router.useInterface(mw)
t.Run("middleware added to router", func(t *testing.T) {
router.useInterface(mw)
req = newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 2 {
t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
}
t.Run("called once for route outside subrouter", func(t *testing.T) {
req = newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 2 {
t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
}
})
req = newRequest("GET", "/sub/x")
router.ServeHTTP(rw, req)
if mw.timesCalled != 4 {
t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
}
t.Run("called twice for route inside subrouter", func(t *testing.T) {
req = newRequest("GET", "/sub/x")
router.ServeHTTP(rw, req)
if mw.timesCalled != 4 {
t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
}
})
})
}
func TestMiddlewareExecution(t *testing.T) {
@@ -144,30 +161,33 @@ func TestMiddlewareExecution(t *testing.T) {
w.Write(handlerStr)
})
rw := NewRecorder()
req := newRequest("GET", "/")
t.Run("responds normally without middleware", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/")
// Test handler-only call
router.ServeHTTP(rw, req)
router.ServeHTTP(rw, req)
if bytes.Compare(rw.Body.Bytes(), handlerStr) != 0 {
t.Fatal("Handler response is not what it should be")
}
// Test middleware call
rw = NewRecorder()
router.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(mwStr)
h.ServeHTTP(w, r)
})
if !bytes.Equal(rw.Body.Bytes(), handlerStr) {
t.Fatal("Handler response is not what it should be")
}
})
router.ServeHTTP(rw, req)
if bytes.Compare(rw.Body.Bytes(), append(mwStr, handlerStr...)) != 0 {
t.Fatal("Middleware + handler response is not what it should be")
}
t.Run("responds with handler and middleware response", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/")
router.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(mwStr)
h.ServeHTTP(w, r)
})
})
router.ServeHTTP(rw, req)
if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) {
t.Fatal("Middleware + handler response is not what it should be")
}
})
}
func TestMiddlewareNotFound(t *testing.T) {
@@ -186,26 +206,29 @@ func TestMiddlewareNotFound(t *testing.T) {
})
// Test not found call with default handler
rw := NewRecorder()
req := newRequest("GET", "/notfound")
t.Run("not called", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/notfound")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a 404")
}
// Test not found call with custom handler
rw = NewRecorder()
req = newRequest("GET", "/notfound")
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Custom 404 handler"))
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a 404")
}
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a custom 404")
}
t.Run("not called with custom not found handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/notfound")
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Custom 404 handler"))
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a custom 404")
}
})
}
func TestMiddlewareMethodMismatch(t *testing.T) {
@@ -224,27 +247,29 @@ func TestMiddlewareMethodMismatch(t *testing.T) {
})
})
// Test method mismatch
rw := NewRecorder()
req := newRequest("POST", "/")
t.Run("not called", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("POST", "/")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
// Test not found call
rw = NewRecorder()
req = newRequest("POST", "/")
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Method not allowed"))
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
t.Run("not called with custom method not allowed handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("POST", "/")
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Method not allowed"))
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
})
}
func TestMiddlewareNotFoundSubrouter(t *testing.T) {
@@ -268,27 +293,29 @@ func TestMiddlewareNotFoundSubrouter(t *testing.T) {
})
})
// Test not found call for default handler
rw := NewRecorder()
req := newRequest("GET", "/sub/notfound")
t.Run("not called", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/sub/notfound")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a 404")
}
// Test not found call with custom handler
rw = NewRecorder()
req = newRequest("GET", "/sub/notfound")
subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Custom 404 handler"))
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a 404")
}
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a custom 404")
}
t.Run("not called with custom not found handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/sub/notfound")
subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Custom 404 handler"))
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a custom 404")
}
})
}
func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
@@ -312,25 +339,207 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
})
})
// Test method mismatch without custom handler
rw := NewRecorder()
req := newRequest("POST", "/sub/")
t.Run("not called", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("POST", "/sub/")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
})
t.Run("not called with custom method not allowed handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("POST", "/sub/")
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Method not allowed"))
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
})
}
func TestCORSMethodMiddleware(t *testing.T) {
testCases := []struct {
name string
registerRoutes func(r *Router)
requestHeader http.Header
requestMethod string
requestPath string
expectedAccessControlAllowMethodsHeader string
expectedResponse string
}{
{
name: "does not set without OPTIONS matcher",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
},
requestMethod: "GET",
requestPath: "/foo",
expectedAccessControlAllowMethodsHeader: "",
expectedResponse: "a",
},
{
name: "sets on non OPTIONS",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
},
requestMethod: "GET",
requestPath: "/foo",
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
expectedResponse: "a",
},
{
name: "sets without preflight headers",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
},
requestMethod: "OPTIONS",
requestPath: "/foo",
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
expectedResponse: "b",
},
{
name: "does not set on error",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a"))
},
requestMethod: "OPTIONS",
requestPath: "/foo",
expectedAccessControlAllowMethodsHeader: "",
expectedResponse: "a",
},
{
name: "sets header on valid preflight",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
},
requestMethod: "OPTIONS",
requestPath: "/foo",
requestHeader: http.Header{
"Access-Control-Request-Method": []string{"GET"},
"Access-Control-Request-Headers": []string{"Authorization"},
"Origin": []string{"http://example.com"},
},
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
expectedResponse: "b",
},
{
name: "does not set methods from unmatching routes",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("c")).Methods(http.MethodDelete)
r.HandleFunc("/foo/bar", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo/bar", stringHandler("b")).Methods(http.MethodOptions)
},
requestMethod: "OPTIONS",
requestPath: "/foo/bar",
requestHeader: http.Header{
"Access-Control-Request-Method": []string{"GET"},
"Access-Control-Request-Headers": []string{"Authorization"},
"Origin": []string{"http://example.com"},
},
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
expectedResponse: "b",
},
}
// Test method mismatch with custom handler
rw = NewRecorder()
req = newRequest("POST", "/sub/")
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
router := NewRouter()
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Method not allowed"))
})
router.ServeHTTP(rw, req)
tt.registerRoutes(router)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
router.Use(CORSMethodMiddleware(router))
rw := NewRecorder()
req := newRequest(tt.requestMethod, tt.requestPath)
req.Header = tt.requestHeader
router.ServeHTTP(rw, req)
actualMethodsHeader := rw.Header().Get("Access-Control-Allow-Methods")
if actualMethodsHeader != tt.expectedAccessControlAllowMethodsHeader {
t.Fatalf("Expected Access-Control-Allow-Methods to equal %s but got %s", tt.expectedAccessControlAllowMethodsHeader, actualMethodsHeader)
}
actualResponse := rw.Body.String()
if actualResponse != tt.expectedResponse {
t.Fatalf("Expected response to equal %s but got %s", tt.expectedResponse, actualResponse)
}
})
}
}
func TestMiddlewareOnMultiSubrouter(t *testing.T) {
first := "first"
second := "second"
notFound := "404 not found"
router := NewRouter()
firstSubRouter := router.PathPrefix("/").Subrouter()
secondSubRouter := router.PathPrefix("/").Subrouter()
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte(notFound))
})
firstSubRouter.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) {
})
secondSubRouter.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) {
})
firstSubRouter.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(first))
h.ServeHTTP(w, r)
})
})
secondSubRouter.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(second))
h.ServeHTTP(w, r)
})
})
t.Run("/first uses first middleware", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/first")
router.ServeHTTP(rw, req)
if rw.Body.String() != first {
t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String())
}
})
t.Run("/second uses second middleware", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/second")
router.ServeHTTP(rw, req)
if rw.Body.String() != second {
t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String())
}
})
t.Run("uses not found handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/second/not-exist")
router.ServeHTTP(rw, req)
if rw.Body.String() != notFound {
t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String())
}
})
}

138
mux.go
View File

@@ -13,13 +13,16 @@ import (
)
var (
// ErrMethodMismatch is returned when the method in the request does not match
// the method defined against the route.
ErrMethodMismatch = errors.New("method is not allowed")
ErrNotFound = errors.New("no matching route was found")
// ErrNotFound is returned when no route match is found.
ErrNotFound = errors.New("no matching route was found")
)
// NewRouter returns a new router instance.
func NewRouter() *Router {
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
return &Router{namedRoutes: make(map[string]*Route)}
}
// Router registers routes to be matched and dispatches a handler.
@@ -47,24 +50,78 @@ type Router struct {
// Configurable Handler to be used when the request method does not match the route.
MethodNotAllowedHandler http.Handler
// Parent route, if this is a subrouter.
parent parentRoute
// Routes to be matched, in order.
routes []*Route
// Routes by name for URL building.
namedRoutes map[string]*Route
// See Router.StrictSlash(). This defines the flag for new routes.
strictSlash bool
// See Router.SkipClean(). This defines the flag for new routes.
skipClean bool
// If true, do not clear the request context after handling the request.
// This has no effect when go1.7+ is used, since the context is stored
//
// Deprecated: No effect when go1.7+ is used, since the context is stored
// on the request itself.
KeepContext bool
// see Router.UseEncodedPath(). This defines a flag for all routes.
useEncodedPath bool
// Slice of middlewares to be called after a match is found
middlewares []middleware
// configuration shared with `Route`
routeConf
}
// common route configuration shared between `Router` and `Route`
type routeConf struct {
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
useEncodedPath bool
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
strictSlash bool
// If true, when the path pattern is "/path//to", accessing "/path//to"
// will not redirect
skipClean bool
// Manager for the variables from host and path.
regexp routeRegexpGroup
// List of matchers.
matchers []matcher
// The scheme used when building URLs.
buildScheme string
buildVarsFunc BuildVarsFunc
}
// returns an effective deep copy of `routeConf`
func copyRouteConf(r routeConf) routeConf {
c := r
if r.regexp.path != nil {
c.regexp.path = copyRouteRegexp(r.regexp.path)
}
if r.regexp.host != nil {
c.regexp.host = copyRouteRegexp(r.regexp.host)
}
c.regexp.queries = make([]*routeRegexp, 0, len(r.regexp.queries))
for _, q := range r.regexp.queries {
c.regexp.queries = append(c.regexp.queries, copyRouteRegexp(q))
}
c.matchers = make([]matcher, 0, len(r.matchers))
for _, m := range r.matchers {
c.matchers = append(c.matchers, m)
}
return c
}
func copyRouteRegexp(r *routeRegexp) *routeRegexp {
c := *r
return &c
}
// Match attempts to match the given request against the router's registered routes.
@@ -95,9 +152,9 @@ func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
if r.MethodNotAllowedHandler != nil {
match.Handler = r.MethodNotAllowedHandler
return true
} else {
return false
}
return false
}
// Closest match for a router (includes sub-routers)
@@ -152,22 +209,18 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
handler = http.NotFoundHandler()
}
if !r.KeepContext {
defer contextClear(req)
}
handler.ServeHTTP(w, req)
}
// Get returns a route registered with the given name.
func (r *Router) Get(name string) *Route {
return r.getNamedRoutes()[name]
return r.namedRoutes[name]
}
// GetRoute returns a route registered with the given name. This method
// was renamed to Get() and remains here for backwards compatibility.
func (r *Router) GetRoute(name string) *Route {
return r.getNamedRoutes()[name]
return r.namedRoutes[name]
}
// StrictSlash defines the trailing slash behavior for new routes. The initial
@@ -218,55 +271,24 @@ func (r *Router) UseEncodedPath() *Router {
return r
}
// ----------------------------------------------------------------------------
// parentRoute
// ----------------------------------------------------------------------------
func (r *Router) getBuildScheme() string {
if r.parent != nil {
return r.parent.getBuildScheme()
}
return ""
}
// getNamedRoutes returns the map where named routes are registered.
func (r *Router) getNamedRoutes() map[string]*Route {
if r.namedRoutes == nil {
if r.parent != nil {
r.namedRoutes = r.parent.getNamedRoutes()
} else {
r.namedRoutes = make(map[string]*Route)
}
}
return r.namedRoutes
}
// getRegexpGroup returns regexp definitions from the parent route, if any.
func (r *Router) getRegexpGroup() *routeRegexpGroup {
if r.parent != nil {
return r.parent.getRegexpGroup()
}
return nil
}
func (r *Router) buildVars(m map[string]string) map[string]string {
if r.parent != nil {
m = r.parent.buildVars(m)
}
return m
}
// ----------------------------------------------------------------------------
// Route factories
// ----------------------------------------------------------------------------
// NewRoute registers an empty route.
func (r *Router) NewRoute() *Route {
route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath}
// initialize a route with a copy of the parent router's configuration
route := &Route{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
r.routes = append(r.routes, route)
return route
}
// Name registers a new route with a name.
// See Route.Name().
func (r *Router) Name(name string) *Route {
return r.NewRoute().Name(name)
}
// Handle registers a new route with a matcher for the URL path.
// See Route.Path() and Route.Handler().
func (r *Router) Handle(path string, handler http.Handler) *Route {

View File

@@ -48,15 +48,6 @@ type routeTest struct {
}
func TestHost(t *testing.T) {
// newRequestHost a new request with a method, url, and host header
newRequestHost := func(method, url, host string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
req.Host = host
return req
}
tests := []routeTest{
{
@@ -113,7 +104,15 @@ func TestHost(t *testing.T) {
path: "",
shouldMatch: false,
},
// BUG {new(Route).Host("aaa.bbb.ccc:1234"), newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:1234"), map[string]string{}, "aaa.bbb.ccc:1234", "", true},
{
title: "Host route with port, match with request header",
route: new(Route).Host("aaa.bbb.ccc:1234"),
request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:1234"),
vars: map[string]string{},
host: "aaa.bbb.ccc:1234",
path: "",
shouldMatch: true,
},
{
title: "Host route with port, wrong host in request header",
route: new(Route).Host("aaa.bbb.ccc:1234"),
@@ -123,6 +122,16 @@ func TestHost(t *testing.T) {
path: "",
shouldMatch: false,
},
{
title: "Host route with pattern, match with request header",
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc:1{v2:(?:23|4)}"),
request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:123"),
vars: map[string]string{"v1": "bbb", "v2": "23"},
host: "aaa.bbb.ccc:123",
path: "",
hostTemplate: `aaa.{v1:[a-z]{3}}.ccc:1{v2:(?:23|4)}`,
shouldMatch: true,
},
{
title: "Host route with pattern, match",
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"),
@@ -205,8 +214,10 @@ func TestHost(t *testing.T) {
},
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -437,10 +448,12 @@ func TestPath(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
testRegexp(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
testRegexp(t, test)
})
}
}
@@ -516,9 +529,11 @@ func TestPathPrefix(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
})
}
}
@@ -623,9 +638,11 @@ func TestSchemeHostPath(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
})
}
}
@@ -682,8 +699,10 @@ func TestHeaders(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -732,9 +751,11 @@ func TestMethods(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testMethods(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testMethods(t, test)
})
}
}
@@ -1039,11 +1060,12 @@ func TestQueries(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testQueriesTemplates(t, test)
testUseEscapedRoute(t, test)
testQueriesRegexp(t, test)
t.Run(test.title, func(t *testing.T) {
testTemplate(t, test)
testQueriesTemplates(t, test)
testUseEscapedRoute(t, test)
testQueriesRegexp(t, test)
})
}
}
@@ -1092,17 +1114,16 @@ func TestSchemes(t *testing.T) {
},
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
func TestMatcherFunc(t *testing.T) {
m := func(r *http.Request, m *RouteMatch) bool {
if r.URL.Host == "aaa.bbb.ccc" {
return true
}
return false
return r.URL.Host == "aaa.bbb.ccc"
}
tests := []routeTest{
@@ -1127,8 +1148,10 @@ func TestMatcherFunc(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -1163,8 +1186,10 @@ func TestBuildVarsFunc(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -1174,7 +1199,6 @@ func TestSubRouter(t *testing.T) {
subrouter3 := new(Route).PathPrefix("/foo").Subrouter()
subrouter4 := new(Route).PathPrefix("/foo/bar").Subrouter()
subrouter5 := new(Route).PathPrefix("/{category}").Subrouter()
tests := []routeTest{
{
route: subrouter1.Path("/{v2:[a-z]+}"),
@@ -1269,6 +1293,106 @@ func TestSubRouter(t *testing.T) {
pathTemplate: `/{category}`,
shouldMatch: true,
},
{
title: "Mismatch method specified on parent route",
route: new(Route).Methods("POST").PathPrefix("/foo").Subrouter().Path("/"),
request: newRequest("GET", "http://localhost/foo/"),
vars: map[string]string{},
host: "",
path: "/foo/",
pathTemplate: `/foo/`,
shouldMatch: false,
},
{
title: "Match method specified on parent route",
route: new(Route).Methods("POST").PathPrefix("/foo").Subrouter().Path("/"),
request: newRequest("POST", "http://localhost/foo/"),
vars: map[string]string{},
host: "",
path: "/foo/",
pathTemplate: `/foo/`,
shouldMatch: true,
},
{
title: "Mismatch scheme specified on parent route",
route: new(Route).Schemes("https").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Match scheme specified on parent route",
route: new(Route).Schemes("http").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: true,
},
{
title: "No match header specified on parent route",
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Header mismatch value specified on parent route",
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
request: newRequestWithHeaders("GET", "http://localhost/", "X-Forwarded-Proto", "http"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Header match value specified on parent route",
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
request: newRequestWithHeaders("GET", "http://localhost/", "X-Forwarded-Proto", "https"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: true,
},
{
title: "Query specified on parent route not present",
route: new(Route).Headers("key", "foobar").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Query mismatch value specified on parent route",
route: new(Route).Queries("key", "foobar").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/?key=notfoobar"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Query match value specified on subroute",
route: new(Route).Queries("key", "foobar").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/?key=foobar"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: true,
},
{
title: "Build with scheme on parent router",
route: new(Route).Schemes("ftp").Host("google.com").Subrouter().Path("/"),
@@ -1294,9 +1418,11 @@ func TestSubRouter(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
})
}
}
@@ -1315,14 +1441,24 @@ func TestNamedRoutes(t *testing.T) {
r3.NewRoute().Name("g")
r3.NewRoute().Name("h")
r3.NewRoute().Name("i")
r3.Name("j")
if r1.namedRoutes == nil || len(r1.namedRoutes) != 9 {
t.Errorf("Expected 9 named routes, got %v", r1.namedRoutes)
} else if r1.Get("i") == nil {
if r1.namedRoutes == nil || len(r1.namedRoutes) != 10 {
t.Errorf("Expected 10 named routes, got %v", r1.namedRoutes)
} else if r1.Get("j") == nil {
t.Errorf("Subroute name not registered")
}
}
func TestNameMultipleCalls(t *testing.T) {
r1 := NewRouter()
rt := r1.NewRoute().Name("foo").Name("bar")
err := rt.GetError()
if err == nil {
t.Errorf("Expected an error")
}
}
func TestStrictSlash(t *testing.T) {
r := NewRouter()
r.StrictSlash(true)
@@ -1391,9 +1527,11 @@ func TestStrictSlash(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
})
}
}
@@ -1425,8 +1563,10 @@ func TestUseEncodedPath(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -1478,12 +1618,16 @@ func TestWalkSingleDepth(t *testing.T) {
func TestWalkNested(t *testing.T) {
router := NewRouter()
g := router.Path("/g").Subrouter()
o := g.PathPrefix("/o").Subrouter()
r := o.PathPrefix("/r").Subrouter()
i := r.PathPrefix("/i").Subrouter()
l1 := i.PathPrefix("/l").Subrouter()
l2 := l1.PathPrefix("/l").Subrouter()
routeSubrouter := func(r *Route) (*Route, *Router) {
return r, r.Subrouter()
}
gRoute, g := routeSubrouter(router.Path("/g"))
oRoute, o := routeSubrouter(g.PathPrefix("/o"))
rRoute, r := routeSubrouter(o.PathPrefix("/r"))
iRoute, i := routeSubrouter(r.PathPrefix("/i"))
l1Route, l1 := routeSubrouter(i.PathPrefix("/l"))
l2Route, l2 := routeSubrouter(l1.PathPrefix("/l"))
l2.Path("/a")
testCases := []struct {
@@ -1491,12 +1635,12 @@ func TestWalkNested(t *testing.T) {
ancestors []*Route
}{
{"/g", []*Route{}},
{"/g/o", []*Route{g.parent.(*Route)}},
{"/g/o/r", []*Route{g.parent.(*Route), o.parent.(*Route)}},
{"/g/o/r/i", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route)}},
{"/g/o/r/i/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route)}},
{"/g/o/r/i/l/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route)}},
{"/g/o/r/i/l/l/a", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route), l2.parent.(*Route)}},
{"/g/o", []*Route{gRoute}},
{"/g/o/r", []*Route{gRoute, oRoute}},
{"/g/o/r/i", []*Route{gRoute, oRoute, rRoute}},
{"/g/o/r/i/l", []*Route{gRoute, oRoute, rRoute, iRoute}},
{"/g/o/r/i/l/l", []*Route{gRoute, oRoute, rRoute, iRoute, l1Route}},
{"/g/o/r/i/l/l/a", []*Route{gRoute, oRoute, rRoute, iRoute, l1Route, l2Route}},
}
idx := 0
@@ -1529,8 +1673,8 @@ func TestWalkSubrouters(t *testing.T) {
o.Methods("GET")
o.Methods("PUT")
// all 4 routes should be matched, but final 2 routes do not have path templates
paths := []string{"/g", "/g/o", "", ""}
// all 4 routes should be matched
paths := []string{"/g", "/g/o", "/g/o", "/g/o"}
idx := 0
err := router.Walk(func(route *Route, router *Router, ancestors []*Route) error {
path := paths[idx]
@@ -1711,7 +1855,11 @@ func testRoute(t *testing.T, test routeTest) {
}
}
if query != "" {
u, _ := route.URL(mapToPairs(match.Vars)...)
u, err := route.URL(mapToPairs(match.Vars)...)
if err != nil {
t.Errorf("(%v) erred while creating url: %v", test.title, err)
return
}
if query != u.RawQuery {
t.Errorf("(%v) URL query not equal: expected %v, got %v", test.title, query, u.RawQuery)
return
@@ -2031,7 +2179,9 @@ func TestMethodsSubrouterCatchall(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2087,7 +2237,9 @@ func TestMethodsSubrouterStrictSlash(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2134,7 +2286,9 @@ func TestMethodsSubrouterPathPrefix(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2190,7 +2344,9 @@ func TestMethodsSubrouterSubrouter(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2244,10 +2400,21 @@ func TestMethodsSubrouterPathVariable(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
func ExampleSetURLVars() {
req, _ := http.NewRequest("GET", "/foo", nil)
req = SetURLVars(req, map[string]string{"foo": "bar"})
fmt.Println(Vars(req)["foo"])
// Output: bar
}
// testMethodsSubrouter runs an individual methodsSubrouterTest.
func testMethodsSubrouter(t *testing.T, test methodsSubrouterTest) {
// Execute request
@@ -2279,6 +2446,305 @@ func testMethodsSubrouter(t *testing.T, test methodsSubrouterTest) {
}
}
func TestSubrouterMatching(t *testing.T) {
const (
none, stdOnly, subOnly uint8 = 0, 1 << 0, 1 << 1
both = subOnly | stdOnly
)
type request struct {
Name string
Request *http.Request
Flags uint8
}
cases := []struct {
Name string
Standard, Subrouter func(*Router)
Requests []request
}{
{
"pathPrefix",
func(r *Router) {
r.PathPrefix("/before").PathPrefix("/after")
},
func(r *Router) {
r.PathPrefix("/before").Subrouter().PathPrefix("/after")
},
[]request{
{"no match final path prefix", newRequest("GET", "/after"), none},
{"no match parent path prefix", newRequest("GET", "/before"), none},
{"matches append", newRequest("GET", "/before/after"), both},
{"matches as prefix", newRequest("GET", "/before/after/1234"), both},
},
},
{
"path",
func(r *Router) {
r.Path("/before").Path("/after")
},
func(r *Router) {
r.Path("/before").Subrouter().Path("/after")
},
[]request{
{"no match subroute path", newRequest("GET", "/after"), none},
{"no match parent path", newRequest("GET", "/before"), none},
{"no match as prefix", newRequest("GET", "/before/after/1234"), none},
{"no match append", newRequest("GET", "/before/after"), none},
},
},
{
"host",
func(r *Router) {
r.Host("before.com").Host("after.com")
},
func(r *Router) {
r.Host("before.com").Subrouter().Host("after.com")
},
[]request{
{"no match before", newRequestHost("GET", "/", "before.com"), none},
{"no match other", newRequestHost("GET", "/", "other.com"), none},
{"matches after", newRequestHost("GET", "/", "after.com"), none},
},
},
{
"queries variant keys",
func(r *Router) {
r.Queries("foo", "bar").Queries("cricket", "baseball")
},
func(r *Router) {
r.Queries("foo", "bar").Subrouter().Queries("cricket", "baseball")
},
[]request{
{"matches with all", newRequest("GET", "/?foo=bar&cricket=baseball"), both},
{"matches with more", newRequest("GET", "/?foo=bar&cricket=baseball&something=else"), both},
{"no match with none", newRequest("GET", "/"), none},
{"no match with some", newRequest("GET", "/?cricket=baseball"), none},
},
},
{
"queries overlapping keys",
func(r *Router) {
r.Queries("foo", "bar").Queries("foo", "baz")
},
func(r *Router) {
r.Queries("foo", "bar").Subrouter().Queries("foo", "baz")
},
[]request{
{"no match old value", newRequest("GET", "/?foo=bar"), none},
{"no match diff value", newRequest("GET", "/?foo=bak"), none},
{"no match with none", newRequest("GET", "/"), none},
{"matches override", newRequest("GET", "/?foo=baz"), none},
},
},
{
"header variant keys",
func(r *Router) {
r.Headers("foo", "bar").Headers("cricket", "baseball")
},
func(r *Router) {
r.Headers("foo", "bar").Subrouter().Headers("cricket", "baseball")
},
[]request{
{
"matches with all",
newRequestWithHeaders("GET", "/", "foo", "bar", "cricket", "baseball"),
both,
},
{
"matches with more",
newRequestWithHeaders("GET", "/", "foo", "bar", "cricket", "baseball", "something", "else"),
both,
},
{"no match with none", newRequest("GET", "/"), none},
{"no match with some", newRequestWithHeaders("GET", "/", "cricket", "baseball"), none},
},
},
{
"header overlapping keys",
func(r *Router) {
r.Headers("foo", "bar").Headers("foo", "baz")
},
func(r *Router) {
r.Headers("foo", "bar").Subrouter().Headers("foo", "baz")
},
[]request{
{"no match old value", newRequestWithHeaders("GET", "/", "foo", "bar"), none},
{"no match diff value", newRequestWithHeaders("GET", "/", "foo", "bak"), none},
{"no match with none", newRequest("GET", "/"), none},
{"matches override", newRequestWithHeaders("GET", "/", "foo", "baz"), none},
},
},
{
"method",
func(r *Router) {
r.Methods("POST").Methods("GET")
},
func(r *Router) {
r.Methods("POST").Subrouter().Methods("GET")
},
[]request{
{"matches before", newRequest("POST", "/"), none},
{"no match other", newRequest("HEAD", "/"), none},
{"matches override", newRequest("GET", "/"), none},
},
},
{
"schemes",
func(r *Router) {
r.Schemes("http").Schemes("https")
},
func(r *Router) {
r.Schemes("http").Subrouter().Schemes("https")
},
[]request{
{"matches overrides", newRequest("GET", "https://www.example.com/"), none},
{"matches original", newRequest("GET", "http://www.example.com/"), none},
{"no match other", newRequest("GET", "ftp://www.example.com/"), none},
},
},
}
// case -> request -> router
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
for _, req := range c.Requests {
t.Run(req.Name, func(t *testing.T) {
for _, v := range []struct {
Name string
Config func(*Router)
Expected bool
}{
{"subrouter", c.Subrouter, (req.Flags & subOnly) != 0},
{"standard", c.Standard, (req.Flags & stdOnly) != 0},
} {
r := NewRouter()
v.Config(r)
if r.Match(req.Request, &RouteMatch{}) != v.Expected {
if v.Expected {
t.Errorf("expected %v match", v.Name)
} else {
t.Errorf("expected %v no match", v.Name)
}
}
}
})
}
})
}
}
// verify that copyRouteConf copies fields as expected.
func Test_copyRouteConf(t *testing.T) {
var (
m MatcherFunc = func(*http.Request, *RouteMatch) bool {
return true
}
b BuildVarsFunc = func(i map[string]string) map[string]string {
return i
}
r, _ = newRouteRegexp("hi", regexpTypeHost, routeRegexpOptions{})
)
tests := []struct {
name string
args routeConf
want routeConf
}{
{
"empty",
routeConf{},
routeConf{},
},
{
"full",
routeConf{
useEncodedPath: true,
strictSlash: true,
skipClean: true,
regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}},
matchers: []matcher{m},
buildScheme: "https",
buildVarsFunc: b,
},
routeConf{
useEncodedPath: true,
strictSlash: true,
skipClean: true,
regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}},
matchers: []matcher{m},
buildScheme: "https",
buildVarsFunc: b,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// special case some incomparable fields of routeConf before delegating to reflect.DeepEqual
got := copyRouteConf(tt.args)
// funcs not comparable, just compare length of slices
if len(got.matchers) != len(tt.want.matchers) {
t.Errorf("matchers different lengths: %v %v", len(got.matchers), len(tt.want.matchers))
}
got.matchers, tt.want.matchers = nil, nil
// deep equal treats nil slice differently to empty slice so check for zero len first
{
bothZero := len(got.regexp.queries) == 0 && len(tt.want.regexp.queries) == 0
if !bothZero && !reflect.DeepEqual(got.regexp.queries, tt.want.regexp.queries) {
t.Errorf("queries unequal: %v %v", got.regexp.queries, tt.want.regexp.queries)
}
got.regexp.queries, tt.want.regexp.queries = nil, nil
}
// funcs not comparable, just compare nullity
if (got.buildVarsFunc == nil) != (tt.want.buildVarsFunc == nil) {
t.Errorf("build vars funcs unequal: %v %v", got.buildVarsFunc == nil, tt.want.buildVarsFunc == nil)
}
got.buildVarsFunc, tt.want.buildVarsFunc = nil, nil
// finish the deal
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("route confs unequal: %v %v", got, tt.want)
}
})
}
}
func TestMethodNotAllowed(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }
router := NewRouter()
router.HandleFunc("/thing", handler).Methods(http.MethodGet)
router.HandleFunc("/something", handler).Methods(http.MethodGet)
w := NewRecorder()
req := newRequest(http.MethodPut, "/thing")
router.ServeHTTP(w, req)
if w.Code != 405 {
t.Fatalf("Expected status code 405 (got %d)", w.Code)
}
}
func TestSubrouterNotFound(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }
router := NewRouter()
router.Path("/a").Subrouter().HandleFunc("/thing", handler).Methods(http.MethodGet)
router.Path("/b").Subrouter().HandleFunc("/something", handler).Methods(http.MethodGet)
w := NewRecorder()
req := newRequest(http.MethodPut, "/not-present")
router.ServeHTTP(w, req)
if w.Code != 404 {
t.Fatalf("Expected status code 404 (got %d)", w.Code)
}
}
// mapToPairs converts a string map to a slice of string pairs
func mapToPairs(m map[string]string) []string {
var i int
@@ -2306,6 +2772,14 @@ func stringMapEqual(m1, m2 map[string]string) bool {
return true
}
// stringHandler returns a handler func that writes a message 's' to the
// http.ResponseWriter.
func stringHandler(s string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(s))
}
}
// newRequest is a helper function to create a new request with a method and url.
// The request returned is a 'server' request as opposed to a 'client' one through
// simulated write onto the wire and read off of the wire.
@@ -2345,3 +2819,28 @@ func newRequest(method, url string) *http.Request {
}
return req
}
// create a new request with the provided headers
func newRequestWithHeaders(method, url string, headers ...string) *http.Request {
req := newRequest(method, url)
if len(headers)%2 != 0 {
panic(fmt.Sprintf("Expected headers length divisible by 2 but got %v", len(headers)))
}
for i := 0; i < len(headers); i += 2 {
req.Header.Set(headers[i], headers[i+1])
}
return req
}
// newRequestHost a new request with a method, url, and host header
func newRequestHost(method, url, host string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
req.Host = host
return req
}

View File

@@ -113,6 +113,13 @@ func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*ro
if typ != regexpTypePrefix {
pattern.WriteByte('$')
}
var wildcardHostPort bool
if typ == regexpTypeHost {
if !strings.Contains(pattern.String(), ":") {
wildcardHostPort = true
}
}
reverse.WriteString(raw)
if endSlash {
reverse.WriteByte('/')
@@ -131,13 +138,14 @@ func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*ro
// Done!
return &routeRegexp{
template: template,
regexpType: typ,
options: options,
regexp: reg,
reverse: reverse.String(),
varsN: varsN,
varsR: varsR,
template: template,
regexpType: typ,
options: options,
regexp: reg,
reverse: reverse.String(),
varsN: varsN,
varsR: varsR,
wildcardHostPort: wildcardHostPort,
}, nil
}
@@ -158,11 +166,22 @@ type routeRegexp struct {
varsN []string
// Variable regexps (validators).
varsR []*regexp.Regexp
// Wildcard host-port (no strict port match in hostname)
wildcardHostPort bool
}
// Match matches the regexp against the URL host or path.
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
if r.regexpType != regexpTypeHost {
if r.regexpType == regexpTypeHost {
host := getHost(req)
if r.wildcardHostPort {
// Don't be strict on the port match
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
}
return r.regexp.MatchString(host)
} else {
if r.regexpType == regexpTypeQuery {
return r.matchQueryString(req)
}
@@ -172,8 +191,6 @@ func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
}
return r.regexp.MatchString(path)
}
return r.regexp.MatchString(getHost(req))
}
// url builds a URL part using the given values.
@@ -267,7 +284,7 @@ type routeRegexpGroup struct {
}
// setMatch extracts the variables from the URL once a route matches.
func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
func (v routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
// Store host variables.
if v.host != nil {
host := getHost(req)
@@ -296,7 +313,7 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
} else {
u.Path += "/"
}
m.Handler = http.RedirectHandler(u.String(), 301)
m.Handler = http.RedirectHandler(u.String(), http.StatusMovedPermanently)
}
}
}
@@ -312,17 +329,13 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
}
// getHost tries its best to return the request host.
// According to section 14.23 of RFC 2616 the Host header
// can include the port number if the default value of 80 is not used.
func getHost(r *http.Request) string {
if r.URL.IsAbs() {
return r.URL.Host
}
host := r.Host
// Slice off any port information.
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
return host
return r.Host
}
func extractVars(input string, matches []int, names []string, output map[string]string) {

151
route.go
View File

@@ -15,24 +15,8 @@ import (
// Route stores information to match a request and build URLs.
type Route struct {
// Parent where the route was registered (a Router).
parent parentRoute
// Request handler for the route.
handler http.Handler
// List of matchers.
matchers []matcher
// Manager for the variables from host and path.
regexp *routeRegexpGroup
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
strictSlash bool
// If true, when the path pattern is "/path//to", accessing "/path//to"
// will not redirect
skipClean bool
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
useEncodedPath bool
// The scheme used when building URLs.
buildScheme string
// If true, this route never matches: it is only used to build URLs.
buildOnly bool
// The name used to build URLs.
@@ -40,9 +24,15 @@ type Route struct {
// Error resulted from building a route.
err error
buildVarsFunc BuildVarsFunc
// "global" reference to all named routes
namedRoutes map[string]*Route
// config possibly passed in from `Router`
routeConf
}
// SkipClean reports whether path cleaning is enabled for this route via
// Router.SkipClean.
func (r *Route) SkipClean() bool {
return r.skipClean
}
@@ -62,6 +52,18 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
matchErr = ErrMethodMismatch
continue
}
// Ignore ErrNotFound errors. These errors arise from match call
// to Subrouters.
//
// This prevents subsequent matching subrouters from failing to
// run middleware. If not ignored, the middleware would see a
// non-nil MatchErr and be skipped, even when there was a
// matching route.
if match.MatchErr == ErrNotFound {
match.MatchErr = nil
}
matchErr = nil
return false
}
@@ -91,9 +93,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
}
// Set variables.
if r.regexp != nil {
r.regexp.setMatch(req, match, r)
}
r.regexp.setMatch(req, match, r)
return true
}
@@ -135,7 +135,7 @@ func (r *Route) GetHandler() http.Handler {
// Name -----------------------------------------------------------------------
// Name sets the name for the route, used to build URLs.
// If the name was registered already it will be overwritten.
// It is an error to call Name more than once on a route.
func (r *Route) Name(name string) *Route {
if r.name != "" {
r.err = fmt.Errorf("mux: route already has name %q, can't set %q",
@@ -143,7 +143,7 @@ func (r *Route) Name(name string) *Route {
}
if r.err == nil {
r.name = name
r.getNamedRoutes()[name] = r
r.namedRoutes[name] = r
}
return r
}
@@ -175,7 +175,6 @@ func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error {
if r.err != nil {
return r.err
}
r.regexp = r.getRegexpGroup()
if typ == regexpTypePath || typ == regexpTypePrefix {
if len(tpl) > 0 && tpl[0] != '/' {
return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
@@ -384,7 +383,7 @@ func (r *Route) PathPrefix(tpl string) *Route {
// The above route will only match if the URL contains the defined queries
// values, e.g.: ?foo=bar&id=42.
//
// It the value is an empty string, it will match any value if the key is set.
// If the value is an empty string, it will match any value if the key is set.
//
// Variables can define an optional regexp pattern to be matched:
//
@@ -422,7 +421,7 @@ func (r *Route) Schemes(schemes ...string) *Route {
for k, v := range schemes {
schemes[k] = strings.ToLower(v)
}
if r.buildScheme == "" && len(schemes) > 0 {
if len(schemes) > 0 {
r.buildScheme = schemes[0]
}
return r.addMatcher(schemeMatcher(schemes))
@@ -437,7 +436,15 @@ type BuildVarsFunc func(map[string]string) map[string]string
// BuildVarsFunc adds a custom function to be used to modify build variables
// before a route's URL is built.
func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
r.buildVarsFunc = f
if r.buildVarsFunc != nil {
// compose the old and new functions
old := r.buildVarsFunc
r.buildVarsFunc = func(m map[string]string) map[string]string {
return f(old(m))
}
} else {
r.buildVarsFunc = f
}
return r
}
@@ -456,7 +463,8 @@ func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
// Here, the routes registered in the subrouter won't be tested if the host
// doesn't match.
func (r *Route) Subrouter() *Router {
router := &Router{parent: r, strictSlash: r.strictSlash}
// initialize a subrouter with a copy of the parent route's configuration
router := &Router{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
r.addMatcher(router)
return router
}
@@ -500,9 +508,6 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil {
return nil, errors.New("mux: route doesn't have a host or path")
}
values, err := r.prepareVars(pairs...)
if err != nil {
return nil, err
@@ -514,8 +519,8 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) {
return nil, err
}
scheme = "http"
if s := r.getBuildScheme(); s != "" {
scheme = s
if r.buildScheme != "" {
scheme = r.buildScheme
}
}
if r.regexp.path != nil {
@@ -545,7 +550,7 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.host == nil {
if r.regexp.host == nil {
return nil, errors.New("mux: route doesn't have a host")
}
values, err := r.prepareVars(pairs...)
@@ -560,8 +565,8 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
Scheme: "http",
Host: host,
}
if s := r.getBuildScheme(); s != "" {
u.Scheme = s
if r.buildScheme != "" {
u.Scheme = r.buildScheme
}
return u, nil
}
@@ -573,7 +578,7 @@ func (r *Route) URLPath(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.path == nil {
if r.regexp.path == nil {
return nil, errors.New("mux: route doesn't have a path")
}
values, err := r.prepareVars(pairs...)
@@ -598,7 +603,7 @@ func (r *Route) GetPathTemplate() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp == nil || r.regexp.path == nil {
if r.regexp.path == nil {
return "", errors.New("mux: route doesn't have a path")
}
return r.regexp.path.template, nil
@@ -612,7 +617,7 @@ func (r *Route) GetPathRegexp() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp == nil || r.regexp.path == nil {
if r.regexp.path == nil {
return "", errors.New("mux: route does not have a path")
}
return r.regexp.path.regexp.String(), nil
@@ -622,12 +627,12 @@ func (r *Route) GetPathRegexp() (string, error) {
// route queries.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An empty list will be returned if the route does not have queries.
// An error will be returned if the route does not have queries.
func (r *Route) GetQueriesRegexp() ([]string, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.queries == nil {
if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries")
}
var queries []string
@@ -641,12 +646,12 @@ func (r *Route) GetQueriesRegexp() ([]string, error) {
// query matching.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An empty list will be returned if the route does not define queries.
// An error will be returned if the route does not define queries.
func (r *Route) GetQueriesTemplates() ([]string, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.queries == nil {
if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries")
}
var queries []string
@@ -659,7 +664,7 @@ func (r *Route) GetQueriesTemplates() ([]string, error) {
// GetMethods returns the methods the route matches against
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An empty list will be returned if route does not have methods.
// An error will be returned if route does not have methods.
func (r *Route) GetMethods() ([]string, error) {
if r.err != nil {
return nil, r.err
@@ -669,7 +674,7 @@ func (r *Route) GetMethods() ([]string, error) {
return []string(methods), nil
}
}
return nil, nil
return nil, errors.New("mux: route doesn't have methods")
}
// GetHostTemplate returns the template used to build the
@@ -681,7 +686,7 @@ func (r *Route) GetHostTemplate() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp == nil || r.regexp.host == nil {
if r.regexp.host == nil {
return "", errors.New("mux: route doesn't have a host")
}
return r.regexp.host.template, nil
@@ -698,64 +703,8 @@ func (r *Route) prepareVars(pairs ...string) (map[string]string, error) {
}
func (r *Route) buildVars(m map[string]string) map[string]string {
if r.parent != nil {
m = r.parent.buildVars(m)
}
if r.buildVarsFunc != nil {
m = r.buildVarsFunc(m)
}
return m
}
// ----------------------------------------------------------------------------
// parentRoute
// ----------------------------------------------------------------------------
// parentRoute allows routes to know about parent host and path definitions.
type parentRoute interface {
getBuildScheme() string
getNamedRoutes() map[string]*Route
getRegexpGroup() *routeRegexpGroup
buildVars(map[string]string) map[string]string
}
func (r *Route) getBuildScheme() string {
if r.buildScheme != "" {
return r.buildScheme
}
if r.parent != nil {
return r.parent.getBuildScheme()
}
return ""
}
// getNamedRoutes returns the map where named routes are registered.
func (r *Route) getNamedRoutes() map[string]*Route {
if r.parent == nil {
// During tests router is not always set.
r.parent = NewRouter()
}
return r.parent.getNamedRoutes()
}
// getRegexpGroup returns regexp definitions from this route.
func (r *Route) getRegexpGroup() *routeRegexpGroup {
if r.regexp == nil {
if r.parent == nil {
// During tests router is not always set.
r.parent = NewRouter()
}
regexp := r.parent.getRegexpGroup()
if regexp == nil {
r.regexp = new(routeRegexpGroup)
} else {
// Copy.
r.regexp = &routeRegexpGroup{
host: regexp.host,
path: regexp.path,
queries: regexp.queries,
}
}
}
return r.regexp
}

View File

@@ -7,7 +7,8 @@ package mux
import "net/http"
// SetURLVars sets the URL variables for the given request, to be accessed via
// mux.Vars for testing route behaviour.
// mux.Vars for testing route behaviour. Arguments are not modified, a shallow
// copy is returned.
//
// This API should only be used for testing purposes; it provides a way to
// inject variables into the request context. Alternatively, URL variables