transfer.sh/vendor/github.com/google/martian/context.go
2019-03-17 20:19:56 +01:00

312 lines
7 KiB
Go

// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package martian
import (
"bufio"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/http"
"sync"
)
// Context provides information and storage for a single request/response pair.
// Contexts are linked to shared session that is used for multiple requests on
// a single connection.
type Context struct {
session *Session
id string
mu sync.RWMutex
vals map[string]interface{}
skipRoundTrip bool
skipLogging bool
apiRequest bool
}
// Session provides information and storage about a connection.
type Session struct {
mu sync.RWMutex
id string
secure bool
hijacked bool
conn net.Conn
brw *bufio.ReadWriter
vals map[string]interface{}
}
var (
ctxmu sync.RWMutex
ctxs = make(map[*http.Request]*Context)
)
// NewContext returns a context for the in-flight HTTP request.
func NewContext(req *http.Request) *Context {
ctxmu.RLock()
defer ctxmu.RUnlock()
return ctxs[req]
}
// TestContext builds a new session and associated context and returns the
// context and a function to remove the associated context. If it fails to
// generate either a new session or a new context it will return an error.
// Intended for tests only.
func TestContext(req *http.Request, conn net.Conn, bw *bufio.ReadWriter) (ctx *Context, remove func(), err error) {
ctxmu.Lock()
defer ctxmu.Unlock()
ctx, ok := ctxs[req]
if ok {
return ctx, func() { unlink(req) }, nil
}
s, err := newSession(conn, bw)
if err != nil {
return nil, nil, err
}
ctx, err = withSession(s)
if err != nil {
return nil, nil, err
}
ctxs[req] = ctx
return ctx, func() { unlink(req) }, nil
}
// ID returns the session ID.
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}
// IsSecure returns whether the current session is from a secure connection,
// such as when receiving requests from a TLS connection that has been MITM'd.
func (s *Session) IsSecure() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.secure
}
// MarkSecure marks the session as secure.
func (s *Session) MarkSecure() {
s.mu.Lock()
defer s.mu.Unlock()
s.secure = true
}
// MarkInsecure marks the session as insecure.
func (s *Session) MarkInsecure() {
s.mu.Lock()
defer s.mu.Unlock()
s.secure = false
}
// Hijack takes control of the connection from the proxy. No further action
// will be taken by the proxy and the connection will be closed following the
// return of the hijacker.
func (s *Session) Hijack() (net.Conn, *bufio.ReadWriter, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.hijacked {
return nil, nil, fmt.Errorf("martian: session has already been hijacked")
}
s.hijacked = true
return s.conn, s.brw, nil
}
// Hijacked returns whether the connection has been hijacked.
func (s *Session) Hijacked() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.hijacked
}
// setConn resets the underlying connection and bufio.ReadWriter of the
// session. Used by the proxy when the connection is upgraded to TLS.
func (s *Session) setConn(conn net.Conn, brw *bufio.ReadWriter) {
s.mu.Lock()
defer s.mu.Unlock()
s.conn = conn
s.brw = brw
}
// Get takes key and returns the associated value from the session.
func (s *Session) Get(key string) (interface{}, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.vals[key]
return val, ok
}
// Set takes a key and associates it with val in the session. The value is
// persisted for the entire session across multiple requests and responses.
func (s *Session) Set(key string, val interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
s.vals[key] = val
}
// Session returns the session for the context.
func (ctx *Context) Session() *Session {
return ctx.session
}
// ID returns the context ID.
func (ctx *Context) ID() string {
return ctx.id
}
// Get takes key and returns the associated value from the context.
func (ctx *Context) Get(key string) (interface{}, bool) {
ctx.mu.RLock()
defer ctx.mu.RUnlock()
val, ok := ctx.vals[key]
return val, ok
}
// Set takes a key and associates it with val in the context. The value is
// persisted for the duration of the request and is removed on the following
// request.
func (ctx *Context) Set(key string, val interface{}) {
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.vals[key] = val
}
// SkipRoundTrip skips the round trip for the current request.
func (ctx *Context) SkipRoundTrip() {
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.skipRoundTrip = true
}
// SkippingRoundTrip returns whether the current round trip will be skipped.
func (ctx *Context) SkippingRoundTrip() bool {
ctx.mu.RLock()
defer ctx.mu.RUnlock()
return ctx.skipRoundTrip
}
// SkipLogging skips logging by Martian loggers for the current request.
func (ctx *Context) SkipLogging() {
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.skipLogging = true
}
// SkippingLogging returns whether the current request / response pair will be logged.
func (ctx *Context) SkippingLogging() bool {
ctx.mu.RLock()
defer ctx.mu.RUnlock()
return ctx.skipLogging
}
// APIRequest marks the requests as a request to the proxy API.
func (ctx *Context) APIRequest() {
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.apiRequest = true
}
// IsAPIRequest returns true when the request patterns matches a pattern in the proxy
// mux. The mux is usually defined as a parameter to the api.Forwarder, which uses
// http.DefaultServeMux by default.
func (ctx *Context) IsAPIRequest() bool {
ctx.mu.RLock()
defer ctx.mu.RUnlock()
return ctx.apiRequest
}
// newID creates a new 16 character random hex ID; note these are not UUIDs.
func newID() (string, error) {
src := make([]byte, 8)
if _, err := rand.Read(src); err != nil {
return "", err
}
return hex.EncodeToString(src), nil
}
// link associates the context with request.
func link(req *http.Request, ctx *Context) {
ctxmu.Lock()
defer ctxmu.Unlock()
ctxs[req] = ctx
}
// unlink removes the context for request.
func unlink(req *http.Request) {
ctxmu.Lock()
defer ctxmu.Unlock()
delete(ctxs, req)
}
// newSession builds a new session.
func newSession(conn net.Conn, brw *bufio.ReadWriter) (*Session, error) {
sid, err := newID()
if err != nil {
return nil, err
}
return &Session{
id: sid,
conn: conn,
brw: brw,
vals: make(map[string]interface{}),
}, nil
}
// withSession builds a new context from an existing session. Session must be
// non-nil.
func withSession(s *Session) (*Context, error) {
cid, err := newID()
if err != nil {
return nil, err
}
return &Context{
session: s,
id: cid,
vals: make(map[string]interface{}),
}, nil
}