From 3f63988e2d63519fde46520a74944599a73b12fc Mon Sep 17 00:00:00 2001 From: dragon Date: Tue, 15 Apr 2025 17:23:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20gotemplate=20=E7=9A=84?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.go | 26 ++++++++++++------------- pkg/renders/gotemplate.go | 40 +++------------------------------------ pkg/utils/net.go | 30 +++++++++++++++++++++++++++++ pkg/utils/template.go | 25 ++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 50 deletions(-) create mode 100644 pkg/utils/net.go create mode 100644 pkg/utils/template.go diff --git a/config.go b/config.go index aa08193..bb2d4ce 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,8 @@ import ( "path/filepath" "time" + sprig "github.com/go-task/slim-sprig/v3" + "github.com/alecthomas/units" "code.d7z.net/d7z-project/gitea-pages/pkg" @@ -52,13 +54,13 @@ func (c *Config) NewPageServerOptions() (*pkg.ServerOptions, error) { if c.Page.DefaultBranch == "" { c.Page.DefaultBranch = "gh-pages" } - defaultErr := template.Must(template.New("err").Parse(defaultErrPage)) + defaultErr := template.Must(template.New("err").Funcs(sprig.FuncMap()).Parse(defaultErrPage)) if c.Page.ErrUnknownPage != "" { data, err := os.ReadFile(c.Page.ErrUnknownPage) if err != nil { return nil, errors.Wrapf(err, "failed to read file %s", string(data)) } - c.pageErrUnknown = template.Must(template.New("err").Parse(c.Page.ErrUnknownPage)) + c.pageErrUnknown = template.Must(template.New("err").Funcs(sprig.FuncMap()).Parse(c.Page.ErrUnknownPage)) } else { c.pageErrUnknown = defaultErr } @@ -67,7 +69,7 @@ func (c *Config) NewPageServerOptions() (*pkg.ServerOptions, error) { if err != nil { return nil, errors.Wrapf(err, "failed to read file %s", c.Page.ErrNotFoundPage) } - c.pageErrNotFound = template.Must(template.New("err").Parse(string(data))) + c.pageErrNotFound = template.Must(template.New("err").Funcs(sprig.FuncMap()).Parse(string(data))) } else { c.pageErrNotFound = defaultErr } @@ -96,20 +98,18 @@ func (c *Config) NewPageServerOptions() (*pkg.ServerOptions, error) { func (c *Config) ErrorHandler(w http.ResponseWriter, r *http.Request, err error) { if errors.Is(err, os.ErrNotExist) { w.WriteHeader(http.StatusNotFound) - if err = c.pageErrNotFound.Execute(w, map[string]any{ - "err": err, - "req": r, - "code": 404, - }); err != nil { + if err = c.pageErrNotFound.Execute(w, utils.NewTemplateInject(r, map[string]any{ + "Error": err, + "Code": 404, + })); err != nil { zap.L().Error("failed to render error page", zap.Error(err)) } } else { w.WriteHeader(http.StatusInternalServerError) - if err = c.pageErrUnknown.Execute(w, map[string]any{ - "err": err, - "req": r, - "code": 500, - }); err != nil { + if err = c.pageErrUnknown.Execute(w, utils.NewTemplateInject(r, map[string]any{ + "Error": err, + "Code": 500, + })); err != nil { zap.L().Error("failed to render error page", zap.Error(err)) } } diff --git a/pkg/renders/gotemplate.go b/pkg/renders/gotemplate.go index b7dca25..93b1cc5 100644 --- a/pkg/renders/gotemplate.go +++ b/pkg/renders/gotemplate.go @@ -3,11 +3,11 @@ package renders import ( "bytes" "io" - "net" "net/http" - "strings" "text/template" + "code.d7z.net/d7z-project/gitea-pages/pkg/utils" + sprig "github.com/go-task/slim-sprig/v3" ) @@ -24,47 +24,13 @@ func (g GoTemplate) Render(w http.ResponseWriter, r *http.Request, input io.Read } out := &bytes.Buffer{} parse, err := template.New("tmpl").Funcs(sprig.FuncMap()).Option("missingkey=error").Parse(string(dataB)) - headers := make(map[string]string) - for k, vs := range r.Header { - headers[k] = strings.Join(vs, ",") - } if err != nil { return err } - err = parse.Execute(out, map[string]interface{}{ - "Request": map[string]any{ - "Headers": headers, - "Request": r.RequestURI, - "RemoteAddr": r.RemoteAddr, - "RemoteIP": GetRemoteIP(r), - }, - }) + err = parse.Execute(out, utils.NewTemplateInject(r, nil)) if err != nil { return err } _, err = out.WriteTo(w) return err } - -// 注意,相关 ip 获取未做反向代理安全判断,可能导致安全降级 - -func GetRemoteIP(r *http.Request) string { - // 最先取 cloudflare 的头 - if ip := r.Header.Get("CF-Connecting-IP"); ip != "" { - return ip - } - if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" { - ips := strings.Split(forwardedFor, ",") - if len(ips) > 0 { - return strings.TrimSpace(ips[0]) - } - } - if realIP := r.Header.Get("X-Real-IP"); realIP != "" { - return realIP - } - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return ip -} diff --git a/pkg/utils/net.go b/pkg/utils/net.go new file mode 100644 index 0000000..6bcb472 --- /dev/null +++ b/pkg/utils/net.go @@ -0,0 +1,30 @@ +package utils + +import ( + "net" + "net/http" + "strings" +) + +// 注意,相关 ip 获取未做反向代理安全判断,可能导致安全降级 + +func GetRemoteIP(r *http.Request) string { + // 最先取 cloudflare 的头 + if ip := r.Header.Get("CF-Connecting-IP"); ip != "" { + return ip + } + if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" { + ips := strings.Split(forwardedFor, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + return realIP + } + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return ip +} diff --git a/pkg/utils/template.go b/pkg/utils/template.go new file mode 100644 index 0000000..e60ce07 --- /dev/null +++ b/pkg/utils/template.go @@ -0,0 +1,25 @@ +package utils + +import ( + "net/http" + "strings" +) + +func NewTemplateInject(r *http.Request, def map[string]any) map[string]any { + if def == nil { + def = make(map[string]any) + } + headers := make(map[string]string) + for k, vs := range r.Header { + headers[k] = strings.Join(vs, ",") + } + def["Request"] = map[string]any{ + "Headers": headers, + "Path": r.URL.Path, + "Method": r.Method, + "RequestURI": r.RequestURI, + "RemoteAddr": r.RemoteAddr, + "RemoteIP": GetRemoteIP(r), + } + return def +}