diff --git a/cmd/cmd.go b/cmd/cmd.go index 7255ead..759cb73 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -274,7 +274,7 @@ var globalFlags = []cli.Flag{ Value: "", EnvVar: "CORS_DOMAINS", }, - cli.Int64Flag{ + cli.IntFlag{ Name: "random-token-length", Usage: "", Value: 6, @@ -383,7 +383,7 @@ func New() *Cmd { options = append(options, server.RateLimit(v)) } - v := c.Int64("random-token-length") + v := c.Int("random-token-length") options = append(options, server.RandomTokenLength(v)) purgeDays := c.Int("purge-days") diff --git a/server/codec.go b/server/codec.go index a33f3c0..1ede6a9 100644 --- a/server/codec.go +++ b/server/codec.go @@ -36,28 +36,14 @@ const ( // someone set us up the bomb !! BASE = float64(len(SYMBOLS)) - - // init seed encode number - INIT_SEED = float64(-1) ) -// encodes a number into our *base* representation -// TODO can this be made better with some bitshifting? -func Encode(number float64, length int64) string { - if number == INIT_SEED { - seed := math.Pow(float64(BASE), float64(length)) - number = seed + (rand.Float64() * seed) // start with seed to enforce desired length - } - - rest := int64(math.Mod(number, BASE)) - // strings are a bit weird in go... - result := string(SYMBOLS[rest]) - if rest > 0 && number-float64(rest) != 0 { - newnumber := (number - float64(rest)) / BASE - result = Encode(newnumber, length) + result - } else { - // it would always be 1 because of starting with seed and we want to skip - return "" +// generate a token +func Token(length int) string { + result := "" + for i := 0; i < length; i++ { + x := rand.Intn(len(SYMBOLS)) + result = string(SYMBOLS[x]) + result } return result diff --git a/server/codec_test.go b/server/codec_test.go index aebd8ab..5ee1db2 100644 --- a/server/codec_test.go +++ b/server/codec_test.go @@ -4,12 +4,12 @@ import "testing" func BenchmarkEncodeConcat(b *testing.B) { for i := 0; i < b.N; i++ { - _ = Encode(INIT_SEED, 5) + Encode(INIT_SEED, 5) + _ = Token(5) + Token(5) } } func BenchmarkEncodeLonger(b *testing.B) { for i := 0; i < b.N; i++ { - _ = Encode(INIT_SEED, 10) + _ = Token(10) } } diff --git a/server/handlers.go b/server/handlers.go index ac20a0f..37b6fed 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -276,7 +276,7 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { return } - token := Encode(INIT_SEED, s.randomTokenLength) + token := Token(s.randomTokenLength) w.Header().Set("Content-Type", "text/plain") @@ -398,13 +398,13 @@ type Metadata struct { DeletionToken string } -func MetadataForRequest(contentType string, randomTokenLength int64, r *http.Request) Metadata { +func MetadataForRequest(contentType string, randomTokenLength int, r *http.Request) Metadata { metadata := Metadata{ ContentType: strings.ToLower(contentType), MaxDate: time.Time{}, Downloads: 0, MaxDownloads: -1, - DeletionToken: Encode(INIT_SEED, randomTokenLength) + Encode(INIT_SEED, randomTokenLength), + DeletionToken: Token(randomTokenLength) + Token(randomTokenLength), } if v := r.Header.Get("Max-Downloads"); v == "" { @@ -492,7 +492,7 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { contentType := mime.TypeByExtension(filepath.Ext(vars["filename"])) - token := Encode(INIT_SEED, s.randomTokenLength) + token := Token(s.randomTokenLength) metadata := MetadataForRequest(contentType, s.randomTokenLength, r) @@ -636,23 +636,22 @@ func (metadata Metadata) remainingLimitHeaderValues() (remainingDownloads, remai return remainingDownloads, remainingDays } -func (s *Server) Lock(token, filename string) error { +func (s *Server) Lock(token, filename string) { key := path.Join(token, filename) - if _, ok := s.locks[key]; !ok { - s.locks[key] = &sync.Mutex{} - } + lock, _ := s.locks.LoadOrStore(key, &sync.Mutex{}) - s.locks[key].Lock() + lock.(*sync.Mutex).Lock() - return nil + return } -func (s *Server) Unlock(token, filename string) error { +func (s *Server) Unlock(token, filename string) { key := path.Join(token, filename) - s.locks[key].Unlock() - return nil + lock, _ := s.locks.LoadOrStore(key, &sync.Mutex{}) + + lock.(*sync.Mutex).Unlock() } func (s *Server) CheckMetadata(token, filename string, increaseDownload bool) (Metadata, error) { diff --git a/server/server.go b/server/server.go index cf72989..ea0a0de 100644 --- a/server/server.go +++ b/server/server.go @@ -187,7 +187,7 @@ func RateLimit(requests int) OptionFn { } } -func RandomTokenLength(length int64) OptionFn { +func RandomTokenLength(length int) OptionFn { return func(srvr *Server) { srvr.randomTokenLength = length } @@ -288,7 +288,7 @@ type Server struct { profilerEnabled bool - locks map[string]*sync.Mutex + locks sync.Map maxUploadSize int64 rateLimitRequests int @@ -300,7 +300,7 @@ type Server struct { forceHTTPs bool - randomTokenLength int64 + randomTokenLength int ipFilterOptions *IPFilterOptions @@ -329,7 +329,7 @@ type Server struct { func New(options ...OptionFn) (*Server, error) { s := &Server{ - locks: map[string]*sync.Mutex{}, + locks: sync.Map{}, } for _, optionFn := range options {