diff --git a/cmd/cmd.go b/cmd/cmd.go index 9163071..5f6287b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -95,6 +95,12 @@ var globalFlags = []cli.Flag{ Value: "", EnvVar: "BUCKET", }, + cli.IntFlag{ + Name: "rate-limit", + Usage: "requests per minute", + Value: 0, + EnvVar: "", + }, cli.StringFlag{ Name: "lets-encrypt-hosts", Usage: "host1, host2", @@ -190,6 +196,10 @@ func New() *Cmd { options = append(options, server.ClamavHost(v)) } + if v := c.Int("rate-limit"); v > 0 { + options = append(options, server.RateLimit(v)) + } + if cert := c.String("tls-cert-file"); cert == "" { } else if pk := c.String("tls-private-key"); pk == "" { } else { diff --git a/lock.json b/lock.json index 5c69d39..fa58056 100644 --- a/lock.json +++ b/lock.json @@ -1,5 +1,5 @@ { - "memo": "332a50078a5c89ced2186b6c9b5e55af4ac02ba87d5990e840080683f703ca9a", + "memo": "5b27aecb0272e40f3b8b8f9a5deeb7c9f5dbf06c53b1134de2b84eac466d27e0", "projects": [ { "name": "github.com/PuerkitoBio/ghost", @@ -10,6 +10,15 @@ "handlers" ] }, + { + "name": "github.com/VojtechVitek/ratelimit", + "branch": "master", + "revision": "dc172bc0f6d241e980010dbc63957ef1a2c8ca33", + "packages": [ + ".", + "memory" + ] + }, { "name": "github.com/dutchcoders/go-clamd", "branch": "master", @@ -100,14 +109,6 @@ "." ] }, - { - "name": "github.com/kennygrant/sanitize", - "version": "v1.2", - "revision": "6a0bfdde8629a3a3a7418a7eae45c54154692514", - "packages": [ - "." - ] - }, { "name": "github.com/mattn/go-colorable", "version": "v0.0.7", @@ -179,9 +180,7 @@ "revision": "a6577fac2d73be281a500b310739095313165611", "packages": [ "context", - "context/ctxhttp", - "html", - "html/atom" + "context/ctxhttp" ] }, { diff --git a/server/server.go b/server/server.go index 40cd5c9..c750b46 100644 --- a/server/server.go +++ b/server/server.go @@ -42,6 +42,8 @@ import ( context "golang.org/x/net/context" "github.com/PuerkitoBio/ghost/handlers" + "github.com/VojtechVitek/ratelimit" + "github.com/VojtechVitek/ratelimit/memory" "github.com/gorilla/mux" _ "net/http/pprof" @@ -116,6 +118,12 @@ func LogFile(s string) OptionFn { } } +func RateLimit(requests int) OptionFn { + return func(srvr *Server) { + srvr.rateLimitRequests = requests + } +} + func ForceHTTPs() OptionFn { return func(srvr *Server) { srvr.forceHTTPs = true @@ -180,6 +188,8 @@ type Server struct { locks map[string]*sync.Mutex + rateLimitRequests int + storage Storage forceHTTPs bool @@ -267,10 +277,12 @@ func (s *Server) Run() { r.PathPrefix("/favicon.ico").Handler(staticHandler) r.PathPrefix("/robots.txt").Handler(staticHandler) + r.HandleFunc("/health.html", healthHandler).Methods("GET") + r.HandleFunc("/", s.viewHandler).Methods("GET") + r.HandleFunc("/({files:.*}).zip", s.zipHandler).Methods("GET") r.HandleFunc("/({files:.*}).tar", s.tarHandler).Methods("GET") r.HandleFunc("/({files:.*}).tar.gz", s.tarGzHandler).Methods("GET") - r.HandleFunc("/download/{token}/{filename}", s.getHandler).Methods("GET") r.HandleFunc("/{token}/{filename}", s.previewHandler).MatcherFunc(func(r *http.Request, rm *mux.RouteMatch) (match bool) { match = false @@ -294,17 +306,22 @@ func (s *Server) Run() { return }).Methods("GET") - r.HandleFunc("/{token}/{filename}", s.getHandler).Methods("GET") - r.HandleFunc("/get/{token}/{filename}", s.getHandler).Methods("GET") + getHandlerFn := s.getHandler + if s.rateLimitRequests > 0 { + getHandlerFn = ratelimit.Request(ratelimit.IP).Rate(s.rateLimitRequests, 60*time.Second).LimitBy(memory.New())(http.HandlerFunc(getHandlerFn)).ServeHTTP + } + + r.HandleFunc("/{token}/{filename}", getHandlerFn).Methods("GET") + r.HandleFunc("/get/{token}/{filename}", getHandlerFn).Methods("GET") + r.HandleFunc("/download/{token}/{filename}", getHandlerFn).Methods("GET") + r.HandleFunc("/{filename}/virustotal", s.virusTotalHandler).Methods("PUT") r.HandleFunc("/{filename}/scan", s.scanHandler).Methods("PUT") r.HandleFunc("/put/{filename}", s.putHandler).Methods("PUT") r.HandleFunc("/upload/{filename}", s.putHandler).Methods("PUT") r.HandleFunc("/{filename}", s.putHandler).Methods("PUT") - r.HandleFunc("/health.html", healthHandler).Methods("GET") r.HandleFunc("/", s.postHandler).Methods("POST") // r.HandleFunc("/{page}", viewHandler).Methods("GET") - r.HandleFunc("/", s.viewHandler).Methods("GET") r.NotFoundHandler = http.HandlerFunc(s.notFoundHandler)