39 Commits

Author SHA1 Message Date
Matt Silverlock
e67b3c02c7 Remove TravisCI badge (#503) 2019-07-20 13:14:35 -07:00
Franklin Harding
7a1bf406d6 [docs] Add documentation for using mux to serve a SPA (#493)
* Add documentation for using mux to serve a SPA

* r -> router to prevent shadowing

* Expand SPA acronym

* BrowserRouter link

* Add more comments to explain how the spaHandler.ServeHTTP method works
2019-07-20 07:53:35 -07:00
Christian Muehlhaeuser
eab9c4f3d2 Simplify code (#501)
Use a single append call instead of a ranged for loop.
2019-07-20 07:49:38 -07:00
Christian Muehlhaeuser
50fbc3e7fb Avoid unnecessary conversion (#502)
No need to convert here.
2019-07-20 07:48:32 -07:00
Matt Silverlock
d83b6ffe49 Update config.yml (#495)
* Update config.yml

* Update config.yml
2019-07-01 13:26:33 -07:00
Matt Silverlock
00bdffe0f3 Update stale.yml (#494) 2019-06-29 21:17:52 -07:00
Franklin Harding
0534769016 Improve CORS Method Middleware (#477)
* More sensical CORSMethodMiddleware

* Only sets Access-Control-Allow-Methods on valid preflight requests
* Does not return after setting the Access-Control-Allow-Methods header
* Does not append OPTIONS header to Access-Control-Allow-Methods
regardless of whether there is an OPTIONS method matcher
* Adds tests for the listed behavior

* Add example for CORSMethodMiddleware

* Do not check for preflight and add documentation to the README

* Use http.MethodOptions instead of "OPTIONS"

* Add link to CORSMethodMiddleware section to readme

* Add test for unmatching route methods

* Rename CORS Method Middleware to Handling CORS Requests in README

* Link CORSMethodMiddleware in README to godoc

* Break CORSMethodMiddleware doc into bullets for readability

* Add comment about specifying OPTIONS to example in README for CORSMethodMiddleware

* Document cURL command used for testing CORS Method Middleware

* Update comment in example to "Handle the request"

* Add explicit comment about OPTIONS matchers to CORSMethodMiddleware doc

* Update circleci config to only check gofmt diff on latest go version

* Break up gofmt and go vet checks into separate steps.

* Use canonical circleci config
2019-06-29 13:52:29 -07:00
Matt Silverlock
d70f7b4baa Delete ISSUE_TEMPLATE.md (#492) 2019-06-29 11:01:59 -07:00
Franklin Harding
48f941fa99 Use subtests for middleware tests (#478)
* Use subtests for middleware tests
* Don't use subtests for MiddlewareAdd
2019-06-29 10:24:12 -07:00
Matt Silverlock
64954673e9 Delete .travis.yml (#490) 2019-06-28 16:07:30 -07:00
Franklin Harding
4248f5cd87 Fix nil panic in authentication middleware example (#489) 2019-06-28 08:33:07 -07:00
Matt Silverlock
212aa90d7c [WIP] Create CircleCI config (#484)
* [ci] Create CircleCI config
* Fix typos in container versions
* Add CircleCI badge
2019-06-24 09:05:39 -07:00
M@
ed099d4238 host:port matching does not require a :port to be specified.
In lieu of checking the template pattern on every Match request, a bool is added to the routeRegexp, and set
if the routeRegexp is a host AND there is no ":" in the template. I dislike extending the type, but I'd dislike
doing a string match on every single Match, even more.
2019-05-16 17:20:44 -07:00
sekky0905
c5c6c98bc2 [build] Remove sudo setting from travis.yml (#462) 2019-03-16 06:32:43 -07:00
Benjamin Boudreau
15a353a636 adding Router.Name to create new Route (#457) 2019-02-28 10:12:03 -08:00
Benjamin Boudreau
8eaa9f1309 fix go1.12 go vet usage (#458) 2019-02-28 09:36:07 -08:00
Souvik Haldar
8559a4f775 [docs] typo (#454) 2019-02-17 07:38:49 -08:00
moeryomenko
a7962380ca replace rr.HeaderMap by rr.Header() (#443) 2019-01-25 10:05:53 -06:00
Tim
797e653da6 Call WriteHeader after setting other header(s) in the example (#442)
From the docs: Changing the header map after a call to WriteHeader (or
Write) has no effect unless the modified headers are
trailers.
2019-01-25 05:41:49 -06:00
Gregor Weckbecker
08e7f807d3 Ignore ErrNotFound while matching Subrouters (#438)
MatchErr is set by the router to ErrNotFound if no route matches. If
no route of a Subrouter matches the error can by safely ignored. This
implementation only ignores these errors and does not ignore other
errors like ErrMethodMismatch.
2019-01-08 08:29:30 -06:00
santsai
f3ff42f93a getHost() now returns full host & port information (#383)
Previously, getHost only returned the host. As it now returns the
port as well, any .Host matches on a route will need to be updated
to also support matching on the port for cases where the port is
non default, eg: 80 for http or 443 for https.
2019-01-04 07:08:45 -08:00
tomare
ef912dd76e [bugfix] Clear matchErr when traversing subrouters.
Previously, when searching for a match, matchErr would be erroneously set, and prevent middleware from running (no match == no middleware runs).

This fix clears matchErr before traversing the next subrouter in a multi-subrouter router.
2018-12-27 16:42:16 -08:00
Raees
a31c1782bf Replace domain.com with example.com (#434)
Because domain.com is an actual business, example.com should be used for example purposes.
2018-12-25 08:41:17 -08:00
Michael Li
6137e193cd remove redundant code that remove support gorilla/context (#427)
* remove redundant code that remove support gorilla/context

* backward compatible for remove redundant code
2018-12-17 09:42:43 -05:00
Matt Silverlock
d2b5d13b92 Update and rename stale to stale.yml (#425) 2018-12-08 12:40:53 -08:00
Matt Silverlock
419fd9fe2a Add stalebot config (#424) 2018-12-07 08:41:48 -08:00
Joe Wilner
758eb64354 Improve subroute configuration propagation #422
* Pull out common shared `routeConf` so that config is pushed on to child
routers and routes.
* Removes obsolete usages of `parentRoute`
* Add tests defining compositional behavior
* Exercise `copyRouteConf` for posterity
2018-12-07 09:48:26 -06:00
kanozec
3d80bc801b Use subtests in mux_test.go (#415) 2018-10-30 08:25:28 -07:00
Nguyen Ngoc Trung (Steven)
521ea7b17d Use constant for 301 status code in regexp.go (#412) 2018-10-23 19:08:00 -07:00
Kamil Kisiel
deb579d6e0 README.md: Update site URL 2018-10-12 08:31:51 -07:00
Matt Silverlock
9e1f5955c0 Always run on the latest stable Go version. (#402)
Only run vet on the latest Go version.
2018-09-03 08:43:05 -07:00
Matt Silverlock
cf6680bc62 Create release-drafter.yml (#399) 2018-09-02 15:36:45 -07:00
Franklin Harding
8771f97498 Drop support for Go < 1.7: remove gorilla/context (#391)
* Drop support for Go < 1.7: remove gorilla/context
* Remove Go < 1.7 from Travis CI config
* Remove unneeded _native from context files
2018-09-02 15:22:40 -07:00
Shalom Yerushalmy
962c5bed07 Add 1.11 to build in travis (#398) 2018-08-30 07:23:24 -07:00
Kamil Kisiel
e48e440e4c Add test for multiple calls to Name().
Fixes #394
2018-08-07 00:52:56 -07:00
Kamil Kisiel
815b8c6a26 Clarify behaviour of Name method if called multiple times. 2018-08-07 00:50:18 -07:00
Matt Silverlock
cb4698366a Update LICENSE & AUTHORS files. (#386) 2018-06-05 14:15:56 -07:00
Jim Kalafut
e0b5abaaae Initialize user map (#371) 2018-05-26 15:17:21 -07:00
Matt Silverlock
c85619274f [deps] Add go.mod for versioned Go (#376) 2018-05-17 10:36:23 -07:00
22 changed files with 1407 additions and 570 deletions

87
.circleci/config.yml Normal file
View File

@@ -0,0 +1,87 @@
version: 2.0
jobs:
# Base test configuration for Go library tests Each distinct version should
# inherit this base, and override (at least) the container image used.
"test": &test
docker:
- image: circleci/golang:latest
working_directory: /go/src/github.com/gorilla/mux
steps: &steps
# Our build steps: we checkout the repo, fetch our deps, lint, and finally
# run "go test" on the package.
- checkout
# Logs the version in our build logs, for posterity
- run: go version
- run:
name: "Fetch dependencies"
command: >
go get -t -v ./...
# Only run gofmt, vet & lint against the latest Go version
- run:
name: "Run golint"
command: >
if [ "${LATEST}" = true ] && [ -z "${SKIP_GOLINT}" ]; then
go get -u golang.org/x/lint/golint
golint ./...
fi
- run:
name: "Run gofmt"
command: >
if [[ "${LATEST}" = true ]]; then
diff -u <(echo -n) <(gofmt -d -e .)
fi
- run:
name: "Run go vet"
command: >
if [[ "${LATEST}" = true ]]; then
go vet -v ./...
fi
- run: go test -v -race ./...
"latest":
<<: *test
environment:
LATEST: true
"1.12":
<<: *test
docker:
- image: circleci/golang:1.12
"1.11":
<<: *test
docker:
- image: circleci/golang:1.11
"1.10":
<<: *test
docker:
- image: circleci/golang:1.10
"1.9":
<<: *test
docker:
- image: circleci/golang:1.9
"1.8":
<<: *test
docker:
- image: circleci/golang:1.8
"1.7":
<<: *test
docker:
- image: circleci/golang:1.7
workflows:
version: 2
build:
jobs:
- "latest"
- "1.12"
- "1.11"
- "1.10"
- "1.9"
- "1.8"
- "1.7"

8
.github/release-drafter.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
# Config for https://github.com/apps/release-drafter
template: |
<summary of changes here>
## CHANGELOG
$CHANGES

12
.github/stale.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
daysUntilStale: 75
daysUntilClose: 14
# Issues with these labels will never be considered stale
exemptLabels:
- proposal
- needs review
- build system
staleLabel: stale
markComment: >
This issue has been automatically marked as stale because it hasn't seen
a recent update. It'll be automatically closed in a few days.
closeComment: false

View File

@@ -1,23 +0,0 @@
language: go
sudo: false
matrix:
include:
- go: 1.5.x
- go: 1.6.x
- go: 1.7.x
- go: 1.8.x
- go: 1.9.x
- go: 1.10.x
- go: tip
allow_failures:
- go: tip
install:
- # Skip
script:
- go get -t -v ./...
- diff -u <(echo -n) <(gofmt -d .)
- go tool vet .
- go test -v -race ./...

8
AUTHORS Normal file
View File

@@ -0,0 +1,8 @@
# This is the official list of gorilla/mux authors for copyright purposes.
#
# Please keep the list sorted.
Google LLC (https://opensource.google.com/)
Kamil Kisielk <kamil@kamilkisiel.net>
Matt Silverlock <matt@eatsleeprepeat.net>
Rodrigo Moraes (https://github.com/moraes)

View File

@@ -1,11 +0,0 @@
**What version of Go are you running?** (Paste the output of `go version`)
**What version of gorilla/mux are you at?** (Paste the output of `git rev-parse HEAD` inside `$GOPATH/src/github.com/gorilla/mux`)
**Describe your problem** (and what you have tried so far)
**Paste a minimal, runnable, reproduction of your issue below** (use backticks to format it)

View File

@@ -1,4 +1,4 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are

174
README.md
View File

@@ -1,12 +1,12 @@
# gorilla/mux
[![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux)
[![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux)
[![CircleCI](https://circleci.com/gh/gorilla/mux.svg?style=svg)](https://circleci.com/gh/gorilla/mux)
[![Sourcegraph](https://sourcegraph.com/github.com/gorilla/mux/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/mux?badge)
![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png)
http://www.gorillatoolkit.org/pkg/mux
https://www.gorillatoolkit.org/pkg/mux
Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to
their respective handler.
@@ -25,10 +25,12 @@ The name mux stands for "HTTP request multiplexer". Like the standard `http.Serv
* [Examples](#examples)
* [Matching Routes](#matching-routes)
* [Static Files](#static-files)
* [Serving Single Page Applications](#serving-single-page-applications) (e.g. React, Vue, Ember.js, etc.)
* [Registered URLs](#registered-urls)
* [Walking Routes](#walking-routes)
* [Graceful Shutdown](#graceful-shutdown)
* [Middleware](#middleware)
* [Handling CORS Requests](#handling-cors-requests)
* [Testing Handlers](#testing-handlers)
* [Full Example](#full-example)
@@ -88,7 +90,7 @@ r := mux.NewRouter()
// Only matches if domain is "www.example.com".
r.Host("www.example.com")
// Matches a dynamic subdomain.
r.Host("{subdomain:[a-z]+}.domain.com")
r.Host("{subdomain:[a-z]+}.example.com")
```
There are several other matchers that can be added. To match path prefixes:
@@ -210,6 +212,93 @@ func main() {
}
```
### Serving Single Page Applications
Most of the time it makes sense to serve your SPA on a separate web server from your API,
but sometimes it's desirable to serve them both from one place. It's possible to write a simple
handler for serving your SPA (for use with React Router's [BrowserRouter](https://reacttraining.com/react-router/web/api/BrowserRouter) for example), and leverage
mux's powerful routing for your API endpoints.
```go
package main
import (
"encoding/json"
"log"
"net/http"
"os"
"path/filepath"
"time"
"github.com/gorilla/mux"
)
// spaHandler implements the http.Handler interface, so we can use it
// to respond to HTTP requests. The path to the static directory and
// path to the index file within that static directory are used to
// serve the SPA in the given static directory.
type spaHandler struct {
staticPath string
indexPath string
}
// ServeHTTP inspects the URL path to locate a file within the static dir
// on the SPA handler. If a file is found, it will be served. If not, the
// file located at the index path on the SPA handler will be served. This
// is suitable behavior for serving an SPA (single page application).
func (h spaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// get the absolute path to prevent directory traversal
path, err := filepath.Abs(r.URL.Path)
if err != nil {
// if we failed to get the absolute path respond with a 400 bad request
// and stop
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// prepend the path with the path to the static directory
path = filepath.Join(h.staticPath, path)
// check whether a file exists at the given path
_, err = os.Stat(path)
if os.IsNotExist(err) {
// file does not exist, serve index.html
http.ServeFile(w, r, filepath.Join(h.staticPath, h.indexPath))
return
} else if err != nil {
// if we got an error (that wasn't that the file doesn't exist) stating the
// file, return a 500 internal server error and stop
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// otherwise, use http.FileServer to serve the static dir
http.FileServer(http.Dir(h.staticPath)).ServeHTTP(w, r)
}
func main() {
router := mux.NewRouter()
router.HandleFunc("/api/health", func(w http.ResponseWriter, r *http.Request) {
// an example API handler
json.NewEncoder(w).Encode(map[string]bool{"ok": true})
})
spa := spaHandler{staticPath: "build", indexPath: "index.html"}
router.PathPrefix("/").Handler(spa)
srv := &http.Server{
Handler: router,
Addr: "127.0.0.1:8000",
// Good practice: enforce timeouts for servers you create!
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
}
log.Fatal(srv.ListenAndServe())
}
```
### Registered URLs
Now let's see how to build registered URLs.
@@ -238,13 +327,13 @@ This also works for host and query value variables:
```go
r := mux.NewRouter()
r.Host("{subdomain}.domain.com").
r.Host("{subdomain}.example.com").
Path("/articles/{category}/{id:[0-9]+}").
Queries("filter", "{filter}").
HandlerFunc(ArticleHandler).
Name("article")
// url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla"
// url.String() will be "http://news.example.com/articles/technology/42?filter=gorilla"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42",
@@ -264,7 +353,7 @@ r.HeadersRegexp("Content-Type", "application/(text|json)")
There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do:
```go
// "http://news.domain.com/"
// "http://news.example.com/"
host, err := r.Get("article").URLHost("subdomain", "news")
// "/articles/technology/42"
@@ -275,12 +364,12 @@ And if you use subrouters, host and path defined separately can be built as well
```go
r := mux.NewRouter()
s := r.Host("{subdomain}.domain.com").Subrouter()
s := r.Host("{subdomain}.example.com").Subrouter()
s.Path("/articles/{category}/{id:[0-9]+}").
HandlerFunc(ArticleHandler).
Name("article")
// "http://news.domain.com/articles/technology/42"
// "http://news.example.com/articles/technology/42"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42")
@@ -491,6 +580,73 @@ r.Use(amw.Middleware)
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
### Handling CORS Requests
[CORSMethodMiddleware](https://godoc.org/github.com/gorilla/mux#CORSMethodMiddleware) intends to make it easier to strictly set the `Access-Control-Allow-Methods` response header.
* You will still need to use your own CORS handler to set the other CORS headers such as `Access-Control-Allow-Origin`
* The middleware will set the `Access-Control-Allow-Methods` header to all the method matchers (e.g. `r.Methods(http.MethodGet, http.MethodPut, http.MethodOptions)` -> `Access-Control-Allow-Methods: GET,PUT,OPTIONS`) on a route
* If you do not specify any methods, then:
> _Important_: there must be an `OPTIONS` method matcher for the middleware to set the headers.
Here is an example of using `CORSMethodMiddleware` along with a custom `OPTIONS` handler to set all the required CORS headers:
```go
package main
import (
"net/http"
"github.com/gorilla/mux"
)
func main() {
r := mux.NewRouter()
// IMPORTANT: you must specify an OPTIONS method matcher for the middleware to set CORS headers
r.HandleFunc("/foo", fooHandler).Methods(http.MethodGet, http.MethodPut, http.MethodPatch, http.MethodOptions)
r.Use(mux.CORSMethodMiddleware(r))
http.ListenAndServe(":8080", r)
}
func fooHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == http.MethodOptions {
return
}
w.Write([]byte("foo"))
}
```
And an request to `/foo` using something like:
```bash
curl localhost:8080/foo -v
```
Would look like:
```bash
* Trying ::1...
* TCP_NODELAY set
* Connected to localhost (::1) port 8080 (#0)
> GET /foo HTTP/1.1
> Host: localhost:8080
> User-Agent: curl/7.59.0
> Accept: */*
>
< HTTP/1.1 200 OK
< Access-Control-Allow-Methods: GET,PUT,PATCH,OPTIONS
< Access-Control-Allow-Origin: *
< Date: Fri, 28 Jun 2019 20:13:30 GMT
< Content-Length: 3
< Content-Type: text/plain; charset=utf-8
<
* Connection #0 to host localhost left intact
foo
```
### Testing Handlers
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
@@ -503,8 +659,8 @@ package main
func HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
// A very simple health check.
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
// In the future we could report back on the status of our DB, or our cache
// (e.g. Redis) by performing a simple PING, and include them in the response.

View File

@@ -1,5 +1,3 @@
// +build go1.7
package mux
import (
@@ -18,7 +16,3 @@ func contextSet(r *http.Request, key, val interface{}) *http.Request {
return r.WithContext(context.WithValue(r.Context(), key, val))
}
func contextClear(r *http.Request) {
return
}

View File

@@ -1,26 +0,0 @@
// +build !go1.7
package mux
import (
"net/http"
"github.com/gorilla/context"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return context.Get(r, key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
context.Set(r, key, val)
return r
}
func contextClear(r *http.Request) {
context.Clear(r)
}

View File

@@ -1,40 +0,0 @@
// +build !go1.7
package mux
import (
"net/http"
"testing"
"github.com/gorilla/context"
)
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/", func1).Name("func1")
req, _ := http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
res := new(http.ResponseWriter)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); ok {
t.Error("Context should have been cleared at end of request")
}
r.KeepContext = true
req, _ = http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); !ok {
t.Error("Context should NOT have been cleared at end of request")
}
}

View File

@@ -1,5 +1,3 @@
// +build go1.7
package mux
import (

2
doc.go
View File

@@ -295,7 +295,7 @@ A more complex authentication middleware, which maps session token to users, cou
r := mux.NewRouter()
r.HandleFunc("/", handler)
amw := authenticationMiddleware{}
amw := authenticationMiddleware{tokenUsers: make(map[string]string)}
amw.Populate()
r.Use(amw.Middleware)

View File

@@ -40,7 +40,7 @@ func Example_authenticationMiddleware() {
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Do something here
})
amw := authenticationMiddleware{}
amw := authenticationMiddleware{make(map[string]string)}
amw.Populate()
r.Use(amw.Middleware)
}

View File

@@ -0,0 +1,37 @@
package mux_test
import (
"fmt"
"net/http"
"net/http/httptest"
"github.com/gorilla/mux"
)
func ExampleCORSMethodMiddleware() {
r := mux.NewRouter()
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
// Handle the request
}).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "http://example.com")
w.Header().Set("Access-Control-Max-Age", "86400")
}).Methods(http.MethodOptions)
r.Use(mux.CORSMethodMiddleware(r))
rw := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "/foo", nil) // needs to be OPTIONS
req.Header.Set("Access-Control-Request-Method", "POST") // needs to be non-empty
req.Header.Set("Access-Control-Request-Headers", "Authorization") // needs to be non-empty
req.Header.Set("Origin", "http://example.com") // needs to be non-empty
r.ServeHTTP(rw, req)
fmt.Println(rw.Header().Get("Access-Control-Allow-Methods"))
fmt.Println(rw.Header().Get("Access-Control-Allow-Origin"))
// Output:
// GET,PUT,PATCH,OPTIONS
// http://example.com
}

3
go.mod Normal file
View File

@@ -0,0 +1,3 @@
module github.com/gorilla/mux
go 1.12

View File

@@ -32,37 +32,19 @@ func (r *Router) useInterface(mw middleware) {
r.middlewares = append(r.middlewares, mw)
}
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
// on a request, by matching routes based only on paths. It also handles
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
// returning without calling the next http handler.
// CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header
// on requests for routes that have an OPTIONS method matcher to all the method matchers on
// the route. Routes that do not explicitly handle OPTIONS requests will not be processed
// by the middleware. See examples for usage.
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var allMethods []string
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}
allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})
allMethods, err := getAllMethodsForRoute(r, req)
if err == nil {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
if req.Method == "OPTIONS" {
return
for _, v := range allMethods {
if v == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ","))
}
}
}
@@ -70,3 +52,28 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc {
})
}
}
// getAllMethodsForRoute returns all the methods from method matchers matching a given
// request.
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
var allMethods []string
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}
allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})
return allMethods, err
}

View File

@@ -3,7 +3,6 @@ package mux
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)
@@ -28,12 +27,12 @@ func TestMiddlewareAdd(t *testing.T) {
router.useInterface(mw)
if len(router.middlewares) != 1 || router.middlewares[0] != mw {
t.Fatal("Middleware was not added correctly")
t.Fatal("Middleware interface was not added correctly")
}
router.Use(mw.Middleware)
if len(router.middlewares) != 2 {
t.Fatal("MiddlewareFunc method was not added correctly")
t.Fatal("Middleware method was not added correctly")
}
banalMw := func(handler http.Handler) http.Handler {
@@ -41,7 +40,7 @@ func TestMiddlewareAdd(t *testing.T) {
}
router.Use(banalMw)
if len(router.middlewares) != 3 {
t.Fatal("MiddlewareFunc method was not added correctly")
t.Fatal("Middleware function was not added correctly")
}
}
@@ -55,34 +54,37 @@ func TestMiddleware(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/")
// Test regular middleware call
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
t.Run("regular middleware call", func(t *testing.T) {
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
// Middleware should not be called for 404
req = newRequest("GET", "/not/found")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
t.Run("not called for 404", func(t *testing.T) {
req = newRequest("GET", "/not/found")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
// Middleware should not be called if there is a method mismatch
req = newRequest("POST", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
// Add the middleware again as function
router.Use(mw.Middleware)
req = newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 3 {
t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
}
t.Run("not called for method mismatch", func(t *testing.T) {
req = newRequest("POST", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
t.Run("regular call using function middleware", func(t *testing.T) {
router.Use(mw.Middleware)
req = newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 3 {
t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
}
})
}
func TestMiddlewareSubrouter(t *testing.T) {
@@ -98,42 +100,56 @@ func TestMiddlewareSubrouter(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 0 {
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
}
t.Run("not called for route outside subrouter", func(t *testing.T) {
router.ServeHTTP(rw, req)
if mw.timesCalled != 0 {
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
}
})
req = newRequest("GET", "/sub/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 0 {
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
}
t.Run("not called for subrouter root 404", func(t *testing.T) {
req = newRequest("GET", "/sub/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 0 {
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
}
})
req = newRequest("GET", "/sub/x")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
t.Run("called once for route inside subrouter", func(t *testing.T) {
req = newRequest("GET", "/sub/x")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
req = newRequest("GET", "/sub/not/found")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
t.Run("not called for 404 inside subrouter", func(t *testing.T) {
req = newRequest("GET", "/sub/not/found")
router.ServeHTTP(rw, req)
if mw.timesCalled != 1 {
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
}
})
router.useInterface(mw)
t.Run("middleware added to router", func(t *testing.T) {
router.useInterface(mw)
req = newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 2 {
t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
}
t.Run("called once for route outside subrouter", func(t *testing.T) {
req = newRequest("GET", "/")
router.ServeHTTP(rw, req)
if mw.timesCalled != 2 {
t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
}
})
req = newRequest("GET", "/sub/x")
router.ServeHTTP(rw, req)
if mw.timesCalled != 4 {
t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
}
t.Run("called twice for route inside subrouter", func(t *testing.T) {
req = newRequest("GET", "/sub/x")
router.ServeHTTP(rw, req)
if mw.timesCalled != 4 {
t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
}
})
})
}
func TestMiddlewareExecution(t *testing.T) {
@@ -145,30 +161,33 @@ func TestMiddlewareExecution(t *testing.T) {
w.Write(handlerStr)
})
rw := NewRecorder()
req := newRequest("GET", "/")
t.Run("responds normally without middleware", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/")
// Test handler-only call
router.ServeHTTP(rw, req)
router.ServeHTTP(rw, req)
if bytes.Compare(rw.Body.Bytes(), handlerStr) != 0 {
t.Fatal("Handler response is not what it should be")
}
// Test middleware call
rw = NewRecorder()
router.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(mwStr)
h.ServeHTTP(w, r)
})
if !bytes.Equal(rw.Body.Bytes(), handlerStr) {
t.Fatal("Handler response is not what it should be")
}
})
router.ServeHTTP(rw, req)
if bytes.Compare(rw.Body.Bytes(), append(mwStr, handlerStr...)) != 0 {
t.Fatal("Middleware + handler response is not what it should be")
}
t.Run("responds with handler and middleware response", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/")
router.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(mwStr)
h.ServeHTTP(w, r)
})
})
router.ServeHTTP(rw, req)
if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) {
t.Fatal("Middleware + handler response is not what it should be")
}
})
}
func TestMiddlewareNotFound(t *testing.T) {
@@ -187,26 +206,29 @@ func TestMiddlewareNotFound(t *testing.T) {
})
// Test not found call with default handler
rw := NewRecorder()
req := newRequest("GET", "/notfound")
t.Run("not called", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/notfound")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a 404")
}
// Test not found call with custom handler
rw = NewRecorder()
req = newRequest("GET", "/notfound")
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Custom 404 handler"))
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a 404")
}
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a custom 404")
}
t.Run("not called with custom not found handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/notfound")
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Custom 404 handler"))
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a custom 404")
}
})
}
func TestMiddlewareMethodMismatch(t *testing.T) {
@@ -225,27 +247,29 @@ func TestMiddlewareMethodMismatch(t *testing.T) {
})
})
// Test method mismatch
rw := NewRecorder()
req := newRequest("POST", "/")
t.Run("not called", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("POST", "/")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
// Test not found call
rw = NewRecorder()
req = newRequest("POST", "/")
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Method not allowed"))
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
t.Run("not called with custom method not allowed handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("POST", "/")
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Method not allowed"))
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
})
}
func TestMiddlewareNotFoundSubrouter(t *testing.T) {
@@ -269,27 +293,29 @@ func TestMiddlewareNotFoundSubrouter(t *testing.T) {
})
})
// Test not found call for default handler
rw := NewRecorder()
req := newRequest("GET", "/sub/notfound")
t.Run("not called", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/sub/notfound")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a 404")
}
// Test not found call with custom handler
rw = NewRecorder()
req = newRequest("GET", "/sub/notfound")
subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Custom 404 handler"))
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a 404")
}
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a custom 404")
}
t.Run("not called with custom not found handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/sub/notfound")
subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Custom 404 handler"))
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a custom 404")
}
})
}
func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
@@ -313,65 +339,207 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
})
})
// Test method mismatch without custom handler
rw := NewRecorder()
req := newRequest("POST", "/sub/")
t.Run("not called", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("POST", "/sub/")
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
// Test method mismatch with custom handler
rw = NewRecorder()
req = newRequest("POST", "/sub/")
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Method not allowed"))
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
t.Run("not called with custom method not allowed handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("POST", "/sub/")
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Method not allowed"))
})
router.ServeHTTP(rw, req)
if bytes.Contains(rw.Body.Bytes(), mwStr) {
t.Fatal("Middleware was called for a method mismatch")
}
})
}
func TestCORSMethodMiddleware(t *testing.T) {
router := NewRouter()
cases := []struct {
path string
response string
method string
testURL string
expectedAllowedMethods string
testCases := []struct {
name string
registerRoutes func(r *Router)
requestHeader http.Header
requestMethod string
requestPath string
expectedAccessControlAllowMethodsHeader string
expectedResponse string
}{
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
{
name: "does not set without OPTIONS matcher",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
},
requestMethod: "GET",
requestPath: "/foo",
expectedAccessControlAllowMethodsHeader: "",
expectedResponse: "a",
},
{
name: "sets on non OPTIONS",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
},
requestMethod: "GET",
requestPath: "/foo",
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
expectedResponse: "a",
},
{
name: "sets without preflight headers",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
},
requestMethod: "OPTIONS",
requestPath: "/foo",
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
expectedResponse: "b",
},
{
name: "does not set on error",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a"))
},
requestMethod: "OPTIONS",
requestPath: "/foo",
expectedAccessControlAllowMethodsHeader: "",
expectedResponse: "a",
},
{
name: "sets header on valid preflight",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
},
requestMethod: "OPTIONS",
requestPath: "/foo",
requestHeader: http.Header{
"Access-Control-Request-Method": []string{"GET"},
"Access-Control-Request-Headers": []string{"Authorization"},
"Origin": []string{"http://example.com"},
},
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
expectedResponse: "b",
},
{
name: "does not set methods from unmatching routes",
registerRoutes: func(r *Router) {
r.HandleFunc("/foo", stringHandler("c")).Methods(http.MethodDelete)
r.HandleFunc("/foo/bar", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
r.HandleFunc("/foo/bar", stringHandler("b")).Methods(http.MethodOptions)
},
requestMethod: "OPTIONS",
requestPath: "/foo/bar",
requestHeader: http.Header{
"Access-Control-Request-Method": []string{"GET"},
"Access-Control-Request-Headers": []string{"Authorization"},
"Origin": []string{"http://example.com"},
},
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
expectedResponse: "b",
},
}
for _, tt := range cases {
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
router := NewRouter()
router.Use(CORSMethodMiddleware(router))
tt.registerRoutes(router)
for _, tt := range cases {
rr := httptest.NewRecorder()
req := newRequest(tt.method, tt.testURL)
router.Use(CORSMethodMiddleware(router))
router.ServeHTTP(rr, req)
rw := NewRecorder()
req := newRequest(tt.requestMethod, tt.requestPath)
req.Header = tt.requestHeader
if rr.Body.String() != tt.response {
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
}
router.ServeHTTP(rw, req)
allowedMethods := rr.HeaderMap.Get("Access-Control-Allow-Methods")
actualMethodsHeader := rw.Header().Get("Access-Control-Allow-Methods")
if actualMethodsHeader != tt.expectedAccessControlAllowMethodsHeader {
t.Fatalf("Expected Access-Control-Allow-Methods to equal %s but got %s", tt.expectedAccessControlAllowMethodsHeader, actualMethodsHeader)
}
if allowedMethods != tt.expectedAllowedMethods {
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
}
actualResponse := rw.Body.String()
if actualResponse != tt.expectedResponse {
t.Fatalf("Expected response to equal %s but got %s", tt.expectedResponse, actualResponse)
}
})
}
}
func TestMiddlewareOnMultiSubrouter(t *testing.T) {
first := "first"
second := "second"
notFound := "404 not found"
router := NewRouter()
firstSubRouter := router.PathPrefix("/").Subrouter()
secondSubRouter := router.PathPrefix("/").Subrouter()
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte(notFound))
})
firstSubRouter.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) {
})
secondSubRouter.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) {
})
firstSubRouter.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(first))
h.ServeHTTP(w, r)
})
})
secondSubRouter.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(second))
h.ServeHTTP(w, r)
})
})
t.Run("/first uses first middleware", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/first")
router.ServeHTTP(rw, req)
if rw.Body.String() != first {
t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String())
}
})
t.Run("/second uses second middleware", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/second")
router.ServeHTTP(rw, req)
if rw.Body.String() != second {
t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String())
}
})
t.Run("uses not found handler", func(t *testing.T) {
rw := NewRecorder()
req := newRequest("GET", "/second/not-exist")
router.ServeHTTP(rw, req)
if rw.Body.String() != notFound {
t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String())
}
})
}

127
mux.go
View File

@@ -22,7 +22,7 @@ var (
// NewRouter returns a new router instance.
func NewRouter() *Router {
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
return &Router{namedRoutes: make(map[string]*Route)}
}
// Router registers routes to be matched and dispatches a handler.
@@ -50,24 +50,76 @@ type Router struct {
// Configurable Handler to be used when the request method does not match the route.
MethodNotAllowedHandler http.Handler
// Parent route, if this is a subrouter.
parent parentRoute
// Routes to be matched, in order.
routes []*Route
// Routes by name for URL building.
namedRoutes map[string]*Route
// See Router.StrictSlash(). This defines the flag for new routes.
strictSlash bool
// See Router.SkipClean(). This defines the flag for new routes.
skipClean bool
// If true, do not clear the request context after handling the request.
// This has no effect when go1.7+ is used, since the context is stored
//
// Deprecated: No effect when go1.7+ is used, since the context is stored
// on the request itself.
KeepContext bool
// see Router.UseEncodedPath(). This defines a flag for all routes.
useEncodedPath bool
// Slice of middlewares to be called after a match is found
middlewares []middleware
// configuration shared with `Route`
routeConf
}
// common route configuration shared between `Router` and `Route`
type routeConf struct {
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
useEncodedPath bool
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
strictSlash bool
// If true, when the path pattern is "/path//to", accessing "/path//to"
// will not redirect
skipClean bool
// Manager for the variables from host and path.
regexp routeRegexpGroup
// List of matchers.
matchers []matcher
// The scheme used when building URLs.
buildScheme string
buildVarsFunc BuildVarsFunc
}
// returns an effective deep copy of `routeConf`
func copyRouteConf(r routeConf) routeConf {
c := r
if r.regexp.path != nil {
c.regexp.path = copyRouteRegexp(r.regexp.path)
}
if r.regexp.host != nil {
c.regexp.host = copyRouteRegexp(r.regexp.host)
}
c.regexp.queries = make([]*routeRegexp, 0, len(r.regexp.queries))
for _, q := range r.regexp.queries {
c.regexp.queries = append(c.regexp.queries, copyRouteRegexp(q))
}
c.matchers = make([]matcher, len(r.matchers))
copy(c.matchers, r.matchers)
return c
}
func copyRouteRegexp(r *routeRegexp) *routeRegexp {
c := *r
return &c
}
// Match attempts to match the given request against the router's registered routes.
@@ -155,22 +207,18 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
handler = http.NotFoundHandler()
}
if !r.KeepContext {
defer contextClear(req)
}
handler.ServeHTTP(w, req)
}
// Get returns a route registered with the given name.
func (r *Router) Get(name string) *Route {
return r.getNamedRoutes()[name]
return r.namedRoutes[name]
}
// GetRoute returns a route registered with the given name. This method
// was renamed to Get() and remains here for backwards compatibility.
func (r *Router) GetRoute(name string) *Route {
return r.getNamedRoutes()[name]
return r.namedRoutes[name]
}
// StrictSlash defines the trailing slash behavior for new routes. The initial
@@ -221,55 +269,24 @@ func (r *Router) UseEncodedPath() *Router {
return r
}
// ----------------------------------------------------------------------------
// parentRoute
// ----------------------------------------------------------------------------
func (r *Router) getBuildScheme() string {
if r.parent != nil {
return r.parent.getBuildScheme()
}
return ""
}
// getNamedRoutes returns the map where named routes are registered.
func (r *Router) getNamedRoutes() map[string]*Route {
if r.namedRoutes == nil {
if r.parent != nil {
r.namedRoutes = r.parent.getNamedRoutes()
} else {
r.namedRoutes = make(map[string]*Route)
}
}
return r.namedRoutes
}
// getRegexpGroup returns regexp definitions from the parent route, if any.
func (r *Router) getRegexpGroup() *routeRegexpGroup {
if r.parent != nil {
return r.parent.getRegexpGroup()
}
return nil
}
func (r *Router) buildVars(m map[string]string) map[string]string {
if r.parent != nil {
m = r.parent.buildVars(m)
}
return m
}
// ----------------------------------------------------------------------------
// Route factories
// ----------------------------------------------------------------------------
// NewRoute registers an empty route.
func (r *Router) NewRoute() *Route {
route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath}
// initialize a route with a copy of the parent router's configuration
route := &Route{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
r.routes = append(r.routes, route)
return route
}
// Name registers a new route with a name.
// See Route.Name().
func (r *Router) Name(name string) *Route {
return r.NewRoute().Name(name)
}
// Handle registers a new route with a matcher for the URL path.
// See Route.Path() and Route.Handler().
func (r *Router) Handle(path string, handler http.Handler) *Route {

View File

@@ -48,15 +48,6 @@ type routeTest struct {
}
func TestHost(t *testing.T) {
// newRequestHost a new request with a method, url, and host header
newRequestHost := func(method, url, host string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
req.Host = host
return req
}
tests := []routeTest{
{
@@ -113,7 +104,15 @@ func TestHost(t *testing.T) {
path: "",
shouldMatch: false,
},
// BUG {new(Route).Host("aaa.bbb.ccc:1234"), newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:1234"), map[string]string{}, "aaa.bbb.ccc:1234", "", true},
{
title: "Host route with port, match with request header",
route: new(Route).Host("aaa.bbb.ccc:1234"),
request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:1234"),
vars: map[string]string{},
host: "aaa.bbb.ccc:1234",
path: "",
shouldMatch: true,
},
{
title: "Host route with port, wrong host in request header",
route: new(Route).Host("aaa.bbb.ccc:1234"),
@@ -123,6 +122,16 @@ func TestHost(t *testing.T) {
path: "",
shouldMatch: false,
},
{
title: "Host route with pattern, match with request header",
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc:1{v2:(?:23|4)}"),
request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:123"),
vars: map[string]string{"v1": "bbb", "v2": "23"},
host: "aaa.bbb.ccc:123",
path: "",
hostTemplate: `aaa.{v1:[a-z]{3}}.ccc:1{v2:(?:23|4)}`,
shouldMatch: true,
},
{
title: "Host route with pattern, match",
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"),
@@ -205,8 +214,10 @@ func TestHost(t *testing.T) {
},
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -437,10 +448,12 @@ func TestPath(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
testRegexp(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
testRegexp(t, test)
})
}
}
@@ -516,9 +529,11 @@ func TestPathPrefix(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
})
}
}
@@ -623,9 +638,11 @@ func TestSchemeHostPath(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
})
}
}
@@ -682,8 +699,10 @@ func TestHeaders(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -732,9 +751,11 @@ func TestMethods(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testMethods(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testMethods(t, test)
})
}
}
@@ -1039,11 +1060,12 @@ func TestQueries(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testQueriesTemplates(t, test)
testUseEscapedRoute(t, test)
testQueriesRegexp(t, test)
t.Run(test.title, func(t *testing.T) {
testTemplate(t, test)
testQueriesTemplates(t, test)
testUseEscapedRoute(t, test)
testQueriesRegexp(t, test)
})
}
}
@@ -1092,17 +1114,16 @@ func TestSchemes(t *testing.T) {
},
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
func TestMatcherFunc(t *testing.T) {
m := func(r *http.Request, m *RouteMatch) bool {
if r.URL.Host == "aaa.bbb.ccc" {
return true
}
return false
return r.URL.Host == "aaa.bbb.ccc"
}
tests := []routeTest{
@@ -1127,8 +1148,10 @@ func TestMatcherFunc(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -1163,8 +1186,10 @@ func TestBuildVarsFunc(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -1174,7 +1199,6 @@ func TestSubRouter(t *testing.T) {
subrouter3 := new(Route).PathPrefix("/foo").Subrouter()
subrouter4 := new(Route).PathPrefix("/foo/bar").Subrouter()
subrouter5 := new(Route).PathPrefix("/{category}").Subrouter()
tests := []routeTest{
{
route: subrouter1.Path("/{v2:[a-z]+}"),
@@ -1269,6 +1293,106 @@ func TestSubRouter(t *testing.T) {
pathTemplate: `/{category}`,
shouldMatch: true,
},
{
title: "Mismatch method specified on parent route",
route: new(Route).Methods("POST").PathPrefix("/foo").Subrouter().Path("/"),
request: newRequest("GET", "http://localhost/foo/"),
vars: map[string]string{},
host: "",
path: "/foo/",
pathTemplate: `/foo/`,
shouldMatch: false,
},
{
title: "Match method specified on parent route",
route: new(Route).Methods("POST").PathPrefix("/foo").Subrouter().Path("/"),
request: newRequest("POST", "http://localhost/foo/"),
vars: map[string]string{},
host: "",
path: "/foo/",
pathTemplate: `/foo/`,
shouldMatch: true,
},
{
title: "Mismatch scheme specified on parent route",
route: new(Route).Schemes("https").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Match scheme specified on parent route",
route: new(Route).Schemes("http").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: true,
},
{
title: "No match header specified on parent route",
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Header mismatch value specified on parent route",
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
request: newRequestWithHeaders("GET", "http://localhost/", "X-Forwarded-Proto", "http"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Header match value specified on parent route",
route: new(Route).Headers("X-Forwarded-Proto", "https").Subrouter().PathPrefix("/"),
request: newRequestWithHeaders("GET", "http://localhost/", "X-Forwarded-Proto", "https"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: true,
},
{
title: "Query specified on parent route not present",
route: new(Route).Headers("key", "foobar").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Query mismatch value specified on parent route",
route: new(Route).Queries("key", "foobar").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/?key=notfoobar"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: false,
},
{
title: "Query match value specified on subroute",
route: new(Route).Queries("key", "foobar").Subrouter().PathPrefix("/"),
request: newRequest("GET", "http://localhost/?key=foobar"),
vars: map[string]string{},
host: "",
path: "/",
pathTemplate: `/`,
shouldMatch: true,
},
{
title: "Build with scheme on parent router",
route: new(Route).Schemes("ftp").Host("google.com").Subrouter().Path("/"),
@@ -1294,9 +1418,11 @@ func TestSubRouter(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
})
}
}
@@ -1315,14 +1441,24 @@ func TestNamedRoutes(t *testing.T) {
r3.NewRoute().Name("g")
r3.NewRoute().Name("h")
r3.NewRoute().Name("i")
r3.Name("j")
if r1.namedRoutes == nil || len(r1.namedRoutes) != 9 {
t.Errorf("Expected 9 named routes, got %v", r1.namedRoutes)
} else if r1.Get("i") == nil {
if r1.namedRoutes == nil || len(r1.namedRoutes) != 10 {
t.Errorf("Expected 10 named routes, got %v", r1.namedRoutes)
} else if r1.Get("j") == nil {
t.Errorf("Subroute name not registered")
}
}
func TestNameMultipleCalls(t *testing.T) {
r1 := NewRouter()
rt := r1.NewRoute().Name("foo").Name("bar")
err := rt.GetError()
if err == nil {
t.Errorf("Expected an error")
}
}
func TestStrictSlash(t *testing.T) {
r := NewRouter()
r.StrictSlash(true)
@@ -1391,9 +1527,11 @@ func TestStrictSlash(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
testUseEscapedRoute(t, test)
})
}
}
@@ -1425,8 +1563,10 @@ func TestUseEncodedPath(t *testing.T) {
}
for _, test := range tests {
testRoute(t, test)
testTemplate(t, test)
t.Run(test.title, func(t *testing.T) {
testRoute(t, test)
testTemplate(t, test)
})
}
}
@@ -1478,12 +1618,16 @@ func TestWalkSingleDepth(t *testing.T) {
func TestWalkNested(t *testing.T) {
router := NewRouter()
g := router.Path("/g").Subrouter()
o := g.PathPrefix("/o").Subrouter()
r := o.PathPrefix("/r").Subrouter()
i := r.PathPrefix("/i").Subrouter()
l1 := i.PathPrefix("/l").Subrouter()
l2 := l1.PathPrefix("/l").Subrouter()
routeSubrouter := func(r *Route) (*Route, *Router) {
return r, r.Subrouter()
}
gRoute, g := routeSubrouter(router.Path("/g"))
oRoute, o := routeSubrouter(g.PathPrefix("/o"))
rRoute, r := routeSubrouter(o.PathPrefix("/r"))
iRoute, i := routeSubrouter(r.PathPrefix("/i"))
l1Route, l1 := routeSubrouter(i.PathPrefix("/l"))
l2Route, l2 := routeSubrouter(l1.PathPrefix("/l"))
l2.Path("/a")
testCases := []struct {
@@ -1491,12 +1635,12 @@ func TestWalkNested(t *testing.T) {
ancestors []*Route
}{
{"/g", []*Route{}},
{"/g/o", []*Route{g.parent.(*Route)}},
{"/g/o/r", []*Route{g.parent.(*Route), o.parent.(*Route)}},
{"/g/o/r/i", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route)}},
{"/g/o/r/i/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route)}},
{"/g/o/r/i/l/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route)}},
{"/g/o/r/i/l/l/a", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route), l2.parent.(*Route)}},
{"/g/o", []*Route{gRoute}},
{"/g/o/r", []*Route{gRoute, oRoute}},
{"/g/o/r/i", []*Route{gRoute, oRoute, rRoute}},
{"/g/o/r/i/l", []*Route{gRoute, oRoute, rRoute, iRoute}},
{"/g/o/r/i/l/l", []*Route{gRoute, oRoute, rRoute, iRoute, l1Route}},
{"/g/o/r/i/l/l/a", []*Route{gRoute, oRoute, rRoute, iRoute, l1Route, l2Route}},
}
idx := 0
@@ -1529,8 +1673,8 @@ func TestWalkSubrouters(t *testing.T) {
o.Methods("GET")
o.Methods("PUT")
// all 4 routes should be matched, but final 2 routes do not have path templates
paths := []string{"/g", "/g/o", "", ""}
// all 4 routes should be matched
paths := []string{"/g", "/g/o", "/g/o", "/g/o"}
idx := 0
err := router.Walk(func(route *Route, router *Router, ancestors []*Route) error {
path := paths[idx]
@@ -1711,7 +1855,11 @@ func testRoute(t *testing.T, test routeTest) {
}
}
if query != "" {
u, _ := route.URL(mapToPairs(match.Vars)...)
u, err := route.URL(mapToPairs(match.Vars)...)
if err != nil {
t.Errorf("(%v) erred while creating url: %v", test.title, err)
return
}
if query != u.RawQuery {
t.Errorf("(%v) URL query not equal: expected %v, got %v", test.title, query, u.RawQuery)
return
@@ -1795,7 +1943,7 @@ type TestA301ResponseWriter struct {
}
func (ho *TestA301ResponseWriter) Header() http.Header {
return http.Header(ho.hh)
return ho.hh
}
func (ho *TestA301ResponseWriter) Write(b []byte) (int, error) {
@@ -2031,7 +2179,9 @@ func TestMethodsSubrouterCatchall(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2087,7 +2237,9 @@ func TestMethodsSubrouterStrictSlash(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2134,7 +2286,9 @@ func TestMethodsSubrouterPathPrefix(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2190,7 +2344,9 @@ func TestMethodsSubrouterSubrouter(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2244,7 +2400,9 @@ func TestMethodsSubrouterPathVariable(t *testing.T) {
}
for _, test := range tests {
testMethodsSubrouter(t, test)
t.Run(test.title, func(t *testing.T) {
testMethodsSubrouter(t, test)
})
}
}
@@ -2288,6 +2446,305 @@ func testMethodsSubrouter(t *testing.T, test methodsSubrouterTest) {
}
}
func TestSubrouterMatching(t *testing.T) {
const (
none, stdOnly, subOnly uint8 = 0, 1 << 0, 1 << 1
both = subOnly | stdOnly
)
type request struct {
Name string
Request *http.Request
Flags uint8
}
cases := []struct {
Name string
Standard, Subrouter func(*Router)
Requests []request
}{
{
"pathPrefix",
func(r *Router) {
r.PathPrefix("/before").PathPrefix("/after")
},
func(r *Router) {
r.PathPrefix("/before").Subrouter().PathPrefix("/after")
},
[]request{
{"no match final path prefix", newRequest("GET", "/after"), none},
{"no match parent path prefix", newRequest("GET", "/before"), none},
{"matches append", newRequest("GET", "/before/after"), both},
{"matches as prefix", newRequest("GET", "/before/after/1234"), both},
},
},
{
"path",
func(r *Router) {
r.Path("/before").Path("/after")
},
func(r *Router) {
r.Path("/before").Subrouter().Path("/after")
},
[]request{
{"no match subroute path", newRequest("GET", "/after"), none},
{"no match parent path", newRequest("GET", "/before"), none},
{"no match as prefix", newRequest("GET", "/before/after/1234"), none},
{"no match append", newRequest("GET", "/before/after"), none},
},
},
{
"host",
func(r *Router) {
r.Host("before.com").Host("after.com")
},
func(r *Router) {
r.Host("before.com").Subrouter().Host("after.com")
},
[]request{
{"no match before", newRequestHost("GET", "/", "before.com"), none},
{"no match other", newRequestHost("GET", "/", "other.com"), none},
{"matches after", newRequestHost("GET", "/", "after.com"), none},
},
},
{
"queries variant keys",
func(r *Router) {
r.Queries("foo", "bar").Queries("cricket", "baseball")
},
func(r *Router) {
r.Queries("foo", "bar").Subrouter().Queries("cricket", "baseball")
},
[]request{
{"matches with all", newRequest("GET", "/?foo=bar&cricket=baseball"), both},
{"matches with more", newRequest("GET", "/?foo=bar&cricket=baseball&something=else"), both},
{"no match with none", newRequest("GET", "/"), none},
{"no match with some", newRequest("GET", "/?cricket=baseball"), none},
},
},
{
"queries overlapping keys",
func(r *Router) {
r.Queries("foo", "bar").Queries("foo", "baz")
},
func(r *Router) {
r.Queries("foo", "bar").Subrouter().Queries("foo", "baz")
},
[]request{
{"no match old value", newRequest("GET", "/?foo=bar"), none},
{"no match diff value", newRequest("GET", "/?foo=bak"), none},
{"no match with none", newRequest("GET", "/"), none},
{"matches override", newRequest("GET", "/?foo=baz"), none},
},
},
{
"header variant keys",
func(r *Router) {
r.Headers("foo", "bar").Headers("cricket", "baseball")
},
func(r *Router) {
r.Headers("foo", "bar").Subrouter().Headers("cricket", "baseball")
},
[]request{
{
"matches with all",
newRequestWithHeaders("GET", "/", "foo", "bar", "cricket", "baseball"),
both,
},
{
"matches with more",
newRequestWithHeaders("GET", "/", "foo", "bar", "cricket", "baseball", "something", "else"),
both,
},
{"no match with none", newRequest("GET", "/"), none},
{"no match with some", newRequestWithHeaders("GET", "/", "cricket", "baseball"), none},
},
},
{
"header overlapping keys",
func(r *Router) {
r.Headers("foo", "bar").Headers("foo", "baz")
},
func(r *Router) {
r.Headers("foo", "bar").Subrouter().Headers("foo", "baz")
},
[]request{
{"no match old value", newRequestWithHeaders("GET", "/", "foo", "bar"), none},
{"no match diff value", newRequestWithHeaders("GET", "/", "foo", "bak"), none},
{"no match with none", newRequest("GET", "/"), none},
{"matches override", newRequestWithHeaders("GET", "/", "foo", "baz"), none},
},
},
{
"method",
func(r *Router) {
r.Methods("POST").Methods("GET")
},
func(r *Router) {
r.Methods("POST").Subrouter().Methods("GET")
},
[]request{
{"matches before", newRequest("POST", "/"), none},
{"no match other", newRequest("HEAD", "/"), none},
{"matches override", newRequest("GET", "/"), none},
},
},
{
"schemes",
func(r *Router) {
r.Schemes("http").Schemes("https")
},
func(r *Router) {
r.Schemes("http").Subrouter().Schemes("https")
},
[]request{
{"matches overrides", newRequest("GET", "https://www.example.com/"), none},
{"matches original", newRequest("GET", "http://www.example.com/"), none},
{"no match other", newRequest("GET", "ftp://www.example.com/"), none},
},
},
}
// case -> request -> router
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
for _, req := range c.Requests {
t.Run(req.Name, func(t *testing.T) {
for _, v := range []struct {
Name string
Config func(*Router)
Expected bool
}{
{"subrouter", c.Subrouter, (req.Flags & subOnly) != 0},
{"standard", c.Standard, (req.Flags & stdOnly) != 0},
} {
r := NewRouter()
v.Config(r)
if r.Match(req.Request, &RouteMatch{}) != v.Expected {
if v.Expected {
t.Errorf("expected %v match", v.Name)
} else {
t.Errorf("expected %v no match", v.Name)
}
}
}
})
}
})
}
}
// verify that copyRouteConf copies fields as expected.
func Test_copyRouteConf(t *testing.T) {
var (
m MatcherFunc = func(*http.Request, *RouteMatch) bool {
return true
}
b BuildVarsFunc = func(i map[string]string) map[string]string {
return i
}
r, _ = newRouteRegexp("hi", regexpTypeHost, routeRegexpOptions{})
)
tests := []struct {
name string
args routeConf
want routeConf
}{
{
"empty",
routeConf{},
routeConf{},
},
{
"full",
routeConf{
useEncodedPath: true,
strictSlash: true,
skipClean: true,
regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}},
matchers: []matcher{m},
buildScheme: "https",
buildVarsFunc: b,
},
routeConf{
useEncodedPath: true,
strictSlash: true,
skipClean: true,
regexp: routeRegexpGroup{host: r, path: r, queries: []*routeRegexp{r}},
matchers: []matcher{m},
buildScheme: "https",
buildVarsFunc: b,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// special case some incomparable fields of routeConf before delegating to reflect.DeepEqual
got := copyRouteConf(tt.args)
// funcs not comparable, just compare length of slices
if len(got.matchers) != len(tt.want.matchers) {
t.Errorf("matchers different lengths: %v %v", len(got.matchers), len(tt.want.matchers))
}
got.matchers, tt.want.matchers = nil, nil
// deep equal treats nil slice differently to empty slice so check for zero len first
{
bothZero := len(got.regexp.queries) == 0 && len(tt.want.regexp.queries) == 0
if !bothZero && !reflect.DeepEqual(got.regexp.queries, tt.want.regexp.queries) {
t.Errorf("queries unequal: %v %v", got.regexp.queries, tt.want.regexp.queries)
}
got.regexp.queries, tt.want.regexp.queries = nil, nil
}
// funcs not comparable, just compare nullity
if (got.buildVarsFunc == nil) != (tt.want.buildVarsFunc == nil) {
t.Errorf("build vars funcs unequal: %v %v", got.buildVarsFunc == nil, tt.want.buildVarsFunc == nil)
}
got.buildVarsFunc, tt.want.buildVarsFunc = nil, nil
// finish the deal
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("route confs unequal: %v %v", got, tt.want)
}
})
}
}
func TestMethodNotAllowed(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }
router := NewRouter()
router.HandleFunc("/thing", handler).Methods(http.MethodGet)
router.HandleFunc("/something", handler).Methods(http.MethodGet)
w := NewRecorder()
req := newRequest(http.MethodPut, "/thing")
router.ServeHTTP(w, req)
if w.Code != 405 {
t.Fatalf("Expected status code 405 (got %d)", w.Code)
}
}
func TestSubrouterNotFound(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }
router := NewRouter()
router.Path("/a").Subrouter().HandleFunc("/thing", handler).Methods(http.MethodGet)
router.Path("/b").Subrouter().HandleFunc("/something", handler).Methods(http.MethodGet)
w := NewRecorder()
req := newRequest(http.MethodPut, "/not-present")
router.ServeHTTP(w, req)
if w.Code != 404 {
t.Fatalf("Expected status code 404 (got %d)", w.Code)
}
}
// mapToPairs converts a string map to a slice of string pairs
func mapToPairs(m map[string]string) []string {
var i int
@@ -2362,3 +2819,28 @@ func newRequest(method, url string) *http.Request {
}
return req
}
// create a new request with the provided headers
func newRequestWithHeaders(method, url string, headers ...string) *http.Request {
req := newRequest(method, url)
if len(headers)%2 != 0 {
panic(fmt.Sprintf("Expected headers length divisible by 2 but got %v", len(headers)))
}
for i := 0; i < len(headers); i += 2 {
req.Header.Set(headers[i], headers[i+1])
}
return req
}
// newRequestHost a new request with a method, url, and host header
func newRequestHost(method, url, host string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
req.Host = host
return req
}

View File

@@ -113,6 +113,13 @@ func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*ro
if typ != regexpTypePrefix {
pattern.WriteByte('$')
}
var wildcardHostPort bool
if typ == regexpTypeHost {
if !strings.Contains(pattern.String(), ":") {
wildcardHostPort = true
}
}
reverse.WriteString(raw)
if endSlash {
reverse.WriteByte('/')
@@ -131,13 +138,14 @@ func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*ro
// Done!
return &routeRegexp{
template: template,
regexpType: typ,
options: options,
regexp: reg,
reverse: reverse.String(),
varsN: varsN,
varsR: varsR,
template: template,
regexpType: typ,
options: options,
regexp: reg,
reverse: reverse.String(),
varsN: varsN,
varsR: varsR,
wildcardHostPort: wildcardHostPort,
}, nil
}
@@ -158,11 +166,22 @@ type routeRegexp struct {
varsN []string
// Variable regexps (validators).
varsR []*regexp.Regexp
// Wildcard host-port (no strict port match in hostname)
wildcardHostPort bool
}
// Match matches the regexp against the URL host or path.
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
if r.regexpType != regexpTypeHost {
if r.regexpType == regexpTypeHost {
host := getHost(req)
if r.wildcardHostPort {
// Don't be strict on the port match
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
}
return r.regexp.MatchString(host)
} else {
if r.regexpType == regexpTypeQuery {
return r.matchQueryString(req)
}
@@ -172,8 +191,6 @@ func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
}
return r.regexp.MatchString(path)
}
return r.regexp.MatchString(getHost(req))
}
// url builds a URL part using the given values.
@@ -267,7 +284,7 @@ type routeRegexpGroup struct {
}
// setMatch extracts the variables from the URL once a route matches.
func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
func (v routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
// Store host variables.
if v.host != nil {
host := getHost(req)
@@ -296,7 +313,7 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
} else {
u.Path += "/"
}
m.Handler = http.RedirectHandler(u.String(), 301)
m.Handler = http.RedirectHandler(u.String(), http.StatusMovedPermanently)
}
}
}
@@ -312,17 +329,13 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
}
// getHost tries its best to return the request host.
// According to section 14.23 of RFC 2616 the Host header
// can include the port number if the default value of 80 is not used.
func getHost(r *http.Request) string {
if r.URL.IsAbs() {
return r.URL.Host
}
host := r.Host
// Slice off any port information.
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
return host
return r.Host
}
func extractVars(input string, matches []int, names []string, output map[string]string) {

141
route.go
View File

@@ -15,24 +15,8 @@ import (
// Route stores information to match a request and build URLs.
type Route struct {
// Parent where the route was registered (a Router).
parent parentRoute
// Request handler for the route.
handler http.Handler
// List of matchers.
matchers []matcher
// Manager for the variables from host and path.
regexp *routeRegexpGroup
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
strictSlash bool
// If true, when the path pattern is "/path//to", accessing "/path//to"
// will not redirect
skipClean bool
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
useEncodedPath bool
// The scheme used when building URLs.
buildScheme string
// If true, this route never matches: it is only used to build URLs.
buildOnly bool
// The name used to build URLs.
@@ -40,7 +24,11 @@ type Route struct {
// Error resulted from building a route.
err error
buildVarsFunc BuildVarsFunc
// "global" reference to all named routes
namedRoutes map[string]*Route
// config possibly passed in from `Router`
routeConf
}
// SkipClean reports whether path cleaning is enabled for this route via
@@ -64,6 +52,18 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
matchErr = ErrMethodMismatch
continue
}
// Ignore ErrNotFound errors. These errors arise from match call
// to Subrouters.
//
// This prevents subsequent matching subrouters from failing to
// run middleware. If not ignored, the middleware would see a
// non-nil MatchErr and be skipped, even when there was a
// matching route.
if match.MatchErr == ErrNotFound {
match.MatchErr = nil
}
matchErr = nil
return false
}
@@ -93,9 +93,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
}
// Set variables.
if r.regexp != nil {
r.regexp.setMatch(req, match, r)
}
r.regexp.setMatch(req, match, r)
return true
}
@@ -137,7 +135,7 @@ func (r *Route) GetHandler() http.Handler {
// Name -----------------------------------------------------------------------
// Name sets the name for the route, used to build URLs.
// If the name was registered already it will be overwritten.
// It is an error to call Name more than once on a route.
func (r *Route) Name(name string) *Route {
if r.name != "" {
r.err = fmt.Errorf("mux: route already has name %q, can't set %q",
@@ -145,7 +143,7 @@ func (r *Route) Name(name string) *Route {
}
if r.err == nil {
r.name = name
r.getNamedRoutes()[name] = r
r.namedRoutes[name] = r
}
return r
}
@@ -177,7 +175,6 @@ func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error {
if r.err != nil {
return r.err
}
r.regexp = r.getRegexpGroup()
if typ == regexpTypePath || typ == regexpTypePrefix {
if len(tpl) > 0 && tpl[0] != '/' {
return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
@@ -386,7 +383,7 @@ func (r *Route) PathPrefix(tpl string) *Route {
// The above route will only match if the URL contains the defined queries
// values, e.g.: ?foo=bar&id=42.
//
// It the value is an empty string, it will match any value if the key is set.
// If the value is an empty string, it will match any value if the key is set.
//
// Variables can define an optional regexp pattern to be matched:
//
@@ -424,7 +421,7 @@ func (r *Route) Schemes(schemes ...string) *Route {
for k, v := range schemes {
schemes[k] = strings.ToLower(v)
}
if r.buildScheme == "" && len(schemes) > 0 {
if len(schemes) > 0 {
r.buildScheme = schemes[0]
}
return r.addMatcher(schemeMatcher(schemes))
@@ -439,7 +436,15 @@ type BuildVarsFunc func(map[string]string) map[string]string
// BuildVarsFunc adds a custom function to be used to modify build variables
// before a route's URL is built.
func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
r.buildVarsFunc = f
if r.buildVarsFunc != nil {
// compose the old and new functions
old := r.buildVarsFunc
r.buildVarsFunc = func(m map[string]string) map[string]string {
return f(old(m))
}
} else {
r.buildVarsFunc = f
}
return r
}
@@ -458,7 +463,8 @@ func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
// Here, the routes registered in the subrouter won't be tested if the host
// doesn't match.
func (r *Route) Subrouter() *Router {
router := &Router{parent: r, strictSlash: r.strictSlash}
// initialize a subrouter with a copy of the parent route's configuration
router := &Router{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
r.addMatcher(router)
return router
}
@@ -502,9 +508,6 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil {
return nil, errors.New("mux: route doesn't have a host or path")
}
values, err := r.prepareVars(pairs...)
if err != nil {
return nil, err
@@ -516,8 +519,8 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) {
return nil, err
}
scheme = "http"
if s := r.getBuildScheme(); s != "" {
scheme = s
if r.buildScheme != "" {
scheme = r.buildScheme
}
}
if r.regexp.path != nil {
@@ -547,7 +550,7 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.host == nil {
if r.regexp.host == nil {
return nil, errors.New("mux: route doesn't have a host")
}
values, err := r.prepareVars(pairs...)
@@ -562,8 +565,8 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
Scheme: "http",
Host: host,
}
if s := r.getBuildScheme(); s != "" {
u.Scheme = s
if r.buildScheme != "" {
u.Scheme = r.buildScheme
}
return u, nil
}
@@ -575,7 +578,7 @@ func (r *Route) URLPath(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.path == nil {
if r.regexp.path == nil {
return nil, errors.New("mux: route doesn't have a path")
}
values, err := r.prepareVars(pairs...)
@@ -600,7 +603,7 @@ func (r *Route) GetPathTemplate() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp == nil || r.regexp.path == nil {
if r.regexp.path == nil {
return "", errors.New("mux: route doesn't have a path")
}
return r.regexp.path.template, nil
@@ -614,7 +617,7 @@ func (r *Route) GetPathRegexp() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp == nil || r.regexp.path == nil {
if r.regexp.path == nil {
return "", errors.New("mux: route does not have a path")
}
return r.regexp.path.regexp.String(), nil
@@ -629,7 +632,7 @@ func (r *Route) GetQueriesRegexp() ([]string, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.queries == nil {
if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries")
}
var queries []string
@@ -648,7 +651,7 @@ func (r *Route) GetQueriesTemplates() ([]string, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.queries == nil {
if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries")
}
var queries []string
@@ -683,7 +686,7 @@ func (r *Route) GetHostTemplate() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp == nil || r.regexp.host == nil {
if r.regexp.host == nil {
return "", errors.New("mux: route doesn't have a host")
}
return r.regexp.host.template, nil
@@ -700,64 +703,8 @@ func (r *Route) prepareVars(pairs ...string) (map[string]string, error) {
}
func (r *Route) buildVars(m map[string]string) map[string]string {
if r.parent != nil {
m = r.parent.buildVars(m)
}
if r.buildVarsFunc != nil {
m = r.buildVarsFunc(m)
}
return m
}
// ----------------------------------------------------------------------------
// parentRoute
// ----------------------------------------------------------------------------
// parentRoute allows routes to know about parent host and path definitions.
type parentRoute interface {
getBuildScheme() string
getNamedRoutes() map[string]*Route
getRegexpGroup() *routeRegexpGroup
buildVars(map[string]string) map[string]string
}
func (r *Route) getBuildScheme() string {
if r.buildScheme != "" {
return r.buildScheme
}
if r.parent != nil {
return r.parent.getBuildScheme()
}
return ""
}
// getNamedRoutes returns the map where named routes are registered.
func (r *Route) getNamedRoutes() map[string]*Route {
if r.parent == nil {
// During tests router is not always set.
r.parent = NewRouter()
}
return r.parent.getNamedRoutes()
}
// getRegexpGroup returns regexp definitions from this route.
func (r *Route) getRegexpGroup() *routeRegexpGroup {
if r.regexp == nil {
if r.parent == nil {
// During tests router is not always set.
r.parent = NewRouter()
}
regexp := r.parent.getRegexpGroup()
if regexp == nil {
r.regexp = new(routeRegexpGroup)
} else {
// Copy.
r.regexp = &routeRegexpGroup{
host: regexp.host,
path: regexp.path,
queries: regexp.queries,
}
}
}
return r.regexp
}