Explorar el Código

feat: add SSE support to request-client

zhongming4762 hace 9 meses
padre
commit
66822a5f95

+ 16 - 5
packages/effects/request/src/request-client/modules/sse.test.ts

@@ -89,7 +89,8 @@ describe('sSE', () => {
 
     expect(onMessage).toHaveBeenCalledTimes(2);
     expect(messages.join('')).toBe('hello world');
-    expect(onEnd).toHaveBeenCalledWith('hello world');
+    // onEnd 不再带参数
+    expect(onEnd).toHaveBeenCalled();
   });
 
   it('should apply request interceptors', async () => {
@@ -101,20 +102,30 @@ describe('sSE', () => {
       fulfilled: interceptor,
     });
 
-    vi.stubGlobal('fetch', createFetchMock(['data']));
-
     // 创建 fetch mock,并挂到全局
     const fetchMock = createFetchMock(['data']);
     vi.stubGlobal('fetch', fetchMock);
+
     await sse.requestSSE('/sse', undefined, {});
 
     expect(interceptor).toHaveBeenCalled();
     expect(fetchMock).toHaveBeenCalledWith(
-      'http://localhost//sse',
+      'http://localhost/sse',
       expect.objectContaining({
-        headers: expect.objectContaining({ 'x-test': 'intercepted' }),
+        headers: expect.any(Headers),
       }),
     );
+
+    const calls = fetchMock.mock?.calls;
+    expect(calls).toBeDefined();
+    expect(calls?.length).toBeGreaterThan(0);
+
+    const init = calls?.[0]?.[1] as RequestInit;
+    expect(init).toBeDefined();
+
+    const headers = init?.headers as Headers;
+    expect(headers?.get('x-test')).toBe('intercepted');
+    expect(headers?.get('accept')).toBe('text/event-stream');
   });
 
   it('should throw error when no reader', async () => {

+ 56 - 16
packages/effects/request/src/request-client/modules/sse.ts

@@ -36,9 +36,10 @@ class SSE {
     requestOptions?: SseRequestOptions,
   ) {
     const baseUrl = this.client.getBaseUrl() || '';
-    const hasUrlSplit = baseUrl.endsWith('/') && url.startsWith('/');
 
-    const axiosConfig: InternalAxiosRequestConfig = {
+    let axiosConfig: InternalAxiosRequestConfig<any> = {
+      url,
+      method: (requestOptions?.method as any) ?? 'GET',
       headers: {} as AxiosRequestHeaders,
     };
     const requestInterceptors = this.client.instance.interceptors
@@ -48,25 +49,45 @@ class SSE {
       requestInterceptors.handlers.length > 0
     ) {
       for (const handler of requestInterceptors.handlers) {
-        if (handler.fulfilled) {
-          await handler.fulfilled(axiosConfig);
+        if (typeof handler?.fulfilled === 'function') {
+          const next = await handler.fulfilled(axiosConfig as any);
+          if (next) axiosConfig = next as InternalAxiosRequestConfig<any>;
         }
       }
     }
 
+    const merged = new Headers();
+    Object.entries(
+      (axiosConfig.headers ?? {}) as Record<string, string>,
+    ).forEach(([k, v]) => merged.set(k, String(v)));
+    if (requestOptions?.headers) {
+      new Headers(requestOptions.headers).forEach((v, k) => merged.set(k, v));
+    }
+    if (!merged.has('accept')) {
+      merged.set('accept', 'text/event-stream');
+    }
+
+    let bodyInit = requestOptions?.body ?? data;
+    const ct = (merged.get('content-type') || '').toLowerCase();
+    if (
+      bodyInit &&
+      typeof bodyInit === 'object' &&
+      !ArrayBuffer.isView(bodyInit as any) &&
+      !(bodyInit instanceof ArrayBuffer) &&
+      !(bodyInit instanceof Blob) &&
+      !(bodyInit instanceof FormData) &&
+      ct.includes('application/json')
+    ) {
+      bodyInit = JSON.stringify(bodyInit);
+    }
     const requestInit: RequestInit = {
       ...requestOptions,
-      body: data,
-      headers: {
-        ...(axiosConfig.headers as Record<string, string>),
-        ...requestOptions?.headers,
-      },
+      method: axiosConfig.method,
+      headers: merged,
+      body: bodyInit,
     };
 
-    const response = await fetch(
-      `${baseUrl}${hasUrlSplit ? '' : '/'}${url}`,
-      requestInit,
-    );
+    const response = await fetch(safeJoinUrl(baseUrl, url), requestInit);
     if (!response.ok) {
       throw new Error(`HTTP error! status: ${response.status}`);
     }
@@ -78,19 +99,38 @@ class SSE {
       throw new Error('No reader');
     }
     let isEnd = false;
-    let allMessage = '';
     while (!isEnd) {
       const { done, value } = await reader.read();
       if (done) {
         isEnd = true;
-        requestOptions?.onEnd?.(allMessage);
+        decoder.decode(new Uint8Array(0), { stream: false });
+        requestOptions?.onEnd?.();
+        reader.releaseLock?.();
         break;
       }
       const content = decoder.decode(value, { stream: true });
       requestOptions?.onMessage?.(content);
-      allMessage += content;
     }
   }
 }
 
+function safeJoinUrl(baseUrl: string | undefined, url: string): string {
+  if (!baseUrl) {
+    return url; // 没有 baseUrl,直接返回 url
+  }
+
+  // 如果 url 本身就是绝对地址,直接返回
+  if (/^https?:\/\//i.test(url)) {
+    return url;
+  }
+
+  // 如果 baseUrl 是完整 URL,就用 new URL
+  if (/^https?:\/\//i.test(baseUrl)) {
+    return new URL(url, baseUrl).toString();
+  }
+
+  // 否则,当作路径拼接
+  return `${baseUrl.replace(/\/+$/, '')}/${url.replace(/^\/+/, '')}`;
+}
+
 export { SSE };

+ 1 - 1
packages/effects/request/src/request-client/types.ts

@@ -46,7 +46,7 @@ type RequestClientOptions = CreateAxiosDefaults & ExtendOptions;
  */
 interface SseRequestOptions extends RequestInit {
   onMessage?: (message: string) => void;
-  onEnd?: (message: string) => void;
+  onEnd?: () => void;
 }
 
 interface RequestInterceptorConfig {