更新 gotemplate 的参数注入
This commit is contained in:
26
config.go
26
config.go
@@ -8,6 +8,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
sprig "github.com/go-task/slim-sprig/v3"
|
||||||
|
|
||||||
"github.com/alecthomas/units"
|
"github.com/alecthomas/units"
|
||||||
|
|
||||||
"code.d7z.net/d7z-project/gitea-pages/pkg"
|
"code.d7z.net/d7z-project/gitea-pages/pkg"
|
||||||
@@ -52,13 +54,13 @@ func (c *Config) NewPageServerOptions() (*pkg.ServerOptions, error) {
|
|||||||
if c.Page.DefaultBranch == "" {
|
if c.Page.DefaultBranch == "" {
|
||||||
c.Page.DefaultBranch = "gh-pages"
|
c.Page.DefaultBranch = "gh-pages"
|
||||||
}
|
}
|
||||||
defaultErr := template.Must(template.New("err").Parse(defaultErrPage))
|
defaultErr := template.Must(template.New("err").Funcs(sprig.FuncMap()).Parse(defaultErrPage))
|
||||||
if c.Page.ErrUnknownPage != "" {
|
if c.Page.ErrUnknownPage != "" {
|
||||||
data, err := os.ReadFile(c.Page.ErrUnknownPage)
|
data, err := os.ReadFile(c.Page.ErrUnknownPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "failed to read file %s", string(data))
|
return nil, errors.Wrapf(err, "failed to read file %s", string(data))
|
||||||
}
|
}
|
||||||
c.pageErrUnknown = template.Must(template.New("err").Parse(c.Page.ErrUnknownPage))
|
c.pageErrUnknown = template.Must(template.New("err").Funcs(sprig.FuncMap()).Parse(c.Page.ErrUnknownPage))
|
||||||
} else {
|
} else {
|
||||||
c.pageErrUnknown = defaultErr
|
c.pageErrUnknown = defaultErr
|
||||||
}
|
}
|
||||||
@@ -67,7 +69,7 @@ func (c *Config) NewPageServerOptions() (*pkg.ServerOptions, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "failed to read file %s", c.Page.ErrNotFoundPage)
|
return nil, errors.Wrapf(err, "failed to read file %s", c.Page.ErrNotFoundPage)
|
||||||
}
|
}
|
||||||
c.pageErrNotFound = template.Must(template.New("err").Parse(string(data)))
|
c.pageErrNotFound = template.Must(template.New("err").Funcs(sprig.FuncMap()).Parse(string(data)))
|
||||||
} else {
|
} else {
|
||||||
c.pageErrNotFound = defaultErr
|
c.pageErrNotFound = defaultErr
|
||||||
}
|
}
|
||||||
@@ -96,20 +98,18 @@ func (c *Config) NewPageServerOptions() (*pkg.ServerOptions, error) {
|
|||||||
func (c *Config) ErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
func (c *Config) ErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
if err = c.pageErrNotFound.Execute(w, map[string]any{
|
if err = c.pageErrNotFound.Execute(w, utils.NewTemplateInject(r, map[string]any{
|
||||||
"err": err,
|
"Error": err,
|
||||||
"req": r,
|
"Code": 404,
|
||||||
"code": 404,
|
})); err != nil {
|
||||||
}); err != nil {
|
|
||||||
zap.L().Error("failed to render error page", zap.Error(err))
|
zap.L().Error("failed to render error page", zap.Error(err))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
if err = c.pageErrUnknown.Execute(w, map[string]any{
|
if err = c.pageErrUnknown.Execute(w, utils.NewTemplateInject(r, map[string]any{
|
||||||
"err": err,
|
"Error": err,
|
||||||
"req": r,
|
"Code": 500,
|
||||||
"code": 500,
|
})); err != nil {
|
||||||
}); err != nil {
|
|
||||||
zap.L().Error("failed to render error page", zap.Error(err))
|
zap.L().Error("failed to render error page", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ package renders
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
"code.d7z.net/d7z-project/gitea-pages/pkg/utils"
|
||||||
|
|
||||||
sprig "github.com/go-task/slim-sprig/v3"
|
sprig "github.com/go-task/slim-sprig/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,47 +24,13 @@ func (g GoTemplate) Render(w http.ResponseWriter, r *http.Request, input io.Read
|
|||||||
}
|
}
|
||||||
out := &bytes.Buffer{}
|
out := &bytes.Buffer{}
|
||||||
parse, err := template.New("tmpl").Funcs(sprig.FuncMap()).Option("missingkey=error").Parse(string(dataB))
|
parse, err := template.New("tmpl").Funcs(sprig.FuncMap()).Option("missingkey=error").Parse(string(dataB))
|
||||||
headers := make(map[string]string)
|
|
||||||
for k, vs := range r.Header {
|
|
||||||
headers[k] = strings.Join(vs, ",")
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = parse.Execute(out, map[string]interface{}{
|
err = parse.Execute(out, utils.NewTemplateInject(r, nil))
|
||||||
"Request": map[string]any{
|
|
||||||
"Headers": headers,
|
|
||||||
"Request": r.RequestURI,
|
|
||||||
"RemoteAddr": r.RemoteAddr,
|
|
||||||
"RemoteIP": GetRemoteIP(r),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = out.WriteTo(w)
|
_, err = out.WriteTo(w)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注意,相关 ip 获取未做反向代理安全判断,可能导致安全降级
|
|
||||||
|
|
||||||
func GetRemoteIP(r *http.Request) string {
|
|
||||||
// 最先取 cloudflare 的头
|
|
||||||
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
|
||||||
ips := strings.Split(forwardedFor, ",")
|
|
||||||
if len(ips) > 0 {
|
|
||||||
return strings.TrimSpace(ips[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
|
|
||||||
return realIP
|
|
||||||
}
|
|
||||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return r.RemoteAddr
|
|
||||||
}
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
|
|||||||
30
pkg/utils/net.go
Normal file
30
pkg/utils/net.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 注意,相关 ip 获取未做反向代理安全判断,可能导致安全降级
|
||||||
|
|
||||||
|
func GetRemoteIP(r *http.Request) string {
|
||||||
|
// 最先取 cloudflare 的头
|
||||||
|
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
||||||
|
ips := strings.Split(forwardedFor, ",")
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return strings.TrimSpace(ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
|
||||||
|
return realIP
|
||||||
|
}
|
||||||
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
25
pkg/utils/template.go
Normal file
25
pkg/utils/template.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTemplateInject(r *http.Request, def map[string]any) map[string]any {
|
||||||
|
if def == nil {
|
||||||
|
def = make(map[string]any)
|
||||||
|
}
|
||||||
|
headers := make(map[string]string)
|
||||||
|
for k, vs := range r.Header {
|
||||||
|
headers[k] = strings.Join(vs, ",")
|
||||||
|
}
|
||||||
|
def["Request"] = map[string]any{
|
||||||
|
"Headers": headers,
|
||||||
|
"Path": r.URL.Path,
|
||||||
|
"Method": r.Method,
|
||||||
|
"RequestURI": r.RequestURI,
|
||||||
|
"RemoteAddr": r.RemoteAddr,
|
||||||
|
"RemoteIP": GetRemoteIP(r),
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user