为相关函数补充 context

This commit is contained in:
dragon
2025-08-18 16:28:13 +08:00
parent a385414e3f
commit d74e356975
10 changed files with 81 additions and 77 deletions

View File

@@ -1,3 +1,5 @@
GOPATH := $(shell go env GOPATH)
fmt:
@(test -f "$(GOPATH)/bin/gofumpt" || go install mvdan.cc/gofumpt@latest) && \
"$(GOPATH)/bin/gofumpt" -l -w .

View File

@@ -1,6 +1,7 @@
package core
import (
"context"
"encoding/json"
"fmt"
@@ -21,8 +22,8 @@ func NewDomainAlias(config utils.KVConfig) *DomainAlias {
return &DomainAlias{config: config}
}
func (a *DomainAlias) Query(domain string) (*Alias, error) {
get, err := a.config.Get("domain/alias/" + domain)
func (a *DomainAlias) Query(ctx context.Context, domain string) (*Alias, error) {
get, err := a.config.Get(ctx, "domain/alias/"+domain)
if err != nil {
return nil, err
}
@@ -33,14 +34,14 @@ func (a *DomainAlias) Query(domain string) (*Alias, error) {
return rel, nil
}
func (a *DomainAlias) Bind(domains []string, owner, repo, branch string) error {
func (a *DomainAlias) Bind(ctx context.Context, domains []string, owner, repo, branch string) error {
oldDomains := make([]string, 0)
rKey := fmt.Sprintf("domain/r-alias/%s/%s/%s", owner, repo, branch)
if oldStr, err := a.config.Get(rKey); err == nil {
if oldStr, err := a.config.Get(ctx, rKey); err == nil {
_ = json.Unmarshal([]byte(oldStr), &oldDomains)
}
for _, oldDomain := range oldDomains {
if err := a.Unbind(oldDomain); err != nil {
if err := a.Unbind(ctx, oldDomain); err != nil {
return err
}
}
@@ -54,15 +55,15 @@ func (a *DomainAlias) Bind(domains []string, owner, repo, branch string) error {
}
aliasMetaRaw, _ := json.Marshal(aliasMeta)
domainsRaw, _ := json.Marshal(domains)
_ = a.config.Put(rKey, string(domainsRaw), utils.TtlKeep)
_ = a.config.Put(ctx, rKey, string(domainsRaw), utils.TtlKeep)
for _, domain := range domains {
if err := a.config.Put("domain/alias/"+domain, string(aliasMetaRaw), utils.TtlKeep); err != nil {
if err := a.config.Put(ctx, "domain/alias/"+domain, string(aliasMetaRaw), utils.TtlKeep); err != nil {
return err
}
}
return nil
}
func (a *DomainAlias) Unbind(domain string) error {
return a.config.Delete("domain/alias/" + domain)
func (a *DomainAlias) Unbind(ctx context.Context, domain string) error {
return a.config.Delete(ctx, "domain/alias/"+domain)
}

View File

@@ -2,6 +2,7 @@ package core
import (
"bytes"
"context"
"encoding/json"
stdErr "errors"
"fmt"
@@ -25,11 +26,11 @@ type BranchInfo struct {
type Backend interface {
Close() error
// Repos return repo name + default branch
Repos(owner string) (map[string]string, error)
Repos(ctx context.Context, owner string) (map[string]string, error)
// Branches return branch + commit id
Branches(owner, repo string) (map[string]*BranchInfo, error)
Branches(ctx context.Context, owner, repo string) (map[string]*BranchInfo, error)
// Open return file or error (error)
Open(client *http.Client, owner, repo, commit, path string, headers http.Header) (*http.Response, error)
Open(ctx context.Context, client *http.Client, owner, repo, commit, path string, headers http.Header) (*http.Response, error)
}
type CacheBackend struct {
@@ -46,15 +47,15 @@ func NewCacheBackend(backend Backend, config utils.KVConfig, ttl time.Duration)
return &CacheBackend{backend: backend, config: config, ttl: ttl}
}
func (c *CacheBackend) Repos(owner string) (map[string]string, error) {
func (c *CacheBackend) Repos(ctx context.Context, owner string) (map[string]string, error) {
ret := make(map[string]string)
key := fmt.Sprintf("repos/%s", owner)
store, err := c.config.Get(key)
store, err := c.config.Get(ctx, key)
if err != nil {
ret, err = c.backend.Repos(owner)
ret, err = c.backend.Repos(ctx, owner)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
_ = c.config.Put(key, "{}", c.ttl)
_ = c.config.Put(ctx, key, "{}", c.ttl)
}
return nil, err
}
@@ -62,7 +63,7 @@ func (c *CacheBackend) Repos(owner string) (map[string]string, error) {
if err != nil {
return nil, err
}
if err = c.config.Put(key, string(storeBin), c.ttl); err != nil {
if err = c.config.Put(ctx, key, string(storeBin), c.ttl); err != nil {
return nil, err
}
} else {
@@ -76,15 +77,15 @@ func (c *CacheBackend) Repos(owner string) (map[string]string, error) {
return ret, nil
}
func (c *CacheBackend) Branches(owner, repo string) (map[string]*BranchInfo, error) {
func (c *CacheBackend) Branches(ctx context.Context, owner, repo string) (map[string]*BranchInfo, error) {
ret := make(map[string]*BranchInfo)
key := fmt.Sprintf("branches/%s/%s", owner, repo)
data, err := c.config.Get(key)
data, err := c.config.Get(ctx, key)
if err != nil {
ret, err = c.backend.Branches(owner, repo)
ret, err = c.backend.Branches(ctx, owner, repo)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
_ = c.config.Put(key, "{}", c.ttl)
_ = c.config.Put(ctx, key, "{}", c.ttl)
}
return nil, err
}
@@ -92,7 +93,7 @@ func (c *CacheBackend) Branches(owner, repo string) (map[string]*BranchInfo, err
if err != nil {
return nil, err
}
if err = c.config.Put(key, string(data), c.ttl); err != nil {
if err = c.config.Put(ctx, key, string(data), c.ttl); err != nil {
return nil, err
}
} else {
@@ -106,8 +107,8 @@ func (c *CacheBackend) Branches(owner, repo string) (map[string]*BranchInfo, err
return ret, nil
}
func (c *CacheBackend) Open(client *http.Client, owner, repo, commit, path string, headers http.Header) (*http.Response, error) {
return c.backend.Open(client, owner, repo, commit, path, headers)
func (c *CacheBackend) Open(ctx context.Context, client *http.Client, owner, repo, commit, path string, headers http.Header) (*http.Response, error) {
return c.backend.Open(ctx, client, owner, repo, commit, path, headers)
}
type CacheBackendBlobReader struct {
@@ -121,7 +122,7 @@ func NewCacheBackendBlobReader(client *http.Client, base Backend, cache utils.Ca
return &CacheBackendBlobReader{client: client, base: base, cache: cache, maxSize: maxCacheSize}
}
func (c *CacheBackendBlobReader) Open(owner, repo, commit, path string) (io.ReadCloser, error) {
func (c *CacheBackendBlobReader) Open(ctx context.Context, owner, repo, commit, path string) (io.ReadCloser, error) {
key := fmt.Sprintf("%s/%s/%s/%s", owner, repo, commit, path)
lastCache, err := c.cache.Get(key)
if err != nil && !errors.Is(err, os.ErrNotExist) {
@@ -132,7 +133,7 @@ func (c *CacheBackendBlobReader) Open(owner, repo, commit, path string) (io.Read
} else if lastCache != nil {
return lastCache, nil
}
open, err := c.base.Open(c.client, owner, repo, commit, path, http.Header{})
open, err := c.base.Open(ctx, c.client, owner, repo, commit, path, http.Header{})
if err != nil || open == nil {
if open != nil {
_ = open.Body.Close()

View File

@@ -1,6 +1,7 @@
package core
import (
"context"
"os"
"strings"
@@ -36,19 +37,19 @@ type PageDomainContent struct {
Path string
}
func (p *PageDomain) ParseDomainMeta(domain, path, branch string) (*PageDomainContent, error) {
func (p *PageDomain) ParseDomainMeta(ctx context.Context, domain, path, branch string) (*PageDomainContent, error) {
if branch == "" {
branch = p.defaultBranch
}
pathArr := strings.Split(strings.TrimPrefix(path, "/"), "/")
if !strings.HasSuffix(domain, "."+p.baseDomain) {
alias, err := p.alias.Query(domain) // 确定 alias 是否存在内容
alias, err := p.alias.Query(ctx, domain) // 确定 alias 是否存在内容
if err != nil {
zap.L().Warn("未知域名", zap.String("base", p.baseDomain), zap.String("domain", domain), zap.Error(err))
return nil, os.ErrNotExist
}
zap.L().Debug("命中别名", zap.String("domain", domain), zap.Any("alias", alias))
return p.ReturnMeta(alias.Owner, alias.Repo, alias.Branch, pathArr)
return p.ReturnMeta(ctx, alias.Owner, alias.Repo, alias.Branch, pathArr)
}
owner := strings.TrimSuffix(domain, "."+p.baseDomain)
repo := pathArr[0]
@@ -57,9 +58,9 @@ func (p *PageDomain) ParseDomainMeta(domain, path, branch string) (*PageDomainCo
if repo == "" {
// 回退到默认仓库 (路径未包含仓库)
zap.L().Debug("fail back to default repo", zap.String("repo", domain))
returnMeta, err = p.ReturnMeta(owner, domain, branch, pathArr)
returnMeta, err = p.ReturnMeta(ctx, owner, domain, branch, pathArr)
} else {
returnMeta, err = p.ReturnMeta(owner, repo, branch, pathArr[1:])
returnMeta, err = p.ReturnMeta(ctx, owner, repo, branch, pathArr[1:])
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err
@@ -67,17 +68,17 @@ func (p *PageDomain) ParseDomainMeta(domain, path, branch string) (*PageDomainCo
return returnMeta, nil
}
// 发现 repo 的情况下回退到默认页面
return p.ReturnMeta(owner, domain, branch, pathArr)
return p.ReturnMeta(ctx, owner, domain, branch, pathArr)
}
func (p *PageDomain) ReturnMeta(owner string, repo string, branch string, path []string) (*PageDomainContent, error) {
func (p *PageDomain) ReturnMeta(ctx context.Context, owner string, repo string, branch string, path []string) (*PageDomainContent, error) {
rel := &PageDomainContent{}
if meta, err := p.GetMeta(owner, repo, branch); err == nil {
if meta, err := p.GetMeta(ctx, owner, repo, branch); err == nil {
rel.PageMetaContent = meta
rel.Owner = owner
rel.Repo = repo
rel.Path = strings.Join(path, "/")
if err = p.alias.Bind(meta.Alias, rel.Owner, rel.Repo, branch); err != nil {
if err = p.alias.Bind(ctx, meta.Alias, rel.Owner, rel.Repo, branch); err != nil {
zap.L().Warn("别名绑定失败", zap.Error(err))
return nil, err
}

View File

@@ -1,6 +1,7 @@
package core
import (
"context"
"fmt"
"io"
"net/http"
@@ -39,9 +40,9 @@ func NewServerMeta(client *http.Client, backend Backend, kv utils.KVConfig, doma
return &ServerMeta{backend, domain, client, kv, ttl, utils.NewLocker()}
}
func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, error) {
func (s *ServerMeta) GetMeta(ctx context.Context, owner, repo, branch string) (*PageMetaContent, error) {
rel := NewPageMetaContent()
if repos, err := s.Repos(owner); err != nil {
if repos, err := s.Repos(ctx, owner); err != nil {
return nil, err
} else {
defBranch := repos[repo]
@@ -52,7 +53,7 @@ func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, erro
branch = defBranch
}
}
if branches, err := s.Branches(owner, repo); err != nil {
if branches, err := s.Branches(ctx, owner, repo); err != nil {
return nil, err
} else {
info := branches[branch]
@@ -64,7 +65,7 @@ func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, erro
}
key := fmt.Sprintf("meta/%s/%s/%s", owner, repo, branch)
cache, err := s.cache.Get(key)
cache, err := s.cache.Get(ctx, key)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err
}
@@ -79,7 +80,7 @@ func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, erro
mux := s.locker.Open(key)
mux.Lock()
defer mux.Unlock()
cache, err = s.cache.Get(key)
cache, err = s.cache.Get(ctx, key)
if err == nil {
if err = rel.From(cache); err == nil {
if !rel.IsPage {
@@ -90,9 +91,9 @@ func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, erro
}
// 确定存在 index.html , 否则跳过
if find, _ := s.FileExists(owner, repo, rel.CommitID, "index.html"); !find {
if find, _ := s.FileExists(ctx, owner, repo, rel.CommitID, "index.html"); !find {
rel.IsPage = false
_ = s.cache.Put(key, rel.String(), s.ttl)
_ = s.cache.Put(ctx, key, rel.String(), s.ttl)
return nil, os.ErrNotExist
} else {
rel.IsPage = true
@@ -100,7 +101,7 @@ func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, erro
errFunc := func(err error) (*PageMetaContent, error) {
rel.IsPage = false
rel.ErrorMsg = err.Error()
_ = s.cache.Put(key, rel.String(), s.ttl)
_ = s.cache.Put(ctx, key, rel.String(), s.ttl)
return nil, err
}
// 添加默认跳过的内容
@@ -108,7 +109,7 @@ func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, erro
rel.ignoreL = append(rel.ignoreL, glob.MustCompile(defIgnore))
}
// 解析配置
if data, err := s.ReadString(owner, repo, rel.CommitID, ".pages.yaml"); err == nil {
if data, err := s.ReadString(ctx, owner, repo, rel.CommitID, ".pages.yaml"); err == nil {
cfg := new(PageConfig)
if err = yaml.Unmarshal([]byte(data), cfg); err != nil {
return errFunc(err)
@@ -172,7 +173,7 @@ func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, erro
}
// 兼容 github 的 CNAME 模式
if cname, err := s.ReadString(owner, repo, rel.CommitID, "CNAME"); err == nil {
if cname, err := s.ReadString(ctx, owner, repo, rel.CommitID, "CNAME"); err == nil {
cname = strings.TrimSpace(cname)
if regexpHostname.MatchString(cname) && !strings.HasSuffix(strings.ToLower(cname), strings.ToLower(s.Domain)) {
rel.Alias = append(rel.Alias, cname)
@@ -182,12 +183,12 @@ func (s *ServerMeta) GetMeta(owner, repo, branch string) (*PageMetaContent, erro
}
rel.Alias = utils.ClearDuplicates(rel.Alias)
rel.Ignore = utils.ClearDuplicates(rel.Ignore)
_ = s.cache.Put(key, rel.String(), s.ttl)
_ = s.cache.Put(ctx, key, rel.String(), s.ttl)
return rel, nil
}
func (s *ServerMeta) ReadString(owner, repo, branch, path string) (string, error) {
resp, err := s.Open(s.client, owner, repo, branch, path, nil)
func (s *ServerMeta) ReadString(ctx context.Context, owner, repo, branch, path string) (string, error) {
resp, err := s.Open(ctx, s.client, owner, repo, branch, path, nil)
if resp != nil {
defer resp.Body.Close()
}
@@ -204,8 +205,8 @@ func (s *ServerMeta) ReadString(owner, repo, branch, path string) (string, error
return string(all), nil
}
func (s *ServerMeta) FileExists(owner, repo, branch, path string) (bool, error) {
resp, err := s.Open(s.client, owner, repo, branch, path, nil)
func (s *ServerMeta) FileExists(ctx context.Context, owner, repo, branch, path string) (bool, error) {
resp, err := s.Open(ctx, s.client, owner, repo, branch, path, nil)
if resp != nil {
defer resp.Body.Close()
}

View File

@@ -1,6 +1,7 @@
package providers
import (
"context"
"net/http"
"net/url"
"os"
@@ -31,7 +32,7 @@ func NewGitea(url string, token string) (*ProviderGitea, error) {
}, nil
}
func (g *ProviderGitea) Repos(owner string) (map[string]string, error) {
func (g *ProviderGitea) Repos(ctx context.Context, owner string) (map[string]string, error) {
result := make(map[string]string)
if repos, resp, err := g.gitea.ListOrgRepos(owner, gitea.ListOrgReposOptions{
ListOptions: gitea.ListOptions{
@@ -74,7 +75,7 @@ func (g *ProviderGitea) Repos(owner string) (map[string]string, error) {
return result, nil
}
func (g *ProviderGitea) Branches(owner, repo string) (map[string]*core.BranchInfo, error) {
func (g *ProviderGitea) Branches(ctx context.Context, owner, repo string) (map[string]*core.BranchInfo, error) {
result := make(map[string]*core.BranchInfo)
if branches, resp, err := g.gitea.ListRepoBranches(owner, repo, gitea.ListRepoBranchesOptions{
ListOptions: gitea.ListOptions{
@@ -102,7 +103,7 @@ func (g *ProviderGitea) Branches(owner, repo string) (map[string]*core.BranchInf
return result, nil
}
func (g *ProviderGitea) Open(client *http.Client, owner, repo, commit, path string, headers http.Header) (*http.Response, error) {
func (g *ProviderGitea) Open(ctx context.Context, client *http.Client, owner, repo, commit, path string, headers http.Header) (*http.Response, error) {
giteaURL, err := url.JoinPath(g.BaseUrl, "api/v1/repos", owner, repo, "media", path)
if err != nil {
return nil, err

View File

@@ -110,8 +110,10 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}
func (s *Server) Serve(writer http.ResponseWriter, request *http.Request) error {
ctx := request.Context()
domainHost := portExp.ReplaceAllString(strings.ToLower(request.Host), "")
meta, err := s.meta.ParseDomainMeta(
ctx,
domainHost,
request.URL.Path,
request.URL.Query().Get("branch"))
@@ -167,19 +169,19 @@ func (s *Server) Serve(writer http.ResponseWriter, request *http.Request) error
zap.L().Debug("ignore path", zap.Any("request", request.RequestURI), zap.Any("meta.path", meta.Path))
err = os.ErrNotExist
} else {
result, err = s.reader.Open(meta.Owner, meta.Repo, meta.CommitID, meta.Path)
result, err = s.reader.Open(ctx, meta.Owner, meta.Repo, meta.CommitID, meta.Path)
}
if err != nil {
if errors.Is(err, os.ErrNotExist) {
if meta.VRoute {
// 回退 abc => index.html
result, err = s.reader.Open(meta.Owner, meta.Repo, meta.CommitID, "index.html")
result, err = s.reader.Open(ctx, meta.Owner, meta.Repo, meta.CommitID, "index.html")
if err == nil {
meta.Path = "index.html"
}
} else {
// 回退 abc => abc/ => abc/index.html
result, err = s.reader.Open(meta.Owner, meta.Repo, meta.CommitID, meta.Path+"/index.html")
result, err = s.reader.Open(ctx, meta.Owner, meta.Repo, meta.CommitID, meta.Path+"/index.html")
if err == nil {
meta.Path = strings.Trim(meta.Path+"/index.html", "/")
}
@@ -191,7 +193,7 @@ func (s *Server) Serve(writer http.ResponseWriter, request *http.Request) error
// 处理请求错误
if err != nil {
if errors.Is(err, os.ErrNotExist) {
result, err = s.reader.Open(meta.Owner, meta.Repo, meta.CommitID, "404.html")
result, err = s.reader.Open(ctx, meta.Owner, meta.Repo, meta.CommitID, "404.html")
if err != nil {
return err
}

View File

@@ -13,9 +13,9 @@ import (
const TtlKeep = -1
type KVConfig interface {
Put(key string, value string, ttl time.Duration) error
Get(key string) (string, error)
Delete(key string) error
Put(ctx context.Context, key string, value string, ttl time.Duration) error
Get(ctx context.Context, key string) (string, error)
Delete(ctx context.Context, key string) error
io.Closer
}
@@ -48,7 +48,7 @@ func NewAutoConfig(src string) (KVConfig, error) {
if err != nil {
return nil, err
}
return NewConfigRedis(context.Background(), parse.Host, pass, dbi)
return NewConfigRedis(parse.Host, pass, dbi)
default:
return nil, fmt.Errorf("unsupported scheme: %s", parse.Scheme)
}

View File

@@ -1,6 +1,7 @@
package utils
import (
"context"
"encoding/json"
"os"
"path/filepath"
@@ -52,7 +53,7 @@ type ConfigContent struct {
Ttl *time.Time `json:"ttl,omitempty"`
}
func (m *ConfigMemory) Put(key string, value string, ttl time.Duration) error {
func (m *ConfigMemory) Put(ctx context.Context, key string, value string, ttl time.Duration) error {
d := time.Now().Add(ttl)
td := &d
if ttl == -1 {
@@ -65,7 +66,7 @@ func (m *ConfigMemory) Put(key string, value string, ttl time.Duration) error {
return nil
}
func (m *ConfigMemory) Get(key string) (string, error) {
func (m *ConfigMemory) Get(ctx context.Context, key string) (string, error) {
if value, ok := m.data.Load(key); ok {
content := value.(ConfigContent)
if content.Ttl != nil && time.Now().After(*content.Ttl) {
@@ -76,7 +77,7 @@ func (m *ConfigMemory) Get(key string) (string, error) {
return "", os.ErrNotExist
}
func (m *ConfigMemory) Delete(key string) error {
func (m *ConfigMemory) Delete(ctx context.Context, key string) error {
m.data.Delete(key)
return nil
}

View File

@@ -13,11 +13,10 @@ import (
)
type ConfigRedis struct {
ctx context.Context
client valkey.Client
}
func NewConfigRedis(ctx context.Context, addr string, password string, db int) (*ConfigRedis, error) {
func NewConfigRedis(addr string, password string, db int) (*ConfigRedis, error) {
if addr == "" {
return nil, fmt.Errorf("addr is empty")
}
@@ -31,33 +30,28 @@ func NewConfigRedis(ctx context.Context, addr string, password string, db int) (
return nil, err
}
return &ConfigRedis{
ctx: ctx,
client: client,
}, nil
}
func (r *ConfigRedis) Put(key string, value string, ttl time.Duration) error {
func (r *ConfigRedis) Put(ctx context.Context, key string, value string, ttl time.Duration) error {
builder := r.client.B().Set().Key(key).Value(value)
if ttl != TtlKeep {
builder.Ex(ttl)
}
return r.client.Do(r.ctx, builder.Build()).Error()
return r.client.Do(ctx, builder.Build()).Error()
}
func (r *ConfigRedis) Get(key string) (string, error) {
v, err := r.client.Do(r.ctx, r.client.B().Get().Key(key).Build()).ToString()
func (r *ConfigRedis) Get(ctx context.Context, key string) (string, error) {
v, err := r.client.Do(ctx, r.client.B().Get().Key(key).Build()).ToString()
if err != nil && errors.Is(err, valkey.Nil) {
return "", os.ErrNotExist
}
return v, err
}
func (r *ConfigRedis) Delete(key string) error {
err := r.client.Do(r.ctx, r.client.B().Del().Key(key).Build()).Error()
if err != nil && errors.Is(err, valkey.Nil) {
return os.ErrNotExist
}
return nil
func (r *ConfigRedis) Delete(ctx context.Context, key string) error {
return r.client.Do(ctx, r.client.B().Del().Key(key).Build()).Error()
}
func (r *ConfigRedis) Close() error {