mirror of
https://github.com/dutchcoders/transfer.sh.git
synced 2024-12-11 21:30:19 +01:00
118 lines
2.2 KiB
Go
118 lines
2.2 KiB
Go
|
package ratelimit
|
||
|
|
||
|
import (
|
||
|
"net/http"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
func DownloadSpeed(keyFn KeyFn) *downloadBuilder {
|
||
|
return &downloadBuilder{
|
||
|
keyFn: keyFn,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type downloadBuilder struct {
|
||
|
keyFn KeyFn
|
||
|
rate int
|
||
|
window time.Duration
|
||
|
}
|
||
|
|
||
|
func (b *downloadBuilder) Rate(rate int, window time.Duration) *downloadBuilder {
|
||
|
b.rate = rate
|
||
|
b.window = window
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
// TODO: Custom burst?
|
||
|
// func (b *downloadBuilder) Burst(burst int) *downloadBuilder {}
|
||
|
|
||
|
func (b *downloadBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler {
|
||
|
store.InitRate(b.rate, b.window)
|
||
|
for _, store := range fallbackStores {
|
||
|
store.InitRate(b.rate, b.window)
|
||
|
}
|
||
|
|
||
|
downloadLimiter := downloadLimiter{
|
||
|
downloadBuilder: b,
|
||
|
store: store,
|
||
|
fallbackStores: fallbackStores,
|
||
|
}
|
||
|
|
||
|
return func(next http.Handler) http.Handler {
|
||
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||
|
key := downloadLimiter.keyFn(r)
|
||
|
if key == "" {
|
||
|
next.ServeHTTP(w, r)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
lw := &limitWriter{
|
||
|
ResponseWriter: w,
|
||
|
downloadLimiter: &downloadLimiter,
|
||
|
key: key,
|
||
|
}
|
||
|
|
||
|
next.ServeHTTP(lw, r)
|
||
|
}
|
||
|
return http.HandlerFunc(fn)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type downloadLimiter struct {
|
||
|
*downloadBuilder
|
||
|
|
||
|
next http.Handler
|
||
|
store TokenBucketStore
|
||
|
fallbackStores []TokenBucketStore
|
||
|
}
|
||
|
|
||
|
type limitWriter struct {
|
||
|
http.ResponseWriter
|
||
|
*downloadLimiter
|
||
|
|
||
|
key string
|
||
|
wroteHeader bool
|
||
|
canWrite int64
|
||
|
}
|
||
|
|
||
|
func (w *limitWriter) Write(buf []byte) (int, error) {
|
||
|
total := 0
|
||
|
for {
|
||
|
if w.canWrite < 1024 {
|
||
|
ok, _, _, err := w.downloadLimiter.store.Take("download:" + w.key)
|
||
|
if err != nil {
|
||
|
for _, store := range w.fallbackStores {
|
||
|
ok, _, _, err = store.Take("download:" + w.key)
|
||
|
if err == nil {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if err != nil {
|
||
|
return total, err
|
||
|
}
|
||
|
if ok {
|
||
|
w.canWrite += 1024
|
||
|
}
|
||
|
}
|
||
|
if w.canWrite == 0 {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
max := len(buf) - total
|
||
|
if int(w.canWrite) < max {
|
||
|
max = int(w.canWrite)
|
||
|
}
|
||
|
if max == 0 {
|
||
|
return total, nil
|
||
|
}
|
||
|
|
||
|
n, err := w.ResponseWriter.Write(buf[total : total+max])
|
||
|
w.canWrite -= int64(n)
|
||
|
total += n
|
||
|
if err != nil {
|
||
|
return total, err
|
||
|
}
|
||
|
}
|
||
|
}
|