// Copyright 2023 The Gitea Authors. All rights reserved. // SPDX-License-Identifier: MIT package poll import ( "context" "errors" "fmt" "sync" "sync/atomic" runnerv1 "code.gitea.io/actions-proto-go/runner/v1" "connectrpc.com/connect" log "github.com/sirupsen/logrus" "golang.org/x/time/rate" "gitea.com/gitea/act_runner/internal/app/run" "gitea.com/gitea/act_runner/internal/pkg/client" "gitea.com/gitea/act_runner/internal/pkg/config" ) const PollerID = "PollerID" type Poller interface { Poll() Shutdown(ctx context.Context) error } type poller struct { client client.Client runner run.RunnerInterface cfg *config.Config tasksVersion atomic.Int64 // tasksVersion used to store the version of the last task fetched from the Gitea. pollingCtx context.Context shutdownPolling context.CancelFunc jobsCtx context.Context shutdownJobs context.CancelFunc done chan any } func New(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller { return (&poller{}).init(cfg, client, runner) } func (p *poller) init(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller { pollingCtx, shutdownPolling := context.WithCancel(context.Background()) jobsCtx, shutdownJobs := context.WithCancel(context.Background()) done := make(chan any) p.client = client p.runner = runner p.cfg = cfg p.pollingCtx = pollingCtx p.shutdownPolling = shutdownPolling p.jobsCtx = jobsCtx p.shutdownJobs = shutdownJobs p.done = done return p } func (p *poller) Poll() { limiter := rate.NewLimiter(rate.Every(p.cfg.Runner.FetchInterval), 1) wg := &sync.WaitGroup{} for i := 0; i < p.cfg.Runner.Capacity; i++ { wg.Add(1) go p.poll(i, wg, limiter) } wg.Wait() // signal the poller is finished close(p.done) } func (p *poller) Shutdown(ctx context.Context) error { p.shutdownPolling() select { case <-p.done: log.Trace("all jobs are complete") return nil case <-ctx.Done(): log.Trace("forcing the jobs to shutdown") p.shutdownJobs() <-p.done log.Trace("all jobs have been shutdown") return ctx.Err() } } func (p *poller) poll(id int, wg *sync.WaitGroup, limiter *rate.Limiter) { log.Infof("[poller %d] launched", id) defer wg.Done() for { if err := limiter.Wait(p.pollingCtx); err != nil { log.Infof("[poller %d] shutdown", id) return } task, ok := p.fetchTask(p.pollingCtx) if !ok { continue } p.runTaskWithRecover(p.jobsCtx, task) } } func (p *poller) runTaskWithRecover(ctx context.Context, task *runnerv1.Task) { defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic: %v", r) log.WithError(err).Error("panic in runTaskWithRecover") } }() if err := p.runner.Run(ctx, task); err != nil { log.WithError(err).Error("failed to run task") } } func (p *poller) fetchTask(ctx context.Context) (*runnerv1.Task, bool) { reqCtx, cancel := context.WithTimeout(ctx, p.cfg.Runner.FetchTimeout) defer cancel() // Load the version value that was in the cache when the request was sent. v := p.tasksVersion.Load() resp, err := p.client.FetchTask(reqCtx, connect.NewRequest(&runnerv1.FetchTaskRequest{ TasksVersion: v, })) if errors.Is(err, context.DeadlineExceeded) { log.Trace("deadline exceeded") err = nil } if err != nil { if errors.Is(err, context.Canceled) { log.WithError(err).Debugf("shutdown, fetch task canceled") } else { log.WithError(err).Error("failed to fetch task") } return nil, false } if resp == nil || resp.Msg == nil { return nil, false } if resp.Msg.TasksVersion > v { p.tasksVersion.CompareAndSwap(v, resp.Msg.TasksVersion) } if resp.Msg.Task == nil { return nil, false } // got a task, set `tasksVersion` to zero to focre query db in next request. p.tasksVersion.CompareAndSwap(resp.Msg.TasksVersion, 0) return resp.Msg.Task, true }