This commit is contained in:
Andrea Spacca 2021-08-19 22:21:31 +02:00
parent 788dfa203f
commit a88c5ebf7a
10 changed files with 247 additions and 160 deletions

View file

@ -12,6 +12,7 @@ import (
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
) )
// Version is inject at build time
var Version = "0.0.0" var Version = "0.0.0"
var helpTemplate = `NAME: var helpTemplate = `NAME:
{{.Name}} - {{.Usage}} {{.Name}} - {{.Usage}}
@ -282,14 +283,16 @@ var globalFlags = []cli.Flag{
}, },
} }
// Cmd wraps cli.app
type Cmd struct { type Cmd struct {
*cli.App *cli.App
} }
func VersionAction(c *cli.Context) { func versionAction(c *cli.Context) {
fmt.Println(color.YellowString(fmt.Sprintf("transfer.sh %s: Easy file sharing from the command line", Version))) fmt.Println(color.YellowString(fmt.Sprintf("transfer.sh %s: Easy file sharing from the command line", Version)))
} }
// New is the factory for transfer.sh
func New() *Cmd { func New() *Cmd {
logger := log.New(os.Stdout, "[transfer.sh]", log.LstdFlags) logger := log.New(os.Stdout, "[transfer.sh]", log.LstdFlags)
@ -304,7 +307,7 @@ func New() *Cmd {
app.Commands = []cli.Command{ app.Commands = []cli.Command{
{ {
Name: "version", Name: "version",
Action: VersionAction, Action: versionAction,
}, },
} }
@ -403,13 +406,13 @@ func New() *Cmd {
} }
if c.Bool("force-https") { if c.Bool("force-https") {
options = append(options, server.ForceHTTPs()) options = append(options, server.ForceHTTPS())
} }
if httpAuthUser := c.String("http-auth-user"); httpAuthUser == "" { if httpAuthUser := c.String("http-auth-user"); httpAuthUser == "" {
} else if httpAuthPass := c.String("http-auth-pass"); httpAuthPass == "" { } else if httpAuthPass := c.String("http-auth-pass"); httpAuthPass == "" {
} else { } else {
options = append(options, server.HttpAuthCredentials(httpAuthUser, httpAuthPass)) options = append(options, server.HTTPAuthCredentials(httpAuthUser, httpAuthPass))
} }
applyIPFilter := false applyIPFilter := false
@ -445,13 +448,13 @@ func New() *Cmd {
case "gdrive": case "gdrive":
chunkSize := c.Int("gdrive-chunk-size") chunkSize := c.Int("gdrive-chunk-size")
if clientJsonFilepath := c.String("gdrive-client-json-filepath"); clientJsonFilepath == "" { if clientJSONFilepath := c.String("gdrive-client-json-filepath"); clientJSONFilepath == "" {
panic("client-json-filepath not set.") panic("client-json-filepath not set.")
} else if localConfigPath := c.String("gdrive-local-config-path"); localConfigPath == "" { } else if localConfigPath := c.String("gdrive-local-config-path"); localConfigPath == "" {
panic("local-config-path not set.") panic("local-config-path not set.")
} else if basedir := c.String("basedir"); basedir == "" { } else if basedir := c.String("basedir"); basedir == "" {
panic("basedir not set.") panic("basedir not set.")
} else if storage, err := server.NewGDriveStorage(clientJsonFilepath, localConfigPath, basedir, chunkSize, logger); err != nil { } else if storage, err := server.NewGDriveStorage(clientJSONFilepath, localConfigPath, basedir, chunkSize, logger); err != nil {
panic(err) panic(err)
} else { } else {
options = append(options, server.UseStorage(storage)) options = append(options, server.UseStorage(storage))

View file

@ -123,7 +123,7 @@ func (s *Server) previewHandler(w http.ResponseWriter, r *http.Request) {
token := vars["token"] token := vars["token"]
filename := vars["filename"] filename := vars["filename"]
metadata, err := s.CheckMetadata(token, filename, false) metadata, err := s.checkMetadata(token, filename, false)
if err != nil { if err != nil {
s.logger.Printf("Error metadata: %s", err.Error()) s.logger.Printf("Error metadata: %s", err.Error())
@ -198,9 +198,9 @@ func (s *Server) previewHandler(w http.ResponseWriter, r *http.Request) {
ContentType string ContentType string
Content html_template.HTML Content html_template.HTML
Filename string Filename string
Url string URL string
UrlGet string URLGet string
UrlRandomToken string URLRandomToken string
Hostname string Hostname string
WebAddress string WebAddress string
ContentLength uint64 ContentLength uint64
@ -264,8 +264,8 @@ func (s *Server) viewHandler(w http.ResponseWriter, r *http.Request) {
s.userVoiceKey, s.userVoiceKey,
purgeTime, purgeTime,
maxUploadSize, maxUploadSize,
Token(s.randomTokenLength), token(s.randomTokenLength),
Token(s.randomTokenLength), token(s.randomTokenLength),
} }
if acceptsHTML(r.Header) { if acceptsHTML(r.Header) {
@ -296,7 +296,7 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
token := Token(s.randomTokenLength) token := token(s.randomTokenLength)
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
@ -354,7 +354,7 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
metadata := MetadataForRequest(contentType, s.randomTokenLength, r) metadata := metadataForRequest(contentType, s.randomTokenLength, r)
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
if err := json.NewEncoder(buffer).Encode(metadata); err != nil { if err := json.NewEncoder(buffer).Encode(metadata); err != nil {
@ -403,7 +403,7 @@ func (s *Server) cleanTmpFile(f *os.File) {
} }
} }
type Metadata struct { type metadata struct {
// ContentType is the original uploading content type // ContentType is the original uploading content type
ContentType string ContentType string
// Secret as knowledge to delete file // Secret as knowledge to delete file
@ -418,13 +418,13 @@ type Metadata struct {
DeletionToken string DeletionToken string
} }
func MetadataForRequest(contentType string, randomTokenLength int, r *http.Request) Metadata { func metadataForRequest(contentType string, randomTokenLength int, r *http.Request) metadata {
metadata := Metadata{ metadata := metadata{
ContentType: strings.ToLower(contentType), ContentType: strings.ToLower(contentType),
MaxDate: time.Time{}, MaxDate: time.Time{},
Downloads: 0, Downloads: 0,
MaxDownloads: -1, MaxDownloads: -1,
DeletionToken: Token(randomTokenLength) + Token(randomTokenLength), DeletionToken: token(randomTokenLength) + token(randomTokenLength),
} }
if v := r.Header.Get("Max-Downloads"); v == "" { if v := r.Header.Get("Max-Downloads"); v == "" {
@ -512,9 +512,9 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) {
contentType := mime.TypeByExtension(filepath.Ext(vars["filename"])) contentType := mime.TypeByExtension(filepath.Ext(vars["filename"]))
token := Token(s.randomTokenLength) token := token(s.randomTokenLength)
metadata := MetadataForRequest(contentType, s.randomTokenLength, r) metadata := metadataForRequest(contentType, s.randomTokenLength, r)
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
if err := json.NewEncoder(buffer).Encode(metadata); err != nil { if err := json.NewEncoder(buffer).Encode(metadata); err != nil {
@ -639,7 +639,7 @@ func getURL(r *http.Request, proxyPort string) *url.URL {
return u return u
} }
func (metadata Metadata) remainingLimitHeaderValues() (remainingDownloads, remainingDays string) { func (metadata metadata) remainingLimitHeaderValues() (remainingDownloads, remainingDays string) {
if metadata.MaxDate.IsZero() { if metadata.MaxDate.IsZero() {
remainingDays = "n/a" remainingDays = "n/a"
} else { } else {
@ -656,7 +656,7 @@ func (metadata Metadata) remainingLimitHeaderValues() (remainingDownloads, remai
return remainingDownloads, remainingDays return remainingDownloads, remainingDays
} }
func (s *Server) Lock(token, filename string) { func (s *Server) lock(token, filename string) {
key := path.Join(token, filename) key := path.Join(token, filename)
lock, _ := s.locks.LoadOrStore(key, &sync.Mutex{}) lock, _ := s.locks.LoadOrStore(key, &sync.Mutex{})
@ -666,7 +666,7 @@ func (s *Server) Lock(token, filename string) {
return return
} }
func (s *Server) Unlock(token, filename string) { func (s *Server) unlock(token, filename string) {
key := path.Join(token, filename) key := path.Join(token, filename)
lock, _ := s.locks.LoadOrStore(key, &sync.Mutex{}) lock, _ := s.locks.LoadOrStore(key, &sync.Mutex{})
@ -674,11 +674,11 @@ func (s *Server) Unlock(token, filename string) {
lock.(*sync.Mutex).Unlock() lock.(*sync.Mutex).Unlock()
} }
func (s *Server) CheckMetadata(token, filename string, increaseDownload bool) (Metadata, error) { func (s *Server) checkMetadata(token, filename string, increaseDownload bool) (metadata, error) {
s.Lock(token, filename) s.lock(token, filename)
defer s.Unlock(token, filename) defer s.unlock(token, filename)
var metadata Metadata var metadata metadata
r, _, err := s.storage.Get(token, fmt.Sprintf("%s.metadata", filename)) r, _, err := s.storage.Get(token, fmt.Sprintf("%s.metadata", filename))
if err != nil { if err != nil {
@ -690,9 +690,9 @@ func (s *Server) CheckMetadata(token, filename string, increaseDownload bool) (M
if err := json.NewDecoder(r).Decode(&metadata); err != nil { if err := json.NewDecoder(r).Decode(&metadata); err != nil {
return metadata, err return metadata, err
} else if metadata.MaxDownloads != -1 && metadata.Downloads >= metadata.MaxDownloads { } else if metadata.MaxDownloads != -1 && metadata.Downloads >= metadata.MaxDownloads {
return metadata, errors.New("MaxDownloads expired.") return metadata, errors.New("maxDownloads expired")
} else if !metadata.MaxDate.IsZero() && time.Now().After(metadata.MaxDate) { } else if !metadata.MaxDate.IsZero() && time.Now().After(metadata.MaxDate) {
return metadata, errors.New("MaxDate expired.") return metadata, errors.New("maxDate expired")
} else if metadata.MaxDownloads != -1 && increaseDownload { } else if metadata.MaxDownloads != -1 && increaseDownload {
// todo(nl5887): mutex? // todo(nl5887): mutex?
@ -710,15 +710,15 @@ func (s *Server) CheckMetadata(token, filename string, increaseDownload bool) (M
return metadata, nil return metadata, nil
} }
func (s *Server) CheckDeletionToken(deletionToken, token, filename string) error { func (s *Server) checkDeletionToken(deletionToken, token, filename string) error {
s.Lock(token, filename) s.lock(token, filename)
defer s.Unlock(token, filename) defer s.unlock(token, filename)
var metadata Metadata var metadata metadata
r, _, err := s.storage.Get(token, fmt.Sprintf("%s.metadata", filename)) r, _, err := s.storage.Get(token, fmt.Sprintf("%s.metadata", filename))
if s.storage.IsNotExist(err) { if s.storage.IsNotExist(err) {
return errors.New("Metadata doesn't exist") return errors.New("metadata doesn't exist")
} else if err != nil { } else if err != nil {
return err return err
} }
@ -728,7 +728,7 @@ func (s *Server) CheckDeletionToken(deletionToken, token, filename string) error
if err := json.NewDecoder(r).Decode(&metadata); err != nil { if err := json.NewDecoder(r).Decode(&metadata); err != nil {
return err return err
} else if metadata.DeletionToken != deletionToken { } else if metadata.DeletionToken != deletionToken {
return errors.New("Deletion token doesn't match.") return errors.New("deletion token doesn't match")
} }
return nil return nil
@ -754,7 +754,7 @@ func (s *Server) deleteHandler(w http.ResponseWriter, r *http.Request) {
filename := vars["filename"] filename := vars["filename"]
deletionToken := vars["deletionToken"] deletionToken := vars["deletionToken"]
if err := s.CheckDeletionToken(deletionToken, token, filename); err != nil { if err := s.checkDeletionToken(deletionToken, token, filename); err != nil {
s.logger.Printf("Error metadata: %s", err.Error()) s.logger.Printf("Error metadata: %s", err.Error())
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return return
@ -790,7 +790,7 @@ func (s *Server) zipHandler(w http.ResponseWriter, r *http.Request) {
token := strings.Split(key, "/")[0] token := strings.Split(key, "/")[0]
filename := sanitize(strings.Split(key, "/")[1]) filename := sanitize(strings.Split(key, "/")[1])
if _, err := s.CheckMetadata(token, filename, true); err != nil { if _, err := s.checkMetadata(token, filename, true); err != nil {
s.logger.Printf("Error metadata: %s", err.Error()) s.logger.Printf("Error metadata: %s", err.Error())
continue continue
} }
@ -801,11 +801,11 @@ func (s *Server) zipHandler(w http.ResponseWriter, r *http.Request) {
if s.storage.IsNotExist(err) { if s.storage.IsNotExist(err) {
http.Error(w, "File not found", 404) http.Error(w, "File not found", 404)
return return
} else {
s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
} }
s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
} }
defer reader.Close() defer reader.Close()
@ -862,7 +862,7 @@ func (s *Server) tarGzHandler(w http.ResponseWriter, r *http.Request) {
token := strings.Split(key, "/")[0] token := strings.Split(key, "/")[0]
filename := sanitize(strings.Split(key, "/")[1]) filename := sanitize(strings.Split(key, "/")[1])
if _, err := s.CheckMetadata(token, filename, true); err != nil { if _, err := s.checkMetadata(token, filename, true); err != nil {
s.logger.Printf("Error metadata: %s", err.Error()) s.logger.Printf("Error metadata: %s", err.Error())
continue continue
} }
@ -872,11 +872,11 @@ func (s *Server) tarGzHandler(w http.ResponseWriter, r *http.Request) {
if s.storage.IsNotExist(err) { if s.storage.IsNotExist(err) {
http.Error(w, "File not found", 404) http.Error(w, "File not found", 404)
return return
} else {
s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
} }
s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
} }
defer reader.Close() defer reader.Close()
@ -921,7 +921,7 @@ func (s *Server) tarHandler(w http.ResponseWriter, r *http.Request) {
token := strings.Split(key, "/")[0] token := strings.Split(key, "/")[0]
filename := strings.Split(key, "/")[1] filename := strings.Split(key, "/")[1]
if _, err := s.CheckMetadata(token, filename, true); err != nil { if _, err := s.checkMetadata(token, filename, true); err != nil {
s.logger.Printf("Error metadata: %s", err.Error()) s.logger.Printf("Error metadata: %s", err.Error())
continue continue
} }
@ -931,11 +931,11 @@ func (s *Server) tarHandler(w http.ResponseWriter, r *http.Request) {
if s.storage.IsNotExist(err) { if s.storage.IsNotExist(err) {
http.Error(w, "File not found", 404) http.Error(w, "File not found", 404)
return return
} else {
s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
} }
s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
} }
defer reader.Close() defer reader.Close()
@ -966,7 +966,7 @@ func (s *Server) headHandler(w http.ResponseWriter, r *http.Request) {
token := vars["token"] token := vars["token"]
filename := vars["filename"] filename := vars["filename"]
metadata, err := s.CheckMetadata(token, filename, false) metadata, err := s.checkMetadata(token, filename, false)
if err != nil { if err != nil {
s.logger.Printf("Error metadata: %s", err.Error()) s.logger.Printf("Error metadata: %s", err.Error())
@ -1001,7 +1001,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) {
token := vars["token"] token := vars["token"]
filename := vars["filename"] filename := vars["filename"]
metadata, err := s.CheckMetadata(token, filename, true) metadata, err := s.checkMetadata(token, filename, true)
if err != nil { if err != nil {
s.logger.Printf("Error metadata: %s", err.Error()) s.logger.Printf("Error metadata: %s", err.Error())
@ -1073,9 +1073,10 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
// RedirectHandler handles redirect
func (s *Server) RedirectHandler(h http.Handler) http.HandlerFunc { func (s *Server) RedirectHandler(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !s.forceHTTPs { if !s.forceHTTPS {
// we don't want to enforce https // we don't want to enforce https
} else if r.URL.Path == "/health.html" { } else if r.URL.Path == "/health.html" {
// health check url won't redirect // health check url won't redirect
@ -1095,17 +1096,17 @@ func (s *Server) RedirectHandler(h http.Handler) http.HandlerFunc {
} }
} }
// Create a log handler for every request it receives. // LoveHandler Create a log handler for every request it receives.
func LoveHandler(h http.Handler) http.HandlerFunc { func LoveHandler(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-made-with", "<3 by DutchCoders") w.Header().Set("x-made-with", "<3 by DutchCoders")
w.Header().Set("x-served-by", "Proudly served by DutchCoders") w.Header().Set("x-served-by", "Proudly served by DutchCoders")
w.Header().Set("Server", "Transfer.sh HTTP Server 1.0") w.Header().Set("server", "Transfer.sh HTTP Server")
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
} }
} }
func IPFilterHandler(h http.Handler, ipFilterOptions *IPFilterOptions) http.HandlerFunc { func ipFilterHandler(h http.Handler, ipFilterOptions *IPFilterOptions) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if ipFilterOptions == nil { if ipFilterOptions == nil {
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
@ -1116,7 +1117,7 @@ func IPFilterHandler(h http.Handler, ipFilterOptions *IPFilterOptions) http.Hand
} }
} }
func (s *Server) BasicAuthHandler(h http.Handler) http.HandlerFunc { func (s *Server) basicAuthHandler(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if s.AuthUser == "" || s.AuthPass == "" { if s.AuthUser == "" || s.AuthPass == "" {
h.ServeHTTP(w, r) h.ServeHTTP(w, r)

View file

@ -13,16 +13,16 @@ import (
func Test(t *testing.T) { TestingT(t) } func Test(t *testing.T) { TestingT(t) }
var ( var (
_ = Suite(&SuiteRedirectWithForceHTTPs{}) _ = Suite(&suiteRedirectWithForceHTTPS{})
_ = Suite(&SuiteRedirectWithoutForceHTTPs{}) _ = Suite(&suiteRedirectWithoutForceHTTPS{})
) )
type SuiteRedirectWithForceHTTPs struct { type suiteRedirectWithForceHTTPS struct {
handler http.HandlerFunc handler http.HandlerFunc
} }
func (s *SuiteRedirectWithForceHTTPs) SetUpTest(c *C) { func (s *suiteRedirectWithForceHTTPS) SetUpTest(c *C) {
srvr, err := New(ForceHTTPs()) srvr, err := New(ForceHTTPS())
c.Assert(err, IsNil) c.Assert(err, IsNil)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -32,7 +32,7 @@ func (s *SuiteRedirectWithForceHTTPs) SetUpTest(c *C) {
s.handler = srvr.RedirectHandler(handler) s.handler = srvr.RedirectHandler(handler)
} }
func (s *SuiteRedirectWithForceHTTPs) TestHTTPs(c *C) { func (s *suiteRedirectWithForceHTTPS) TestHTTPs(c *C) {
req := httptest.NewRequest("GET", "https://test/test", nil) req := httptest.NewRequest("GET", "https://test/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -42,7 +42,7 @@ func (s *SuiteRedirectWithForceHTTPs) TestHTTPs(c *C) {
c.Assert(resp.StatusCode, Equals, http.StatusOK) c.Assert(resp.StatusCode, Equals, http.StatusOK)
} }
func (s *SuiteRedirectWithForceHTTPs) TestOnion(c *C) { func (s *suiteRedirectWithForceHTTPS) TestOnion(c *C) {
req := httptest.NewRequest("GET", "http://test.onion/test", nil) req := httptest.NewRequest("GET", "http://test.onion/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -52,7 +52,7 @@ func (s *SuiteRedirectWithForceHTTPs) TestOnion(c *C) {
c.Assert(resp.StatusCode, Equals, http.StatusOK) c.Assert(resp.StatusCode, Equals, http.StatusOK)
} }
func (s *SuiteRedirectWithForceHTTPs) TestXForwardedFor(c *C) { func (s *suiteRedirectWithForceHTTPS) TestXForwardedFor(c *C) {
req := httptest.NewRequest("GET", "http://127.0.0.1/test", nil) req := httptest.NewRequest("GET", "http://127.0.0.1/test", nil)
req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Proto", "https")
@ -63,7 +63,7 @@ func (s *SuiteRedirectWithForceHTTPs) TestXForwardedFor(c *C) {
c.Assert(resp.StatusCode, Equals, http.StatusOK) c.Assert(resp.StatusCode, Equals, http.StatusOK)
} }
func (s *SuiteRedirectWithForceHTTPs) TestHTTP(c *C) { func (s *suiteRedirectWithForceHTTPS) TestHTTP(c *C) {
req := httptest.NewRequest("GET", "http://127.0.0.1/test", nil) req := httptest.NewRequest("GET", "http://127.0.0.1/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -74,11 +74,11 @@ func (s *SuiteRedirectWithForceHTTPs) TestHTTP(c *C) {
c.Assert(resp.Header.Get("Location"), Equals, "https://127.0.0.1/test") c.Assert(resp.Header.Get("Location"), Equals, "https://127.0.0.1/test")
} }
type SuiteRedirectWithoutForceHTTPs struct { type suiteRedirectWithoutForceHTTPS struct {
handler http.HandlerFunc handler http.HandlerFunc
} }
func (s *SuiteRedirectWithoutForceHTTPs) SetUpTest(c *C) { func (s *suiteRedirectWithoutForceHTTPS) SetUpTest(c *C) {
srvr, err := New() srvr, err := New()
c.Assert(err, IsNil) c.Assert(err, IsNil)
@ -89,7 +89,7 @@ func (s *SuiteRedirectWithoutForceHTTPs) SetUpTest(c *C) {
s.handler = srvr.RedirectHandler(handler) s.handler = srvr.RedirectHandler(handler)
} }
func (s *SuiteRedirectWithoutForceHTTPs) TestHTTP(c *C) { func (s *suiteRedirectWithoutForceHTTPS) TestHTTP(c *C) {
req := httptest.NewRequest("GET", "http://127.0.0.1/test", nil) req := httptest.NewRequest("GET", "http://127.0.0.1/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -99,7 +99,7 @@ func (s *SuiteRedirectWithoutForceHTTPs) TestHTTP(c *C) {
c.Assert(resp.StatusCode, Equals, http.StatusOK) c.Assert(resp.StatusCode, Equals, http.StatusOK)
} }
func (s *SuiteRedirectWithoutForceHTTPs) TestHTTPs(c *C) { func (s *suiteRedirectWithoutForceHTTPS) TestHTTPs(c *C) {
req := httptest.NewRequest("GET", "https://127.0.0.1/test", nil) req := httptest.NewRequest("GET", "https://127.0.0.1/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View file

@ -21,7 +21,7 @@ import (
"github.com/tomasen/realip" "github.com/tomasen/realip"
) )
//IPFilterOptions for IPFilter. Allowed takes precendence over Blocked. //IPFilterOptions for ipFilter. Allowed takes precedence over Blocked.
//IPs can be IPv4 or IPv6 and can optionally contain subnet //IPs can be IPv4 or IPv6 and can optionally contain subnet
//masks (/24). Note however, determining if a given IP is //masks (/24). Note however, determining if a given IP is
//included in a subnet requires a linear scan so is less performant //included in a subnet requires a linear scan so is less performant
@ -43,7 +43,8 @@ type IPFilterOptions struct {
} }
} }
type IPFilter struct { // ipFilter
type ipFilter struct {
opts IPFilterOptions opts IPFilterOptions
//mut protects the below //mut protects the below
//rw since writes are rare //rw since writes are rare
@ -59,13 +60,12 @@ type subnet struct {
allowed bool allowed bool
} }
//New constructs IPFilter instance. func newIPFilter(opts IPFilterOptions) *ipFilter {
func NewIPFilter(opts IPFilterOptions) *IPFilter {
if opts.Logger == nil { if opts.Logger == nil {
flags := log.LstdFlags flags := log.LstdFlags
opts.Logger = log.New(os.Stdout, "", flags) opts.Logger = log.New(os.Stdout, "", flags)
} }
f := &IPFilter{ f := &ipFilter{
opts: opts, opts: opts,
ips: map[string]bool{}, ips: map[string]bool{},
defaultAllowed: !opts.BlockByDefault, defaultAllowed: !opts.BlockByDefault,
@ -79,15 +79,15 @@ func NewIPFilter(opts IPFilterOptions) *IPFilter {
return f return f
} }
func (f *IPFilter) AllowIP(ip string) bool { func (f *ipFilter) AllowIP(ip string) bool {
return f.ToggleIP(ip, true) return f.ToggleIP(ip, true)
} }
func (f *IPFilter) BlockIP(ip string) bool { func (f *ipFilter) BlockIP(ip string) bool {
return f.ToggleIP(ip, false) return f.ToggleIP(ip, false)
} }
func (f *IPFilter) ToggleIP(str string, allowed bool) bool { func (f *ipFilter) ToggleIP(str string, allowed bool) bool {
//check if has subnet //check if has subnet
if ip, net, err := net.ParseCIDR(str); err == nil { if ip, net, err := net.ParseCIDR(str); err == nil {
// containing only one ip? // containing only one ip?
@ -128,19 +128,19 @@ func (f *IPFilter) ToggleIP(str string, allowed bool) bool {
} }
//ToggleDefault alters the default setting //ToggleDefault alters the default setting
func (f *IPFilter) ToggleDefault(allowed bool) { func (f *ipFilter) ToggleDefault(allowed bool) {
f.mut.Lock() f.mut.Lock()
f.defaultAllowed = allowed f.defaultAllowed = allowed
f.mut.Unlock() f.mut.Unlock()
} }
//Allowed returns if a given IP can pass through the filter //Allowed returns if a given IP can pass through the filter
func (f *IPFilter) Allowed(ipstr string) bool { func (f *ipFilter) Allowed(ipstr string) bool {
return f.NetAllowed(net.ParseIP(ipstr)) return f.NetAllowed(net.ParseIP(ipstr))
} }
//NetAllowed returns if a given net.IP can pass through the filter //NetAllowed returns if a given net.IP can pass through the filter
func (f *IPFilter) NetAllowed(ip net.IP) bool { func (f *ipFilter) NetAllowed(ip net.IP) bool {
//invalid ip //invalid ip
if ip == nil { if ip == nil {
return false return false
@ -173,35 +173,35 @@ func (f *IPFilter) NetAllowed(ip net.IP) bool {
} }
//Blocked returns if a given IP can NOT pass through the filter //Blocked returns if a given IP can NOT pass through the filter
func (f *IPFilter) Blocked(ip string) bool { func (f *ipFilter) Blocked(ip string) bool {
return !f.Allowed(ip) return !f.Allowed(ip)
} }
//NetBlocked returns if a given net.IP can NOT pass through the filter //NetBlocked returns if a given net.IP can NOT pass through the filter
func (f *IPFilter) NetBlocked(ip net.IP) bool { func (f *ipFilter) NetBlocked(ip net.IP) bool {
return !f.NetAllowed(ip) return !f.NetAllowed(ip)
} }
//WrapIPFilter the provided handler with simple IP blocking middleware //WrapIPFilter the provided handler with simple IP blocking middleware
//using this IP filter and its configuration //using this IP filter and its configuration
func (f *IPFilter) Wrap(next http.Handler) http.Handler { func (f *ipFilter) Wrap(next http.Handler) http.Handler {
return &ipFilterMiddleware{IPFilter: f, next: next} return &ipFilterMiddleware{ipFilter: f, next: next}
} }
//WrapIPFilter is equivalent to NewIPFilter(opts) then Wrap(next) //WrapIPFilter is equivalent to newIPFilter(opts) then Wrap(next)
func WrapIPFilter(next http.Handler, opts IPFilterOptions) http.Handler { func WrapIPFilter(next http.Handler, opts IPFilterOptions) http.Handler {
return NewIPFilter(opts).Wrap(next) return newIPFilter(opts).Wrap(next)
} }
type ipFilterMiddleware struct { type ipFilterMiddleware struct {
*IPFilter *ipFilter
next http.Handler next http.Handler
} }
func (m *ipFilterMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (m *ipFilterMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
remoteIP := realip.FromRequest(r) remoteIP := realip.FromRequest(r)
if !m.IPFilter.Allowed(remoteIP) { if !m.ipFilter.Allowed(remoteIP) {
//show simple forbidden text //show simple forbidden text
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return return

View file

@ -48,6 +48,7 @@ import (
"github.com/VojtechVitek/ratelimit/memory" "github.com/VojtechVitek/ratelimit/memory"
"github.com/gorilla/mux" "github.com/gorilla/mux"
// import pprof
_ "net/http/pprof" _ "net/http/pprof"
"crypto/tls" "crypto/tls"
@ -59,28 +60,30 @@ import (
"path/filepath" "path/filepath"
) )
const SERVER_INFO = "transfer.sh"
// parse request with maximum memory of _24Kilobits // parse request with maximum memory of _24Kilobits
const _24K = (1 << 3) * 24 const _24K = (1 << 3) * 24
// parse request with maximum memory of _5Megabytes // parse request with maximum memory of _5Megabytes
const _5M = (1 << 20) * 5 const _5M = (1 << 20) * 5
// OptionFn is the option function type
type OptionFn func(*Server) type OptionFn func(*Server)
// ClamavHost sets clamav host
func ClamavHost(s string) OptionFn { func ClamavHost(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.ClamAVDaemonHost = s srvr.ClamAVDaemonHost = s
} }
} }
// VirustotalKey sets virus total key
func VirustotalKey(s string) OptionFn { func VirustotalKey(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.VirusTotalKey = s srvr.VirusTotalKey = s
} }
} }
// Listener set listener
func Listener(s string) OptionFn { func Listener(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.ListenerString = s srvr.ListenerString = s
@ -88,6 +91,7 @@ func Listener(s string) OptionFn {
} }
// CorsDomains sets CORS domains
func CorsDomains(s string) OptionFn { func CorsDomains(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.CorsDomains = s srvr.CorsDomains = s
@ -95,18 +99,21 @@ func CorsDomains(s string) OptionFn {
} }
// GoogleAnalytics sets GA key
func GoogleAnalytics(gaKey string) OptionFn { func GoogleAnalytics(gaKey string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.gaKey = gaKey srvr.gaKey = gaKey
} }
} }
// UserVoice sets UV key
func UserVoice(userVoiceKey string) OptionFn { func UserVoice(userVoiceKey string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.userVoiceKey = userVoiceKey srvr.userVoiceKey = userVoiceKey
} }
} }
// TLSListener sets TLS listener and option
func TLSListener(s string, t bool) OptionFn { func TLSListener(s string, t bool) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.TLSListenerString = s srvr.TLSListenerString = s
@ -115,12 +122,14 @@ func TLSListener(s string, t bool) OptionFn {
} }
// ProfileListener sets profile listener
func ProfileListener(s string) OptionFn { func ProfileListener(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.ProfileListenerString = s srvr.ProfileListenerString = s
} }
} }
// WebPath sets web path
func WebPath(s string) OptionFn { func WebPath(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
if s[len(s)-1:] != "/" { if s[len(s)-1:] != "/" {
@ -131,6 +140,7 @@ func WebPath(s string) OptionFn {
} }
} }
// ProxyPath sets proxy path
func ProxyPath(s string) OptionFn { func ProxyPath(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
if s[len(s)-1:] != "/" { if s[len(s)-1:] != "/" {
@ -141,12 +151,14 @@ func ProxyPath(s string) OptionFn {
} }
} }
// ProxyPort sets proxy port
func ProxyPort(s string) OptionFn { func ProxyPort(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.proxyPort = s srvr.proxyPort = s
} }
} }
// TempPath sets temp path
func TempPath(s string) OptionFn { func TempPath(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
if s[len(s)-1:] != "/" { if s[len(s)-1:] != "/" {
@ -157,6 +169,7 @@ func TempPath(s string) OptionFn {
} }
} }
// LogFile sets log file
func LogFile(logger *log.Logger, s string) OptionFn { func LogFile(logger *log.Logger, s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
f, err := os.OpenFile(s, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) f, err := os.OpenFile(s, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
@ -169,30 +182,36 @@ func LogFile(logger *log.Logger, s string) OptionFn {
} }
} }
// Logger sets logger
func Logger(logger *log.Logger) OptionFn { func Logger(logger *log.Logger) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.logger = logger srvr.logger = logger
} }
} }
// MaxUploadSize sets max upload size
func MaxUploadSize(kbytes int64) OptionFn { func MaxUploadSize(kbytes int64) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.maxUploadSize = kbytes * 1024 srvr.maxUploadSize = kbytes * 1024
} }
} }
// RateLimit set rate limit
func RateLimit(requests int) OptionFn { func RateLimit(requests int) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.rateLimitRequests = requests srvr.rateLimitRequests = requests
} }
} }
// RandomTokenLength sets random token length
func RandomTokenLength(length int) OptionFn { func RandomTokenLength(length int) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.randomTokenLength = length srvr.randomTokenLength = length
} }
} }
// Purge sets purge days and option
func Purge(days, interval int) OptionFn { func Purge(days, interval int) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.purgeDays = time.Duration(days) * time.Hour * 24 srvr.purgeDays = time.Duration(days) * time.Hour * 24
@ -200,24 +219,28 @@ func Purge(days, interval int) OptionFn {
} }
} }
func ForceHTTPs() OptionFn { // ForceHTTPS sets forcing https
func ForceHTTPS() OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.forceHTTPs = true srvr.forceHTTPS = true
} }
} }
// EnableProfiler sets enable profiler
func EnableProfiler() OptionFn { func EnableProfiler() OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.profilerEnabled = true srvr.profilerEnabled = true
} }
} }
// UseStorage set storage to use
func UseStorage(s Storage) OptionFn { func UseStorage(s Storage) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.storage = s srvr.storage = s
} }
} }
// UseLetsEncrypt set letsencrypt usage
func UseLetsEncrypt(hosts []string) OptionFn { func UseLetsEncrypt(hosts []string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
cacheDir := "./cache/" cacheDir := "./cache/"
@ -246,6 +269,7 @@ func UseLetsEncrypt(hosts []string) OptionFn {
} }
} }
// TLSConfig sets TLS config
func TLSConfig(cert, pk string) OptionFn { func TLSConfig(cert, pk string) OptionFn {
certificate, err := tls.LoadX509KeyPair(cert, pk) certificate, err := tls.LoadX509KeyPair(cert, pk)
return func(srvr *Server) { return func(srvr *Server) {
@ -257,13 +281,15 @@ func TLSConfig(cert, pk string) OptionFn {
} }
} }
func HttpAuthCredentials(user string, pass string) OptionFn { // HTTPAuthCredentials sets basic http auth credentials
func HTTPAuthCredentials(user string, pass string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
srvr.AuthUser = user srvr.AuthUser = user
srvr.AuthPass = pass srvr.AuthPass = pass
} }
} }
// FilterOptions sets ip filtering
func FilterOptions(options IPFilterOptions) OptionFn { func FilterOptions(options IPFilterOptions) OptionFn {
for i, allowedIP := range options.AllowedIPs { for i, allowedIP := range options.AllowedIPs {
options.AllowedIPs[i] = strings.TrimSpace(allowedIP) options.AllowedIPs[i] = strings.TrimSpace(allowedIP)
@ -278,6 +304,7 @@ func FilterOptions(options IPFilterOptions) OptionFn {
} }
} }
// Server is the main application
type Server struct { type Server struct {
AuthUser string AuthUser string
AuthPass string AuthPass string
@ -298,7 +325,7 @@ type Server struct {
storage Storage storage Storage
forceHTTPs bool forceHTTPS bool
randomTokenLength int randomTokenLength int
@ -327,6 +354,7 @@ type Server struct {
LetsEncryptCache string LetsEncryptCache string
} }
// New is the factory fot Server
func New(options ...OptionFn) (*Server, error) { func New(options ...OptionFn) (*Server, error) {
s := &Server{ s := &Server{
locks: sync.Map{}, locks: sync.Map{},
@ -347,6 +375,7 @@ func init() {
rand.Seed(int64(binary.LittleEndian.Uint64(seedBytes[:]))) rand.Seed(int64(binary.LittleEndian.Uint64(seedBytes[:])))
} }
// Run starts Server
func (s *Server) Run() { func (s *Server) Run() {
listening := false listening := false
@ -402,7 +431,7 @@ func (s *Server) Run() {
r.HandleFunc("/favicon.ico", staticHandler.ServeHTTP).Methods("GET") r.HandleFunc("/favicon.ico", staticHandler.ServeHTTP).Methods("GET")
r.HandleFunc("/robots.txt", staticHandler.ServeHTTP).Methods("GET") r.HandleFunc("/robots.txt", staticHandler.ServeHTTP).Methods("GET")
r.HandleFunc("/{filename:(?:favicon\\.ico|robots\\.txt|health\\.html)}", s.BasicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT") r.HandleFunc("/{filename:(?:favicon\\.ico|robots\\.txt|health\\.html)}", s.basicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT")
r.HandleFunc("/health.html", healthHandler).Methods("GET") r.HandleFunc("/health.html", healthHandler).Methods("GET")
r.HandleFunc("/", s.viewHandler).Methods("GET") r.HandleFunc("/", s.viewHandler).Methods("GET")
@ -446,10 +475,10 @@ func (s *Server) Run() {
r.HandleFunc("/{filename}/virustotal", s.virusTotalHandler).Methods("PUT") r.HandleFunc("/{filename}/virustotal", s.virusTotalHandler).Methods("PUT")
r.HandleFunc("/{filename}/scan", s.scanHandler).Methods("PUT") r.HandleFunc("/{filename}/scan", s.scanHandler).Methods("PUT")
r.HandleFunc("/put/{filename}", s.BasicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT") r.HandleFunc("/put/{filename}", s.basicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT")
r.HandleFunc("/upload/{filename}", s.BasicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT") r.HandleFunc("/upload/{filename}", s.basicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT")
r.HandleFunc("/{filename}", s.BasicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT") r.HandleFunc("/{filename}", s.basicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT")
r.HandleFunc("/", s.BasicAuthHandler(http.HandlerFunc(s.postHandler))).Methods("POST") r.HandleFunc("/", s.basicAuthHandler(http.HandlerFunc(s.postHandler))).Methods("POST")
// r.HandleFunc("/{page}", viewHandler).Methods("GET") // r.HandleFunc("/{page}", viewHandler).Methods("GET")
r.HandleFunc("/{token}/{filename}/{deletionToken}", s.deleteHandler).Methods("DELETE") r.HandleFunc("/{token}/{filename}/{deletionToken}", s.deleteHandler).Methods("DELETE")
@ -474,7 +503,7 @@ func (s *Server) Run() {
} }
h := handlers.PanicHandler( h := handlers.PanicHandler(
IPFilterHandler( ipFilterHandler(
handlers.LogHandler( handlers.LogHandler(
LoveHandler( LoveHandler(
s.RedirectHandler(cors(r))), s.RedirectHandler(cors(r))),

View file

@ -27,31 +27,43 @@ import (
"storj.io/uplink" "storj.io/uplink"
) )
// Storage is the interface for storage operation
type Storage interface { type Storage interface {
// Get retrieves a file from storage
Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error)
// Head retrieves content length of a file from storage
Head(token string, filename string) (contentLength uint64, err error) Head(token string, filename string) (contentLength uint64, err error)
// Put saves a file on storage
Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error
// Delete removes a file from storage
Delete(token string, filename string) error Delete(token string, filename string) error
// IsNotExist indicates if a file doesn't exist on storage
IsNotExist(err error) bool IsNotExist(err error) bool
// Purge cleans up the storage
Purge(days time.Duration) error Purge(days time.Duration) error
// Type returns the storage type
Type() string Type() string
} }
// LocalStorage is a local storage
type LocalStorage struct { type LocalStorage struct {
Storage Storage
basedir string basedir string
logger *log.Logger logger *log.Logger
} }
// NewLocalStorage is the factory for LocalStorage
func NewLocalStorage(basedir string, logger *log.Logger) (*LocalStorage, error) { func NewLocalStorage(basedir string, logger *log.Logger) (*LocalStorage, error) {
return &LocalStorage{basedir: basedir, logger: logger}, nil return &LocalStorage{basedir: basedir, logger: logger}, nil
} }
// Type returns the storage type
func (s *LocalStorage) Type() string { func (s *LocalStorage) Type() string {
return "local" return "local"
} }
// Head retrieves content length of a file from storage
func (s *LocalStorage) Head(token string, filename string) (contentLength uint64, err error) { func (s *LocalStorage) Head(token string, filename string) (contentLength uint64, err error) {
path := filepath.Join(s.basedir, token, filename) path := filepath.Join(s.basedir, token, filename)
@ -65,6 +77,7 @@ func (s *LocalStorage) Head(token string, filename string) (contentLength uint64
return return
} }
// Get retrieves a file from storage
func (s *LocalStorage) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) { func (s *LocalStorage) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) {
path := filepath.Join(s.basedir, token, filename) path := filepath.Join(s.basedir, token, filename)
@ -83,6 +96,7 @@ func (s *LocalStorage) Get(token string, filename string) (reader io.ReadCloser,
return return
} }
// Delete removes a file from storage
func (s *LocalStorage) Delete(token string, filename string) (err error) { func (s *LocalStorage) Delete(token string, filename string) (err error) {
metadata := filepath.Join(s.basedir, token, fmt.Sprintf("%s.metadata", filename)) metadata := filepath.Join(s.basedir, token, fmt.Sprintf("%s.metadata", filename))
os.Remove(metadata) os.Remove(metadata)
@ -92,6 +106,7 @@ func (s *LocalStorage) Delete(token string, filename string) (err error) {
return return
} }
// Purge cleans up the storage
func (s *LocalStorage) Purge(days time.Duration) (err error) { func (s *LocalStorage) Purge(days time.Duration) (err error) {
err = filepath.Walk(s.basedir, err = filepath.Walk(s.basedir,
func(path string, info os.FileInfo, err error) error { func(path string, info os.FileInfo, err error) error {
@ -113,6 +128,7 @@ func (s *LocalStorage) Purge(days time.Duration) (err error) {
return return
} }
// IsNotExist indicates if a file doesn't exist on storage
func (s *LocalStorage) IsNotExist(err error) bool { func (s *LocalStorage) IsNotExist(err error) bool {
if err == nil { if err == nil {
return false return false
@ -121,6 +137,7 @@ func (s *LocalStorage) IsNotExist(err error) bool {
return os.IsNotExist(err) return os.IsNotExist(err)
} }
// Put saves a file on storage
func (s *LocalStorage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error { func (s *LocalStorage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error {
var f io.WriteCloser var f io.WriteCloser
var err error var err error
@ -144,6 +161,7 @@ func (s *LocalStorage) Put(token string, filename string, reader io.Reader, cont
return nil return nil
} }
// S3Storage is a storage backed by AWS S3
type S3Storage struct { type S3Storage struct {
Storage Storage
bucket string bucket string
@ -154,6 +172,7 @@ type S3Storage struct {
noMultipart bool noMultipart bool
} }
// NewS3Storage is the factory for S3Storage
func NewS3Storage(accessKey, secretKey, bucketName string, purgeDays int, region, endpoint string, disableMultipart bool, forcePathStyle bool, logger *log.Logger) (*S3Storage, error) { func NewS3Storage(accessKey, secretKey, bucketName string, purgeDays int, region, endpoint string, disableMultipart bool, forcePathStyle bool, logger *log.Logger) (*S3Storage, error) {
sess := getAwsSession(accessKey, secretKey, region, endpoint, forcePathStyle) sess := getAwsSession(accessKey, secretKey, region, endpoint, forcePathStyle)
@ -167,10 +186,12 @@ func NewS3Storage(accessKey, secretKey, bucketName string, purgeDays int, region
}, nil }, nil
} }
// Type returns the storage type
func (s *S3Storage) Type() string { func (s *S3Storage) Type() string {
return "s3" return "s3"
} }
// Head retrieves content length of a file from storage
func (s *S3Storage) Head(token string, filename string) (contentLength uint64, err error) { func (s *S3Storage) Head(token string, filename string) (contentLength uint64, err error) {
key := fmt.Sprintf("%s/%s", token, filename) key := fmt.Sprintf("%s/%s", token, filename)
@ -192,11 +213,13 @@ func (s *S3Storage) Head(token string, filename string) (contentLength uint64, e
return return
} }
// Purge cleans up the storage
func (s *S3Storage) Purge(days time.Duration) (err error) { func (s *S3Storage) Purge(days time.Duration) (err error) {
// NOOP expiration is set at upload time // NOOP expiration is set at upload time
return nil return nil
} }
// IsNotExist indicates if a file doesn't exist on storage
func (s *S3Storage) IsNotExist(err error) bool { func (s *S3Storage) IsNotExist(err error) bool {
if err == nil { if err == nil {
return false return false
@ -212,6 +235,7 @@ func (s *S3Storage) IsNotExist(err error) bool {
return false return false
} }
// Get retrieves a file from storage
func (s *S3Storage) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) { func (s *S3Storage) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) {
key := fmt.Sprintf("%s/%s", token, filename) key := fmt.Sprintf("%s/%s", token, filename)
@ -233,6 +257,7 @@ func (s *S3Storage) Get(token string, filename string) (reader io.ReadCloser, co
return return
} }
// Delete removes a file from storage
func (s *S3Storage) Delete(token string, filename string) (err error) { func (s *S3Storage) Delete(token string, filename string) (err error) {
metadata := fmt.Sprintf("%s/%s.metadata", token, filename) metadata := fmt.Sprintf("%s/%s.metadata", token, filename)
deleteRequest := &s3.DeleteObjectInput{ deleteRequest := &s3.DeleteObjectInput{
@ -256,6 +281,7 @@ func (s *S3Storage) Delete(token string, filename string) (err error) {
return return
} }
// Put saves a file on storage
func (s *S3Storage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) (err error) { func (s *S3Storage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) (err error) {
key := fmt.Sprintf("%s/%s", token, filename) key := fmt.Sprintf("%s/%s", token, filename)
@ -288,17 +314,19 @@ func (s *S3Storage) Put(token string, filename string, reader io.Reader, content
return return
} }
// GDrive is a storage backed by GDrive
type GDrive struct { type GDrive struct {
service *drive.Service service *drive.Service
rootId string rootID string
basedir string basedir string
localConfigPath string localConfigPath string
chunkSize int chunkSize int
logger *log.Logger logger *log.Logger
} }
func NewGDriveStorage(clientJsonFilepath string, localConfigPath string, basedir string, chunkSize int, logger *log.Logger) (*GDrive, error) { // NewGDriveStorage is the factory for GDrive
b, err := ioutil.ReadFile(clientJsonFilepath) func NewGDriveStorage(clientJSONFilepath string, localConfigPath string, basedir string, chunkSize int, logger *log.Logger) (*GDrive, error) {
b, err := ioutil.ReadFile(clientJSONFilepath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -315,7 +343,7 @@ func NewGDriveStorage(clientJsonFilepath string, localConfigPath string, basedir
} }
chunkSize = chunkSize * 1024 * 1024 chunkSize = chunkSize * 1024 * 1024
storage := &GDrive{service: srv, basedir: basedir, rootId: "", localConfigPath: localConfigPath, chunkSize: chunkSize, logger: logger} storage := &GDrive{service: srv, basedir: basedir, rootID: "", localConfigPath: localConfigPath, chunkSize: chunkSize, logger: logger}
err = storage.setupRoot() err = storage.setupRoot()
if err != nil { if err != nil {
return nil, err return nil, err
@ -324,26 +352,26 @@ func NewGDriveStorage(clientJsonFilepath string, localConfigPath string, basedir
return storage, nil return storage, nil
} }
const GDriveRootConfigFile = "root_id.conf" const gdriveRootConfigFile = "root_id.conf"
const GDriveTokenJsonFile = "token.json" const gdriveTokenJSONFile = "token.json"
const GDriveDirectoryMimeType = "application/vnd.google-apps.folder" const gdriveDirectoryMimeType = "application/vnd.google-apps.folder"
func (s *GDrive) setupRoot() error { func (s *GDrive) setupRoot() error {
rootFileConfig := filepath.Join(s.localConfigPath, GDriveRootConfigFile) rootFileConfig := filepath.Join(s.localConfigPath, gdriveRootConfigFile)
rootId, err := ioutil.ReadFile(rootFileConfig) rootID, err := ioutil.ReadFile(rootFileConfig)
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
return err return err
} }
if string(rootId) != "" { if string(rootID) != "" {
s.rootId = string(rootId) s.rootID = string(rootID)
return nil return nil
} }
dir := &drive.File{ dir := &drive.File{
Name: s.basedir, Name: s.basedir,
MimeType: GDriveDirectoryMimeType, MimeType: gdriveDirectoryMimeType,
} }
di, err := s.service.Files.Create(dir).Fields("id").Do() di, err := s.service.Files.Create(dir).Fields("id").Do()
@ -351,8 +379,8 @@ func (s *GDrive) setupRoot() error {
return err return err
} }
s.rootId = di.Id s.rootID = di.Id
err = ioutil.WriteFile(rootFileConfig, []byte(s.rootId), os.FileMode(0600)) err = ioutil.WriteFile(rootFileConfig, []byte(s.rootID), os.FileMode(0600))
if err != nil { if err != nil {
return err return err
} }
@ -368,13 +396,13 @@ func (s *GDrive) list(nextPageToken string, q string) (*drive.FileList, error) {
return s.service.Files.List().Fields("nextPageToken, files(id, name, mimeType)").Q(q).PageToken(nextPageToken).Do() return s.service.Files.List().Fields("nextPageToken, files(id, name, mimeType)").Q(q).PageToken(nextPageToken).Do()
} }
func (s *GDrive) findId(filename string, token string) (string, error) { func (s *GDrive) findID(filename string, token string) (string, error) {
filename = strings.Replace(filename, `'`, `\'`, -1) filename = strings.Replace(filename, `'`, `\'`, -1)
filename = strings.Replace(filename, `"`, `\"`, -1) filename = strings.Replace(filename, `"`, `\"`, -1)
fileId, tokenId, nextPageToken := "", "", "" fileID, tokenID, nextPageToken := "", "", ""
q := fmt.Sprintf("'%s' in parents and name='%s' and mimeType='%s' and trashed=false", s.rootId, token, GDriveDirectoryMimeType) q := fmt.Sprintf("'%s' in parents and name='%s' and mimeType='%s' and trashed=false", s.rootID, token, gdriveDirectoryMimeType)
l, err := s.list(nextPageToken, q) l, err := s.list(nextPageToken, q)
if err != nil { if err != nil {
return "", err return "", err
@ -382,7 +410,7 @@ func (s *GDrive) findId(filename string, token string) (string, error) {
for 0 < len(l.Files) { for 0 < len(l.Files) {
for _, fi := range l.Files { for _, fi := range l.Files {
tokenId = fi.Id tokenID = fi.Id
break break
} }
@ -391,15 +419,18 @@ func (s *GDrive) findId(filename string, token string) (string, error) {
} }
l, err = s.list(l.NextPageToken, q) l, err = s.list(l.NextPageToken, q)
if err != nil {
return "", err
}
} }
if filename == "" { if filename == "" {
return tokenId, nil return tokenID, nil
} else if tokenId == "" { } else if tokenID == "" {
return "", fmt.Errorf("Cannot find file %s/%s", token, filename) return "", fmt.Errorf("Cannot find file %s/%s", token, filename)
} }
q = fmt.Sprintf("'%s' in parents and name='%s' and mimeType!='%s' and trashed=false", tokenId, filename, GDriveDirectoryMimeType) q = fmt.Sprintf("'%s' in parents and name='%s' and mimeType!='%s' and trashed=false", tokenID, filename, gdriveDirectoryMimeType)
l, err = s.list(nextPageToken, q) l, err = s.list(nextPageToken, q)
if err != nil { if err != nil {
return "", err return "", err
@ -408,7 +439,7 @@ func (s *GDrive) findId(filename string, token string) (string, error) {
for 0 < len(l.Files) { for 0 < len(l.Files) {
for _, fi := range l.Files { for _, fi := range l.Files {
fileId = fi.Id fileID = fi.Id
break break
} }
@ -417,28 +448,33 @@ func (s *GDrive) findId(filename string, token string) (string, error) {
} }
l, err = s.list(l.NextPageToken, q) l, err = s.list(l.NextPageToken, q)
if err != nil {
return "", err
}
} }
if fileId == "" { if fileID == "" {
return "", fmt.Errorf("Cannot find file %s/%s", token, filename) return "", fmt.Errorf("Cannot find file %s/%s", token, filename)
} }
return fileId, nil return fileID, nil
} }
// Type returns the storage type
func (s *GDrive) Type() string { func (s *GDrive) Type() string {
return "gdrive" return "gdrive"
} }
// Head retrieves content length of a file from storage
func (s *GDrive) Head(token string, filename string) (contentLength uint64, err error) { func (s *GDrive) Head(token string, filename string) (contentLength uint64, err error) {
var fileId string var fileID string
fileId, err = s.findId(filename, token) fileID, err = s.findID(filename, token)
if err != nil { if err != nil {
return return
} }
var fi *drive.File var fi *drive.File
if fi, err = s.service.Files.Get(fileId).Fields("size").Do(); err != nil { if fi, err = s.service.Files.Get(fileID).Fields("size").Do(); err != nil {
return return
} }
@ -447,15 +483,16 @@ func (s *GDrive) Head(token string, filename string) (contentLength uint64, err
return return
} }
// Get retrieves a file from storage
func (s *GDrive) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) { func (s *GDrive) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) {
var fileId string var fileID string
fileId, err = s.findId(filename, token) fileID, err = s.findID(filename, token)
if err != nil { if err != nil {
return return
} }
var fi *drive.File var fi *drive.File
fi, err = s.service.Files.Get(fileId).Fields("size", "md5Checksum").Do() fi, err = s.service.Files.Get(fileID).Fields("size", "md5Checksum").Do()
if !s.hasChecksum(fi) { if !s.hasChecksum(fi) {
err = fmt.Errorf("Cannot find file %s/%s", token, filename) err = fmt.Errorf("Cannot find file %s/%s", token, filename)
return return
@ -465,7 +502,7 @@ func (s *GDrive) Get(token string, filename string) (reader io.ReadCloser, conte
ctx := context.Background() ctx := context.Background()
var res *http.Response var res *http.Response
res, err = s.service.Files.Get(fileId).Context(ctx).Download() res, err = s.service.Files.Get(fileID).Context(ctx).Download()
if err != nil { if err != nil {
return return
} }
@ -475,25 +512,27 @@ func (s *GDrive) Get(token string, filename string) (reader io.ReadCloser, conte
return return
} }
// Delete removes a file from storage
func (s *GDrive) Delete(token string, filename string) (err error) { func (s *GDrive) Delete(token string, filename string) (err error) {
metadata, _ := s.findId(fmt.Sprintf("%s.metadata", filename), token) metadata, _ := s.findID(fmt.Sprintf("%s.metadata", filename), token)
s.service.Files.Delete(metadata).Do() s.service.Files.Delete(metadata).Do()
var fileId string var fileID string
fileId, err = s.findId(filename, token) fileID, err = s.findID(filename, token)
if err != nil { if err != nil {
return return
} }
err = s.service.Files.Delete(fileId).Do() err = s.service.Files.Delete(fileID).Do()
return return
} }
// Purge cleans up the storage
func (s *GDrive) Purge(days time.Duration) (err error) { func (s *GDrive) Purge(days time.Duration) (err error) {
nextPageToken := "" nextPageToken := ""
expirationDate := time.Now().Add(-1 * days).Format(time.RFC3339) expirationDate := time.Now().Add(-1 * days).Format(time.RFC3339)
q := fmt.Sprintf("'%s' in parents and modifiedTime < '%s' and mimeType!='%s' and trashed=false", s.rootId, expirationDate, GDriveDirectoryMimeType) q := fmt.Sprintf("'%s' in parents and modifiedTime < '%s' and mimeType!='%s' and trashed=false", s.rootID, expirationDate, gdriveDirectoryMimeType)
l, err := s.list(nextPageToken, q) l, err := s.list(nextPageToken, q)
if err != nil { if err != nil {
return err return err
@ -512,32 +551,39 @@ func (s *GDrive) Purge(days time.Duration) (err error) {
} }
l, err = s.list(l.NextPageToken, q) l, err = s.list(l.NextPageToken, q)
if err != nil {
return
}
} }
return return
} }
// IsNotExist indicates if a file doesn't exist on storage
func (s *GDrive) IsNotExist(err error) bool { func (s *GDrive) IsNotExist(err error) bool {
if err != nil { if err == nil {
if e, ok := err.(*googleapi.Error); ok { return false
return e.Code == http.StatusNotFound }
}
if e, ok := err.(*googleapi.Error); ok {
return e.Code == http.StatusNotFound
} }
return false return false
} }
// Put saves a file on storage
func (s *GDrive) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error { func (s *GDrive) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error {
dirId, err := s.findId("", token) dirID, err := s.findID("", token)
if err != nil { if err != nil {
return err return err
} }
if dirId == "" { if dirID == "" {
dir := &drive.File{ dir := &drive.File{
Name: token, Name: token,
Parents: []string{s.rootId}, Parents: []string{s.rootID},
MimeType: GDriveDirectoryMimeType, MimeType: gdriveDirectoryMimeType,
} }
di, err := s.service.Files.Create(dir).Fields("id").Do() di, err := s.service.Files.Create(dir).Fields("id").Do()
@ -545,13 +591,13 @@ func (s *GDrive) Put(token string, filename string, reader io.Reader, contentTyp
return err return err
} }
dirId = di.Id dirID = di.Id
} }
// Instantiate empty drive file // Instantiate empty drive file
dst := &drive.File{ dst := &drive.File{
Name: filename, Name: filename,
Parents: []string{dirId}, Parents: []string{dirID},
MimeType: contentType, MimeType: contentType,
} }
@ -567,7 +613,7 @@ func (s *GDrive) Put(token string, filename string, reader io.Reader, contentTyp
// Retrieve a token, saves the token, then returns the generated client. // Retrieve a token, saves the token, then returns the generated client.
func getGDriveClient(config *oauth2.Config, localConfigPath string, logger *log.Logger) *http.Client { func getGDriveClient(config *oauth2.Config, localConfigPath string, logger *log.Logger) *http.Client {
tokenFile := filepath.Join(localConfigPath, GDriveTokenJsonFile) tokenFile := filepath.Join(localConfigPath, gdriveTokenJSONFile)
tok, err := gDriveTokenFromFile(tokenFile) tok, err := gDriveTokenFromFile(tokenFile)
if err != nil { if err != nil {
tok = getGDriveTokenFromWeb(config, logger) tok = getGDriveTokenFromWeb(config, logger)
@ -619,6 +665,7 @@ func saveGDriveToken(path string, token *oauth2.Token, logger *log.Logger) {
json.NewEncoder(f).Encode(token) json.NewEncoder(f).Encode(token)
} }
// StorjStorage is a storage backed by Storj
type StorjStorage struct { type StorjStorage struct {
Storage Storage
project *uplink.Project project *uplink.Project
@ -627,6 +674,7 @@ type StorjStorage struct {
logger *log.Logger logger *log.Logger
} }
// NewStorjStorage is the factory for StorjStorage
func NewStorjStorage(access, bucket string, purgeDays int, logger *log.Logger) (*StorjStorage, error) { func NewStorjStorage(access, bucket string, purgeDays int, logger *log.Logger) (*StorjStorage, error) {
var instance StorjStorage var instance StorjStorage
var err error var err error
@ -657,10 +705,12 @@ func NewStorjStorage(access, bucket string, purgeDays int, logger *log.Logger) (
return &instance, nil return &instance, nil
} }
// Type returns the storage type
func (s *StorjStorage) Type() string { func (s *StorjStorage) Type() string {
return "storj" return "storj"
} }
// Head retrieves content length of a file from storage
func (s *StorjStorage) Head(token string, filename string) (contentLength uint64, err error) { func (s *StorjStorage) Head(token string, filename string) (contentLength uint64, err error) {
key := storj.JoinPaths(token, filename) key := storj.JoinPaths(token, filename)
@ -676,6 +726,7 @@ func (s *StorjStorage) Head(token string, filename string) (contentLength uint64
return return
} }
// Get retrieves a file from storage
func (s *StorjStorage) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) { func (s *StorjStorage) Get(token string, filename string) (reader io.ReadCloser, contentLength uint64, err error) {
key := storj.JoinPaths(token, filename) key := storj.JoinPaths(token, filename)
@ -694,6 +745,7 @@ func (s *StorjStorage) Get(token string, filename string) (reader io.ReadCloser,
return return
} }
// Delete removes a file from storage
func (s *StorjStorage) Delete(token string, filename string) (err error) { func (s *StorjStorage) Delete(token string, filename string) (err error) {
key := storj.JoinPaths(token, filename) key := storj.JoinPaths(token, filename)
@ -706,11 +758,13 @@ func (s *StorjStorage) Delete(token string, filename string) (err error) {
return return
} }
// Purge cleans up the storage
func (s *StorjStorage) Purge(days time.Duration) (err error) { func (s *StorjStorage) Purge(days time.Duration) (err error) {
// NOOP expiration is set at upload time // NOOP expiration is set at upload time
return nil return nil
} }
// Put saves a file on storage
func (s *StorjStorage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) (err error) { func (s *StorjStorage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) (err error) {
key := storj.JoinPaths(token, filename) key := storj.JoinPaths(token, filename)
@ -745,6 +799,7 @@ func (s *StorjStorage) Put(token string, filename string, reader io.Reader, cont
return err return err
} }
// IsNotExist indicates if a file doesn't exist on storage
func (s *StorjStorage) IsNotExist(err error) bool { func (s *StorjStorage) IsNotExist(err error) bool {
return errors.Is(err, uplink.ErrObjectNotFound) return errors.Is(err, uplink.ErrObjectNotFound)
} }

View file

@ -29,12 +29,12 @@ import (
) )
const ( const (
// characters used for short-urls // SYMBOLS characters used for short-urls
SYMBOLS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" SYMBOLS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
) )
// generate a token // generate a token
func Token(length int) string { func token(length int) string {
result := "" result := ""
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
x := rand.Intn(len(SYMBOLS) - 1) x := rand.Intn(len(SYMBOLS) - 1)

View file

@ -4,12 +4,12 @@ import "testing"
func BenchmarkTokenConcat(b *testing.B) { func BenchmarkTokenConcat(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = Token(5) + Token(5) _ = token(5) + token(5)
} }
} }
func BenchmarkTokenLonger(b *testing.B) { func BenchmarkTokenLonger(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = Token(10) _ = token(10)
} }
} }

View file

@ -50,7 +50,7 @@ func getAwsSession(accessKey, secretKey, region, endpoint string, forcePathStyle
} }
func formatNumber(format string, s uint64) string { func formatNumber(format string, s uint64) string {
return RenderFloat(format, float64(s)) return renderFloat(format, float64(s))
} }
var renderFloatPrecisionMultipliers = [10]float64{ var renderFloatPrecisionMultipliers = [10]float64{
@ -79,7 +79,7 @@ var renderFloatPrecisionRounders = [10]float64{
0.0000000005, 0.0000000005,
} }
func RenderFloat(format string, n float64) string { func renderFloat(format string, n float64) string {
// Special cases: // Special cases:
// NaN = "NaN" // NaN = "NaN"
// +Inf = "+Infinity" // +Inf = "+Infinity"
@ -127,7 +127,7 @@ func RenderFloat(format string, n float64) string {
// +0000 // +0000
if formatDirectiveIndices[0] == 0 { if formatDirectiveIndices[0] == 0 {
if formatDirectiveChars[formatDirectiveIndices[0]] != '+' { if formatDirectiveChars[formatDirectiveIndices[0]] != '+' {
panic("RenderFloat(): invalid positive sign directive") panic("renderFloat(): invalid positive sign directive")
} }
positiveStr = "+" positiveStr = "+"
formatDirectiveIndices = formatDirectiveIndices[1:] formatDirectiveIndices = formatDirectiveIndices[1:]
@ -141,7 +141,7 @@ func RenderFloat(format string, n float64) string {
// 000,000.00 // 000,000.00
if len(formatDirectiveIndices) == 2 { if len(formatDirectiveIndices) == 2 {
if (formatDirectiveIndices[1] - formatDirectiveIndices[0]) != 4 { if (formatDirectiveIndices[1] - formatDirectiveIndices[0]) != 4 {
panic("RenderFloat(): thousands separator directive must be followed by 3 digit-specifiers") panic("renderFloat(): thousands separator directive must be followed by 3 digit-specifiers")
} }
thousandStr = string(formatDirectiveChars[formatDirectiveIndices[0]]) thousandStr = string(formatDirectiveChars[formatDirectiveIndices[0]])
formatDirectiveIndices = formatDirectiveIndices[1:] formatDirectiveIndices = formatDirectiveIndices[1:]
@ -201,8 +201,8 @@ func RenderFloat(format string, n float64) string {
return signStr + intStr + decimalStr + fracStr return signStr + intStr + decimalStr + fracStr
} }
func RenderInteger(format string, n int) string { func renderInteger(format string, n int) string {
return RenderFloat(format, float64(n)) return renderFloat(format, float64(n))
} }
// Request.RemoteAddress contains port, which we want to remove i.e.: // Request.RemoteAddress contains port, which we want to remove i.e.:

View file

@ -29,7 +29,6 @@ import (
"io" "io"
"net/http" "net/http"
_ "github.com/PuerkitoBio/ghost/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
virustotal "github.com/dutchcoders/go-virustotal" virustotal "github.com/dutchcoders/go-virustotal"