From e069bd25363dbb949332e375ac5202d86a64c716 Mon Sep 17 00:00:00 2001 From: Alexey Kostin Date: Mon, 3 Jun 2024 16:29:57 +0300 Subject: [PATCH] Add fallback to math/rand instead of crypto/rand. Bypass logger to token function. Zero logger for test functions --- server/handlers.go | 17 +++++++++-------- server/token.go | 19 ++++++++++++++----- server/token_test.go | 12 +++++++++--- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index 265aab4..79a7517 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -39,6 +39,7 @@ import ( "html" htmlTemplate "html/template" "io" + "log" "mime" "net" "net/http" @@ -395,8 +396,8 @@ func (s *Server) viewHandler(w http.ResponseWriter, r *http.Request) { s.userVoiceKey, purgeTime, maxUploadSize, - token(s.randomTokenLength), - token(s.randomTokenLength), + token(s.randomTokenLength, s.logger), + token(s.randomTokenLength, s.logger), } w.Header().Set("Vary", "Accept") @@ -428,7 +429,7 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { return } - token := token(s.randomTokenLength) + token := token(s.randomTokenLength, s.logger) w.Header().Set("Content-Type", "text/plain") @@ -493,7 +494,7 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { } } - metadata := metadataForRequest(contentType, contentLength, s.randomTokenLength, r) + metadata := metadataForRequest(contentType, contentLength, s.randomTokenLength, s.logger, r) buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { @@ -570,14 +571,14 @@ type metadata struct { DecryptedContentType string } -func metadataForRequest(contentType string, contentLength int64, randomTokenLength int, r *http.Request) metadata { +func metadataForRequest(contentType string, contentLength int64, randomTokenLength int, logger *log.Logger, r *http.Request) metadata { metadata := metadata{ ContentType: strings.ToLower(contentType), ContentLength: contentLength, MaxDate: time.Time{}, Downloads: 0, MaxDownloads: -1, - DeletionToken: token(randomTokenLength) + token(randomTokenLength), + DeletionToken: token(randomTokenLength, logger) + token(randomTokenLength, logger), } if v := r.Header.Get("Max-Downloads"); v == "" { @@ -675,9 +676,9 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { contentType := mime.TypeByExtension(filepath.Ext(vars["filename"])) - token := token(s.randomTokenLength) + token := token(s.randomTokenLength, s.logger) - metadata := metadataForRequest(contentType, contentLength, s.randomTokenLength, r) + metadata := metadataForRequest(contentType, contentLength, s.randomTokenLength, s.logger, r) buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { diff --git a/server/token.go b/server/token.go index 9bbdf81..c079caf 100644 --- a/server/token.go +++ b/server/token.go @@ -28,6 +28,7 @@ import ( "crypto/rand" "log" "math/big" + mathrand "math/rand" ) const ( @@ -36,14 +37,22 @@ const ( ) // generate a token -func token(length int) string { +func token(length int, logger *log.Logger) string { result := make([]byte, length) + var err error for i := 0; i < length; i++ { - x, err := rand.Int(rand.Reader, big.NewInt(int64(len(SYMBOLS)))) - if err != nil { - log.Fatal("Failed to generate token") + if err == nil { + var x *big.Int + x, err = rand.Int(rand.Reader, big.NewInt(int64(len(SYMBOLS)))) + if err != nil { + logger.Printf("Fallback to math/rand instead of crypto/rand due error %s", err.Error()) + x = big.NewInt(int64(mathrand.Intn(len(SYMBOLS) - 1))) + } + result[i] = SYMBOLS[x.Int64()] + } else { // fallback to math rand + x := int64(mathrand.Intn(len(SYMBOLS) - 1)) + result[i] = SYMBOLS[x] } - result[i] = SYMBOLS[x.Int64()] } return string(result) diff --git a/server/token_test.go b/server/token_test.go index cec3d79..fa55704 100644 --- a/server/token_test.go +++ b/server/token_test.go @@ -1,15 +1,21 @@ package server -import "testing" +import ( + "io" + "log" + "testing" +) + +var logger = log.New(io.Discard, "", log.LstdFlags) func BenchmarkTokenConcat(b *testing.B) { for i := 0; i < b.N; i++ { - _ = token(5) + token(5) + _ = token(5, logger) + token(5, logger) } } func BenchmarkTokenLonger(b *testing.B) { for i := 0; i < b.N; i++ { - _ = token(10) + _ = token(10, logger) } }