mirror of
https://github.com/dutchcoders/transfer.sh.git
synced 2025-01-15 21:20:19 +01:00
188 lines
4.3 KiB
Go
188 lines
4.3 KiB
Go
// Copyright 2015 Google Inc. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package trafficshape
|
|
|
|
import (
|
|
"errors"
|
|
"runtime"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestBucket(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
b := NewBucket(10, 10*time.Millisecond)
|
|
defer b.Close()
|
|
|
|
if got, want := b.Capacity(), int64(10); got != want {
|
|
t.Fatalf("b.Capacity(): got %d, want %d", got, want)
|
|
}
|
|
|
|
n, err := b.Fill(func(remaining int64) (int64, error) {
|
|
if want := int64(10); remaining != want {
|
|
t.Errorf("remaining: got %d, want %d", remaining, want)
|
|
}
|
|
return 5, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Fill(): got %v, want no error", err)
|
|
}
|
|
if got, want := n, int64(5); got != want {
|
|
t.Fatalf("n: got %d, want %d", got, want)
|
|
}
|
|
|
|
n, err = b.Fill(func(remaining int64) (int64, error) {
|
|
if want := int64(5); remaining != want {
|
|
t.Errorf("remaining: got %d, want %d", remaining, want)
|
|
}
|
|
return 5, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Fill(): got %v, want no error", err)
|
|
}
|
|
if got, want := n, int64(5); got != want {
|
|
t.Fatalf("n: got %d, want %d", got, want)
|
|
}
|
|
n, err = b.Fill(func(remaining int64) (int64, error) {
|
|
t.Fatal("Fill: executed func when full, want skipped")
|
|
return 0, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Fill(): got %v, want no error", err)
|
|
}
|
|
|
|
// Wait for the bucket to drain.
|
|
for {
|
|
if atomic.LoadInt64(&b.fill) == 0 {
|
|
break
|
|
}
|
|
// Allow for a goroutine switch, required for GOMAXPROCS = 1.
|
|
runtime.Gosched()
|
|
}
|
|
|
|
wanterr := errors.New("fill function error")
|
|
n, err = b.Fill(func(remaining int64) (int64, error) {
|
|
if want := int64(10); remaining != want {
|
|
t.Errorf("remaining: got %d, want %d", remaining, want)
|
|
}
|
|
return 0, wanterr
|
|
})
|
|
if err != wanterr {
|
|
t.Fatalf("Fill(): got %v, want %v", err, wanterr)
|
|
}
|
|
if got, want := n, int64(0); got != want {
|
|
t.Fatalf("n: got %d, want %d", got, want)
|
|
}
|
|
}
|
|
|
|
func TestBucketClosed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
b := NewBucket(0, time.Millisecond)
|
|
b.Close()
|
|
|
|
if _, err := b.Fill(nil); err != errFillClosedBucket {
|
|
t.Errorf("Fill(): got %v, want errFillClosedBucket", err)
|
|
}
|
|
if _, err := b.FillThrottle(nil); err != errFillClosedBucket {
|
|
t.Errorf("FillThrottle(): got %v, want errFillClosedBucket", err)
|
|
}
|
|
}
|
|
|
|
func TestBucketOverflow(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
b := NewBucket(10, 10*time.Millisecond)
|
|
defer b.Close()
|
|
|
|
n, err := b.Fill(func(remaining int64) (int64, error) {
|
|
return 11, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Fill(): got %v, want no error", err)
|
|
}
|
|
|
|
n, err = b.Fill(func(int64) (int64, error) {
|
|
t.Fatal("Fill: executed func when full, want skipped")
|
|
return 0, nil
|
|
})
|
|
if err != ErrBucketOverflow {
|
|
t.Fatalf("Fill(): got %v, want ErrBucketOverflow", err)
|
|
}
|
|
if got, want := n, int64(0); got != want {
|
|
t.Fatalf("n: got %d, want %d", got, want)
|
|
}
|
|
}
|
|
|
|
func TestBucketThrottle(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
b := NewBucket(50, 50*time.Millisecond)
|
|
defer b.Close()
|
|
|
|
closec := make(chan struct{})
|
|
errc := make(chan error, 1)
|
|
|
|
fill := func() {
|
|
for {
|
|
select {
|
|
case <-closec:
|
|
return
|
|
default:
|
|
if _, err := b.FillThrottle(func(remaining int64) (int64, error) {
|
|
if remaining < 10 {
|
|
return remaining, nil
|
|
}
|
|
return 10, nil
|
|
}); err != nil {
|
|
select {
|
|
case errc <- err:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for i := 0; i < 5; i++ {
|
|
go fill()
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
|
|
close(closec)
|
|
|
|
select {
|
|
case err := <-errc:
|
|
t.Fatalf("FillThrottle: got %v, want no error", err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestBucketFillThrottleCloseBeforeTick(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
b := NewBucket(0, time.Minute)
|
|
time.AfterFunc(time.Second, func() { b.Close() })
|
|
|
|
if _, err := b.FillThrottle(func(int64) (int64, error) {
|
|
t.Fatal("FillThrottle(): executed func after close, want skipped")
|
|
return 0, nil
|
|
}); err != errFillClosedBucket {
|
|
t.Errorf("b.FillThrottle(): got nil, want errFillClosedBucket")
|
|
}
|
|
}
|