一、前言
在之前的两篇文章中,我们学习了如何构建基本的即时消息(IM)功能。今天,我们将进一步将IM模块与AI服务进行连接,实现用户提问并由模型进行回答,最后将结果展示在用户界面上。
二、术语
2.1. Spring Boot
是一个用于快速构建基于Spring框架的Java应用程序的开源框架。它简化了Spring应用程序的初始化和配置过程,使开发人员能够更专注于业务逻辑的实现。
2.2. 读超时时间(Read Timeout)
是指在进行网络通信时,接收数据的操作所允许的最长等待时间。当一个请求被发送到服务器,并且在规定的时间内没有收到服务器的响应数据,就会触发读超时错误。读超时时间用于控制客户端等待服务器响应的时间,以防止长时间的阻塞。
2.3. 写超时时间(Write Timeout)
是指在进行网络通信时,发送数据的操作所允许的最长等待时间。当一个请求被发送到服务器,但在规定的时间内无法将数据完全发送完成,就会触发写超时错误。写超时时间用于控制客户端发送数据的时间,以防止长时间的阻塞。
2.4. 连接超时时间(Connection Timeout)
是指在建立网络连接时,客户端尝试连接到服务器所允许的最长等待时间。当一个客户端尝试连接到服务器时,如果在规定的时间内无法建立连接,就会触发连接超时错误。连接超时时间用于控制客户端与服务器建立连接的时间,以防止长时间的等待。
三、前置条件
3.1. 调通最基本的WebSocket流程(参见开源模型应用落地-业务整合篇(二))
3.2. 已经部署至少单节点的AI服务
四、技术实现
# 打通IM和AI服务之间的通道
4.1. 新增AI服务调用的公共类
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.Objects;@Slf4j
@Component
public class AIChatUtils {@Autowiredprivate AIConfig aiConfig;private Request buildRequest(Long userId, String prompt) throws Exception {//创建一个请求体对象(body)MediaType mediaType = MediaType.parse("application/json");RequestBody requestBody = RequestBody.create(mediaType, prompt);return buildHeader(userId, new Request.Builder().post(requestBody)).url(aiConfig.getUrl()).build();}private Request.Builder buildHeader(Long userId, Request.Builder builder) throws Exception {return builder.addHeader("Content-Type", "application/json").addHeader("userId", String.valueOf(userId)).addHeader("secret",generateSecret(userId))}/*** 生成请求密钥** @param userId 用户ID* @return*/private String generateSecret(Long userId) throws Exception {String key = aiConfig.getServerKey();String content = key + userId + key;MessageDigest digest = MessageDigest.getInstance("SHA-256");byte[] hash = digest.digest(content.getBytes(StandardCharsets.UTF_8));StringBuilder hexString = new StringBuilder();for (byte b : hash) {String hex = Integer.toHexString(0xff & b);if (hex.length() == 1) {hexString.append('0');}hexString.append(hex);}return hexString.toString();}public String chatStream(ApiReqMessage apiReqMessage) throws Exception {//定义请求的参数String prompt = JSON.toJSONString(AIChatReqVO.init(apiReqMessage.getContents(), apiReqMessage.getHistory()));log.info("【AIChatUtils】调用AI聊天,用户({}),prompt:{}", apiReqMessage.getUserId(), prompt);//创建一个请求对象Request request = buildRequest(apiReqMessage.getUserId(), prompt);InputStream is = null;try {// 从线程池获取http请求并执行Response response =OkHttpUtils.getInstance(aiConfig).getOkHttpClient().newCall(request).execute();// 响应结果StringBuffer resultBuff = new StringBuffer();//正常返回if (response.code() == 200) {//打印返回的字符数据is = response.body().byteStream();byte[] bytes = new byte[1024];int len = is.read(bytes);while (len != -1) {ByteArrayOutputStream outputStream = new ByteArrayOutputStream();outputStream.write(bytes, 0, len);outputStream.flush();// 本轮读取到的数据String result = new String(outputStream.toByteArray(), StandardCharsets.UTF_8);resultBuff.append(result);len = is.read(bytes);// 将数据逐个传输给用户AbstractBusinessLogicHandler.pushChatMessageForUser(apiReqMessage.getUserId(), result);}// 正常响应return resultBuff.toString();}else {String result = response.body().string();log.warn("处理异常,异常描述:{}",result);}} catch (Throwable e) {log.error("【AIChatUtils】消息({})调用AI聊天 chatStream 异常,异常消息:{}", apiReqMessage.getMessageId(), e.getMessage(), e);} finally {if (!Objects.isNull(is)) {try {is.close();} catch (Exception e) {e.printStackTrace();}}}return null;}}
4.2. 新增OkHttp调用的公共类
import lombok.Getter;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import java.util.concurrent.TimeUnit;/*** http线程池工具类**/
public class OkHttpUtils {private static OkHttpUtils okHttpUtils ;@Getterprivate OkHttpClient okHttpClient;public OkHttpUtils(AIConfig aiConfig){this.okHttpClient = new OkHttpClient.Builder().readTimeout(aiConfig.getReadTimeout(), TimeUnit.SECONDS).connectTimeout(aiConfig.getConnectionTimeout(), TimeUnit.SECONDS).writeTimeout(aiConfig.getWriteTimeout(), TimeUnit.SECONDS).connectionPool(new ConnectionPool(aiConfig.getKeepAliveConnections(), aiConfig.getKeepAliveDuration(), TimeUnit.SECONDS)).build();}public static OkHttpUtils getInstance(AIConfig aiConfig){if (null == okHttpUtils){synchronized (OkHttpUtils.class){if (null == okHttpUtils){return new OkHttpUtils(aiConfig);}}}return okHttpUtils;}}
4.3. 修改第二篇定义好的具体业务处理类
import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandler;
import lombok.extern.slf4j.Slf4j;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;/*** @Description: 处理消息的handler*/
@Slf4j
@ChannelHandler.Sharable
@Component
public class BusinessHandler extends AbstractBusinessLogicHandler<TextWebSocketFrame> {@Autowiredprivate AIChatUtils aiChatUtils;@Overridepublic void handlerAdded(ChannelHandlerContext ctx) throws Exception {String channelId = ctx.channel().id().asShortText();log.info("add client,channelId:{}", channelId);}@Overridepublic void handlerRemoved(ChannelHandlerContext ctx) throws Exception {String channelId = ctx.channel().id().asShortText();log.info("remove client,channelId:{}", channelId);}@Overrideprotected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)throws Exception {// 获取客户端传输过来的消息String content = textWebSocketFrame.text();log.info("接收到客户端发送的信息: {}",content);Long userIdForReq;String msgType = "";String contents = "";try {ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);msgType = apiReqMessage.getMsgType();contents = apiReqMessage.getContents();userIdForReq = apiReqMessage.getUserId();// 添加用户if(!isExists(userIdForReq)){addChannel(channelHandlerContext, userIdForReq);}log.info("用户标识: {}, 消息类型: {}, 消息内容: {}",userIdForReq,msgType,contents);if(StringUtils.equals(msgType,String.valueOf(MsgType.CHAT.getCode()))){
// ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
// .respTime(String.valueOf(System.currentTimeMillis()))
// .contents("测试通过,很高兴收到你的信息")
// .msgType(String.valueOf(MsgType.CHAT.getCode()))
// .build();
// String response = JSON.toJSONString(apiRespMessage);
// channelHandlerContext.writeAndFlush(new TextWebSocketFrame(response));aiChatUtils.chatStream(apiReqMessage);}else{log.info("用户标识: {}, 消息类型有误,不支持类型: {}",userIdForReq,msgType);}} catch (Exception e) {log.warn("【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);// 异常返回return;}}}
PS:
1. 原继承SimpleChannelInboundHandler<TextWebSocketFrame>,现在继承自定义的AbstractBusinessLogicHandler<TextWebSocketFrame>
2. 用户连接上WebSocketServer之后,需要保存用户与channel之间的关系。此处采用userId(全局唯一)关联channel。具体参见:AbstractBusinessLogicHandler
4.4. 新增AbstractBusinessLogicHandler
import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.GlobalEventExecutor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.DisposableBean;
import java.util.concurrent.ConcurrentHashMap;@SuppressWarnings("all")
@Slf4j
public abstract class AbstractBusinessLogicHandler<I> extends SimpleChannelInboundHandler<I> implements DisposableBean {protected static final ConcurrentHashMap<Long, ChannelHandlerContext> USER_ID_TO_CHANNEL = new ConcurrentHashMap<>();/*** 添加socket通道** @param channelHandlerContext socket通道上下文*/protected void addChannel(ChannelHandlerContext channelHandlerContext, Long userId) {// 将当前通道存放起来USER_ID_TO_CHANNEL.put(userId, channelHandlerContext);}/*** 判斷用戶是否存在* @param userId* @return*/protected boolean isExists(Long userId){return USER_ID_TO_CHANNEL.containsKey(userId);}protected static void buildResponse(ChannelHandlerContext channelHandlerContext, int code, long respTime, int msgType, String msg) {buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(code)).respTime(String.valueOf(respTime)).msgType(String.valueOf(msgType)).contents(msg).build());}protected static void buildResponseIncludeOperateId(ChannelHandlerContext channelHandlerContext, int code, long respTime, int msgType, String msg, String operateId) {buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(code)).respTime(String.valueOf(respTime)).msgType(String.valueOf(msgType)).operateId(operateId).contents(msg).build());}protected static void buildResponse(ChannelHandlerContext channelHandlerContext, ApiRespMessage apiRespMessage) {String response = JSON.toJSONString(apiRespMessage);channelHandlerContext.writeAndFlush(new TextWebSocketFrame(response));}@Overridepublic void destroy() throws Exception {try {USER_ID_TO_CHANNEL.clear();} catch (Throwable e) {}}public static void pushChatMessageForUser(Long userId,String chatRespMessage) {ChannelHandlerContext channelHandlerContext = USER_ID_TO_CHANNEL.get(userId);if (channelHandlerContext != null ) {buildResponse(channelHandlerContext, ApiRespMessage.builder().code("200").respTime(String.valueOf(System.currentTimeMillis())).msgType(String.valueOf(MsgType.CHAT.getCode())).contents(chatRespMessage).build());return;}}}
4.5. AI配置类
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;@ConfigurationProperties(prefix="ai.server")
@Component("aiConfig")
@Setter
@Getter
@ToString
public class AIConfig {private String url;private Integer connectionTimeout;private Integer writeTimeout;private Integer readTimeout;private String serverKey;private Integer keepAliveConnections;private Integer keepAliveDuration;
}
4.6. AI配置类对应的具体配置
ai:server:url: http://127.0.0.1:9999/api/chatconnection_timeout: 3write_timeout: 30read_timeout: 30server_key: 88888888keep_alive_connections: 30keep_alive_duration: 60
PS:
1. 需要根据实际情况修改url和server_key
4.7.Netty配置类
package com.zwzt.communication.config;import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;@ConfigurationProperties(prefix="ws.server")
@Component
@Setter
@Getter
@ToString
public class NettyConfig {private String path;private int port;private int backlog;private int bossThread;private int workThread;private int businessThread;private int idleRead;private int idleWrite;private int idleAll;private int aggregator;
}
4.8.Netty配置类对应的具体配置
ws:server:path: /wsport: 7778backlog: 1024boss_thread: 1work_thread: 8business_thread: 16idle_read: 30idle_write: 30idle_all: 60aggregator: 65536
4.9.VO类
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;import java.util.List;@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class AIChatReqVO {// 问题private String prompt;// 对话历史private List<ChatContext> history;// AI模型参数private Double top_p;private Double temperature;private Double repetition_penalty;private Long max_new_tokens;public static AIChatReqVO init(String prompt, List<ChatContext> history) {return AIChatReqVO.builder().prompt(prompt).history(history).top_p(0.9).temperature(0.45).repetition_penalty(1.1).max_new_tokens(8192L).build();}}
4.10.实体类
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatContext {// 发送者private String from;// 发送数据private String value;
}
# 将Netty集成进SpringBoot项目
4.11.新增SpringBoot启动类
package com.zwzt.communication;import com.zwzt.communication.netty.server.Server;
import com.zwzt.communication.utils.SpringContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;@SpringBootApplication
@Slf4j
public class Application implements ApplicationListener<ContextRefreshedEvent> , ApplicationContextAware , DisposableBean {public static void main(String[] args) {SpringApplication.run(Application.class, args);}@Overridepublic void onApplicationEvent(ContextRefreshedEvent contextRefreshedEvent) {if (contextRefreshedEvent.getApplicationContext().getParent() == null) {try {//启动websocket服务new Thread(){@Overridepublic void run() {Server.getInstance().start();}}.start();} catch (Exception e) {log.error("webSocket server startup exception!",e);System.exit(-1);}}}@Overridepublic void setApplicationContext(ApplicationContext applicationContext) throws BeansException {SpringContextUtils.setApplicationContext(applicationContext);}@Overridepublic void destroy() throws Exception {try{Server.getInstance().close();}catch(Throwable e){}}
}
4.12.SpringBoot对应的配置
application.yml
server:port: 7777tomcat:uri-encoding: UTF-8
spring:application:name: ai_business_projectmain:banner-mode: "off"profiles:active: ai-dev
# 日志配置
logging:config: classpath:logback-spring.xml
application-ai-dev.yml
ai:server:url: http://127.0.0.1:9999/api/chatconnection_timeout: 3write_timeout: 30read_timeout: 30server_key: 88888888keep_alive_connections: 30keep_alive_duration: 60
ws:server:path: /wsport: 7778backlog: 1024boss_thread: 1work_thread: 8business_thread: 16idle_read: 30idle_write: 30idle_all: 60aggregator: 65536
4.13.Spring上下文公共类
import org.springframework.context.ApplicationContext;public class SpringContextUtils {private static ApplicationContext applicationContext;public static void setApplicationContext(ApplicationContext applicationContext){SpringContextUtils.applicationContext = applicationContext;}public static ApplicationContext getApplicationContext(){return applicationContext;}
}
4.14. 启动服务,执行Application类
启动成功后SpringBoot监听7777端口,WebSocket监听7778端口
五、测试
#沿用上一篇的代码,不需要调整
6.1. 页面测试
6.2. 在线测试
到此我们已经成功调通了整个IM与AI服务交互的链路
六、附带说明
6.1. 上面的代码还有很多需要改进的地方,尤其是没有考虑到一些非功能性需求。我们的主要目标是确保整个程序能够顺利运行,然后逐步进行改进和完善。
6.2.关于搭建Spring Boot项目,网上已经有很多成熟的案例可供参考。由于内容过长,这里就不再详细介绍了。