本文介绍: 前文,通过我们开发的Client能够正常的和 Open AI 进行交互能够调用GPT的API, 通过API将我们message 请求发送给GPT并且正常的接收到了GPT对我们的返回, 在前面我们去浏览 GPT 它的API的时候,我们发现它是支持流式访问的, 我们可以开发一个Stream的Client能够支持流式接收GPT的响应, 流式的Client很多场景下也是非常有必要的

Java版GPT的StreamClient

1 )核心代码结构设计

2 )相关程序如下

AbstractStreamListener.java

package com.xxx.gpt.client.listener;

import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSON;
import com.xxx.gpt.client.entity.ChatChoice;
import com.xxx.gpt.client.entity.ChatCompletionResponse;
import com.xxx.gpt.client.entity.Message;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;

import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;

@Slf4j
public abstract class AbstractStreamListener extends EventSourceListener {
    protected String lastMessage = "";

    /**
     * Called when all new message are received.
     *
     * @param message the new message
     */
    @Setter
    @Getter
    protected Consumer<String> onComplate = s -> {};

    /**
     * Called when a new message is received.
     * 收到消息 单个字
     *
     * @param message the new message
     */
    public abstract void onMsg(String message);
    /**
     * Called when an error occurs.
     * 出错调用
     *
     * @param throwable the throwable that caused the error
     * @param response  the response associated with the error, if any
     */
    public abstract void onError(Throwable throwable, String response);

    @Override
    public void onOpen(EventSource eventSource, Response response) {
        // do nothing
    }

    @Override
    public void onClosed(EventSource eventSource) {
        // do nothing
    }

    @Override
    public void onEvent(EventSource eventSource, String id, String type, String data) {
        if (data.equals("[DONE]")) {
            onComplate.accept(lastMessage);
            return;
        }
        // 将数据序列化为 GPT的 response
        ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
        // 获取GPT的返回读取Json
        List<ChatChoice> choices = response.getChoices();
        // 为空则 return
        if (choices == null || choices.isEmpty()) {
            return;
        }
        // 获取流式信息
        Message delta = choices.get(0).getDelta();
        String text = delta.getContent();
        if (text != null) {
            lastMessage += text;
            onMsg(text);
        }
    }

    @SneakyThrows
    @Override
    public void onFailure(EventSource eventSource, Throwable throwable, Response response) {
        try {
            log.error("Stream connection error: {}", throwable);
            String responseText = "";
            if (Objects.nonNull(response)) {
                responseText = response.body().string();
            }
            log.error("response:{}", responseText);
            String forbiddenText = "Your access was terminated due to violation of our policies";
            if (StrUtil.contains(responseText, forbiddenText)) {
                log.error("Chat session has been terminated due to policy violation");
                log.error("检测到号被封了");
            }
            String overloadedText = "That model is currently overloaded with other requests.";
            if (StrUtil.contains(responseText, overloadedText)) {
                log.error("检测官方超载了,赶紧优化你的代码,做重试吧");
            }
            this.onError(throwable, responseText);
        } catch (Exception e) {
            log.warn("onFailure error:{}", e);
            // do nothing
        } finally {
            eventSource.cancel();
        }
    }
}

ConsoleStreamListener.java

package com.xxx.gpt.client.listener;

import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ConsoleStreamListener extends AbstractStreamListener {
    @Override
    public void onMsg(String message) {
        System.out.print(message);
    }
    @Override
    public void onError(Throwable throwable, String response) {}
}

ChatGPTStreamClient.java

package com.xxx.gpt.client;

import cn.hutool.core.util.RandomUtil;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.xxx.gpt.client.entity.ChatCompletion;
import com.xxx.gpt.client.entity.Message;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;

import java.net.Proxy;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

@Slf4j
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatGPTStreamClient {
    private String apiKey;
    private List<String> apiKeyList;
    private OkHttpClient okHttpClient;
    /**
     * 连接超时
     */
    @Builder.Default
    private long timeout = 90;

    /**
     * 网络代理
     */
    @Builder.Default
    private Proxy proxy = Proxy.NO_PROXY;
    /**
     * 反向代理
     */
    @Builder.Default
    private String apiHost = ChatApi.CHAT_GPT_API_HOST;

    /**
     * 初始化
     */
    public ChatGPTStreamClient init() {
        OkHttpClient.Builder client = new OkHttpClient.Builder();
        client.connectTimeout(timeout, TimeUnit.SECONDS);
        client.writeTimeout(timeout, TimeUnit.SECONDS);
        client.readTimeout(timeout, TimeUnit.SECONDS);
        if (Objects.nonNull(proxy)) {
            client.proxy(proxy);
        }
        okHttpClient = client.build();
        return this;
    }

    /**
     * 流式输出
     */
    public void streamChatCompletion(ChatCompletion chatCompletion,
                                     EventSourceListener eventSourceListener) {
        chatCompletion.setStream(true);
        try {
            EventSource.Factory factory = EventSources.createFactory(okHttpClient);
            ObjectMapper mapper = new ObjectMapper();
            String requestBody = mapper.writeValueAsString(chatCompletion);
            String key = apiKey;
            if (apiKeyList != null &amp;& !apiKeyList.isEmpty()) {
                key = RandomUtil.randomEle(apiKeyList);
            }
            Request request = new Request.Builder()
                    .url(apiHost + "v1/chat/completions")
                    .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()),
                            requestBody))
                    .header("Authorization", "Bearer " + key)
                    .build();
            factory.newEventSource(request, eventSourceListener);
        } catch (Exception e) {
            log.error("请求出错:{}", e);
        }
    }

    /**
     * 流式输出
     */
    public void streamChatCompletion(List<Message> messages,
                                     EventSourceListener eventSourceListener) {
        ChatCompletion chatCompletion = ChatCompletion.builder()
                .messages(messages)
                .stream(true)
                .build();
        streamChatCompletion(chatCompletion, eventSourceListener);
    }
}

再添加一个测试方法 StreamClientTest.java

package com.xxx.gpt.client.test;

import com.xxx.gpt.client.ChatGPTStreamClient;
import com.xxx.gpt.client.entity.ChatCompletion;
import com.xxx.gpt.client.entity.Message;
import com.xxx.gpt.client.listener.ConsoleStreamListener;
import com.xxx.gpt.client.util.Proxys;
import org.junit.Before;
import org.junit.Test;

import java.net.Proxy;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;

public class StreamClientTest {
    private ChatGPTStreamClient chatGPTStreamClient;

    @Before
    public void before() {
        Proxy proxy = Proxys.http("127.0.0.1", 7890);
        chatGPTStreamClient = ChatGPTStreamClient.builder()
                .apiKey("sk-6kchadsfsfkc3aIs66ct") // 填入自己的 key
                .proxy(proxy)
                .timeout(600)
                .apiHost("https://api.openai.com/")
                .build()
                .init();
    }
    @Test
    public void chatCompletions() {
        ConsoleStreamListener listener = new ConsoleStreamListener();
        Message message = Message.of("写一段七言绝句诗");
        ChatCompletion chatCompletion = ChatCompletion.builder()
                .messages(Arrays.asList(message))
                .build();
        chatGPTStreamClient.streamChatCompletion(chatCompletion, listener);
        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }
}

原文地址:https://blog.csdn.net/Tyro_java/article/details/134781021

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任

如若转载,请注明出处:http://www.7code.cn/show_42170.html

如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱suwngjj01@126.com进行投诉反馈,一经查实,立即删除

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注