Merge pull request #169 from ejholmes/go1.7-context
Store vars and route in context.Context when go1.7+ is used
This commit is contained in:
26
context_gorilla.go
Normal file
26
context_gorilla.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// +build !go1.7
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func contextGet(r *http.Request, key interface{}) interface{} {
|
||||||
|
return context.Get(r, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextSet(r *http.Request, key, val interface{}) *http.Request {
|
||||||
|
if val == nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
context.Set(r, key, val)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextClear(r *http.Request) {
|
||||||
|
context.Clear(r)
|
||||||
|
}
|
||||||
40
context_gorilla_test.go
Normal file
40
context_gorilla_test.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
// +build !go1.7
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tests that the context is cleared or not cleared properly depending on
|
||||||
|
// the configuration of the router
|
||||||
|
func TestKeepContext(t *testing.T) {
|
||||||
|
func1 := func(w http.ResponseWriter, r *http.Request) {}
|
||||||
|
|
||||||
|
r := NewRouter()
|
||||||
|
r.HandleFunc("/", func1).Name("func1")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "http://localhost/", nil)
|
||||||
|
context.Set(req, "t", 1)
|
||||||
|
|
||||||
|
res := new(http.ResponseWriter)
|
||||||
|
r.ServeHTTP(*res, req)
|
||||||
|
|
||||||
|
if _, ok := context.GetOk(req, "t"); ok {
|
||||||
|
t.Error("Context should have been cleared at end of request")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.KeepContext = true
|
||||||
|
|
||||||
|
req, _ = http.NewRequest("GET", "http://localhost/", nil)
|
||||||
|
context.Set(req, "t", 1)
|
||||||
|
|
||||||
|
r.ServeHTTP(*res, req)
|
||||||
|
if _, ok := context.GetOk(req, "t"); !ok {
|
||||||
|
t.Error("Context should NOT have been cleared at end of request")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
24
context_native.go
Normal file
24
context_native.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// +build go1.7
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func contextGet(r *http.Request, key interface{}) interface{} {
|
||||||
|
return r.Context().Value(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextSet(r *http.Request, key, val interface{}) *http.Request {
|
||||||
|
if val == nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.WithContext(context.WithValue(r.Context(), key, val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextClear(r *http.Request) {
|
||||||
|
return
|
||||||
|
}
|
||||||
32
context_native_test.go
Normal file
32
context_native_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
// +build go1.7
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNativeContextMiddleware(t *testing.T) {
|
||||||
|
withTimeout := func(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
h.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
r := NewRouter()
|
||||||
|
r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
vars := Vars(r)
|
||||||
|
if vars["foo"] != "bar" {
|
||||||
|
t.Fatal("Expected foo var to be set")
|
||||||
|
}
|
||||||
|
})))
|
||||||
|
|
||||||
|
rec := NewRecorder()
|
||||||
|
req := newRequest("GET", "/path/bar")
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
28
mux.go
28
mux.go
@@ -10,8 +10,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"github.com/gorilla/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewRouter returns a new router instance.
|
// NewRouter returns a new router instance.
|
||||||
@@ -50,7 +48,9 @@ type Router struct {
|
|||||||
strictSlash bool
|
strictSlash bool
|
||||||
// See Router.SkipClean(). This defines the flag for new routes.
|
// See Router.SkipClean(). This defines the flag for new routes.
|
||||||
skipClean bool
|
skipClean bool
|
||||||
// If true, do not clear the request context after handling the request
|
// If true, do not clear the request context after handling the request.
|
||||||
|
// This has no effect when go1.7+ is used, since the context is stored
|
||||||
|
// on the request itself.
|
||||||
KeepContext bool
|
KeepContext bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,14 +95,14 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
var handler http.Handler
|
var handler http.Handler
|
||||||
if r.Match(req, &match) {
|
if r.Match(req, &match) {
|
||||||
handler = match.Handler
|
handler = match.Handler
|
||||||
setVars(req, match.Vars)
|
req = setVars(req, match.Vars)
|
||||||
setCurrentRoute(req, match.Route)
|
req = setCurrentRoute(req, match.Route)
|
||||||
}
|
}
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
handler = http.NotFoundHandler()
|
handler = http.NotFoundHandler()
|
||||||
}
|
}
|
||||||
if !r.KeepContext {
|
if !r.KeepContext {
|
||||||
defer context.Clear(req)
|
defer contextClear(req)
|
||||||
}
|
}
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
}
|
}
|
||||||
@@ -325,7 +325,7 @@ const (
|
|||||||
|
|
||||||
// Vars returns the route variables for the current request, if any.
|
// Vars returns the route variables for the current request, if any.
|
||||||
func Vars(r *http.Request) map[string]string {
|
func Vars(r *http.Request) map[string]string {
|
||||||
if rv := context.Get(r, varsKey); rv != nil {
|
if rv := contextGet(r, varsKey); rv != nil {
|
||||||
return rv.(map[string]string)
|
return rv.(map[string]string)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -337,22 +337,18 @@ func Vars(r *http.Request) map[string]string {
|
|||||||
// after the handler returns, unless the KeepContext option is set on the
|
// after the handler returns, unless the KeepContext option is set on the
|
||||||
// Router.
|
// Router.
|
||||||
func CurrentRoute(r *http.Request) *Route {
|
func CurrentRoute(r *http.Request) *Route {
|
||||||
if rv := context.Get(r, routeKey); rv != nil {
|
if rv := contextGet(r, routeKey); rv != nil {
|
||||||
return rv.(*Route)
|
return rv.(*Route)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setVars(r *http.Request, val interface{}) {
|
func setVars(r *http.Request, val interface{}) *http.Request {
|
||||||
if val != nil {
|
return contextSet(r, varsKey, val)
|
||||||
context.Set(r, varsKey, val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func setCurrentRoute(r *http.Request, val interface{}) {
|
func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
|
||||||
if val != nil {
|
return contextSet(r, routeKey, val)
|
||||||
context.Set(r, routeKey, val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|||||||
32
mux_test.go
32
mux_test.go
@@ -9,8 +9,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gorilla/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *Route) GoString() string {
|
func (r *Route) GoString() string {
|
||||||
@@ -1316,36 +1314,6 @@ func testTemplate(t *testing.T, test routeTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that the context is cleared or not cleared properly depending on
|
|
||||||
// the configuration of the router
|
|
||||||
func TestKeepContext(t *testing.T) {
|
|
||||||
func1 := func(w http.ResponseWriter, r *http.Request) {}
|
|
||||||
|
|
||||||
r := NewRouter()
|
|
||||||
r.HandleFunc("/", func1).Name("func1")
|
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "http://localhost/", nil)
|
|
||||||
context.Set(req, "t", 1)
|
|
||||||
|
|
||||||
res := new(http.ResponseWriter)
|
|
||||||
r.ServeHTTP(*res, req)
|
|
||||||
|
|
||||||
if _, ok := context.GetOk(req, "t"); ok {
|
|
||||||
t.Error("Context should have been cleared at end of request")
|
|
||||||
}
|
|
||||||
|
|
||||||
r.KeepContext = true
|
|
||||||
|
|
||||||
req, _ = http.NewRequest("GET", "http://localhost/", nil)
|
|
||||||
context.Set(req, "t", 1)
|
|
||||||
|
|
||||||
r.ServeHTTP(*res, req)
|
|
||||||
if _, ok := context.GetOk(req, "t"); !ok {
|
|
||||||
t.Error("Context should NOT have been cleared at end of request")
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
type TestA301ResponseWriter struct {
|
type TestA301ResponseWriter struct {
|
||||||
hh http.Header
|
hh http.Header
|
||||||
status int
|
status int
|
||||||
|
|||||||
Reference in New Issue
Block a user