Springboot integrates Wen Xinyiyan – non-streaming response and streaming response (front-end and back-end)

The so-called non-streaming response is to directly wait for Baidu to generate the answer and return it directly to you, and the latter is a form of streaming. Baidu generates the answer and returns the answer at the same time. This is the most common way we use ChatGPT. It is a kind of performance. When it answers questions, it always comes out word by word. Both of these answering methods have a certain scope of use. I think that if you do not need to generate many answers (limited by writing corresponding prompts), or can receive long waits, there is no problem with non-streaming responses.

But if you have certain requirements for network connection requests, such as when the front-end uses Uniapp for encoding, the default timeout for using uni.uploadFile is 10s. It seems that the timeout cannot be modified. I have not changed it successfully. . But this is not the key hh. When establishing a network connection, if the client exceeds the timeout period and has not received the message from the server, it will refuse to receive it. Even if you generate the answer only after a few tenths of a second, the client It will still refuse to receive it, so at this time, choosing streaming response is an inevitable choice.

This article is to filter the streaming responses in the Java part, or it would be better to direct the stream to the front end for processing. Most of the people on the market use SSE technology to maintain the entire conversation. Because Uniapp does not support this technology, I used websocket. Maintenance, much the same

Dependency introduction:

 <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>

        <dependency>
            <groupId>com.squareup.okhttp3</groupId>
            <artifactId>okhttp</artifactId>
            <version>4.9.3</version>
        </dependency>

Front-end part:

 //Disconnect and reconnect
            reconnect() {
                if(this.ohHideFlag)
                if (!this.is_open_socket) {
                    this.reconnectTimeOut = setTimeout(() => {
                        this.connectSocketInit();
                    }, 3000)
                }
            },
            connectSocketInit() {
                let token = getToken()
                this.socketTask = uni.connectSocket({
                    //in the case ofhttpthen usews,in the case ofhttpsthen usewss,Mini programs need to be recorded on the public platform
                    url: 'wss://' + this.socketUrl + '/websocket/' + token,
                    success: () => {
                        console.log("Preparing to createwebsocketmiddle...");
                        // Return instance
                        return this.socketTask
                    },
                });
                this.socketTask.onOpen((res) => {
                    console.log("WebSocketThe connection is normal!");
                    this.is_open_socket = true;
                    this.socketTask.onMessage((res) => {
                        if (result == "") {
                            return;
                            console.log("Answer completed")
                        }
                        let jsonString = res.data
                        const dataPrefix = "data: ";
                        if (jsonString.startsWith(dataPrefix)) {
                            jsonString = jsonString.substring(dataPrefix.length);
                        }

                        // parseJSONstring
                        const jsonObject = JSON.parse(jsonString);

                        // ObtainresultAttributes
                        const result = jsonObject.result;
                        console.log(result);
                        this.tempItem.content += result
                        this.scrollToBottom();
                    });
                })
                this.socketTask.onClose(() => {
                    console.log("has been closed")
                    this.is_open_socket = false;
                    this.reconnect();
                })
            },

Backend code:

package com.farm.controller;

import com.farm.chat.StreamChat;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import org.json.JSONException;
import org.json.JSONObject;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArraySet;


@Slf4j
@Component
@ServerEndpoint("/websocket/{target}") //createwsrequest path。
public class WebsocketServerEndpoint {
    private Session session;
    private String target;
    //Support continuous streaming push
    private InputStream inputStream;
    private final static CopyOnWriteArraySet<WebsocketServerEndpoint> websockets = new CopyOnWriteArraySet<>();

    @OnOpen
    public void onOpen(Session session , @PathParam("target") String target){
        this.session = session;
        this.target = target;
        websockets.add(this);
        log.info("websocket connect server success , target is {},total is {}",target,websockets.size());
    }

//This method will be triggered when the client actively contacts
    @OnMessage
    public void onMessage(String message) throws IOException, JSONException {
        log.info("message is {}",message);
        JSONObject jsonObject = new JSONObject(message);
        String user = (String)jsonObject.get("user");
        String question = (String)jsonObject.get("message");

        StreamChat streamChat = new StreamChat();
        ResponseBody body = streamChat.getAnswerStream(question);
        InputStream inputStream = body.byteStream();

        sendMessageSync(user,inputStream);
    }

    @OnClose
    public void onClose(){
        log.info("connection has been closed ,target is {},total is {}" ,this.target, websockets.size());
        this.destroy();
    }

    @OnError
    public void onError(Throwable throwable){
        this.destroy();
        log.info("websocket connect error , target is {} ,total is {}, error is {}",this.target ,websockets.size(),throwable.getMessage());
    }

    /**
     * Push messages based on target identity
     * @param target
     * @param message
     * @throws IOException
     */
    public void sendMessageOnce(String target, String message) throws IOException {
        this.sendMessage(target,message,false,null);
    }

    /**
     * stream Synchronous log output,passwebsocketPush to front desk。
     * @param target
     * @param is
     * @throws IOException
     */
    private void sendMessageSync(String target, InputStream is) throws IOException {
        WebsocketServerEndpoint websocket = getWebsocket(target);
        if (Objects.isNull(websocket)) {
            throw new RuntimeException("The websocket does not exist or has been closed.");
        }
        if (Objects.isNull(is)) {
            throw new RuntimeException("InputStream cannot be null.");
        } else {
            websocket.inputStream = is;
            CompletableFuture.runAsync(websocket::sendMessageWithInputSteam);
        }
    }


    /**
     * Send message.
     * @param target passtargetObtain{@link WebsocketServerEndpoint}.
     * @param message message
     * @param continuous pass or notinputStreamContinuously push messages。
     * @param is input stream
     * @throws IOException
     */
    private void sendMessage(String target , String message ,Boolean continuous , InputStream is) throws IOException {
        WebsocketServerEndpoint websocket = getWebsocket(target);
        if(Objects.isNull(websocket)){
            throw new RuntimeException("The websocket does not exists or has been closed.");
        }
        if(continuous){
            if(Objects.isNull(is)){
                throw new RuntimeException("InputStream can not be null when continuous is true.");
            }else{
                websocket.inputStream = is;
                CompletableFuture.runAsync(websocket::sendMessageWithInputSteam);
            }
        }else{
            websocket.session.getBasicRemote().sendText(message);
        }
    }

    /**
     * passinputStream Continuously push messages。
     * Supporting documents、information、Log etc.。
     */
    private void sendMessageWithInputSteam() {
        String message;
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(this.inputStream));
        try {
            while ((message = bufferedReader.readLine()) != null) {
                if(message.equals(""))
                    continue;
                if (websockets.contains(this)) {
                    System.out.println(message);
                    this.session.getBasicRemote().sendText(message);
                }
            }
        } catch (IOException e) {
            log.warn("SendMessage failed {}", e.getMessage());
        } finally {
            this.closeInputStream();
        }
    }

    /**
     * Get corresponding according to the target{@link WebsocketServerEndpoint}。
     * @param target Agreed subject
     * @return WebsocketServerEndpoint
     */
    private WebsocketServerEndpoint getWebsocket(String target){
        WebsocketServerEndpoint websocket = null;
        for (WebsocketServerEndpoint ws : websockets) {
            if (target.equals(ws.target)) {
                websocket = ws;
            }
        }
        return websocket;
    }

    private void closeInputStream(){
        if(Objects.nonNull(inputStream)){
            try {
                inputStream.close();
            } catch (Exception e) {
                log.warn("websocket close failed {}",e.getMessage());
            }
        }
    }

    private void destroy(){
        websockets.remove(this);
        this.closeInputStream();
    }
} 

StreamChat

package com.farm.chat;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.web.bind.annotation.GetMapping;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;


@Slf4j
public class StreamChat {

    //historical dialogue,need to followuser,assistant
    List<Map<String,String>> messages = new ArrayList<>();

    private final String ACCESS_TOKEN_URI = "https://aip.baidubce.com/oauth/2.0/token";
    private final String CHAT_URI = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-preview";
//Just fill in your own identification code here
    private String apiKey = " ";
    private String secretKey = " ";
    private int responseTimeOut = 5000;
    private OkHttpClient client ;
    private String accessToken = "";


    public boolean getAccessToken(){
        this.client = new OkHttpClient.Builder().readTimeout(responseTimeOut, TimeUnit.SECONDS).build();
        MediaType mediaType = MediaType.parse("application/json");
        RequestBody body = RequestBody.create(mediaType, "");
        //Create a request
        Request request = new Request.Builder()
                .url(ACCESS_TOKEN_URI+"?client_id=" + apiKey + "&client_secret=" + secretKey + "&grant_type=client_credentials")
                .method("POST",body)
                .addHeader("Content-Type", "application/json")
                .build();
        try {
            //Make a request using the browser object
            Response response = client.newCall(request).execute();
            //Can only be executed onceresponse.body().string()。The next execution will throw a stream closing exception.,Therefore, an object is needed to store the returned results
            String responseMessage = response.body().string();
            log.debug("ObtainaccessTokensuccess");
            JSONObject jsonObject = JSON.parseObject(responseMessage);
            accessToken = (String) jsonObject.get("access_token");
            return true;
        } catch (IOException e) {
            e.printStackTrace();
        }
        return false;
    }
    public ResponseBody getAnswerStream(String question){
        getAccessToken();
        OkHttpClient client = new OkHttpClient();

        HashMap<String, String> user = new HashMap<>();
        user.put("role","user");
        user.put("content",question);
        messages.add(user);
        String requestJson = constructRequestJson(1,0.95,0.8,1.0,true,messages);
        RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
        Request request = new Request.Builder()
                .url(CHAT_URI + "?access_token="+accessToken)
                .method("POST", body)
                .addHeader("Content-Type", "application/json")
                .build();

        StringBuilder answer = new StringBuilder();
        // Initiate an asynchronous request
        try {
            Response response = client.newCall(request).execute();
            // Check if the response is successful
            if (response.isSuccessful()) {
                // Get response stream
                return response.body();
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return null;
    }


    /**
     * Construct request parameters for the request
     * @param userId
     * @param temperature
     * @param topP
     * @param penaltyScore
     * @param messages
     * @return
     */
    public String constructRequestJson(Integer userId,
                                       Double temperature,
                                       Double topP,
                                       Double penaltyScore,
                                       boolean stream,
                                       List<Map<String, String>> messages) {
        Map<String,Object> request = new HashMap<>();
        request.put("user_id",userId.toString());
        request.put("temperature",temperature);
        request.put("top_p",topP);
        request.put("penalty_score",penaltyScore);
        request.put("stream",stream);
        request.put("messages",messages);
        System.out.println(JSON.toJSONString(request));
        return JSON.toJSONString(request);
    }

}