From 268f21c2af46e317a4053a43e0df72badf535602 Mon Sep 17 00:00:00 2001 From: ExplodingDragon Date: Tue, 18 Nov 2025 23:56:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BA=20filter=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E9=85=8D=E7=BD=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/local/main.go | 7 ++- cmd/server/config.go | 5 +- cmd/server/main.go | 6 ++- pkg/core/filter.go | 17 +++--- pkg/filters/block.go | 40 +++++++------- pkg/filters/common.go | 26 ++++++++- pkg/filters/default.go | 40 +++++++------- pkg/filters/direct.go | 100 +++++++++++++++++----------------- pkg/filters/failback.go | 58 ++++++++++---------- pkg/filters/goja/goja.go | 114 ++++++++++++++++++++------------------- pkg/filters/proxy.go | 60 +++++++++++---------- pkg/filters/redirect.go | 72 +++++++++++++------------ pkg/filters/template.go | 60 +++++++++++---------- pkg/server.go | 13 +++-- tests/core/test.go | 7 ++- 15 files changed, 340 insertions(+), 285 deletions(-) diff --git a/cmd/local/main.go b/cmd/local/main.go index 9d035a2..98564a4 100644 --- a/cmd/local/main.go +++ b/cmd/local/main.go @@ -56,7 +56,7 @@ func main() { if err != nil { zap.L().Fatal("failed to init memory provider", zap.Error(err)) } - server := pkg.NewPageServer(http.DefaultClient, + server, err := pkg.NewPageServer(http.DefaultClient, provider, domain, "gh-pages", memory, memory, 0, &nopCache{}, func(w http.ResponseWriter, r *http.Request, err error) { if errors.Is(err, os.ErrNotExist) { @@ -64,7 +64,10 @@ func main() { } else if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } - }) + }, make(map[string]map[string]any)) + if err != nil { + zap.L().Fatal("failed to init page", zap.Error(err)) + } err = http.ListenAndServe(port, server) if err != nil && !errors.Is(err, http.ErrServerClosed) { zap.L().Fatal("failed to start server", zap.Error(err)) diff --git a/cmd/server/config.go b/cmd/server/config.go index 6ee614e..8bc9dfa 100644 --- a/cmd/server/config.go +++ b/cmd/server/config.go @@ -29,11 +29,10 @@ type Config struct { Page ConfigPage `yaml:"page"` // 页面配置 - Render ConfigRender `yaml:"render"` // 渲染配置 - Proxy ConfigProxy `yaml:"proxy"` // 反向代理配置 - StaticDir string `yaml:"static"` // 静态资源提供路径 + Filters map[string]map[string]any `yaml:"filters"` // 渲染器配置 + pageErrNotFound, pageErrUnknown *template.Template } diff --git a/cmd/server/main.go b/cmd/server/main.go index 60e0a9d..dda9661 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -64,7 +64,7 @@ func main() { if !ok { log.Fatalln(errors.New("database not support cursor")) } - pageServer := pkg.NewPageServer( + pageServer, err := pkg.NewPageServer( http.DefaultClient, backend, config.Domain, @@ -74,7 +74,11 @@ func main() { config.Cache.MetaTTL, cacheBlob.Child("filter"), config.ErrorHandler, + config.Filters, ) + if err != nil { + log.Fatalln(err) + } ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT) defer stop() diff --git a/pkg/core/filter.go b/pkg/core/filter.go index 002c67e..80e27c2 100644 --- a/pkg/core/filter.go +++ b/pkg/core/filter.go @@ -22,14 +22,14 @@ type FilterContext struct { RepoDB kv.CursorPagedKV } -type FilterParams map[string]any +type Params map[string]any -func (f FilterParams) String() string { +func (f Params) String() string { marshal, _ := json.Marshal(f) return strings.ReplaceAll(string(marshal), "\"", "'") } -func (f FilterParams) Unmarshal(target any) error { +func (f Params) Unmarshal(target any) error { marshal, err := json.Marshal(f) if err != nil { return err @@ -38,9 +38,9 @@ func (f FilterParams) Unmarshal(target any) error { } type Filter struct { - Path string `json:"path"` - Type string `json:"type"` - Params FilterParams `json:"params"` + Path string `json:"path"` + Type string `json:"type"` + Params Params `json:"params"` } func NextCallWrapper(call FilterCall, parentCall NextCall, stack Filter) NextCall { @@ -69,4 +69,7 @@ type FilterCall func( next NextCall, ) error -type FilterInstance func(config FilterParams) (FilterCall, error) +type ( + GlobalFilter func(config Params) (FilterInstance, error) + FilterInstance func(route Params) (FilterCall, error) +) diff --git a/pkg/filters/block.go b/pkg/filters/block.go index bef006e..5ccc32f 100644 --- a/pkg/filters/block.go +++ b/pkg/filters/block.go @@ -6,25 +6,27 @@ import ( "gopkg.d7z.net/gitea-pages/pkg/core" ) -var FilterInstBlock core.FilterInstance = func(config core.FilterParams) (core.FilterCall, error) { - var param struct { - Code int `json:"code"` - Message string `json:"message"` - } - if err := config.Unmarshal(¶m); nil != err { - return nil, err - } - if param.Code == 0 { - param.Code = http.StatusForbidden - } - if param.Message == "" { - param.Message = http.StatusText(param.Code) - } - return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { - writer.WriteHeader(param.Code) - if param.Message != "" { - _, _ = writer.Write([]byte(param.Message)) +func FilterInstBlock(_ core.Params) (core.FilterInstance, error) { + return func(config core.Params) (core.FilterCall, error) { + var param struct { + Code int `json:"code"` + Message string `json:"message"` } - return nil + if err := config.Unmarshal(¶m); nil != err { + return nil, err + } + if param.Code == 0 { + param.Code = http.StatusForbidden + } + if param.Message == "" { + param.Message = http.StatusText(param.Code) + } + return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { + writer.WriteHeader(param.Code) + if param.Message != "" { + _, _ = writer.Write([]byte(param.Message)) + } + return nil + }, nil }, nil } diff --git a/pkg/filters/common.go b/pkg/filters/common.go index 39fbf1b..b9e5d6d 100644 --- a/pkg/filters/common.go +++ b/pkg/filters/common.go @@ -1,12 +1,19 @@ package filters import ( + "errors" + + "go.uber.org/zap" "gopkg.d7z.net/gitea-pages/pkg/core" "gopkg.d7z.net/gitea-pages/pkg/filters/goja" ) -func DefaultFilters() map[string]core.FilterInstance { - return map[string]core.FilterInstance{ +func DefaultFilters(config map[string]map[string]any) (map[string]core.FilterInstance, error) { + if config == nil { + return nil, errors.New("config is nil") + } + result := make(map[string]core.FilterInstance) + for key, instance := range map[string]core.GlobalFilter{ "block": FilterInstBlock, "redirect": FilterInstRedirect, "direct": FilterInstDirect, @@ -15,5 +22,20 @@ func DefaultFilters() map[string]core.FilterInstance { "failback": FilterInstFailback, "template": FilterInstTemplate, "js": goja.FilterInstGoJa, + } { + item, ok := config[key] + if !ok { + item = make(map[string]any) + } + if item["_disable"] == true { + zap.L().Debug("skip filter", zap.String("key", key)) + continue + } + inst, err := instance(item) + if err != nil { + return nil, err + } + result[key] = inst } + return result, nil } diff --git a/pkg/filters/default.go b/pkg/filters/default.go index 0770d67..887e4ca 100644 --- a/pkg/filters/default.go +++ b/pkg/filters/default.go @@ -9,25 +9,27 @@ import ( "gopkg.d7z.net/gitea-pages/pkg/core" ) -var FilterInstDefaultNotFound core.FilterInstance = func(config core.FilterParams) (core.FilterCall, error) { - return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { - err := next(ctx, writer, request) - if err != nil && errors.Is(err, os.ErrNotExist) { - open, err := ctx.NativeOpen(ctx, "/404.html", nil) - if open != nil { - defer open.Body.Close() +func FilterInstDefaultNotFound(_ core.Params) (core.FilterInstance, error) { + return func(config core.Params) (core.FilterCall, error) { + return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { + err := next(ctx, writer, request) + if err != nil && errors.Is(err, os.ErrNotExist) { + open, err := ctx.NativeOpen(ctx, "/404.html", nil) + if open != nil { + defer open.Body.Close() + } + if err != nil { + return err + } + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + if l := open.Header.Get("Content-Length"); l != "" { + writer.Header().Set("Content-Length", l) + } + writer.WriteHeader(http.StatusNotFound) + _, _ = io.Copy(writer, open.Body) + return nil } - if err != nil { - return err - } - writer.Header().Set("Content-Type", "text/html; charset=utf-8") - if l := open.Header.Get("Content-Length"); l != "" { - writer.Header().Set("Content-Length", l) - } - writer.WriteHeader(http.StatusNotFound) - _, _ = io.Copy(writer, open.Body) - return nil - } - return err + return err + }, nil }, nil } diff --git a/pkg/filters/direct.go b/pkg/filters/direct.go index cadd6d8..6208448 100644 --- a/pkg/filters/direct.go +++ b/pkg/filters/direct.go @@ -14,59 +14,61 @@ import ( "gopkg.d7z.net/gitea-pages/pkg/core" ) -var FilterInstDirect core.FilterInstance = func(config core.FilterParams) (core.FilterCall, error) { - var param struct { - Prefix string `json:"prefix"` - } - if err := config.Unmarshal(¶m); err != nil { - return nil, err - } - param.Prefix = strings.Trim(param.Prefix, "/") + "/" - return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { - err := next(ctx, writer, request) - if (err != nil && !errors.Is(err, os.ErrNotExist)) || err == nil { - return err +func FilterInstDirect(_ core.Params) (core.FilterInstance, error) { + return func(config core.Params) (core.FilterCall, error) { + var param struct { + Prefix string `json:"prefix"` } - if request.Method != http.MethodHead && request.Method != http.MethodGet { - http.Error(writer, "Method not allowed", http.StatusMethodNotAllowed) - return nil + if err := config.Unmarshal(¶m); err != nil { + return nil, err } - var resp *http.Response - var path string - defaultPath := param.Prefix + strings.TrimSuffix(ctx.Path, "/") - for _, p := range []string{defaultPath, defaultPath + "/index.html"} { - zap.L().Debug("direct fetch", zap.String("path", p)) - resp, err = ctx.NativeOpen(request.Context(), p, nil) - if err != nil { - if resp != nil { - resp.Body.Close() - } - if !errors.Is(err, os.ErrNotExist) { - zap.L().Debug("error", zap.Any("error", err)) - return err - } - continue + param.Prefix = strings.Trim(param.Prefix, "/") + "/" + return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { + err := next(ctx, writer, request) + if (err != nil && !errors.Is(err, os.ErrNotExist)) || err == nil { + return err } - path = p - break - } - if resp == nil { - return os.ErrNotExist - } - defer resp.Body.Close() - if err != nil { - return err - } - - writer.Header().Set("Content-Type", mime.TypeByExtension(filepath.Ext(path))) - lastMod, err := time.Parse(http.TimeFormat, resp.Header.Get("Last-Modified")) - if err == nil { - if seeker, ok := resp.Body.(io.ReadSeeker); ok { - http.ServeContent(writer, request, filepath.Base(path), lastMod, seeker) + if request.Method != http.MethodHead && request.Method != http.MethodGet { + http.Error(writer, "Method not allowed", http.StatusMethodNotAllowed) return nil } - } - _, err = io.Copy(writer, resp.Body) - return err + var resp *http.Response + var path string + defaultPath := param.Prefix + strings.TrimSuffix(ctx.Path, "/") + for _, p := range []string{defaultPath, defaultPath + "/index.html"} { + zap.L().Debug("direct fetch", zap.String("path", p)) + resp, err = ctx.NativeOpen(request.Context(), p, nil) + if err != nil { + if resp != nil { + resp.Body.Close() + } + if !errors.Is(err, os.ErrNotExist) { + zap.L().Debug("error", zap.Any("error", err)) + return err + } + continue + } + path = p + break + } + if resp == nil { + return os.ErrNotExist + } + defer resp.Body.Close() + if err != nil { + return err + } + + writer.Header().Set("Content-Type", mime.TypeByExtension(filepath.Ext(path))) + lastMod, err := time.Parse(http.TimeFormat, resp.Header.Get("Last-Modified")) + if err == nil { + if seeker, ok := resp.Body.(io.ReadSeeker); ok { + http.ServeContent(writer, request, filepath.Base(path), lastMod, seeker) + return nil + } + } + _, err = io.Copy(writer, resp.Body) + return err + }, nil }, nil } diff --git a/pkg/filters/failback.go b/pkg/filters/failback.go index 31ebded..c7cd0a8 100644 --- a/pkg/filters/failback.go +++ b/pkg/filters/failback.go @@ -12,37 +12,39 @@ import ( "gopkg.d7z.net/gitea-pages/pkg/core" ) -var FilterInstFailback core.FilterInstance = func(config core.FilterParams) (core.FilterCall, error) { - var param struct { - Path string `json:"path"` - } - if err := config.Unmarshal(¶m); err != nil { - return nil, err - } - if param.Path == "" { - return nil, errors.Errorf("filter failback: path is empty") - } - return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { - err := next(ctx, writer, request) - if (err != nil && !errors.Is(err, os.ErrNotExist)) || err == nil { - return err +func FilterInstFailback(_ core.Params) (core.FilterInstance, error) { + return func(config core.Params) (core.FilterCall, error) { + var param struct { + Path string `json:"path"` } - resp, err := ctx.NativeOpen(ctx, param.Path, nil) - if resp != nil { - defer resp.Body.Close() + if err := config.Unmarshal(¶m); err != nil { + return nil, err } - if err != nil { - return err + if param.Path == "" { + return nil, errors.Errorf("filter failback: path is empty") } - writer.Header().Set("Content-Type", mime.TypeByExtension(filepath.Ext(param.Path))) - lastMod, err := time.Parse(http.TimeFormat, resp.Header.Get("Last-Modified")) - if err == nil { - if seeker, ok := resp.Body.(io.ReadSeeker); ok { - http.ServeContent(writer, request, filepath.Base(param.Path), lastMod, seeker) - return nil + return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { + err := next(ctx, writer, request) + if (err != nil && !errors.Is(err, os.ErrNotExist)) || err == nil { + return err } - } - _, err = io.Copy(writer, resp.Body) - return err + resp, err := ctx.NativeOpen(ctx, param.Path, nil) + if resp != nil { + defer resp.Body.Close() + } + if err != nil { + return err + } + writer.Header().Set("Content-Type", mime.TypeByExtension(filepath.Ext(param.Path))) + lastMod, err := time.Parse(http.TimeFormat, resp.Header.Get("Last-Modified")) + if err == nil { + if seeker, ok := resp.Body.(io.ReadSeeker); ok { + http.ServeContent(writer, request, filepath.Base(param.Path), lastMod, seeker) + return nil + } + } + _, err = io.Copy(writer, resp.Body) + return err + }, nil }, nil } diff --git a/pkg/filters/goja/goja.go b/pkg/filters/goja/goja.go index 7c321b6..2793ab2 100644 --- a/pkg/filters/goja/goja.go +++ b/pkg/filters/goja/goja.go @@ -15,66 +15,68 @@ import ( "gopkg.d7z.net/gitea-pages/pkg/core" ) -var FilterInstGoJa core.FilterInstance = func(config core.FilterParams) (core.FilterCall, error) { - var param struct { - Exec string `json:"exec"` - Debug bool `json:"debug"` - } - if err := config.Unmarshal(¶m); err != nil { - return nil, err - } - if param.Exec == "" { - return nil, errors.New("no exec specified") - } - return func(ctx core.FilterContext, w http.ResponseWriter, request *http.Request, next core.NextCall) error { - js, err := ctx.ReadString(ctx, param.Exec) - if err != nil { - return err +func FilterInstGoJa(_ core.Params) (core.FilterInstance, error) { + return func(config core.Params) (core.FilterCall, error) { + var param struct { + Exec string `json:"exec"` + Debug bool `json:"debug"` } - prg, err := goja.Compile("main.js", js, false) - if err != nil { - return err + if err := config.Unmarshal(¶m); err != nil { + return nil, err } - debug := NewDebug(param.Debug && request.URL.Query().Get("debug") == "true", request, w) - registry := newRegistry(ctx) - registry.RegisterNativeModule(console.ModuleName, console.RequireWithPrinter(debug)) - loop := eventloop.NewEventLoop(eventloop.WithRegistry(registry), eventloop.EnableConsole(true)) - stop := make(chan struct{}, 1) - shutdown := make(chan struct{}, 1) - timeout, cancelFunc := context.WithTimeout(ctx, 3*time.Second) - defer cancelFunc() - count := 0 - go func() { - defer func() { - shutdown <- struct{}{} - close(shutdown) + if param.Exec == "" { + return nil, errors.New("no exec specified") + } + return func(ctx core.FilterContext, w http.ResponseWriter, request *http.Request, next core.NextCall) error { + js, err := ctx.ReadString(ctx, param.Exec) + if err != nil { + return err + } + prg, err := goja.Compile("main.js", js, false) + if err != nil { + return err + } + debug := NewDebug(param.Debug && request.URL.Query().Get("debug") == "true", request, w) + registry := newRegistry(ctx) + registry.RegisterNativeModule(console.ModuleName, console.RequireWithPrinter(debug)) + loop := eventloop.NewEventLoop(eventloop.WithRegistry(registry), eventloop.EnableConsole(true)) + stop := make(chan struct{}, 1) + shutdown := make(chan struct{}, 1) + timeout, cancelFunc := context.WithTimeout(ctx, 3*time.Second) + defer cancelFunc() + count := 0 + go func() { + defer func() { + shutdown <- struct{}{} + close(shutdown) + }() + select { + case <-timeout.Done(): + case <-stop: + } + count = loop.Stop() }() - select { - case <-timeout.Done(): - case <-stop: + loop.Run(func(vm *goja.Runtime) { + url.Enable(vm) + if err = RequestInject(ctx, vm, request); err != nil { + panic(err) + } + if err = ResponseInject(vm, debug, request); err != nil { + panic(err) + } + if err = KVInject(ctx, vm); err != nil { + panic(err) + } + _, err = vm.RunProgram(prg) + }) + stop <- struct{}{} + close(stop) + <-shutdown + if count != 0 { + err = errors.Join(context.DeadlineExceeded, err) } - count = loop.Stop() - }() - loop.Run(func(vm *goja.Runtime) { - url.Enable(vm) - if err = RequestInject(ctx, vm, request); err != nil { - panic(err) - } - if err = ResponseInject(vm, debug, request); err != nil { - panic(err) - } - if err = KVInject(ctx, vm); err != nil { - panic(err) - } - _, err = vm.RunProgram(prg) - }) - stop <- struct{}{} - close(stop) - <-shutdown - if count != 0 { - err = errors.Join(context.DeadlineExceeded, err) - } - return debug.Flush(err) + return debug.Flush(err) + }, nil }, nil } diff --git a/pkg/filters/proxy.go b/pkg/filters/proxy.go index 6a4c295..0075f10 100644 --- a/pkg/filters/proxy.go +++ b/pkg/filters/proxy.go @@ -13,36 +13,38 @@ import ( "gopkg.d7z.net/gitea-pages/pkg/utils" ) -var FilterInstProxy core.FilterInstance = func(config core.FilterParams) (core.FilterCall, error) { - var param struct { - Prefix string `json:"prefix"` - Target string `json:"target"` - } - if err := config.Unmarshal(¶m); err != nil { - return nil, err - } - return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { - proxyPath := "/" + ctx.Path - targetPath := strings.TrimPrefix(proxyPath, param.Prefix) - if !strings.HasPrefix(targetPath, "/") { - targetPath = "/" + targetPath +func FilterInstProxy(_ core.Params) (core.FilterInstance, error) { + return func(config core.Params) (core.FilterCall, error) { + var param struct { + Prefix string `json:"prefix"` + Target string `json:"target"` } - u, _ := url.Parse(param.Target) - request.URL.Path = targetPath - request.RequestURI = request.URL.RequestURI() - proxy := httputil.NewSingleHostReverseProxy(u) - // todo: 处理透传 - // proxy.Transport = s.options.HTTPClient.Transport - if host, _, err := net.SplitHostPort(request.RemoteAddr); err == nil { - request.Header.Set("X-Real-IP", host) + if err := config.Unmarshal(¶m); err != nil { + return nil, err } - request.Header.Set("X-Page-IP", utils.GetRemoteIP(request)) - request.Header.Set("X-Page-Refer", fmt.Sprintf("%s/%s/%s", ctx.Owner, ctx.Repo, ctx.Path)) - request.Header.Set("X-Page-Host", request.Host) - zap.L().Debug("命中反向代理", zap.Any("prefix", param.Prefix), zap.Any("target", param.Target), - zap.Any("path", proxyPath), zap.Any("target", fmt.Sprintf("%s%s", u, targetPath))) - // todo(security): 处理 websocket - proxy.ServeHTTP(writer, request) - return nil + return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { + proxyPath := "/" + ctx.Path + targetPath := strings.TrimPrefix(proxyPath, param.Prefix) + if !strings.HasPrefix(targetPath, "/") { + targetPath = "/" + targetPath + } + u, _ := url.Parse(param.Target) + request.URL.Path = targetPath + request.RequestURI = request.URL.RequestURI() + proxy := httputil.NewSingleHostReverseProxy(u) + // todo: 处理透传 + // proxy.Transport = s.options.HTTPClient.Transport + if host, _, err := net.SplitHostPort(request.RemoteAddr); err == nil { + request.Header.Set("X-Real-IP", host) + } + request.Header.Set("X-Page-IP", utils.GetRemoteIP(request)) + request.Header.Set("X-Page-Refer", fmt.Sprintf("%s/%s/%s", ctx.Owner, ctx.Repo, ctx.Path)) + request.Header.Set("X-Page-Host", request.Host) + zap.L().Debug("命中反向代理", zap.Any("prefix", param.Prefix), zap.Any("target", param.Target), + zap.Any("path", proxyPath), zap.Any("target", fmt.Sprintf("%s%s", u, targetPath))) + // todo(security): 处理 websocket + proxy.ServeHTTP(writer, request) + return nil + }, nil }, nil } diff --git a/pkg/filters/redirect.go b/pkg/filters/redirect.go index edb1ae4..5cddde0 100644 --- a/pkg/filters/redirect.go +++ b/pkg/filters/redirect.go @@ -15,41 +15,43 @@ import ( var portExp = regexp.MustCompile(`:\d+$`) -var FilterInstRedirect core.FilterInstance = func(config core.FilterParams) (core.FilterCall, error) { - var param struct { - Targets []string `json:"targets"` - Code int `json:"code"` - } - if err := config.Unmarshal(¶m); err != nil { - return nil, err - } - if len(param.Targets) == 0 { - return nil, errors.New("no targets") - } - if param.Code == 0 { - param.Code = http.StatusFound - } - if param.Code < 300 || param.Code > 399 { - return nil, fmt.Errorf("invalid code: %d", param.Code) - } - return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { - domain := portExp.ReplaceAllString(strings.ToLower(request.Host), "") - if len(param.Targets) > 0 && !slices.Contains(ctx.Alias, domain) { - // 重定向到配置的地址 - zap.L().Debug("redirect", zap.Any("src", request.Host), zap.Any("dst", param.Targets[0])) - path := ctx.Path - if strings.HasSuffix(path, "/index.html") || path == "index.html" { - path = strings.TrimSuffix(path, "index.html") - } - target, err := url.Parse(fmt.Sprintf("https://%s/%s", param.Targets[0], path)) - if err != nil { - return err - } - target.RawQuery = request.URL.RawQuery - - http.Redirect(writer, request, target.String(), param.Code) - return nil +func FilterInstRedirect(_ core.Params) (core.FilterInstance, error) { + return func(config core.Params) (core.FilterCall, error) { + var param struct { + Targets []string `json:"targets"` + Code int `json:"code"` } - return next(ctx, writer, request) + if err := config.Unmarshal(¶m); err != nil { + return nil, err + } + if len(param.Targets) == 0 { + return nil, errors.New("no targets") + } + if param.Code == 0 { + param.Code = http.StatusFound + } + if param.Code < 300 || param.Code > 399 { + return nil, fmt.Errorf("invalid code: %d", param.Code) + } + return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { + domain := portExp.ReplaceAllString(strings.ToLower(request.Host), "") + if len(param.Targets) > 0 && !slices.Contains(ctx.Alias, domain) { + // 重定向到配置的地址 + zap.L().Debug("redirect", zap.Any("src", request.Host), zap.Any("dst", param.Targets[0])) + path := ctx.Path + if strings.HasSuffix(path, "/index.html") || path == "index.html" { + path = strings.TrimSuffix(path, "index.html") + } + target, err := url.Parse(fmt.Sprintf("https://%s/%s", param.Targets[0], path)) + if err != nil { + return err + } + target.RawQuery = request.URL.RawQuery + + http.Redirect(writer, request, target.String(), param.Code) + return nil + } + return next(ctx, writer, request) + }, nil }, nil } diff --git a/pkg/filters/template.go b/pkg/filters/template.go index 569f029..b007dc7 100644 --- a/pkg/filters/template.go +++ b/pkg/filters/template.go @@ -9,36 +9,38 @@ import ( "gopkg.d7z.net/gitea-pages/pkg/utils" ) -var FilterInstTemplate core.FilterInstance = func(config core.FilterParams) (core.FilterCall, error) { - var param struct { - Prefix string `json:"prefix"` - } - if err := config.Unmarshal(¶m); err != nil { - return nil, err - } - param.Prefix = strings.Trim(param.Prefix, "/") + "/" - return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { - data, err := ctx.ReadString(ctx, param.Prefix+ctx.Path) - if err != nil { - return err +func FilterInstTemplate(_ core.Params) (core.FilterInstance, error) { + return func(config core.Params) (core.FilterCall, error) { + var param struct { + Prefix string `json:"prefix"` } - if err != nil { - return err + if err := config.Unmarshal(¶m); err != nil { + return nil, err } - out := &bytes.Buffer{} - parse, err := utils.NewTemplate().Funcs(map[string]any{ - "load": func(path string) (any, error) { - return ctx.ReadString(ctx, param.Prefix+path) - }, - }).Parse(data) - if err != nil { - return err - } - err = parse.Execute(out, utils.NewTemplateInject(request, nil)) - if err != nil { - return err - } - _, _ = out.WriteTo(writer) - return nil + param.Prefix = strings.Trim(param.Prefix, "/") + "/" + return func(ctx core.FilterContext, writer http.ResponseWriter, request *http.Request, next core.NextCall) error { + data, err := ctx.ReadString(ctx, param.Prefix+ctx.Path) + if err != nil { + return err + } + if err != nil { + return err + } + out := &bytes.Buffer{} + parse, err := utils.NewTemplate().Funcs(map[string]any{ + "load": func(path string) (any, error) { + return ctx.ReadString(ctx, param.Prefix+path) + }, + }).Parse(data) + if err != nil { + return err + } + err = parse.Execute(out, utils.NewTemplateInject(request, nil)) + if err != nil { + return err + } + _, _ = out.WriteTo(writer) + return nil + }, nil }, nil } diff --git a/pkg/server.go b/pkg/server.go index 09e296c..48a1968 100644 --- a/pkg/server.go +++ b/pkg/server.go @@ -44,22 +44,27 @@ func NewPageServer( cacheTTL time.Duration, cacheBlob cache.Cache, errorHandler func(w http.ResponseWriter, r *http.Request, err error), -) *Server { + filterConfig map[string]map[string]any, +) (*Server, error) { svcMeta := core.NewServerMeta(client, backend, domain, cacheMeta, cacheTTL) pageMeta := core.NewPageDomain(svcMeta, core.NewDomainAlias(db.Child("config").Child("alias")), domain, defaultBranch) globCache, err := lru.New[string, glob.Glob](256) if err != nil { - panic(err) + return nil, err + } + defaultFilters, err := filters.DefaultFilters(filterConfig) + if err != nil { + return nil, err } return &Server{ backend: backend, meta: pageMeta, db: db, globCache: globCache, - filterMgr: filters.DefaultFilters(), + filterMgr: defaultFilters, errorHandler: errorHandler, cacheBlob: cacheBlob, - } + }, nil } func (s *Server) ServeHTTP(w http.ResponseWriter, request *http.Request) { diff --git a/tests/core/test.go b/tests/core/test.go index 746590c..518398d 100644 --- a/tests/core/test.go +++ b/tests/core/test.go @@ -46,7 +46,7 @@ func NewTestServer(domain string) *TestServer { CleanupInt: time.Minute, }) memoryKV, _ := kv.NewMemory("") - server := pkg.NewPageServer( + server, err := pkg.NewPageServer( http.DefaultClient, dummy, domain, @@ -62,8 +62,11 @@ func NewTestServer(domain string) *TestServer { http.Error(w, err.Error(), http.StatusInternalServerError) } }, + make(map[string]map[string]any), ) - + if err != nil { + panic(err) + } return &TestServer{ dummy: dummy, server: server,