transfer.sh/vendor/github.com/google/martian/proxy.go
2019-03-17 20:19:56 +01:00

588 lines
No EOL
16 KiB
Go

// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package martian
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"sync"
"time"
"github.com/google/martian/log"
"github.com/google/martian/mitm"
"github.com/google/martian/nosigpipe"
"github.com/google/martian/proxyutil"
"github.com/google/martian/trafficshape"
)
var errClose = errors.New("closing connection")
var noop = Noop("martian")
func isCloseable(err error) bool {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return true
}
switch err {
case io.EOF, io.ErrClosedPipe, errClose:
return true
}
return false
}
// Proxy is an HTTP proxy with support for TLS MITM and customizable behavior.
type Proxy struct {
roundTripper http.RoundTripper
dial func(string, string) (net.Conn, error)
timeout time.Duration
mitm *mitm.Config
proxyURL *url.URL
conns *sync.WaitGroup
closing chan bool
reqmod RequestModifier
resmod ResponseModifier
}
// NewProxy returns a new HTTP proxy.
func NewProxy() *Proxy {
proxy := &Proxy{
roundTripper: &http.Transport{
// TODO(adamtanner): This forces the http.Transport to not upgrade requests
// to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2.
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: time.Second,
},
timeout: 5 * time.Minute,
conns: &sync.WaitGroup{},
closing: make(chan bool),
reqmod: noop,
resmod: noop,
}
proxy.SetDial((&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial)
return proxy
}
// SetRoundTripper sets the http.RoundTripper of the proxy.
func (p *Proxy) SetRoundTripper(rt http.RoundTripper) {
p.roundTripper = rt
if tr, ok := p.roundTripper.(*http.Transport); ok {
tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
tr.Proxy = http.ProxyURL(p.proxyURL)
tr.Dial = p.dial
}
}
// SetDownstreamProxy sets the proxy that receives requests from the upstream
// proxy.
func (p *Proxy) SetDownstreamProxy(proxyURL *url.URL) {
p.proxyURL = proxyURL
if tr, ok := p.roundTripper.(*http.Transport); ok {
tr.Proxy = http.ProxyURL(p.proxyURL)
}
}
// SetTimeout sets the request timeout of the proxy.
func (p *Proxy) SetTimeout(timeout time.Duration) {
p.timeout = timeout
}
// SetMITM sets the config to use for MITMing of CONNECT requests.
func (p *Proxy) SetMITM(config *mitm.Config) {
p.mitm = config
}
// SetDial sets the dial func used to establish a connection.
func (p *Proxy) SetDial(dial func(string, string) (net.Conn, error)) {
p.dial = func(a, b string) (net.Conn, error) {
c, e := dial(a, b)
nosigpipe.IgnoreSIGPIPE(c)
return c, e
}
if tr, ok := p.roundTripper.(*http.Transport); ok {
tr.Dial = p.dial
}
}
// Close sets the proxy to the closing state so it stops receiving new connections,
// finishes processing any inflight requests, and closes existing connections without
// reading anymore requests from them.
func (p *Proxy) Close() {
log.Infof("martian: closing down proxy")
close(p.closing)
log.Infof("martian: waiting for connections to close")
p.conns.Wait()
log.Infof("martian: all connections closed")
}
// Closing returns whether the proxy is in the closing state.
func (p *Proxy) Closing() bool {
select {
case <-p.closing:
return true
default:
return false
}
}
// SetRequestModifier sets the request modifier.
func (p *Proxy) SetRequestModifier(reqmod RequestModifier) {
if reqmod == nil {
reqmod = noop
}
p.reqmod = reqmod
}
// SetResponseModifier sets the response modifier.
func (p *Proxy) SetResponseModifier(resmod ResponseModifier) {
if resmod == nil {
resmod = noop
}
p.resmod = resmod
}
// Serve accepts connections from the listener and handles the requests.
func (p *Proxy) Serve(l net.Listener) error {
defer l.Close()
var delay time.Duration
for {
if p.Closing() {
return nil
}
conn, err := l.Accept()
nosigpipe.IgnoreSIGPIPE(conn)
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
if delay == 0 {
delay = 5 * time.Millisecond
} else {
delay *= 2
}
if max := time.Second; delay > max {
delay = max
}
log.Debugf("martian: temporary error on accept: %v", err)
time.Sleep(delay)
continue
}
log.Errorf("martian: failed to accept: %v", err)
return err
}
delay = 0
log.Debugf("martian: accepted connection from %s", conn.RemoteAddr())
if tconn, ok := conn.(*net.TCPConn); ok {
tconn.SetKeepAlive(true)
tconn.SetKeepAlivePeriod(3 * time.Minute)
}
go p.handleLoop(conn)
}
}
func (p *Proxy) handleLoop(conn net.Conn) {
p.conns.Add(1)
defer p.conns.Done()
defer conn.Close()
brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
s, err := newSession(conn, brw)
if err != nil {
log.Errorf("martian: failed to create session: %v", err)
return
}
ctx, err := withSession(s)
if err != nil {
log.Errorf("martian: failed to create context: %v", err)
return
}
for {
deadline := time.Now().Add(p.timeout)
conn.SetDeadline(deadline)
if err := p.handle(ctx, conn, brw); isCloseable(err) {
log.Debugf("martian: closing connection: %v", conn.RemoteAddr())
return
}
}
}
func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error {
log.Debugf("martian: waiting for request: %v", conn.RemoteAddr())
var req *http.Request
reqc := make(chan *http.Request, 1)
errc := make(chan error, 1)
go func() {
r, err := http.ReadRequest(brw.Reader)
if err != nil {
errc <- err
return
}
reqc <- r
}()
select {
case err := <-errc:
if isCloseable(err) {
log.Debugf("martian: connection closed prematurely: %v", err)
} else {
log.Errorf("martian: failed to read request: %v", err)
}
// TODO: TCPConn.WriteClose() to avoid sending an RST to the client.
return errClose
case req = <-reqc:
case <-p.closing:
return errClose
}
defer req.Body.Close()
session := ctx.Session()
ctx, err := withSession(session)
if err != nil {
log.Errorf("martian: failed to build new context: %v", err)
return err
}
link(req, ctx)
defer unlink(req)
if tconn, ok := conn.(*tls.Conn); ok {
session.MarkSecure()
cs := tconn.ConnectionState()
req.TLS = &cs
}
req.URL.Scheme = "http"
if session.IsSecure() {
log.Debugf("martian: forcing HTTPS inside secure session")
req.URL.Scheme = "https"
}
req.RemoteAddr = conn.RemoteAddr().String()
if req.URL.Host == "" {
req.URL.Host = req.Host
}
if req.Method == "CONNECT" {
if err := p.reqmod.ModifyRequest(req); err != nil {
log.Errorf("martian: error modifying CONNECT request: %v", err)
proxyutil.Warning(req.Header, err)
}
if session.Hijacked() {
log.Infof("martian: connection hijacked by request modifier")
return nil
}
if p.mitm != nil {
log.Debugf("martian: attempting MITM for connection: %s", req.Host)
res := proxyutil.NewResponse(200, nil, req)
if err := p.resmod.ModifyResponse(res); err != nil {
log.Errorf("martian: error modifying CONNECT response: %v", err)
proxyutil.Warning(res.Header, err)
}
if session.Hijacked() {
log.Infof("martian: connection hijacked by response modifier")
return nil
}
if err := res.Write(brw); err != nil {
log.Errorf("martian: got error while writing response back to client: %v", err)
}
if err := brw.Flush(); err != nil {
log.Errorf("martian: got error while flushing response back to client: %v", err)
}
log.Debugf("martian: completed MITM for connection: %s", req.Host)
b := make([]byte, 1)
if _, err := brw.Read(b); err != nil {
log.Errorf("martian: error peeking message through CONNECT tunnel to determine type: %v", err)
}
// Drain all of the rest of the buffered data.
buf := make([]byte, brw.Reader.Buffered())
brw.Read(buf)
// 22 is the TLS handshake.
// https://tools.ietf.org/html/rfc5246#section-6.2.1
if b[0] == 22 {
// Prepend the previously read data to be read again by
// http.ReadRequest.
tlsconn := tls.Server(&peekedConn{conn, io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)}, p.mitm.TLSForHost(req.Host))
if err := tlsconn.Handshake(); err != nil {
p.mitm.HandshakeErrorCallback(req, err)
return err
}
var finalTLSconn net.Conn
finalTLSconn = tlsconn
// If the original connection was a traffic shaped connection, wrap the tls
// connection inside a traffic shaped connection too.
if ptsconn, ok := conn.(*trafficshape.Conn); ok {
finalTLSconn = ptsconn.Listener.GetTrafficShapedConn(tlsconn)
}
brw.Writer.Reset(finalTLSconn)
brw.Reader.Reset(finalTLSconn)
return p.handle(ctx, finalTLSconn, brw)
}
// Prepend the previously read data to be read again by http.ReadRequest.
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn))
return p.handle(ctx, conn, brw)
}
log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host)
res, cconn, cerr := p.connect(req)
if cerr != nil {
log.Errorf("martian: failed to CONNECT: %v", err)
res = proxyutil.NewResponse(502, nil, req)
proxyutil.Warning(res.Header, cerr)
if err := p.resmod.ModifyResponse(res); err != nil {
log.Errorf("martian: error modifying CONNECT response: %v", err)
proxyutil.Warning(res.Header, err)
}
if session.Hijacked() {
log.Infof("martian: connection hijacked by response modifier")
return nil
}
if err := res.Write(brw); err != nil {
log.Errorf("martian: got error while writing response back to client: %v", err)
}
err := brw.Flush()
if err != nil {
log.Errorf("martian: got error while flushing response back to client: %v", err)
}
return err
}
defer res.Body.Close()
defer cconn.Close()
if err := p.resmod.ModifyResponse(res); err != nil {
log.Errorf("martian: error modifying CONNECT response: %v", err)
proxyutil.Warning(res.Header, err)
}
if session.Hijacked() {
log.Infof("martian: connection hijacked by response modifier")
return nil
}
res.ContentLength = -1
if err := res.Write(brw); err != nil {
log.Errorf("martian: got error while writing response back to client: %v", err)
}
if err := brw.Flush(); err != nil {
log.Errorf("martian: got error while flushing response back to client: %v", err)
}
cbw := bufio.NewWriter(cconn)
cbr := bufio.NewReader(cconn)
defer cbw.Flush()
copySync := func(w io.Writer, r io.Reader, donec chan<- bool) {
if _, err := io.Copy(w, r); err != nil && err != io.EOF {
log.Errorf("martian: failed to copy CONNECT tunnel: %v", err)
}
log.Debugf("martian: CONNECT tunnel finished copying")
donec <- true
}
donec := make(chan bool, 2)
go copySync(cbw, brw, donec)
go copySync(brw, cbr, donec)
log.Debugf("martian: established CONNECT tunnel, proxying traffic")
<-donec
<-donec
log.Debugf("martian: closed CONNECT tunnel")
return errClose
}
if err := p.reqmod.ModifyRequest(req); err != nil {
log.Errorf("martian: error modifying request: %v", err)
proxyutil.Warning(req.Header, err)
}
if session.Hijacked() {
log.Infof("martian: connection hijacked by request modifier")
return nil
}
res, err := p.roundTrip(ctx, req)
if err != nil {
log.Errorf("martian: failed to round trip: %v", err)
res = proxyutil.NewResponse(502, nil, req)
proxyutil.Warning(res.Header, err)
}
defer res.Body.Close()
if err := p.resmod.ModifyResponse(res); err != nil {
log.Errorf("martian: error modifying response: %v", err)
proxyutil.Warning(res.Header, err)
}
if session.Hijacked() {
log.Infof("martian: connection hijacked by response modifier")
return nil
}
var closing error
if req.Close || res.Close || p.Closing() {
log.Debugf("martian: received close request: %v", req.RemoteAddr)
res.Close = true
closing = errClose
}
// Check if conn is a traffic shaped connection.
if ptsconn, ok := conn.(*trafficshape.Conn); ok {
ptsconn.Context = &trafficshape.Context{}
// Check if the request URL matches any URLRegex in Shapes. If so, set the connections's Context
// with the required information, so that the Write() method of the Conn has access to it.
for urlregex, buckets := range ptsconn.LocalBuckets {
if match, _ := regexp.MatchString(urlregex, req.URL.String()); match {
if rangeStart := proxyutil.GetRangeStart(res); rangeStart > -1 {
dump, err := httputil.DumpResponse(res, false)
if err != nil {
return err
}
ptsconn.Context = &trafficshape.Context{
Shaping: true,
Buckets: buckets,
GlobalBucket: ptsconn.GlobalBuckets[urlregex],
URLRegex: urlregex,
RangeStart: rangeStart,
ByteOffset: rangeStart,
HeaderLen: int64(len(dump)),
HeaderBytesWritten: 0,
}
// Get the next action to perform, if there.
ptsconn.Context.NextActionInfo = ptsconn.GetNextActionFromByte(rangeStart)
// Check if response lies in a throttled byte range.
ptsconn.Context.ThrottleContext = ptsconn.GetCurrentThrottle(rangeStart)
if ptsconn.Context.ThrottleContext.ThrottleNow {
ptsconn.Context.Buckets.WriteBucket.SetCapacity(
ptsconn.Context.ThrottleContext.Bandwidth)
}
log.Infof(
"trafficshape: Request %s with Range Start: %d matches a Shaping request %s. Will enforce Traffic shaping.",
req.URL, rangeStart, urlregex)
}
break
}
}
}
err = res.Write(brw)
if err != nil {
log.Errorf("martian: got error while writing response back to client: %v", err)
if _, ok := err.(*trafficshape.ErrForceClose); ok {
closing = errClose
}
}
err = brw.Flush()
if err != nil {
log.Errorf("martian: got error while flushing response back to client: %v", err)
if _, ok := err.(*trafficshape.ErrForceClose); ok {
closing = errClose
}
}
return closing
}
// A peekedConn subverts the net.Conn.Read implementation, primarily so that
// sniffed bytes can be transparently prepended.
type peekedConn struct {
net.Conn
r io.Reader
}
// Read allows control over the embedded net.Conn's read data. By using an
// io.MultiReader one can read from a conn, and then replace what they read, to
// be read again.
func (c *peekedConn) Read(buf []byte) (int, error) { return c.r.Read(buf) }
func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, error) {
if ctx.SkippingRoundTrip() {
log.Debugf("martian: skipping round trip")
return proxyutil.NewResponse(200, nil, req), nil
}
return p.roundTripper.RoundTrip(req)
}
func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) {
if p.proxyURL != nil {
log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host)
conn, err := p.dial("tcp", p.proxyURL.Host)
if err != nil {
return nil, nil, err
}
pbw := bufio.NewWriter(conn)
pbr := bufio.NewReader(conn)
req.Write(pbw)
pbw.Flush()
res, err := http.ReadResponse(pbr, req)
if err != nil {
return nil, nil, err
}
return res, conn, nil
}
log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host)
conn, err := p.dial("tcp", req.URL.Host)
if err != nil {
return nil, nil, err
}
return proxyutil.NewResponse(200, nil, req), conn, nil
}