transfer.sh/vendor/github.com/VojtechVitek/ratelimit/download.go
2017-03-28 17:26:32 +02:00

117 lines
2.2 KiB
Go

package ratelimit
import (
"net/http"
"time"
)
func DownloadSpeed(keyFn KeyFn) *downloadBuilder {
return &downloadBuilder{
keyFn: keyFn,
}
}
type downloadBuilder struct {
keyFn KeyFn
rate int
window time.Duration
}
func (b *downloadBuilder) Rate(rate int, window time.Duration) *downloadBuilder {
b.rate = rate
b.window = window
return b
}
// TODO: Custom burst?
// func (b *downloadBuilder) Burst(burst int) *downloadBuilder {}
func (b *downloadBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler {
store.InitRate(b.rate, b.window)
for _, store := range fallbackStores {
store.InitRate(b.rate, b.window)
}
downloadLimiter := downloadLimiter{
downloadBuilder: b,
store: store,
fallbackStores: fallbackStores,
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
key := downloadLimiter.keyFn(r)
if key == "" {
next.ServeHTTP(w, r)
return
}
lw := &limitWriter{
ResponseWriter: w,
downloadLimiter: &downloadLimiter,
key: key,
}
next.ServeHTTP(lw, r)
}
return http.HandlerFunc(fn)
}
}
type downloadLimiter struct {
*downloadBuilder
next http.Handler
store TokenBucketStore
fallbackStores []TokenBucketStore
}
type limitWriter struct {
http.ResponseWriter
*downloadLimiter
key string
wroteHeader bool
canWrite int64
}
func (w *limitWriter) Write(buf []byte) (int, error) {
total := 0
for {
if w.canWrite < 1024 {
ok, _, _, err := w.downloadLimiter.store.Take("download:" + w.key)
if err != nil {
for _, store := range w.fallbackStores {
ok, _, _, err = store.Take("download:" + w.key)
if err == nil {
break
}
}
}
if err != nil {
return total, err
}
if ok {
w.canWrite += 1024
}
}
if w.canWrite == 0 {
continue
}
max := len(buf) - total
if int(w.canWrite) < max {
max = int(w.canWrite)
}
if max == 0 {
return total, nil
}
n, err := w.ResponseWriter.Write(buf[total : total+max])
w.canWrite -= int64(n)
total += n
if err != nil {
return total, err
}
}
}