mirror of
https://github.com/dutchcoders/transfer.sh.git
synced 2025-01-14 20:50:19 +01:00
280 lines
6.9 KiB
Go
280 lines
6.9 KiB
Go
// Copyright 2016 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package gensupport
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
type unexpectedReader struct{}
|
|
|
|
func (unexpectedReader) Read([]byte) (int, error) {
|
|
return 0, fmt.Errorf("unexpected read in test")
|
|
}
|
|
|
|
// event is an expected request/response pair
|
|
type event struct {
|
|
// the byte range header that should be present in a request.
|
|
byteRange string
|
|
// the http status code to send in response.
|
|
responseStatus int
|
|
}
|
|
|
|
// interruptibleTransport is configured with a canned set of requests/responses.
|
|
// It records the incoming data, unless the corresponding event is configured to return
|
|
// http.StatusServiceUnavailable.
|
|
type interruptibleTransport struct {
|
|
events []event
|
|
buf []byte
|
|
bodies bodyTracker
|
|
}
|
|
|
|
// bodyTracker keeps track of response bodies that have not been closed.
|
|
type bodyTracker map[io.ReadCloser]struct{}
|
|
|
|
func (bt bodyTracker) Add(body io.ReadCloser) {
|
|
bt[body] = struct{}{}
|
|
}
|
|
|
|
func (bt bodyTracker) Close(body io.ReadCloser) {
|
|
delete(bt, body)
|
|
}
|
|
|
|
type trackingCloser struct {
|
|
io.Reader
|
|
tracker bodyTracker
|
|
}
|
|
|
|
func (tc *trackingCloser) Close() error {
|
|
tc.tracker.Close(tc)
|
|
return nil
|
|
}
|
|
|
|
func (tc *trackingCloser) Open() {
|
|
tc.tracker.Add(tc)
|
|
}
|
|
|
|
func (t *interruptibleTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
ev := t.events[0]
|
|
t.events = t.events[1:]
|
|
if got, want := req.Header.Get("Content-Range"), ev.byteRange; got != want {
|
|
return nil, fmt.Errorf("byte range: got %s; want %s", got, want)
|
|
}
|
|
|
|
if ev.responseStatus != http.StatusServiceUnavailable {
|
|
buf, err := ioutil.ReadAll(req.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading from request data: %v", err)
|
|
}
|
|
t.buf = append(t.buf, buf...)
|
|
}
|
|
|
|
tc := &trackingCloser{unexpectedReader{}, t.bodies}
|
|
tc.Open()
|
|
h := http.Header{}
|
|
status := ev.responseStatus
|
|
|
|
// Support "X-GUploader-No-308" like Google:
|
|
if status == 308 && req.Header.Get("X-GUploader-No-308") == "yes" {
|
|
status = 200
|
|
h.Set("X-Http-Status-Code-Override", "308")
|
|
}
|
|
|
|
res := &http.Response{
|
|
StatusCode: status,
|
|
Header: h,
|
|
Body: tc,
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// progressRecorder records updates, and calls f for every invocation of ProgressUpdate.
|
|
type progressRecorder struct {
|
|
updates []int64
|
|
f func()
|
|
}
|
|
|
|
func (pr *progressRecorder) ProgressUpdate(current int64) {
|
|
pr.updates = append(pr.updates, current)
|
|
if pr.f != nil {
|
|
pr.f()
|
|
}
|
|
}
|
|
|
|
func TestInterruptedTransferChunks(t *testing.T) {
|
|
type testCase struct {
|
|
data string
|
|
chunkSize int
|
|
events []event
|
|
wantProgress []int64
|
|
}
|
|
|
|
for _, tc := range []testCase{
|
|
{
|
|
data: strings.Repeat("a", 300),
|
|
chunkSize: 90,
|
|
events: []event{
|
|
{"bytes 0-89/*", http.StatusServiceUnavailable},
|
|
{"bytes 0-89/*", 308},
|
|
{"bytes 90-179/*", 308},
|
|
{"bytes 180-269/*", http.StatusServiceUnavailable},
|
|
{"bytes 180-269/*", 308},
|
|
{"bytes 270-299/300", 200},
|
|
},
|
|
|
|
wantProgress: []int64{90, 180, 270, 300},
|
|
},
|
|
{
|
|
data: strings.Repeat("a", 20),
|
|
chunkSize: 10,
|
|
events: []event{
|
|
{"bytes 0-9/*", http.StatusServiceUnavailable},
|
|
{"bytes 0-9/*", 308},
|
|
{"bytes 10-19/*", http.StatusServiceUnavailable},
|
|
{"bytes 10-19/*", 308},
|
|
// 0 byte final request demands a byte range with leading asterix.
|
|
{"bytes */20", http.StatusServiceUnavailable},
|
|
{"bytes */20", 200},
|
|
},
|
|
|
|
wantProgress: []int64{10, 20},
|
|
},
|
|
} {
|
|
media := strings.NewReader(tc.data)
|
|
|
|
tr := &interruptibleTransport{
|
|
buf: make([]byte, 0, len(tc.data)),
|
|
events: tc.events,
|
|
bodies: bodyTracker{},
|
|
}
|
|
|
|
pr := progressRecorder{}
|
|
rx := &ResumableUpload{
|
|
Client: &http.Client{Transport: tr},
|
|
Media: NewMediaBuffer(media, tc.chunkSize),
|
|
MediaType: "text/plain",
|
|
Callback: pr.ProgressUpdate,
|
|
Backoff: NoPauseStrategy,
|
|
}
|
|
res, err := rx.Upload(context.Background())
|
|
if err == nil {
|
|
res.Body.Close()
|
|
}
|
|
if err != nil || res == nil || res.StatusCode != http.StatusOK {
|
|
if res == nil {
|
|
t.Errorf("Upload not successful, res=nil: %v", err)
|
|
} else {
|
|
t.Errorf("Upload not successful, statusCode=%v: %v", res.StatusCode, err)
|
|
}
|
|
}
|
|
if !reflect.DeepEqual(tr.buf, []byte(tc.data)) {
|
|
t.Errorf("transferred contents:\ngot %s\nwant %s", tr.buf, tc.data)
|
|
}
|
|
|
|
if !reflect.DeepEqual(pr.updates, tc.wantProgress) {
|
|
t.Errorf("progress updates: got %v, want %v", pr.updates, tc.wantProgress)
|
|
}
|
|
|
|
if len(tr.events) > 0 {
|
|
t.Errorf("did not observe all expected events. leftover events: %v", tr.events)
|
|
}
|
|
if len(tr.bodies) > 0 {
|
|
t.Errorf("unclosed request bodies: %v", tr.bodies)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCancelUploadFast(t *testing.T) {
|
|
const (
|
|
chunkSize = 90
|
|
mediaSize = 300
|
|
)
|
|
media := strings.NewReader(strings.Repeat("a", mediaSize))
|
|
|
|
tr := &interruptibleTransport{
|
|
buf: make([]byte, 0, mediaSize),
|
|
}
|
|
|
|
pr := progressRecorder{}
|
|
rx := &ResumableUpload{
|
|
Client: &http.Client{Transport: tr},
|
|
Media: NewMediaBuffer(media, chunkSize),
|
|
MediaType: "text/plain",
|
|
Callback: pr.ProgressUpdate,
|
|
Backoff: NoPauseStrategy,
|
|
}
|
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
|
cancelFunc() // stop the upload that hasn't started yet
|
|
res, err := rx.Upload(ctx)
|
|
if err != context.Canceled {
|
|
t.Errorf("Upload err: got: %v; want: context cancelled", err)
|
|
}
|
|
if res != nil {
|
|
t.Errorf("Upload result: got: %v; want: nil", res)
|
|
}
|
|
if pr.updates != nil {
|
|
t.Errorf("progress updates: got %v; want: nil", pr.updates)
|
|
}
|
|
}
|
|
|
|
func TestCancelUpload(t *testing.T) {
|
|
const (
|
|
chunkSize = 90
|
|
mediaSize = 300
|
|
)
|
|
media := strings.NewReader(strings.Repeat("a", mediaSize))
|
|
|
|
tr := &interruptibleTransport{
|
|
buf: make([]byte, 0, mediaSize),
|
|
events: []event{
|
|
{"bytes 0-89/*", http.StatusServiceUnavailable},
|
|
{"bytes 0-89/*", 308},
|
|
{"bytes 90-179/*", 308},
|
|
{"bytes 180-269/*", 308}, // Upload should be cancelled before this event.
|
|
},
|
|
bodies: bodyTracker{},
|
|
}
|
|
|
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
|
numUpdates := 0
|
|
|
|
pr := progressRecorder{f: func() {
|
|
numUpdates++
|
|
if numUpdates >= 2 {
|
|
cancelFunc()
|
|
}
|
|
}}
|
|
|
|
rx := &ResumableUpload{
|
|
Client: &http.Client{Transport: tr},
|
|
Media: NewMediaBuffer(media, chunkSize),
|
|
MediaType: "text/plain",
|
|
Callback: pr.ProgressUpdate,
|
|
Backoff: NoPauseStrategy,
|
|
}
|
|
res, err := rx.Upload(ctx)
|
|
if err != context.Canceled {
|
|
t.Errorf("Upload err: got: %v; want: context cancelled", err)
|
|
}
|
|
if res != nil {
|
|
t.Errorf("Upload result: got: %v; want: nil", res)
|
|
}
|
|
if got, want := tr.buf, []byte(strings.Repeat("a", chunkSize*2)); !reflect.DeepEqual(got, want) {
|
|
t.Errorf("transferred contents:\ngot %s\nwant %s", got, want)
|
|
}
|
|
if got, want := pr.updates, []int64{chunkSize, chunkSize * 2}; !reflect.DeepEqual(got, want) {
|
|
t.Errorf("progress updates: got %v; want: %v", got, want)
|
|
}
|
|
if len(tr.bodies) > 0 {
|
|
t.Errorf("unclosed request bodies: %v", tr.bodies)
|
|
}
|
|
}
|