mirror of
https://github.com/dutchcoders/transfer.sh.git
synced 2025-01-04 00:20:18 +01:00
543 lines
14 KiB
Go
543 lines
14 KiB
Go
|
// Copyright 2018, OpenCensus Authors
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
package ochttp
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/hex"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"go.opencensus.io/plugin/ochttp/propagation/b3"
|
||
|
"go.opencensus.io/plugin/ochttp/propagation/tracecontext"
|
||
|
"go.opencensus.io/trace"
|
||
|
)
|
||
|
|
||
|
type testExporter struct {
|
||
|
spans []*trace.SpanData
|
||
|
}
|
||
|
|
||
|
func (t *testExporter) ExportSpan(s *trace.SpanData) {
|
||
|
t.spans = append(t.spans, s)
|
||
|
}
|
||
|
|
||
|
type testTransport struct {
|
||
|
ch chan *http.Request
|
||
|
}
|
||
|
|
||
|
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
|
t.ch <- req
|
||
|
return nil, errors.New("noop")
|
||
|
}
|
||
|
|
||
|
type testPropagator struct{}
|
||
|
|
||
|
func (t testPropagator) SpanContextFromRequest(req *http.Request) (sc trace.SpanContext, ok bool) {
|
||
|
header := req.Header.Get("trace")
|
||
|
buf, err := hex.DecodeString(header)
|
||
|
if err != nil {
|
||
|
log.Fatalf("Cannot decode trace header: %q", header)
|
||
|
}
|
||
|
r := bytes.NewReader(buf)
|
||
|
r.Read(sc.TraceID[:])
|
||
|
r.Read(sc.SpanID[:])
|
||
|
opts, err := r.ReadByte()
|
||
|
if err != nil {
|
||
|
log.Fatalf("Cannot read trace options from trace header: %q", header)
|
||
|
}
|
||
|
sc.TraceOptions = trace.TraceOptions(opts)
|
||
|
return sc, true
|
||
|
}
|
||
|
|
||
|
func (t testPropagator) SpanContextToRequest(sc trace.SpanContext, req *http.Request) {
|
||
|
var buf bytes.Buffer
|
||
|
buf.Write(sc.TraceID[:])
|
||
|
buf.Write(sc.SpanID[:])
|
||
|
buf.WriteByte(byte(sc.TraceOptions))
|
||
|
req.Header.Set("trace", hex.EncodeToString(buf.Bytes()))
|
||
|
}
|
||
|
|
||
|
func TestTransport_RoundTrip_Race(t *testing.T) {
|
||
|
// This tests that we don't modify the request in accordance with the
|
||
|
// specification for http.RoundTripper.
|
||
|
// We attempt to trigger a race by reading the request from a separate
|
||
|
// goroutine. If the request is modified by Transport, this should trigger
|
||
|
// the race detector.
|
||
|
|
||
|
transport := &testTransport{ch: make(chan *http.Request, 1)}
|
||
|
rt := &Transport{
|
||
|
Propagation: &testPropagator{},
|
||
|
Base: transport,
|
||
|
}
|
||
|
req, _ := http.NewRequest("GET", "http://foo.com", nil)
|
||
|
go func() {
|
||
|
fmt.Println(*req)
|
||
|
}()
|
||
|
rt.RoundTrip(req)
|
||
|
_ = <-transport.ch
|
||
|
}
|
||
|
|
||
|
func TestTransport_RoundTrip(t *testing.T) {
|
||
|
_, parent := trace.StartSpan(context.Background(), "parent")
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
parent *trace.Span
|
||
|
}{
|
||
|
{
|
||
|
name: "no parent",
|
||
|
parent: nil,
|
||
|
},
|
||
|
{
|
||
|
name: "parent",
|
||
|
parent: parent,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
transport := &testTransport{ch: make(chan *http.Request, 1)}
|
||
|
|
||
|
rt := &Transport{
|
||
|
Propagation: &testPropagator{},
|
||
|
Base: transport,
|
||
|
}
|
||
|
|
||
|
req, _ := http.NewRequest("GET", "http://foo.com", nil)
|
||
|
if tt.parent != nil {
|
||
|
req = req.WithContext(trace.NewContext(req.Context(), tt.parent))
|
||
|
}
|
||
|
rt.RoundTrip(req)
|
||
|
|
||
|
req = <-transport.ch
|
||
|
span := trace.FromContext(req.Context())
|
||
|
|
||
|
if header := req.Header.Get("trace"); header == "" {
|
||
|
t.Fatalf("Trace header = empty; want valid trace header")
|
||
|
}
|
||
|
if span == nil {
|
||
|
t.Fatalf("Got no spans in req context; want one")
|
||
|
}
|
||
|
if tt.parent != nil {
|
||
|
if got, want := span.SpanContext().TraceID, tt.parent.SpanContext().TraceID; got != want {
|
||
|
t.Errorf("span.SpanContext().TraceID=%v; want %v", got, want)
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestHandler(t *testing.T) {
|
||
|
traceID := [16]byte{16, 84, 69, 170, 120, 67, 188, 139, 242, 6, 177, 32, 0, 16, 0, 0}
|
||
|
tests := []struct {
|
||
|
header string
|
||
|
wantTraceID trace.TraceID
|
||
|
wantTraceOptions trace.TraceOptions
|
||
|
}{
|
||
|
{
|
||
|
header: "105445aa7843bc8bf206b12000100000000000000000000000",
|
||
|
wantTraceID: traceID,
|
||
|
wantTraceOptions: trace.TraceOptions(0),
|
||
|
},
|
||
|
{
|
||
|
header: "105445aa7843bc8bf206b12000100000000000000000000001",
|
||
|
wantTraceID: traceID,
|
||
|
wantTraceOptions: trace.TraceOptions(1),
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.header, func(t *testing.T) {
|
||
|
handler := &Handler{
|
||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
span := trace.FromContext(r.Context())
|
||
|
sc := span.SpanContext()
|
||
|
if got, want := sc.TraceID, tt.wantTraceID; got != want {
|
||
|
t.Errorf("TraceID = %q; want %q", got, want)
|
||
|
}
|
||
|
if got, want := sc.TraceOptions, tt.wantTraceOptions; got != want {
|
||
|
t.Errorf("TraceOptions = %v; want %v", got, want)
|
||
|
}
|
||
|
}),
|
||
|
StartOptions: trace.StartOptions{Sampler: trace.ProbabilitySampler(0.0)},
|
||
|
Propagation: &testPropagator{},
|
||
|
}
|
||
|
req, _ := http.NewRequest("GET", "http://foo.com", nil)
|
||
|
req.Header.Add("trace", tt.header)
|
||
|
handler.ServeHTTP(nil, req)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var _ http.RoundTripper = (*traceTransport)(nil)
|
||
|
|
||
|
type collector []*trace.SpanData
|
||
|
|
||
|
func (c *collector) ExportSpan(s *trace.SpanData) {
|
||
|
*c = append(*c, s)
|
||
|
}
|
||
|
|
||
|
func TestEndToEnd(t *testing.T) {
|
||
|
tc := []struct {
|
||
|
name string
|
||
|
handler *Handler
|
||
|
transport *Transport
|
||
|
wantSameTraceID bool
|
||
|
wantLinks bool // expect a link between client and server span
|
||
|
}{
|
||
|
{
|
||
|
name: "internal default propagation",
|
||
|
handler: &Handler{},
|
||
|
transport: &Transport{},
|
||
|
wantSameTraceID: true,
|
||
|
},
|
||
|
{
|
||
|
name: "external default propagation",
|
||
|
handler: &Handler{IsPublicEndpoint: true},
|
||
|
transport: &Transport{},
|
||
|
wantSameTraceID: false,
|
||
|
wantLinks: true,
|
||
|
},
|
||
|
{
|
||
|
name: "internal TraceContext propagation",
|
||
|
handler: &Handler{Propagation: &tracecontext.HTTPFormat{}},
|
||
|
transport: &Transport{Propagation: &tracecontext.HTTPFormat{}},
|
||
|
wantSameTraceID: true,
|
||
|
},
|
||
|
{
|
||
|
name: "misconfigured propagation",
|
||
|
handler: &Handler{IsPublicEndpoint: true, Propagation: &tracecontext.HTTPFormat{}},
|
||
|
transport: &Transport{Propagation: &b3.HTTPFormat{}},
|
||
|
wantSameTraceID: false,
|
||
|
wantLinks: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tc {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
var spans collector
|
||
|
trace.RegisterExporter(&spans)
|
||
|
defer trace.UnregisterExporter(&spans)
|
||
|
|
||
|
// Start the server.
|
||
|
serverDone := make(chan struct{})
|
||
|
serverReturn := make(chan time.Time)
|
||
|
tt.handler.StartOptions.Sampler = trace.AlwaysSample()
|
||
|
url := serveHTTP(tt.handler, serverDone, serverReturn)
|
||
|
|
||
|
ctx := context.Background()
|
||
|
// Make the request.
|
||
|
req, err := http.NewRequest(
|
||
|
http.MethodPost,
|
||
|
fmt.Sprintf("%s/example/url/path?qparam=val", url),
|
||
|
strings.NewReader("expected-request-body"))
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
req = req.WithContext(ctx)
|
||
|
tt.transport.StartOptions.Sampler = trace.AlwaysSample()
|
||
|
c := &http.Client{
|
||
|
Transport: tt.transport,
|
||
|
}
|
||
|
resp, err := c.Do(req)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if resp.StatusCode != http.StatusOK {
|
||
|
t.Fatalf("resp.StatusCode = %d", resp.StatusCode)
|
||
|
}
|
||
|
|
||
|
// Tell the server to return from request handling.
|
||
|
serverReturn <- time.Now().Add(time.Millisecond)
|
||
|
|
||
|
respBody, err := ioutil.ReadAll(resp.Body)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if got, want := string(respBody), "expected-response"; got != want {
|
||
|
t.Fatalf("respBody = %q; want %q", got, want)
|
||
|
}
|
||
|
|
||
|
resp.Body.Close()
|
||
|
|
||
|
<-serverDone
|
||
|
trace.UnregisterExporter(&spans)
|
||
|
|
||
|
if got, want := len(spans), 2; got != want {
|
||
|
t.Fatalf("len(spans) = %d; want %d", got, want)
|
||
|
}
|
||
|
|
||
|
var client, server *trace.SpanData
|
||
|
for _, sp := range spans {
|
||
|
switch sp.SpanKind {
|
||
|
case trace.SpanKindClient:
|
||
|
client = sp
|
||
|
if got, want := client.Name, "/example/url/path"; got != want {
|
||
|
t.Errorf("Span name: %q; want %q", got, want)
|
||
|
}
|
||
|
case trace.SpanKindServer:
|
||
|
server = sp
|
||
|
if got, want := server.Name, "/example/url/path"; got != want {
|
||
|
t.Errorf("Span name: %q; want %q", got, want)
|
||
|
}
|
||
|
default:
|
||
|
t.Fatalf("server or client span missing; kind = %v", sp.SpanKind)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if tt.wantSameTraceID {
|
||
|
if server.TraceID != client.TraceID {
|
||
|
t.Errorf("TraceID does not match: server.TraceID=%q client.TraceID=%q", server.TraceID, client.TraceID)
|
||
|
}
|
||
|
if !server.HasRemoteParent {
|
||
|
t.Errorf("server span should have remote parent")
|
||
|
}
|
||
|
if server.ParentSpanID != client.SpanID {
|
||
|
t.Errorf("server span should have client span as parent")
|
||
|
}
|
||
|
}
|
||
|
if !tt.wantSameTraceID {
|
||
|
if server.TraceID == client.TraceID {
|
||
|
t.Errorf("TraceID should not be trusted")
|
||
|
}
|
||
|
}
|
||
|
if tt.wantLinks {
|
||
|
if got, want := len(server.Links), 1; got != want {
|
||
|
t.Errorf("len(server.Links) = %d; want %d", got, want)
|
||
|
} else {
|
||
|
link := server.Links[0]
|
||
|
if got, want := link.Type, trace.LinkTypeParent; got != want {
|
||
|
t.Errorf("link.Type = %v; want %v", got, want)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if server.StartTime.Before(client.StartTime) {
|
||
|
t.Errorf("server span starts before client span")
|
||
|
}
|
||
|
if server.EndTime.After(client.EndTime) {
|
||
|
t.Errorf("client span ends before server span")
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func serveHTTP(handler *Handler, done chan struct{}, wait chan time.Time) string {
|
||
|
handler.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
w.WriteHeader(200)
|
||
|
w.(http.Flusher).Flush()
|
||
|
|
||
|
// Simulate a slow-responding server.
|
||
|
sleepUntil := <-wait
|
||
|
for time.Now().Before(sleepUntil) {
|
||
|
time.Sleep(sleepUntil.Sub(time.Now()))
|
||
|
}
|
||
|
|
||
|
io.WriteString(w, "expected-response")
|
||
|
close(done)
|
||
|
})
|
||
|
server := httptest.NewServer(handler)
|
||
|
go func() {
|
||
|
<-done
|
||
|
server.Close()
|
||
|
}()
|
||
|
return server.URL
|
||
|
}
|
||
|
|
||
|
func TestSpanNameFromURL(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
u string
|
||
|
want string
|
||
|
}{
|
||
|
{
|
||
|
u: "http://localhost:80/hello?q=a",
|
||
|
want: "/hello",
|
||
|
},
|
||
|
{
|
||
|
u: "/a/b?q=c",
|
||
|
want: "/a/b",
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.u, func(t *testing.T) {
|
||
|
req, err := http.NewRequest("GET", tt.u, nil)
|
||
|
if err != nil {
|
||
|
t.Errorf("url issue = %v", err)
|
||
|
}
|
||
|
if got := spanNameFromURL(req); got != tt.want {
|
||
|
t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestFormatSpanName(t *testing.T) {
|
||
|
formatSpanName := func(r *http.Request) string {
|
||
|
return r.Method + " " + r.URL.Path
|
||
|
}
|
||
|
|
||
|
handler := &Handler{
|
||
|
Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
||
|
resp.Write([]byte("Hello, world!"))
|
||
|
}),
|
||
|
FormatSpanName: formatSpanName,
|
||
|
}
|
||
|
|
||
|
server := httptest.NewServer(handler)
|
||
|
defer server.Close()
|
||
|
|
||
|
client := &http.Client{
|
||
|
Transport: &Transport{
|
||
|
FormatSpanName: formatSpanName,
|
||
|
StartOptions: trace.StartOptions{
|
||
|
Sampler: trace.AlwaysSample(),
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
tests := []struct {
|
||
|
u string
|
||
|
want string
|
||
|
}{
|
||
|
{
|
||
|
u: "/hello?q=a",
|
||
|
want: "GET /hello",
|
||
|
},
|
||
|
{
|
||
|
u: "/a/b?q=c",
|
||
|
want: "GET /a/b",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.u, func(t *testing.T) {
|
||
|
var te testExporter
|
||
|
trace.RegisterExporter(&te)
|
||
|
res, err := client.Get(server.URL + tt.u)
|
||
|
if err != nil {
|
||
|
t.Fatalf("error creating request: %v", err)
|
||
|
}
|
||
|
res.Body.Close()
|
||
|
trace.UnregisterExporter(&te)
|
||
|
if want, got := 2, len(te.spans); want != got {
|
||
|
t.Fatalf("got exported spans %#v, wanted two spans", te.spans)
|
||
|
}
|
||
|
if got := te.spans[0].Name; got != tt.want {
|
||
|
t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want)
|
||
|
}
|
||
|
if got := te.spans[1].Name; got != tt.want {
|
||
|
t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRequestAttributes(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
makeReq func() *http.Request
|
||
|
wantAttrs []trace.Attribute
|
||
|
}{
|
||
|
{
|
||
|
name: "GET example.com/hello",
|
||
|
makeReq: func() *http.Request {
|
||
|
req, _ := http.NewRequest("GET", "http://example.com:779/hello", nil)
|
||
|
req.Header.Add("User-Agent", "ua")
|
||
|
return req
|
||
|
},
|
||
|
wantAttrs: []trace.Attribute{
|
||
|
trace.StringAttribute("http.path", "/hello"),
|
||
|
trace.StringAttribute("http.host", "example.com:779"),
|
||
|
trace.StringAttribute("http.method", "GET"),
|
||
|
trace.StringAttribute("http.user_agent", "ua"),
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
req := tt.makeReq()
|
||
|
attrs := requestAttrs(req)
|
||
|
|
||
|
if got, want := attrs, tt.wantAttrs; !reflect.DeepEqual(got, want) {
|
||
|
t.Errorf("Request attributes = %#v; want %#v", got, want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestResponseAttributes(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
resp *http.Response
|
||
|
wantAttrs []trace.Attribute
|
||
|
}{
|
||
|
{
|
||
|
name: "non-zero HTTP 200 response",
|
||
|
resp: &http.Response{StatusCode: 200},
|
||
|
wantAttrs: []trace.Attribute{
|
||
|
trace.Int64Attribute("http.status_code", 200),
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "zero HTTP 500 response",
|
||
|
resp: &http.Response{StatusCode: 500},
|
||
|
wantAttrs: []trace.Attribute{
|
||
|
trace.Int64Attribute("http.status_code", 500),
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
attrs := responseAttrs(tt.resp)
|
||
|
if got, want := attrs, tt.wantAttrs; !reflect.DeepEqual(got, want) {
|
||
|
t.Errorf("Response attributes = %#v; want %#v", got, want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStatusUnitTest(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
in int
|
||
|
want trace.Status
|
||
|
}{
|
||
|
{200, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
|
||
|
{204, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
|
||
|
{100, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
|
||
|
{500, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
|
||
|
{404, trace.Status{Code: trace.StatusCodeNotFound, Message: `NOT_FOUND`}},
|
||
|
{600, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
|
||
|
{401, trace.Status{Code: trace.StatusCodeUnauthenticated, Message: `UNAUTHENTICATED`}},
|
||
|
{403, trace.Status{Code: trace.StatusCodePermissionDenied, Message: `PERMISSION_DENIED`}},
|
||
|
{301, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
|
||
|
{501, trace.Status{Code: trace.StatusCodeUnimplemented, Message: `UNIMPLEMENTED`}},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
got, want := TraceStatus(tt.in, ""), tt.want
|
||
|
if got != want {
|
||
|
t.Errorf("status(%d) got = (%#v) want = (%#v)", tt.in, got, want)
|
||
|
}
|
||
|
}
|
||
|
}
|