package main import ( "bufio" "bytes" "compress/flate" "compress/gzip" "context" "encoding/json" "fmt" "io" "log" "net/http" _ "net/http/pprof" "os" "os/signal" "strconv" "strings" "sync" "syscall" "time" redisprom "github.com/globocom/go-redis-prometheus" "github.com/go-redis/redis/v8" "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/tevino/abool/v2" ) const ( ItemChannelBuffer = 100000 ItemWrapSize = 100000 ) type ProjectRedisConfig struct { Host string `json:"host"` Pass string `json:"pass"` Port int `json:"port"` } type ProjectConfig struct { RedisConfig *ProjectRedisConfig `json:"redis,omitempty"` } type BackfeedItem struct { PrimaryShard byte SecondaryShard string Item []byte } type ProjectBackfeedManager struct { Context context.Context Cancel context.CancelFunc Done chan bool C chan *BackfeedItem Name string BackfeedRedis *redis.ClusterClient ProjectRedis *redis.Client //Lock sync.RWMutex ProjectConfig ProjectConfig } func (that *ProjectBackfeedManager) RedisConfigDiffers(new *ProjectRedisConfig) bool { if that.ProjectConfig.RedisConfig == nil && new == nil { return false } return that.ProjectConfig.RedisConfig == nil || new == nil || *that.ProjectConfig.RedisConfig != *new } func (that *ProjectBackfeedManager) PushItem(ctx context.Context, item *BackfeedItem) error { //that.Lock.RLock() //defer that.Lock.RUnlock() //if that.C == nil { // return false //} select { case <-ctx.Done(): return ctx.Err() case <-that.Context.Done(): return fmt.Errorf("backfeed channel closed") case that.C <- item: return nil //default: // return fmt.Errorf("backfeed channel full") } } func (that *ProjectBackfeedManager) PopItem(blocking bool) (*BackfeedItem, bool) { if blocking { select { case <-that.Context.Done(): return nil, false case item, ok := <-that.C: return item, ok } } else { select { case <-that.Context.Done(): return nil, false case item, ok := <-that.C: return item, ok default: return nil, false } } } //func (that *ProjectBackfeedManager) CloseItemChannel() { // log.Printf("closing item channel for %s", that.Name) // that.Lock.Lock() // defer that.Lock.Unlock() // if that.C == nil { // return // } // close(that.C) // that.C = nil //} func (that *ProjectBackfeedManager) Do() { defer close(that.Done) //defer that.CloseItemChannel() defer that.Cancel() for { select { case <-that.Context.Done(): break case <-that.Done: break default: } item, ok := that.PopItem(true) if !ok { break } keyMap := map[string][][]byte{} key := fmt.Sprintf("%s:%02x:%s", that.Name, item.PrimaryShard, item.SecondaryShard) keyMap[key] = append(keyMap[key], item.Item) wrapped := 1 for wrapped < ItemWrapSize { item, ok := that.PopItem(false) if !ok { break } key := fmt.Sprintf("%s:%02x:%s", that.Name, item.PrimaryShard, item.SecondaryShard) keyMap[key] = append(keyMap[key], item.Item) wrapped++ } select { case <-that.Context.Done(): break case <-that.Done: break default: } now := time.Now() resultMap := map[string]*redis.Cmd{} pipe := that.BackfeedRedis.Pipeline() lastTS := make([]any, 0, len(keyMap)*2) for key := range keyMap { lastTS = append(lastTS, key) lastTS = append(lastTS, fmt.Sprintf("%d", now.Unix())) } pipe.HSet(context.Background(), ":last_ts", lastTS...) for key, items := range keyMap { args := []any{ "bf.madd", key, } for _, item := range items { args = append(args, item) } resultMap[key] = pipe.Do(context.Background(), args...) } if _, err := pipe.Exec(context.Background()); err != nil { log.Printf("%s", err) } var sAddItems []any for key, items := range keyMap { res, err := resultMap[key].BoolSlice() if err != nil { log.Printf("%s", err) continue } if len(res) != len(keyMap[key]) { continue } for i, v := range res { if v { sAddItems = append(sAddItems, items[i]) } } } dupes := wrapped - len(sAddItems) if len(sAddItems) != 0 { if err := that.ProjectRedis.SAdd(context.Background(), fmt.Sprintf("%s:todo:backfeed", that.Name), sAddItems...).Err(); err != nil { log.Printf("failed to sadd items for %s: %s", that.Name, err) } } if dupes > 0 { that.BackfeedRedis.HIncrBy(context.Background(), ":", that.Name, int64(dupes)) } } } type GlobalBackfeedManager struct { Context context.Context Cancel context.CancelFunc ActiveFeeds map[string]*ProjectBackfeedManager ActiveSlugs map[string]string TrackerRedis *redis.Client BackfeedRedis *redis.ClusterClient Lock sync.RWMutex Populated *abool.AtomicBool } func (that *GlobalBackfeedManager) RefreshFeeds() error { slugProjectMap, err := that.TrackerRedis.HGetAll(that.Context, "backfeed").Result() if err != nil { return err } var projects []string projectSlugMap := map[string][]string{} for slug, project := range slugProjectMap { projectSlugMap[project] = append(projectSlugMap[project], slug) } for project := range projectSlugMap { projects = append(projects, project) } projectConfigs := map[string]ProjectConfig{} if len(projects) != 0 { cfgi, err := that.TrackerRedis.HMGet(that.Context, "trackers", projects...).Result() if err != nil { return err } if len(projects) != len(cfgi) { return fmt.Errorf("hmget result had unexpected length") } for i, project := range projects { configString, ok := cfgi[i].(string) if !ok { continue } config := ProjectConfig{} if err := json.Unmarshal([]byte(configString), &config); err != nil { continue } projectConfigs[project] = config } } projects = nil for project := range projectSlugMap { if _, has := projectConfigs[project]; !has { delete(projectSlugMap, project) continue } projects = append(projects, project) } for slug, project := range slugProjectMap { if _, has := projectConfigs[project]; !has { delete(slugProjectMap, slug) } } // add feeds for new projects for _, project := range projects { projectConfig := projectConfigs[project] var outdatedProjectBackfeedManager *ProjectBackfeedManager if projectBackfeedManager, has := that.ActiveFeeds[project]; has { if that.ActiveFeeds[project].RedisConfigDiffers(projectConfig.RedisConfig) { outdatedProjectBackfeedManager = projectBackfeedManager } else { continue } } ctx, cancel := context.WithCancel(that.Context) projectBackfeedManager := &ProjectBackfeedManager{ Context: ctx, Cancel: cancel, Done: make(chan bool), C: make(chan *BackfeedItem, ItemChannelBuffer), BackfeedRedis: that.BackfeedRedis, Name: project, ProjectConfig: projectConfig, } if projectConfig.RedisConfig != nil { projectBackfeedManager.ProjectRedis = redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", projectConfig.RedisConfig.Host, projectConfig.RedisConfig.Port), Username: "default", Password: projectConfig.RedisConfig.Pass, ReadTimeout: 15 * time.Minute, }) } else { projectBackfeedManager.ProjectRedis = that.TrackerRedis } go projectBackfeedManager.Do() that.Lock.Lock() that.ActiveFeeds[project] = projectBackfeedManager that.Lock.Unlock() if outdatedProjectBackfeedManager != nil { outdatedProjectBackfeedManager.Cancel() <-outdatedProjectBackfeedManager.Done log.Printf("updated project: %s", project) } else { log.Printf("added project: %s", project) } } that.Lock.Lock() that.ActiveSlugs = slugProjectMap that.Lock.Unlock() // remove feeds for old projects for project, projectBackfeedManager := range that.ActiveFeeds { if _, has := projectSlugMap[project]; has { continue } log.Printf("removing project: %s", project) that.Lock.Lock() delete(that.ActiveFeeds, project) that.Lock.Unlock() projectBackfeedManager.Cancel() <-projectBackfeedManager.Done log.Printf("removed project: %s", project) } if !that.Populated.IsSet() { that.Populated.Set() } return nil } type Splitter struct { Delimiter []byte IgnoreEOF bool } func (that *Splitter) Split(data []byte, atEOF bool) (int, []byte, error) { for i := 0; i < len(data); i++ { if bytes.Equal(data[i:i+len(that.Delimiter)], that.Delimiter) { return i + len(that.Delimiter), data[:i], nil } } if len(data) == 0 || !atEOF { return 0, nil, nil } if atEOF && that.IgnoreEOF { return len(data), data, nil } return 0, data, io.ErrUnexpectedEOF } func GenShardHash(b []byte) (final byte) { for i, b := range b { final = (b ^ final ^ byte(i)) + final + byte(i) + final*byte(i) } return final } func WriteResponse(res http.ResponseWriter, statusCode int, v any) { res.Header().Set("Content-Type", "application/json") res.WriteHeader(statusCode) if statusCode == http.StatusNoContent { return } if err, isError := v.(error); isError { v = map[string]any{ "error": fmt.Sprintf("%v", err), "status_code": statusCode, } } else { v = map[string]any{ "data": v, "status_code": statusCode, } } json.NewEncoder(res).Encode(v) } func (that *GlobalBackfeedManager) GetFeed(slug string) *ProjectBackfeedManager { that.Lock.RLock() defer that.Lock.RUnlock() project, has := that.ActiveSlugs[slug] if !has { return nil } projectBackfeedManager, has := that.ActiveFeeds[project] if !has { return nil } return projectBackfeedManager } type LastAccessStatsKey struct { Project string Shard string SubShard string } type LastAccessStatsMap map[LastAccessStatsKey]time.Time func (that LastAccessStatsMap) MarshalJSON() ([]byte, error) { mapped := map[string]string{} for key, value := range that { mapped[fmt.Sprintf("%s:%s:%s", key.Project, key.Shard, key.SubShard)] = value.Format(time.RFC3339) } return json.Marshal(mapped) } func LastAccessStatsKeyFromString(s string) (LastAccessStatsKey, error) { parts := strings.SplitN(s, ":", 3) if len(parts) != 3 { return LastAccessStatsKey{}, fmt.Errorf("invalid key: %s", s) } return LastAccessStatsKey{ Project: parts[0], Shard: parts[1], SubShard: parts[2], }, nil } func (that *GlobalBackfeedManager) HandleLastAccessStats(res http.ResponseWriter, req *http.Request) { defer req.Body.Close() merge := map[string]bool{} if vv, ok := req.URL.Query()["merge"]; ok { for _, v := range vv { merge[v] = true } } lastTs, err := that.BackfeedRedis.HGetAll(req.Context(), ":last_ts").Result() if err != nil { WriteResponse(res, http.StatusInternalServerError, err) return } lastAccessStats := LastAccessStatsMap{} for key, value := range lastTs { // value is in unix timestamp format ts, err := strconv.ParseInt(value, 10, 64) if err != nil { WriteResponse(res, http.StatusInternalServerError, err) return } lastAccessStatsKey, err := LastAccessStatsKeyFromString(key) if err != nil { WriteResponse(res, http.StatusInternalServerError, err) return } if merge["project"] { lastAccessStatsKey.Project = "*" } if merge["shard"] { lastAccessStatsKey.Shard = "*" } if merge["sub_shard"] { lastAccessStatsKey.SubShard = "*" } parsedTs := time.Unix(ts, 0) if v, has := lastAccessStats[lastAccessStatsKey]; !has || v.Before(parsedTs) { lastAccessStats[lastAccessStatsKey] = parsedTs } } WriteResponse(res, http.StatusOK, lastAccessStats) } func (that *GlobalBackfeedManager) HandleLegacy(res http.ResponseWriter, req *http.Request) { defer req.Body.Close() vars := mux.Vars(req) slug := vars["slug"] secondaryShard := req.URL.Query().Get("shard") projectBackfeedManager := that.GetFeed(slug) if projectBackfeedManager == nil { WriteResponse(res, http.StatusNotFound, fmt.Errorf("%s", "no such backfeed channel")) return } splitter := &Splitter{ Delimiter: []byte(req.URL.Query().Get("delimiter")), IgnoreEOF: req.URL.Query().Get("ignoreeof") != "", } if len(splitter.Delimiter) == 0 { splitter.Delimiter = []byte{0x00} } var body io.ReadCloser switch req.Header.Get("Content-Encoding") { case "": body = req.Body case "gzip": var err error body, err = gzip.NewReader(req.Body) if err != nil { WriteResponse(res, http.StatusBadRequest, err) return } defer body.Close() case "deflate": body = flate.NewReader(req.Body) defer body.Close() default: WriteResponse(res, http.StatusBadRequest, fmt.Errorf("unsupported Content-Encoding: %s", req.Header.Get("Content-Encoding"))) } scanner := bufio.NewScanner(req.Body) scanner.Split(splitter.Split) statusCode := http.StatusNoContent n := 0 for scanner.Scan() { b := scanner.Bytes() if len(b) == 0 { continue } bcopy := make([]byte, len(b)) copy(bcopy, b) item := &BackfeedItem{ PrimaryShard: GenShardHash(bcopy), SecondaryShard: secondaryShard, Item: bcopy, } if err := projectBackfeedManager.PushItem(req.Context(), item); err != nil { WriteResponse(res, http.StatusInternalServerError, err) return } n++ } if err := scanner.Err(); err != nil { WriteResponse(res, statusCode, err) return } WriteResponse(res, http.StatusOK, fmt.Sprintf("%d items queued for deduplication", n)) return } func (that *GlobalBackfeedManager) HandleHealth(res http.ResponseWriter, req *http.Request) { if that.Populated.IsNotSet() { WriteResponse(res, http.StatusServiceUnavailable, fmt.Errorf("%s", "backfeed not populated")) return } if err := that.BackfeedRedis.ForEachShard(req.Context(), func(ctx context.Context, client *redis.Client) error { client.ClientGetName(ctx) return client.Ping(ctx).Err() }); err != nil { WriteResponse(res, http.StatusInternalServerError, fmt.Errorf("failed to ping backfeed redis: %s", err)) return } WriteResponse(res, http.StatusOK, "ok") } func (that *GlobalBackfeedManager) HandlePing(res http.ResponseWriter, _ *http.Request) { WriteResponse(res, http.StatusOK, "pong") } func (that *GlobalBackfeedManager) CancelAllFeeds() { that.Populated.UnSet() that.Cancel() for project, projectBackfeedManager := range that.ActiveFeeds { log.Printf("waiting for %s channel to shut down...", project) <-projectBackfeedManager.Done delete(that.ActiveFeeds, project) } } func main() { log.SetFlags(log.Flags() | log.Lshortfile) trackerRedisOptions, err := redis.ParseURL(os.Getenv("REDIS_TRACKER")) if err != nil { log.Panicf("invalid REDIS_TRACKER url: %s", err) } trackerRedisOptions.ReadTimeout = 15 * time.Minute trackerRedisClient := redis.NewClient(trackerRedisOptions) backfeedRedisClient := redis.NewClusterClient(&redis.ClusterOptions{ Addrs: strings.Split(os.Getenv("REDIS_BACKFEED_ADDRS"), ","), Username: os.Getenv("REDIS_BACKFEED_USERNAME"), Password: os.Getenv("REDIS_BACKFEED_PASSWORD"), ReadTimeout: 15 * time.Minute, PoolSize: 256, }) backfeedRedisMetricsHook := redisprom.NewHook( redisprom.WithInstanceName("backfeed"), ) backfeedRedisClient.AddHook(backfeedRedisMetricsHook) trackerRedisMetricsHook := redisprom.NewHook( redisprom.WithInstanceName("tracker"), ) trackerRedisClient.AddHook(trackerRedisMetricsHook) if err := trackerRedisClient.Ping(context.Background()).Err(); err != nil { log.Panicf("unable to ping tracker redis: %s", err) } if err := backfeedRedisClient.Ping(context.Background()).Err(); err != nil { log.Panicf("unable to ping backfeed redis: %s", err) } err = backfeedRedisClient.ForEachShard(context.Background(), func(ctx context.Context, client *redis.Client) error { client.ClientGetName(ctx) return client.Ping(ctx).Err() }) globalBackfeedManager := &GlobalBackfeedManager{ ActiveFeeds: map[string]*ProjectBackfeedManager{}, ActiveSlugs: map[string]string{}, TrackerRedis: trackerRedisClient, BackfeedRedis: backfeedRedisClient, Populated: abool.New(), } globalBackfeedManager.Context, globalBackfeedManager.Cancel = context.WithCancel(context.Background()) defer globalBackfeedManager.CancelAllFeeds() if err := globalBackfeedManager.RefreshFeeds(); err != nil { log.Panicf("unable to set up backfeed projects: %s", err) } r := mux.NewRouter() r.Methods(http.MethodPost).Path("/legacy/{slug}").HandlerFunc(globalBackfeedManager.HandleLegacy) r.Methods(http.MethodGet).Path("/ping").HandlerFunc(globalBackfeedManager.HandlePing) r.Methods(http.MethodGet).Path("/health").HandlerFunc(globalBackfeedManager.HandleHealth) r.Methods(http.MethodGet).Path("/lastaccessstats").HandlerFunc(globalBackfeedManager.HandleLastAccessStats) rMetrics := mux.NewRouter() rMetrics.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux) rMetrics.Path("/metrics").Handler(promhttp.Handler()) doneChan := make(chan bool) serveErrChan := make(chan error) go func() { s := &http.Server{ Addr: os.Getenv("HTTP_ADDR"), IdleTimeout: 1 * time.Hour, MaxHeaderBytes: 1 * 1024 * 1024, Handler: r, } serveErrChan <- s.ListenAndServe() }() metricsErrChan := make(chan error) go func() { if os.Getenv("METRICS_ADDR") != "" { s := &http.Server{ Addr: os.Getenv("METRICS_ADDR"), IdleTimeout: 1 * time.Hour, MaxHeaderBytes: 1 * 1024 * 1024, Handler: rMetrics, } metricsErrChan <- s.ListenAndServe() } else { <-doneChan metricsErrChan <- nil } }() log.Printf("backfeed listening on %s", os.Getenv("HTTP_ADDR")) if os.Getenv("METRICS_ADDR") != "" { log.Printf("metrics/debug listening on %s", os.Getenv("METRICS_ADDR")) } sc := make(chan os.Signal, 1) signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill) ticker := time.NewTicker(1 * time.Second) for { select { case <-sc: return case <-ticker.C: } if err := globalBackfeedManager.RefreshFeeds(); err != nil { log.Printf("unable to refresh backfeed projects: %s", err) } } }