diff --git a/global-types/globals.d.ts b/global-types/globals.d.ts index 912b5f1..fbc4d3d 100644 --- a/global-types/globals.d.ts +++ b/global-types/globals.d.ts @@ -161,6 +161,24 @@ declare global { // @ts-ignore const console: Console; + + // Fetch API 相关类型 + interface FetchResponse { + ok: boolean; + status: number; + statusText: string; + headers: Record; + text(): Promise; + json(): Promise; + } + + interface FetchOptions { + method?: string; + headers?: Record; + body?: string; + } + + function fetch(url: string, options?: FetchOptions): Promise; } export {}; \ No newline at end of file diff --git a/pkg/filters/goja/goja.go b/pkg/filters/goja/goja.go index 4c839bc..947789d 100644 --- a/pkg/filters/goja/goja.go +++ b/pkg/filters/goja/goja.go @@ -84,6 +84,9 @@ func FilterInstGoJa(gl core.Params) (core.FilterInstance, error) { if err = EventInject(ctx, vm, jsLoop); err != nil { return err } + if err = FetchInject(vm, jsLoop); err != nil { + return err + } if global.EnableWebsocket { var closer io.Closer closer, err = WebsocketInject(ctx, vm, debug, request, jsLoop) diff --git a/pkg/filters/goja/var_fetch.go b/pkg/filters/goja/var_fetch.go new file mode 100644 index 0000000..de86d3a --- /dev/null +++ b/pkg/filters/goja/var_fetch.go @@ -0,0 +1,99 @@ +package goja + +import ( + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/dop251/goja" + "github.com/dop251/goja_nodejs/eventloop" +) + +func FetchInject(jsCtx *goja.Runtime, loop *eventloop.EventLoop) error { + return jsCtx.GlobalObject().Set("fetch", func(url string, options ...map[string]interface{}) *goja.Promise { + promise, resolve, reject := jsCtx.NewPromise() + + go func() { + method := "GET" + var body io.Reader + headers := make(http.Header) + + if len(options) > 0 { + opts := options[0] + if m, ok := opts["method"].(string); ok { + method = strings.ToUpper(m) + } + if h, ok := opts["headers"].(map[string]interface{}); ok { + for k, v := range h { + if strVal, ok := v.(string); ok { + headers.Set(k, strVal) + } + } + } + if b, ok := opts["body"].(string); ok { + body = strings.NewReader(b) + } + } + + req, err := http.NewRequest(method, url, body) + if err != nil { + loop.RunOnLoop(func(*goja.Runtime) { + _ = reject(err) + }) + return + } + req.Header = headers + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + loop.RunOnLoop(func(*goja.Runtime) { + _ = reject(err) + }) + return + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + loop.RunOnLoop(func(*goja.Runtime) { + _ = reject(err) + }) + return + } + + headersMap := make(map[string]interface{}) + for k, v := range resp.Header { + headersMap[k] = v + } + + loop.RunOnLoop(func(vm *goja.Runtime) { + responseObj := map[string]interface{}{ + "ok": resp.StatusCode >= 200 && resp.StatusCode < 300, + "status": resp.StatusCode, + "statusText": resp.Status, + "headers": headersMap, + "text": func() *goja.Promise { + p, res, _ := vm.NewPromise() + _ = res(string(respBody)) + return p + }, + "json": func() *goja.Promise { + p, res, rej := vm.NewPromise() + var data interface{} + if err := json.Unmarshal(respBody, &data); err != nil { + _ = rej(err) + } else { + _ = res(data) + } + return p + }, + } + _ = resolve(responseObj) + }) + }() + + return promise + }) +} diff --git a/pkg/filters/goja/var_websocket.go b/pkg/filters/goja/var_websocket.go index 63f7929..3d75700 100644 --- a/pkg/filters/goja/var_websocket.go +++ b/pkg/filters/goja/var_websocket.go @@ -3,6 +3,7 @@ package goja import ( "io" "net/http" + "sync" "time" "github.com/dop251/goja" @@ -21,6 +22,9 @@ func WebsocketInject(ctx core.FilterContext, jsCtx *goja.Runtime, w http.Respons if err != nil { return nil, err } + var readMu sync.Mutex + var writeMu sync.Mutex + zap.L().Debug("websocket upgrader created") closers.AddCloser(conn.Close) go func() { @@ -61,7 +65,9 @@ func WebsocketInject(ctx core.FilterContext, jsCtx *goja.Runtime, w http.Respons }) } }() + readMu.Lock() _, p, err := conn.ReadMessage() + readMu.Unlock() loop.RunOnLoop(func(runtime *goja.Runtime) { if err != nil { _ = reject(runtime.ToValue(err)) @@ -91,7 +97,9 @@ func WebsocketInject(ctx core.FilterContext, jsCtx *goja.Runtime, w http.Respons }) } }() + readMu.Lock() messageType, p, err := conn.ReadMessage() + readMu.Unlock() loop.RunOnLoop(func(runtime *goja.Runtime) { if err != nil { _ = reject(runtime.ToValue(err)) @@ -124,7 +132,9 @@ func WebsocketInject(ctx core.FilterContext, jsCtx *goja.Runtime, w http.Respons }) } }() + writeMu.Lock() err := conn.WriteMessage(websocket.TextMessage, []byte(data)) + writeMu.Unlock() loop.RunOnLoop(func(runtime *goja.Runtime) { if err != nil { _ = reject(runtime.ToValue(err)) @@ -171,7 +181,9 @@ func WebsocketInject(ctx core.FilterContext, jsCtx *goja.Runtime, w http.Respons return } + writeMu.Lock() err := conn.WriteMessage(mType, dataRaw) + writeMu.Unlock() loop.RunOnLoop(func(runtime *goja.Runtime) { if err != nil { _ = reject(runtime.ToValue(err)) diff --git a/tests/filter_goja_test.go b/tests/filter_goja_test.go index b5507af..8577a54 100644 --- a/tests/filter_goja_test.go +++ b/tests/filter_goja_test.go @@ -2,6 +2,7 @@ package tests import ( "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -87,6 +88,37 @@ routes: assert.Equal(t, "abc", string(data)) } +func Test_GoJa_Fetch(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "test-header") + _, _ = w.Write([]byte("fetched-content")) + })) + defer ts.Close() + + server := core.NewDefaultTestServer() + defer server.Close() + server.AddFile("org1/repo1/gh-pages/index.js", ` +(async()=>{ + const res = await fetch('%s') + response.setHeader('X-Fetched-Header', res.headers['X-Test'] || res.headers['x-test']) + const text = await res.text() + response.write(text) +})() +`, ts.URL) + server.AddFile("org1/repo1/gh-pages/index.html", "dummy") + server.AddFile("org1/repo1/gh-pages/.pages.yaml", ` +routes: +- path: "**" + js: + exec: "index.js" +`) + + data, resp, err := server.OpenFile("https://org1.example.com/repo1/fetch") + assert.NoError(t, err) + assert.Equal(t, "fetched-content", string(data)) + assert.Equal(t, "test-header", resp.Header.Get("X-Fetched-Header")) +} + func Benchmark_GoJa_Request(b *testing.B) { b.Setenv("BM", "1") server := core.NewDefaultTestServer()