transfer.sh/vendor/github.com/google/martian/proxy_test.go

1317 lines
34 KiB
Go
Raw Normal View History

2019-03-17 20:19:56 +01:00
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package martian
import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/google/martian/log"
"github.com/google/martian/martiantest"
"github.com/google/martian/mitm"
"github.com/google/martian/proxyutil"
)
type tempError struct{}
func (e *tempError) Error() string { return "temporary" }
func (e *tempError) Timeout() bool { return true }
func (e *tempError) Temporary() bool { return true }
type timeoutListener struct {
net.Listener
errCount int
err error
}
func newTimeoutListener(l net.Listener, errCount int) net.Listener {
return &timeoutListener{
Listener: l,
errCount: errCount,
err: &tempError{},
}
}
func (l *timeoutListener) Accept() (net.Conn, error) {
if l.errCount > 0 {
l.errCount--
return nil, l.err
}
return l.Listener.Accept()
}
func TestIntegrationTemporaryTimeout(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
tr := martiantest.NewTransport()
p.SetRoundTripper(tr)
p.SetTimeout(200 * time.Millisecond)
// Start the proxy with a listener that will return a temporary error on
// Accept() three times.
go p.Serve(newTimeoutListener(l, 3))
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
req.Header.Set("Connection", "close")
// GET http://example.com/ HTTP/1.1
// Host: example.com
if err := req.WriteProxy(conn); err != nil {
t.Fatalf("req.WriteProxy(): got %v, want no error", err)
}
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
defer res.Body.Close()
if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
}
func TestIntegrationHTTP(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
p.SetRequestModifier(nil)
p.SetResponseModifier(nil)
tr := martiantest.NewTransport()
p.SetRoundTripper(tr)
p.SetTimeout(200 * time.Millisecond)
tm := martiantest.NewModifier()
tm.RequestFunc(func(req *http.Request) {
ctx := NewContext(req)
ctx.Set("martian.test", "true")
})
tm.ResponseFunc(func(res *http.Response) {
ctx := NewContext(res.Request)
v, _ := ctx.Get("martian.test")
res.Header.Set("Martian-Test", v.(string))
})
p.SetRequestModifier(tm)
p.SetResponseModifier(tm)
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET http://example.com/ HTTP/1.1
// Host: example.com
if err := req.WriteProxy(conn); err != nil {
t.Fatalf("req.WriteProxy(): got %v, want no error", err)
}
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, 200; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Martian-Test"), "true"; got != want {
t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want)
}
}
func TestIntegrationHTTP100Continue(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
p.SetTimeout(2 * time.Second)
sl, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
go func() {
conn, err := sl.Accept()
if err != nil {
log.Errorf("proxy_test: failed to accept connection: %v", err)
return
}
defer conn.Close()
log.Infof("proxy_test: accepted connection: %s", conn.RemoteAddr())
req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil {
log.Errorf("proxy_test: failed to read request: %v", err)
return
}
if req.Header.Get("Expect") == "100-continue" {
log.Infof("proxy_test: received 100-continue request")
conn.Write([]byte("HTTP/1.1 100 Continue\r\n\r\n"))
log.Infof("proxy_test: sent 100-continue response")
} else {
log.Infof("proxy_test: received non 100-continue request")
res := proxyutil.NewResponse(417, nil, req)
res.Header.Set("Connection", "close")
res.Write(conn)
return
}
res := proxyutil.NewResponse(200, req.Body, req)
res.Header.Set("Connection", "close")
res.Write(conn)
log.Infof("proxy_test: sent 200 response")
}()
tm := martiantest.NewModifier()
p.SetRequestModifier(tm)
p.SetResponseModifier(tm)
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
host := sl.Addr().String()
raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+
"Host: %s\r\n"+
"Content-Length: 12\r\n"+
"Expect: 100-continue\r\n\r\n", host, host)
if _, err := conn.Write([]byte(raw)); err != nil {
t.Fatalf("conn.Write(headers): got %v, want no error", err)
}
go func() {
select {
case <-time.After(time.Second):
conn.Write([]byte("body content"))
}
}()
res, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
defer res.Body.Close()
if got, want := res.StatusCode, 200; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
}
if want := []byte("body content"); !bytes.Equal(got, want) {
t.Errorf("res.Body: got %q, want %q", got, want)
}
if !tm.RequestModified() {
t.Error("tm.RequestModified(): got false, want true")
}
if !tm.ResponseModified() {
t.Error("tm.ResponseModified(): got false, want true")
}
}
func TestIntegrationHTTPDownstreamProxy(t *testing.T) {
t.Parallel()
// Start first proxy to use as downstream.
dl, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
downstream := NewProxy()
defer downstream.Close()
dtr := martiantest.NewTransport()
dtr.Respond(299)
downstream.SetRoundTripper(dtr)
downstream.SetTimeout(600 * time.Millisecond)
go downstream.Serve(dl)
// Start second proxy as upstream proxy, will write to downstream proxy.
ul, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
upstream := NewProxy()
defer upstream.Close()
// Set upstream proxy's downstream proxy to the host:port of the first proxy.
upstream.SetDownstreamProxy(&url.URL{
Host: dl.Addr().String(),
})
upstream.SetTimeout(600 * time.Millisecond)
go upstream.Serve(ul)
// Open connection to upstream proxy.
conn, err := net.Dial("tcp", ul.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET http://example.com/ HTTP/1.1
// Host: example.com
if err := req.WriteProxy(conn); err != nil {
t.Fatalf("req.WriteProxy(): got %v, want no error", err)
}
// Response from downstream proxy.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, 299; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
}
func TestIntegrationHTTPDownstreamProxyError(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
// Set proxy's downstream proxy to invalid host:port to force failure.
p.SetDownstreamProxy(&url.URL{
Host: "[::]:0",
})
p.SetTimeout(600 * time.Millisecond)
tm := martiantest.NewModifier()
reserr := errors.New("response error")
tm.ResponseError(reserr)
p.SetResponseModifier(tm)
go p.Serve(l)
// Open connection to upstream proxy.
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// CONNECT example.com:443 HTTP/1.1
// Host: example.com
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
// Response from upstream proxy, assuming downstream proxy failed to CONNECT.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, 502; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
if got, want := res.Header["Warning"][1], reserr.Error(); !strings.Contains(got, want) {
t.Errorf("res.Header.get(%q): got %q, want to contain %q", "Warning", got, want)
}
}
func TestIntegrationTLSHandshakeErrorCallback(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
// Test TLS server.
ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
}
mc, err := mitm.NewConfig(ca, priv)
if err != nil {
t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
}
var herr error
mc.SetHandshakeErrorCallback(func(_ *http.Request, err error) { herr = fmt.Errorf("handshake error") })
p.SetMITM(mc)
tl, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("tls.Listen(): got %v, want no error", err)
}
tl = tls.NewListener(tl, mc.TLS())
go http.Serve(tl, http.HandlerFunc(
func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(200)
}))
tm := martiantest.NewModifier()
// Force the CONNECT request to dial the local TLS server.
tm.RequestFunc(func(req *http.Request) {
req.URL.Host = tl.Addr().String()
})
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// CONNECT example.com:443 HTTP/1.1
// Host: example.com
//
// Rewritten to CONNECT to host:port in CONNECT request modifier.
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
// CONNECT response after establishing tunnel.
if _, err := http.ReadResponse(bufio.NewReader(conn), req); err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
tlsconn := tls.Client(conn, &tls.Config{
ServerName: "example.com",
// Client has no cert so it will get "x509: certificate signed by unknown authority" from the
// handshake and send "remote error: bad certificate" to the server.
RootCAs: x509.NewCertPool(),
})
defer tlsconn.Close()
req, err = http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
req.Header.Set("Connection", "close")
if got, want := req.Write(tlsconn), "x509: certificate signed by unknown authority"; !strings.Contains(got.Error(), want) {
t.Fatalf("Got incorrect error from Client Handshake(), got: %v, want: %v", got, want)
}
// TODO: herr is not being asserted against. It should be pushed on to a channel
// of err, and the assertion should pull off of it and assert. That design resulted in the test
// hanging for unknown reasons.
t.Skip("skipping assertion of handshake error callback error due to mysterious deadlock")
if got, want := herr, "remote error: bad certificate"; !strings.Contains(got.Error(), want) {
t.Fatalf("Got incorrect error from Server Handshake(), got: %v, want: %v", got, want)
}
}
func TestIntegrationConnect(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
// Test TLS server.
ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
}
mc, err := mitm.NewConfig(ca, priv)
if err != nil {
t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
}
tl, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("tls.Listen(): got %v, want no error", err)
}
tl = tls.NewListener(tl, mc.TLS())
go http.Serve(tl, http.HandlerFunc(
func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(299)
}))
tm := martiantest.NewModifier()
reqerr := errors.New("request error")
reserr := errors.New("response error")
// Force the CONNECT request to dial the local TLS server.
tm.RequestFunc(func(req *http.Request) {
req.URL.Host = tl.Addr().String()
})
tm.RequestError(reqerr)
tm.ResponseError(reserr)
p.SetRequestModifier(tm)
p.SetResponseModifier(tm)
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// CONNECT example.com:443 HTTP/1.1
// Host: example.com
//
// Rewritten to CONNECT to host:port in CONNECT request modifier.
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
// CONNECT response after establishing tunnel.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, 200; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
if !tm.RequestModified() {
t.Error("tm.RequestModified(): got false, want true")
}
if !tm.ResponseModified() {
t.Error("tm.ResponseModified(): got false, want true")
}
if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
}
roots := x509.NewCertPool()
roots.AddCert(ca)
tlsconn := tls.Client(conn, &tls.Config{
ServerName: "example.com",
RootCAs: roots,
})
defer tlsconn.Close()
req, err = http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
req.Header.Set("Connection", "close")
// GET / HTTP/1.1
// Host: example.com
// Connection: close
if err := req.Write(tlsconn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
defer res.Body.Close()
if got, want := res.StatusCode, 299; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Warning"), reserr.Error(); strings.Contains(got, want) {
t.Errorf("res.Header.Get(%q): got %s, want to not contain %s", "Warning", got, want)
}
}
func TestIntegrationConnectDownstreamProxy(t *testing.T) {
t.Parallel()
// Start first proxy to use as downstream.
dl, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
downstream := NewProxy()
defer downstream.Close()
dtr := martiantest.NewTransport()
dtr.Respond(299)
downstream.SetRoundTripper(dtr)
ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
}
mc, err := mitm.NewConfig(ca, priv)
if err != nil {
t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
}
downstream.SetMITM(mc)
go downstream.Serve(dl)
// Start second proxy as upstream proxy, will CONNECT to downstream proxy.
ul, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
upstream := NewProxy()
defer upstream.Close()
// Set upstream proxy's downstream proxy to the host:port of the first proxy.
upstream.SetDownstreamProxy(&url.URL{
Host: dl.Addr().String(),
})
go upstream.Serve(ul)
// Open connection to upstream proxy.
conn, err := net.Dial("tcp", ul.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// CONNECT example.com:443 HTTP/1.1
// Host: example.com
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
// Response from downstream proxy starting MITM.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, 200; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
roots := x509.NewCertPool()
roots.AddCert(ca)
tlsconn := tls.Client(conn, &tls.Config{
// Validate the hostname.
ServerName: "example.com",
// The certificate will have been MITM'd, verify using the MITM CA
// certificate.
RootCAs: roots,
})
defer tlsconn.Close()
req, err = http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET / HTTP/1.1
// Host: example.com
if err := req.Write(tlsconn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
// Response from MITM in downstream proxy.
res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
defer res.Body.Close()
if got, want := res.StatusCode, 299; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
}
func TestIntegrationMITM(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
tr := martiantest.NewTransport()
tr.Func(func(req *http.Request) (*http.Response, error) {
res := proxyutil.NewResponse(200, nil, req)
res.Header.Set("Request-Scheme", req.URL.Scheme)
return res, nil
})
p.SetRoundTripper(tr)
p.SetTimeout(600 * time.Millisecond)
ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
}
mc, err := mitm.NewConfig(ca, priv)
if err != nil {
t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
}
p.SetMITM(mc)
tm := martiantest.NewModifier()
reqerr := errors.New("request error")
reserr := errors.New("response error")
tm.RequestError(reqerr)
tm.ResponseError(reserr)
p.SetRequestModifier(tm)
p.SetResponseModifier(tm)
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// CONNECT example.com:443 HTTP/1.1
// Host: example.com
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
// Response MITM'd from proxy.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
}
roots := x509.NewCertPool()
roots.AddCert(ca)
tlsconn := tls.Client(conn, &tls.Config{
ServerName: "example.com",
RootCAs: roots,
})
defer tlsconn.Close()
req, err = http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET / HTTP/1.1
// Host: example.com
if err := req.Write(tlsconn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
// Response from MITM proxy.
res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
defer res.Body.Close()
if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Request-Scheme"), "https"; got != want {
t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want)
}
if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
}
}
func TestIntegrationTransparentHTTP(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
tr := martiantest.NewTransport()
p.SetRoundTripper(tr)
p.SetTimeout(200 * time.Millisecond)
tm := martiantest.NewModifier()
p.SetRequestModifier(tm)
p.SetResponseModifier(tm)
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET / HTTP/1.1
// Host: www.example.com
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, 200; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
if !tm.RequestModified() {
t.Error("tm.RequestModified(): got false, want true")
}
if !tm.ResponseModified() {
t.Error("tm.ResponseModified(): got false, want true")
}
}
func TestIntegrationTransparentMITM(t *testing.T) {
t.Parallel()
ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
}
mc, err := mitm.NewConfig(ca, priv)
if err != nil {
t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
}
// Start TLS listener with config that will generate certificates based on
// SNI from connection.
//
// BUG: tls.Listen will not accept a tls.Config where Certificates is empty,
// even though it is supported by tls.Server when GetCertificate is not nil.
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
l = tls.NewListener(l, mc.TLS())
p := NewProxy()
defer p.Close()
tr := martiantest.NewTransport()
tr.Func(func(req *http.Request) (*http.Response, error) {
res := proxyutil.NewResponse(200, nil, req)
res.Header.Set("Request-Scheme", req.URL.Scheme)
return res, nil
})
p.SetRoundTripper(tr)
tm := martiantest.NewModifier()
p.SetRequestModifier(tm)
p.SetResponseModifier(tm)
go p.Serve(l)
roots := x509.NewCertPool()
roots.AddCert(ca)
tlsconn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{
// Verify the hostname is example.com.
ServerName: "example.com",
// The certificate will have been generated during MITM, so we need to
// verify it with the generated CA certificate.
RootCAs: roots,
})
if err != nil {
t.Fatalf("tls.Dial(): got %v, want no error", err)
}
defer tlsconn.Close()
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// Write Encrypted request directly, no CONNECT.
// GET / HTTP/1.1
// Host: example.com
if err := req.Write(tlsconn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
res, err := http.ReadResponse(bufio.NewReader(tlsconn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
defer res.Body.Close()
if got, want := res.StatusCode, 200; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Request-Scheme"), "https"; got != want {
t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want)
}
if !tm.RequestModified() {
t.Errorf("tm.RequestModified(): got false, want true")
}
if !tm.ResponseModified() {
t.Errorf("tm.ResponseModified(): got false, want true")
}
}
func TestIntegrationFailedRoundTrip(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
tr := martiantest.NewTransport()
trerr := errors.New("round trip error")
tr.RespondError(trerr)
p.SetRoundTripper(tr)
p.SetTimeout(200 * time.Millisecond)
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET http://example.com/ HTTP/1.1
// Host: example.com
if err := req.WriteProxy(conn); err != nil {
t.Fatalf("req.WriteProxy(): got %v, want no error", err)
}
// Response from failed round trip.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
defer res.Body.Close()
if got, want := res.StatusCode, 502; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Warning"), trerr.Error(); !strings.Contains(got, want) {
t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
}
}
func TestIntegrationSkipRoundTrip(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
// Transport will be skipped, no 500.
tr := martiantest.NewTransport()
tr.Respond(500)
p.SetRoundTripper(tr)
p.SetTimeout(200 * time.Millisecond)
tm := martiantest.NewModifier()
tm.RequestFunc(func(req *http.Request) {
ctx := NewContext(req)
ctx.SkipRoundTrip()
})
p.SetRequestModifier(tm)
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET http://example.com/ HTTP/1.1
// Host: example.com
if err := req.WriteProxy(conn); err != nil {
t.Fatalf("req.WriteProxy(): got %v, want no error", err)
}
// Response from skipped round trip.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
defer res.Body.Close()
if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
}
func TestHTTPThroughConnectWithMITM(t *testing.T) {
t.Parallel()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
p := NewProxy()
defer p.Close()
tm := martiantest.NewModifier()
tm.RequestFunc(func(req *http.Request) {
ctx := NewContext(req)
ctx.SkipRoundTrip()
if req.Method != "GET" && req.Method != "CONNECT" {
t.Errorf("unexpected method on request handler: %v", req.Method)
}
})
p.SetRequestModifier(tm)
ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
}
mc, err := mitm.NewConfig(ca, priv)
if err != nil {
t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
}
p.SetMITM(mc)
go p.Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("CONNECT", "//example.com:80", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// CONNECT example.com:80 HTTP/1.1
// Host: example.com
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
// Response skipped round trip.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
res.Body.Close()
if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
req, err = http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET http://example.com/ HTTP/1.1
// Host: example.com
if err := req.WriteProxy(conn); err != nil {
t.Fatalf("req.WriteProxy(): got %v, want no error", err)
}
// Response from skipped round trip.
res, err = http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
res.Body.Close()
if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
req, err = http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// GET http://example.com/ HTTP/1.1
// Host: example.com
if err := req.WriteProxy(conn); err != nil {
t.Fatalf("req.WriteProxy(): got %v, want no error", err)
}
// Response from skipped round trip.
res, err = http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
res.Body.Close()
if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
}
func TestServerClosesConnection(t *testing.T) {
t.Parallel()
dstl, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("Failed to create http listener: %v", err)
}
defer dstl.Close()
go func() {
t.Logf("Waiting for server side connection")
conn, err := dstl.Accept()
if err != nil {
t.Fatalf("Got error while accepting connection on destination listener: %v", err)
}
t.Logf("Accepted server side connection")
buf := make([]byte, 16384)
if _, err := conn.Read(buf); err != nil {
t.Fatalf("Error reading: %v", err)
}
_, err = conn.Write([]byte("HTTP/1.1 301 MOVED PERMANENTLY\r\n" +
"Server: \r\n" +
"Date: \r\n" +
"Referer: \r\n" +
"Location: http://www.foo.com/\r\n" +
"Content-type: text/html\r\n" +
"Connection: close\r\n\r\n"))
if err != nil {
t.Fatalf("Got error while writting to connection on destination listener: %v", err)
}
conn.Close()
}()
l, err := net.Listen("tcp", "[::]:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
}
mc, err := mitm.NewConfig(ca, priv)
if err != nil {
t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
}
p := NewProxy()
p.SetMITM(mc)
defer p.Close()
// Start the proxy with a listener that will return a temporary error on
// Accept() three times.
go p.Serve(newTimeoutListener(l, 3))
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()
req, err := http.NewRequest("CONNECT", fmt.Sprintf("//%s", dstl.Addr().String()), nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
// CONNECT example.com:443 HTTP/1.1
// Host: example.com
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
res.Body.Close()
_, err = conn.Write([]byte("GET / HTTP/1.1\r\n" +
"User-Agent: curl/7.35.0\r\n" +
fmt.Sprintf("Host: %s\r\n", dstl.Addr()) +
"Accept: */*\r\n\r\n"))
if err != nil {
t.Fatalf("Error while writing GET request: %v", err)
}
res, err = http.ReadResponse(bufio.NewReader(io.TeeReader(conn, os.Stderr)), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
_, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("error while ReadAll: %v", err)
}
defer res.Body.Close()
}