Bladeren bron

feat: add SSE support to request-client

zhongming4762 9 maanden geleden
bovenliggende
commit
eb4f1f8164

+ 131 - 0
packages/effects/request/src/request-client/modules/sse.test.ts

@@ -0,0 +1,131 @@
+import type { RequestClient } from '../request-client';
+
+import { beforeEach, describe, expect, it, vi } from 'vitest';
+
+import { SSE } from './sse';
+
+// 模拟 TextDecoder
+const OriginalTextDecoder = globalThis.TextDecoder;
+
+beforeEach(() => {
+  vi.stubGlobal(
+    'TextDecoder',
+    class {
+      private decoder = new OriginalTextDecoder();
+      decode(value: Uint8Array, opts?: any) {
+        return this.decoder.decode(value, opts);
+      }
+    },
+  );
+});
+
+// 创建 fetch mock
+const createFetchMock = (chunks: string[], ok = true) => {
+  const encoder = new TextEncoder();
+  let index = 0;
+  return vi.fn().mockResolvedValue({
+    ok,
+    status: ok ? 200 : 500,
+    body: {
+      getReader: () => ({
+        read: async () => {
+          if (index < chunks.length) {
+            return { done: false, value: encoder.encode(chunks[index++]) };
+          }
+          return { done: true, value: undefined };
+        },
+      }),
+    },
+  });
+};
+
+describe('sSE', () => {
+  let client: RequestClient;
+  let sse: SSE;
+
+  beforeEach(() => {
+    vi.restoreAllMocks();
+    client = {
+      getBaseUrl: () => 'http://localhost',
+      instance: {
+        interceptors: {
+          request: {
+            handlers: [],
+          },
+        },
+      },
+    } as unknown as RequestClient;
+    sse = new SSE(client);
+  });
+
+  it('should call requestSSE when postSSE is used', async () => {
+    const spy = vi.spyOn(sse, 'requestSSE').mockResolvedValue(undefined);
+    await sse.postSSE('/test', { foo: 'bar' }, { headers: { a: '1' } });
+    expect(spy).toHaveBeenCalledWith(
+      '/test',
+      { foo: 'bar' },
+      {
+        headers: { a: '1' },
+        method: 'POST',
+      },
+    );
+  });
+
+  it('should throw error if fetch response not ok', async () => {
+    vi.stubGlobal('fetch', createFetchMock([], false));
+    await expect(sse.requestSSE('/bad')).rejects.toThrow(
+      'HTTP error! status: 500',
+    );
+  });
+
+  it('should trigger onMessage and onEnd callbacks', async () => {
+    const messages: string[] = [];
+    const onMessage = vi.fn((msg: string) => messages.push(msg));
+    const onEnd = vi.fn();
+
+    vi.stubGlobal('fetch', createFetchMock(['hello', ' world']));
+
+    await sse.requestSSE('/sse', undefined, { onMessage, onEnd });
+
+    expect(onMessage).toHaveBeenCalledTimes(2);
+    expect(messages.join('')).toBe('hello world');
+    expect(onEnd).toHaveBeenCalledWith('hello world');
+  });
+
+  it('should apply request interceptors', async () => {
+    const interceptor = vi.fn(async (config) => {
+      config.headers['x-test'] = 'intercepted';
+      return config;
+    });
+    (client.instance.interceptors.request as any).handlers.push({
+      fulfilled: interceptor,
+    });
+
+    vi.stubGlobal('fetch', createFetchMock(['data']));
+
+    // 创建 fetch mock,并挂到全局
+    const fetchMock = createFetchMock(['data']);
+    vi.stubGlobal('fetch', fetchMock);
+    await sse.requestSSE('/sse', undefined, {});
+
+    expect(interceptor).toHaveBeenCalled();
+    expect(fetchMock).toHaveBeenCalledWith(
+      'http://localhost//sse',
+      expect.objectContaining({
+        headers: expect.objectContaining({ 'x-test': 'intercepted' }),
+      }),
+    );
+  });
+
+  it('should throw error when no reader', async () => {
+    vi.stubGlobal(
+      'fetch',
+      vi.fn().mockResolvedValue({
+        ok: true,
+        status: 200,
+        body: null,
+      }),
+    );
+    await expect(sse.requestSSE('/sse')).rejects.toThrow('No reader');
+  });
+});

+ 96 - 0
packages/effects/request/src/request-client/modules/sse.ts

@@ -0,0 +1,96 @@
+import type { AxiosRequestHeaders, InternalAxiosRequestConfig } from 'axios';
+
+import type { RequestClient } from '../request-client';
+import type { SseRequestOptions } from '../types';
+
+/**
+ * SSE模块
+ */
+class SSE {
+  private client: RequestClient;
+
+  constructor(client: RequestClient) {
+    this.client = client;
+  }
+
+  public async postSSE(
+    url: string,
+    data?: any,
+    requestOptions?: SseRequestOptions,
+  ) {
+    return this.requestSSE(url, data, {
+      ...requestOptions,
+      method: 'POST',
+    });
+  }
+
+  /**
+   * SSE请求方法
+   * @param url - 请求URL
+   * @param data - 请求数据
+   * @param requestOptions - SSE请求选项
+   */
+  public async requestSSE(
+    url: string,
+    data?: any,
+    requestOptions?: SseRequestOptions,
+  ) {
+    const baseUrl = this.client.getBaseUrl() || '';
+    const hasUrlSplit = baseUrl.endsWith('/') && url.startsWith('/');
+
+    const axiosConfig: InternalAxiosRequestConfig = {
+      headers: {} as AxiosRequestHeaders,
+    };
+    const requestInterceptors = this.client.instance.interceptors
+      .request as any;
+    if (
+      requestInterceptors.handlers &&
+      requestInterceptors.handlers.length > 0
+    ) {
+      for (const handler of requestInterceptors.handlers) {
+        if (handler.fulfilled) {
+          await handler.fulfilled(axiosConfig);
+        }
+      }
+    }
+
+    const requestInit: RequestInit = {
+      ...requestOptions,
+      body: data,
+      headers: {
+        ...(axiosConfig.headers as Record<string, string>),
+        ...requestOptions?.headers,
+      },
+    };
+
+    const response = await fetch(
+      `${baseUrl}${hasUrlSplit ? '' : '/'}${url}`,
+      requestInit,
+    );
+    if (!response.ok) {
+      throw new Error(`HTTP error! status: ${response.status}`);
+    }
+
+    const reader = response.body?.getReader();
+    const decoder = new TextDecoder();
+
+    if (!reader) {
+      throw new Error('No reader');
+    }
+    let isEnd = false;
+    let allMessage = '';
+    while (!isEnd) {
+      const { done, value } = await reader.read();
+      if (done) {
+        isEnd = true;
+        requestOptions?.onEnd?.(allMessage);
+        break;
+      }
+      const content = decoder.decode(value, { stream: true });
+      requestOptions?.onMessage?.(content);
+      allMessage += content;
+    }
+  }
+}
+
+export { SSE };

+ 15 - 1
packages/effects/request/src/request-client/request-client.ts

@@ -9,6 +9,7 @@ import qs from 'qs';
 
 import { FileDownloader } from './modules/downloader';
 import { InterceptorManager } from './modules/interceptor';
+import { SSE } from './modules/sse';
 import { FileUploader } from './modules/uploader';
 
 function getParamsSerializer(
@@ -41,12 +42,14 @@ class RequestClient {
   public addResponseInterceptor: InterceptorManager['addResponseInterceptor'];
   public download: FileDownloader['download'];
 
+  public readonly instance: AxiosInstance;
   // 是否正在刷新token
   public isRefreshing = false;
+  public postSSE: SSE['postSSE'];
   // 刷新token队列
   public refreshTokenQueue: ((token: string) => void)[] = [];
+  public requestSSE: SSE['requestSSE'];
   public upload: FileUploader['upload'];
-  private readonly instance: AxiosInstance;
 
   /**
    * 构造函数,用于创建Axios实例
@@ -84,6 +87,10 @@ class RequestClient {
     // 实例化文件下载器
     const fileDownloader = new FileDownloader(this);
     this.download = fileDownloader.download.bind(fileDownloader);
+    // 实例化SSE模块
+    const sse = new SSE(this);
+    this.postSSE = sse.postSSE.bind(sse);
+    this.requestSSE = sse.requestSSE.bind(sse);
   }
 
   /**
@@ -103,6 +110,13 @@ class RequestClient {
     return this.request<T>(url, { ...config, method: 'GET' });
   }
 
+  /**
+   * 获取基础URL
+   */
+  public getBaseUrl() {
+    return this.instance.defaults.baseURL;
+  }
+
   /**
    * POST请求方法
    */

+ 9 - 0
packages/effects/request/src/request-client/types.ts

@@ -41,6 +41,14 @@ type RequestContentType =
 
 type RequestClientOptions = CreateAxiosDefaults & ExtendOptions;
 
+/**
+ * SSE 请求选项
+ */
+interface SseRequestOptions extends RequestInit {
+  onMessage?: (message: string) => void;
+  onEnd?: (message: string) => void;
+}
+
 interface RequestInterceptorConfig {
   fulfilled?: (
     config: ExtendOptions & InternalAxiosRequestConfig,
@@ -78,4 +86,5 @@ export type {
   RequestInterceptorConfig,
   RequestResponse,
   ResponseInterceptorConfig,
+  SseRequestOptions,
 };