langchain4j源码阅读
这篇博客主要针对langchain4j源码的阅读,框架的使用请移步上一篇博客
以下是我阅读源码的版本
<!-- <langchain4j.version>1.0.0-beta1</langchain4j.version> -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-cohere</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-cohere</artifactId>
<version>${langchain4j.version}</version>
</dependency>
当我们使用langchain4j就是在构建一个AiServices,最后内部会返回一个自定义的基于jdk代理接口
通过建造者模式创建一个AI服务实例,具体包含以下几个关键配置
chatLanguageModel 配置底层的语言模型,负责实际的AI对话生成
chatMemory 设置默认对话记忆功能
chatMemoryProvider 自定义对话记忆功能
retrievalAugmentor 配置RAG的功能,可以从外部知识库检索相关信息
build 最终构建并返回一个实现了AiCustomer接口的代理对象
interface AiCustomer {
// String call(String query);
String chat(@MemoryId String memoryId, @UserMessage String userMessage);
}
String userMessage = "余额提现什么时候到账?";
AiCustomer aiCustomer = AiServices.builder(AiCustomer.class)
.chatLanguageModel(chatModel)
// memoryId固定为default
// .chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.retrievalAugmentor(retrievalAugmentor)
// 动态的memoryId
.chatMemoryProvider(memoryId -> MessageWindowChatMemory.builder()
.id(memoryId)
.maxMessages(100)
.chatMemoryStore(store)
.build())
.build();
String result = aiCustomer.call(userMessage);
构建
进入build方法,可以看到内部是如何实现的
我这里没有使用到SPI机制,所以执行的是默认的实现,也就是DefaultAiServices
public static <T> AiServices<T> builder(Class<T> aiService) {
AiServiceContext context = new AiServiceContext(aiService);
// 使用jdk自带的SPI机制获取resource中配置的AiServicesFactory
for (AiServicesFactory factory : loadFactories(AiServicesFactory.class)) {
return factory.create(context);
}
// 否则使用默认的AiServices
return new DefaultAiServices<>(context);
}
进行构建,build方法很长,所以后续代码块中的代码会被截断成多份
验证
build方法会进行验证,在创建代理对象之前进行各种配置的合法性检查,如果不合法则抛出异常,并且结束执行
检查是否配置了大语言模型
如果方法标注了@Moderate注解,必须配置相应的审核模型moderationModel,否则抛出异常
对于返回Result、List或Set类型的方法,验证泛型参数是否正确配置
方法参数使用了@MemoryId注解但没有配置chatMemoryProvider,抛出异常
public T build() {
// 进行验证,是否配置了Model或StreamModel
performBasicValidation();
for (Method method : context.aiServiceClass.getMethods()) {
// 循环判断方法上是否添加了@Moderate,如果设置了@Moderate则必须配置识别敏感词的模型
if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) {
throw illegalConfiguration(
"The @Moderate annotation is present, but the moderationModel is not set up. Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
}
// 判断返回值类型
if (method.getReturnType() == Result.class
|| method.getReturnType() == List.class
|| method.getReturnType() == Set.class) {
TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
}
// 如果未配置了chatMemoryProvider,但配置了@MemoryId抛出异常,AiServices的builder方法可以构建一个chatMemoryProvider
if (context.chatMemoryProvider == null) {
for (Parameter parameter : method.getParameters()) {
if (parameter.isAnnotationPresent(MemoryId.class)) {
throw illegalConfiguration(
"In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.",
context.aiServiceClass.getName());
}
}
}
}
// 截断后续build方法代码...
protected void performBasicValidation() {
if (context.chatModel == null && context.streamingChatModel == null) {
throw illegalConfiguration("Please specify either chatLanguageModel or streamingChatLanguageModel");
}
}
// TypeUtils
public static void validateReturnTypesAreProperlyParametrized(String methodName, Type type) {
TypeUtils.validateReturnTypesAreProperlyParametrized(methodName, type, new ArrayList<>());
}
private static void validateReturnTypesAreProperlyParametrized(String methodName, Type type, List<Type> typeChain) {
if (type instanceof ParameterizedType parameterizedType) {
// 递归调用,解析范型中的范型
for (Type actualTypeArgument : parameterizedType.getActualTypeArguments()) {
typeChain.add(parameterizedType);
validateReturnTypesAreProperlyParametrized(methodName, actualTypeArgument, typeChain);
}
} else if (type instanceof WildcardType) {
// 禁止范型为<?>
typeChain.add(type);
throw genericNotProperlySpecifiedException(methodName, typeChain);
} else if (type instanceof TypeVariable) {
// 禁止范型为<T>
typeChain.add(type);
throw genericNotProperlySpecifiedException(methodName, typeChain);
} else if (type instanceof Class<?> clazz && clazz.getTypeParameters().length > 0) {
// 禁止类本身有泛型参数,但是方法返回没有指定范型
typeChain.add(type);
throw genericNotProperlySpecifiedException(methodName, typeChain);
}
}
创建代理对象
这段代码是langchain4j框架的核心部分,使用JDK动态代理技术创建AiService的实例对象
newProxyInstance创建代理对象,指定要代理的接口,定义InvocationHandler
这个线程池用于处理后续异步操作
校验接口参数、获取memoryId、处理系统和用户提示词
// 创建代理对象
Object proxyInstance = Proxy.newProxyInstance(
context.aiServiceClass.getClassLoader(),
new Class<?>[] {context.aiServiceClass},
new InvocationHandler() {
private final ExecutorService executor = Executors.newCachedThreadPool();
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
// 如果是调用的方法是Object类,例如Object的hashCode等方法则直接调用不进行代理
if (method.getDeclaringClass() == Object.class) {
return method.invoke(this, args);
}
// 校验接口的参数
validateParameters(method);
// 获取@MemoryId的值,没有标注@MemoryId则使用默认的
Object memoryId = findMemoryId(method, args).orElse(DEFAULT);
// 解析系统提示词
Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
// 解析用户提示词
UserMessage userMessage = prepareUserMessage(method, args);
// 截断后续newProxyInstance方法代码...
}
});
private static UserMessage prepareUserMessage(Method method, Object[] args) {
// 判断方法上或者方法参数上是否添加了@UserMessage
String template = getUserMessageTemplate(method, args);
// 获取参数列表
Map<String, Object> variables = findTemplateVariables(template, method, args);
Prompt prompt = PromptTemplate.from(template).apply(variables);
// 判断是否有参数使用了@UserName(对话中 区分不同用户发言)
Optional<String> maybeUserName = findUserName(method.getParameters(), args);
// 将占位符填充到用户提示词
return maybeUserName
.map(userName -> UserMessage.from(userName, prompt.text()))
.orElseGet(prompt::toUserMessage);
}
校验参数
如果方法参数大于1个,并且@UserMessage、@V、@MemoryId、@UserName注解都为null抛出异常
@UserMessage 标记某个参数是用户输入的消息
@V 标记某个参数是 提示词Prompt模板变量
@MemoryId 对话上下文的唯一标识
@UserName 用户名身份信息
static void validateParameters(Method method) {
// 获取方法的参数列表
Parameter[] parameters = method.getParameters();
// 如果参数小于2不进行校验
if (parameters == null || parameters.length < 2) {
return;
for (Parameter parameter : parameters) {
V v = parameter.getAnnotation(V.class);
dev.langchain4j.service.UserMessage userMessage =
parameter.getAnnotation(dev.langchain4j.service.UserMessage.class);
MemoryId memoryId = parameter.getAnnotation(MemoryId.class);
UserName userName = parameter.getAnnotation(UserName.class);
// 如果这4个注解都为空则抛出异常
if (v == null && userMessage == null && memoryId == null && userName == null) {
throw illegalConfiguration(
"Parameter '%s' of method '%s' should be annotated with @V or @UserMessage or @UserName or @MemoryId",
parameter.getName(), method.getName());
}
}
}
系统提示词
private Optional<SystemMessage> prepareSystemMessage(Object memoryId, Method method, Object[] args) {
// 获取系统提示词并将@V注解填充到占位符中
return findSystemMessageTemplate(memoryId, method)
.map(systemMessageTemplate -> PromptTemplate.from(systemMessageTemplate)
.apply(findTemplateVariables(systemMessageTemplate, method, args))
.toSystemMessage());
}
private Optional<String> findSystemMessageTemplate(Object memoryId, Method method) {
// 获取方法上的@SystemMessage注解
dev.langchain4j.service.SystemMessage annotation =
method.getAnnotation(dev.langchain4j.service.SystemMessage.class);
if (annotation != null) {
// 获取@SystemMessage的值,系统提示词可以传递多个,所以内部会将系统提示词value拼接,并且通过delimiter分隔,得到最终的系统提示词.系统提示词也可以以文件的信息通过fromResource指定路径
return Optional.of(getTemplate(
method, "System", annotation.fromResource(), annotation.value(), annotation.delimiter()));
}
return context.systemMessageProvider.apply(memoryId);
}
private static Map<String, Object> findTemplateVariables(String template, Method method, Object[] args) {
Parameter[] parameters = method.getParameters();
// 获取方法上的参数,key是@V的value,value是对应的参数值
Map<String, Object> variables = new HashMap<>();
for (int i = 0; i < parameters.length; i++) {
String variableName = getVariableName(parameters[i]);
Object variableValue = args[i];
variables.put(variableName, variableValue);
}
// 如果包含it占位符,就将方法参数第1个进行填充
if (template.contains("{{it}}") && !variables.containsKey("it")) {
String itValue = getValueOfVariableIt(parameters, args);
variables.put("it", itValue);
}
return variables;
}
用户提示词
和系统提示词的构建类似
private static UserMessage prepareUserMessage(Method method, Object[] args) {
// 判断方法上或者方法参数上是否添加了@UserMessage
String template = getUserMessageTemplate(method, args);
// 获取参数列表
Map<String, Object> variables = findTemplateVariables(template, method, args);
Prompt prompt = PromptTemplate.from(template).apply(variables);
// 判断是否有参数使用了@UserName(对话中 区分不同用户发言)
Optional<String> maybeUserName = findUserName(method.getParameters(), args);
// 将占位符填充到用户提示词
return maybeUserName
.map(userName -> UserMessage.from(userName, prompt.text()))
.orElseGet(prompt::toUserMessage);
}
检索增强
RAG(Retrieval-Augmented Generation)是现代AI应用的核心技术之一,它允许模型从外部知识库中检索信息,从而提供更准确、更相关、更丰富的回答
检索增强器内可以包含一下组件,具体使用在RAG那篇博客中已经详细写了,我这里就不重复了
ContentRetriever
QueryRouter
QueryTransfomer
ContentAggregator
Contentinjector
// 创建一个检索增强器,内部挺简单就是将以下所有组件都封装到一个类中
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryRouter(queryRouter) // 查询路由
.queryTransformer(queryTransformer) // 内容转换
.contentRetriever(contentRetriever) // 内容检索器
.contentInjector(contentInjector) // 提示词注入器
.build();
// 注入AiServices代理对象中
AiCustomer aiCustomer = AiServices.builder(AiCustomer.class)
.chatLanguageModel(chatModel)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.retrievalAugmentor(retrievalAugmentor)
.build();
检索增强器检查
首先检查是否配置了retrievalAugmentor
上下文信息收集
如果配置了聊天记忆功能,获取当前用户的历史对话消息
检索提供上下文信息,创建包含用户消息、memoryId和聊天历史的元数据对象
检索增强请求
创建包含原始用户消息和相关元数据的AugmentationRequest对象
调用检索增强器的augment()方法进行处理
消息增强
检索增强器会根据用户消息从知识库中检索相关信息
将检索到的信息与原始用户消息结合,生成增强后的消息
用增强后的消息替换原始用户消息,提供更丰富的上下文信息
// newProxyInstance的后续代码
AugmentationResult augmentationResult = null;
// 是否添加了检索增强器
if (context.retrievalAugmentor != null) {
// 这部分移步 [消息淘汰]
List<ChatMessage> chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
: null;
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
// 检索增强
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
// 获取到最终增强的用户消息
userMessage = (UserMessage) augmentationResult.chatMessage();
// 截断后续newProxyInstance方法代码...
}
执行augment方法
用户发起一个提问,RAG会将用户已有的聊天内容进行封装到augmentationRequest中并将用户的消息
通过queryTransformer转换为多个问题,通过合适的转换可以在知识库或搜索引擎中匹配度更高
queryTransformer分为以下几种,内部原理很简单就是通过PromptTemplate(内置的提示词)向LLM发起提问,然后将结果返回并封装到Content
ExpandingQueryTransformer 扩展转换 将一个问题转换为多个
CompressingQueryTransformer 压缩转换 将一个问题进行压缩
通过process将每一个问题进行路由,对每个扩展后的查询执行路由和检索,拿到相关内容
通过queryRouter.route(query)根据查询内容选择合适的检索器
如果只有一个检索器,直接调用检索并返回结果
如果有多个检索器使用retrieveFromAll方法并发检索
检索器也分为2种
EmbeddingStoreContentRetriever 基于向量数据库进行检索
WebSearchContentRetriever 基于互联网进行检索
通过contentAggregator 把不同查询返回的内容进行聚合排序
contentAggregator分为2种
DefaultContentAggregator 内部使用倒数融合算法将多个List<Content>融合到单个List<Content>中
ReRankingContentAggregator 使用scoringModel(评分模型)进行打分
最后通过contentInjector将返回的contents和用户消息进行提示词增强
@Override
public AugmentationResult augment(AugmentationRequest augmentationRequest) {
ChatMessage chatMessage = augmentationRequest.chatMessage();
Metadata metadata = augmentationRequest.metadata();
// 封装一个Query对象
Query originalQuery = Query.from(chatMessage.text(), metadata);
// 调用transform,将一个问题进行转换, 参照 [消息转换]
Collection<Query> queries = queryTransformer.transform(originalQuery);
logQueries(originalQuery, queries);
// 循环遍历每个问题并且进行路由, 参照 [内容检索器]
Map<Query, Collection<List<Content>>> queryToContents = process(queries);
// 通过提示词增强器,为每一个问题设置优先级,参照 [内容聚合器]
List<Content> contents = contentAggregator.aggregate(queryToContents);
log(queryToContents, contents);
// 使用提示词注入器,将最终生成的contents和用户消息拼接成一个增强之后的用户提问
ChatMessage augmentedChatMessage = contentInjector.inject(contents, chatMessage);
log(augmentedChatMessage);
// 将最终的结果通过AugmentationResult返回
return AugmentationResult.builder()
.chatMessage(augmentedChatMessage)
.contents(contents)
.build();
}
消息转换
此处以ExpandingQueryTransformer 扩展转换 为例
提示词模版的翻译为:生成{{n}}提供的用户查询的不同版本。 使用同义词或替代句子结构,每个版本都应以不同的方式措辞,但它们都应保留原始含义。 这些版本将用于检索相关文档。 在单独的行上提供每个查询版本非常重要,而无需枚举,连字符或任何其他格式!用户查询:{{query}}
// 默认的提示词模版
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"""
Generate {{n}} different versions of a provided user query. \
Each version should be worded differently, using synonyms or alternative sentence structures, \
but they should all retain the original meaning. \
These versions will be used to retrieve relevant documents. \
It is very important to provide each query version on a separate line, \
without enumerations, hyphens, or any additional formatting!
User query: {{query}}"""
);
@Override
public Collection<Query> transform(Query query) {
// 将提示词中query和n填充到提示词模版中
Prompt prompt = createPrompt(query);
// 与语言模型交互
String response = chatLanguageModel.generate(prompt.text());
// 通过换行分割并返回
List<String> queries = parse(response);
return queries.stream()
.map(queryText -> query.metadata() == null
? Query.from(queryText)
: Query.from(queryText, query.metadata()))
.collect(toList());
}
protected Prompt createPrompt(Query query) {
Map<String, Object> variables = new HashMap<>();
variables.put("query", query.text());
variables.put("n", n);
return promptTemplate.apply(variables);
}
protected List<String> parse(String queries) {
return stream(queries.split("\n"))
.filter(Utils::isNotNullOrBlank)
.collect(toList());
}
查询路由
此处以LanguageModelQueryRouter为例
默认的提示词翻译:基于用户查询,确定最合适的数据源以从以下选项中检索相关信息:{{options}}}您的答案由单个数字或多个由逗号分隔的数字非常重要!用户查询:{{query}}
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"""
Based on the user query, determine the most suitable data source(s) \
to retrieve relevant information from the following options:
{{options}}
It is very important that your answer consists of either a single number \
or multiple numbers separated by commas and nothing else!
User query: {{query}}"""
);
@Override
public Collection<ContentRetriever> route(Query query) {
Prompt prompt = createPrompt(query);
try {
// 向大模型发起请求,并返回结果
String response = chatLanguageModel.generate(prompt.text());
return parse(response);
} catch (Exception e) {
log.warn("Failed to route query '{}'", query.text(), e);
return fallback(query, e);
}
}
// 通过提示模版,并将query和options替换提示模版中的占位符
protected Prompt createPrompt(Query query) {
Map<String, Object> variables = new HashMap<>();
variables.put("query", query.text());
variables.put("options", options);
return promptTemplate.apply(variables);
}
内容检索器
把一个或多个Query分配给合适的 ContentRetriever,因为问题会有多个所以涉及并并发执行检索
最后返回查询和内容的Map
private Map<Query, Collection<List<Content>>> process(Collection<Query> queries) {
// 如果只有1个问题,获取第一个问题进行路由
if (queries.size() == 1) {
Query query = queries.iterator().next();
// 通过[查询路由],获取到符合条件的内容检索器
Collection<ContentRetriever> retrievers = queryRouter.route(query);
if (retrievers.size() == 1) {
// 只有一个返回内容
ContentRetriever contentRetriever = retrievers.iterator().next();
List<Content> contents = contentRetriever.retrieve(query);
return singletonMap(query, singletonList(contents));
} else if (retrievers.size() > 1) {
// 多个检索器,异步获取所有内容
Collection<List<Content>> contents = retrieveFromAll(retrievers, query).join();
return singletonMap(query, contents);
} else {
return emptyMap();
}
// 多个问题
} else if (queries.size() > 1) {
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
queries.forEach(query -> {
CompletableFuture<Collection<List<Content>>> futureContents =
supplyAsync(() -> {
Collection<ContentRetriever> retrievers = queryRouter.route(query);
log(query, retrievers);
return retrievers;
},
executor
// 先路由,再并行检索
).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
queryToFutureContents.put(query, futureContents);
});
// 循环之后统一join
return join(queryToFutureContents);
} else {
return emptyMap();
}
}
private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers, Query query) {
List<CompletableFuture<List<Content>>> futureContents = retrievers.stream()
.map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
.toList();
return allOf(futureContents.toArray(new CompletableFuture[0]))
.thenApply(ignored ->
futureContents.stream()
.map(CompletableFuture::join)
.toList());
}
private static List<Content> retrieve(ContentRetriever retriever, Query query) {
// 此处可以看到内容检索器ContentRetriever有2个实现
// 一个是基于embeddingModel的EmbeddingStoreContentRetriever,通过向量模型匹配符合的内容
// 另一个是基于webSearchEngine的WebSearchContentRetriever,在互联网上执行检索
List<Content> contents = retriever.retrieve(query);
log(query, retriever, contents);
return contents;
}
内容聚合器
此处以ReRankingContentAggregator,内部使用了scoringModel,因为不是大语言模型所以此处没有提示词
我问了一下ai,为什么不直接交给 scoringModel,以下是他的回答,我摘略了一下
最终排序确实是 scoringModel (精排)在做,但中间加一层 fuse(粗排) 有两个好处
降噪
scoringModel 是逐条打分的,如果输入全是嘈杂内容,性能和结果都会受影响
fuse 可以先做一次初步「多路检索融合」,过滤掉明显靠后的低质量结果
公平性
scoringModel 的目标是「精排」,但它需要有一个 候选集
fuse 先把不同 Query/不同 Retriever 的结果整合 → 再交给 scoringModel 精排
效率
scoringModel(通常是一个深度模型/BERT/Reranker)开销比 fuse 大很多
fuse 能帮忙缩小候选集,减少 scoringModel 的计算量
如果不使用scoringModel,使用的是DefaultContentAggregator,则只使用了粗排
private final Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
@Override
public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {
if (queryToContents.isEmpty()) {
return emptyList();
}
// 先通过外部传递的lambda手动获取一个query
// 所以用querySelector来选一个最优的Query(比如内容最多/上下文最相关的)
Query query = querySelector.apply(queryToContents);
// 将queryToContents数据结构转换为Map<Query, List<Content>>
Map<Query, List<Content>> queryToFusedContents = fuse(queryToContents);
// 使用倒数融合算法将所有结果融合为List
List<Content> fusedContents = ReciprocalRankFuser.fuse(queryToFusedContents.values());
if (fusedContents.isEmpty()) {
return fusedContents;
}
// 使用scoringModel重新进行排序
return reRankAndFilter(fusedContents, query);
}
protected List<Content> reRankAndFilter(List<Content> contents, Query query) {
List<TextSegment> segments = contents.stream()
.map(Content::textSegment)
.collect(Collectors.toList());
// 使用scoringModel
List<Double> scores = scoringModel.scoreAll(segments, query.text()).content();
Map<TextSegment, Double> segmentToScore = new HashMap<>();
for (int i = 0; i < segments.size(); i++) {
segmentToScore.put(segments.get(i), scores.get(i));
}
return segmentToScore.entrySet().stream()
// 过滤过滤掉低分,按分数从高到低排序
.filter(entry -> minScore == null || entry.getValue() >= minScore)
.sorted(Map.Entry.<TextSegment, Double>comparingByValue().reversed())
// 转成带有打分信息封装到Content
.map(entry -> Content.from(entry.getKey(), Map.of(RERANKED_SCORE, entry.getValue())))
// 限制最大数量
.limit(maxResults)
.collect(Collectors.toList());
}
提示词注入器
实现很简单就是叫原本的提示词外封装了一层提示词
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"""
{{userMessage}}
Answer using the following information:
{{contents}}"""
);
@Override
public ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {
if (contents.isEmpty()) {
return chatMessage;
}
Prompt prompt = createPrompt(chatMessage, contents);
if (chatMessage instanceof UserMessage message && isNotNullOrBlank(message.name())) {
return prompt.toUserMessage(message.name());
}
return prompt.toUserMessage();
}
消息淘汰
这段代码实现了智能的聊天记忆管理机制,确保模型能够维持上下文连贯的对话
记忆存储策略
通过memoryId区分不同用户的对话上下文,确保用户间的对话不会相互干扰
将系统消息和用户消息添加到对应用户的聊天记忆中
当消息数量超过限制时会自动淘汰旧消息
消息列表构建
如果配置了聊天记忆,直接从记忆中获取完整的消息历史
如果没有配置记忆功能,创建临时消息列表,只包含当前对话的系统消息和用户消息
上下文连续性
通过维护消息历史,AI能够理解对话的上下文,提供更加连贯和相关的回复
系统消息通常包含AI的角色定义和行为指导,确保AI始终按照预期的方式回应
// newProxyInstance的后续代码
// 通过不同的MemoryId获取不同用户的消息
if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
systemMessage.ifPresent(chatMemory::add);
// 消息淘汰
chatMemory.add(userMessage);
}
// 构建消息
List<ChatMessage> messages;
// 通过memoryId获取所有消息
if (context.hasChatMemory()) {
messages = context.chatMemory(memoryId).messages();
} else {
messages = new ArrayList<>();
// 将系统提示词添加到messages中
systemMessage.ifPresent(messages::add);
// 将用户提示词也添加到messages中
messages.add(userMessage);
}
// AiServiceContext
// 指定了store就算memoryId存储在磁盘中,langchain4j依然会将数据存储在内存中
public ChatMemory chatMemory(Object memoryId) {
// 执行构建AiServices的ChatMemoryProvider
return chatMemories.computeIfAbsent(memoryId, ignored -> chatMemoryProvider.get(memoryId));
}
实现ChatMemoryStore可以自定义增删改查消息,实现getMessages、updateMessages、deleteMessages
@Bean
public AssistantUnique assistantUniqueStore(@Qualifier("ollamaChatModel") ChatLanguageModel languageModel,
@Qualifier("ollamaStreamingChatModel") StreamingChatLanguageModel streamingChatLanguageModel,
// 从ioc容器中获取
PersistentChatMemoryStore store,
ToolService toolService) {
return AiServices.builder(AssistantUnique.class)
.chatLanguageModel(languageModel)
.streamingChatLanguageModel(streamingChatLanguageModel)
// .tools(toolService)
.chatMemoryProvider(memoryId -> MessageWindowChatMemory.builder()
.id(memoryId)
.maxMessages(100)
.chatMemoryStore(store)
.build())
.build();
}
// langchain4j的提供的一个默认实现
// key是memoryId,value就是存储的消息
public class InMemoryChatMemoryStore implements ChatMemoryStore {
// 内部使用十分简单,将所有消息都存储在ConcurrentHashMap中
private final Map<Object, List<ChatMessage>> messagesByMemoryId = new ConcurrentHashMap<>();
public InMemoryChatMemoryStore() {}
@Override
public List<ChatMessage> getMessages(Object memoryId) {
return messagesByMemoryId.computeIfAbsent(memoryId, ignored -> new ArrayList<>());
}
@Override
// 此处就是消息淘汰后调用,存储剩余的消息
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
messagesByMemoryId.put(memoryId, messages);
}
@Override
public void deleteMessages(Object memoryId) {
messagesByMemoryId.remove(memoryId);
}
}
构建不同的ChatMemory,langchain4j执行不同的淘汰方式
MessageWindowChatMemory 通过消息的数量进行淘汰
TokenWindowChatMemory 通过token进行淘汰
public class MessageWindowChatMemory implements ChatMemory {
private static final Logger log = LoggerFactory.getLogger(MessageWindowChatMemory.class);
private final Object id;
private final Integer maxMessages;
private final ChatMemoryStore store; // 默认使用了InMemoryChatMemoryStore
private MessageWindowChatMemory(Builder builder) {
this.id = ensureNotNull(builder.id, "id");
this.maxMessages = ensureGreaterThanZero(builder.maxMessages, "maxMessages");
this.store = ensureNotNull(builder.store, "store");
}
@Override
public Object id() {
return id;
}
@Override
// 添加一条消息
public void add(ChatMessage message) {
// 获取所有的历史记录
List<ChatMessage> messages = messages();
if (message instanceof SystemMessage) {
// 过滤系统提示词
Optional<SystemMessage> systemMessage = findSystemMessage(messages);
if (systemMessage.isPresent()) {
if (systemMessage.get().equals(message)) {
return;
} else {
messages.remove(systemMessage.get());
}
}
}
// 加到list最后
messages.add(message);
// 消息淘汰
ensureCapacity(messages, maxMessages);
// 将淘汰后剩余消息放入对应的store中
store.updateMessages(id, messages);
}
private static Optional<SystemMessage> findSystemMessage(List<ChatMessage> messages) {
return messages.stream()
.filter(message -> message instanceof SystemMessage)
.map(message -> (SystemMessage) message)
.findAny();
}
@Override
public List<ChatMessage> messages() {
List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
ensureCapacity(messages, maxMessages);
return messages;
}
// 消息淘汰
private static void ensureCapacity(List<ChatMessage> messages, int maxMessages) {
// 消息总量若大于maxMessages则进行淘汰
while (messages.size() > maxMessages) {
// 从0开始判断,如果第0个是系统提示此则从第1个开始
int messageToEvictIndex = 0;
if (messages.get(0) instanceof SystemMessage) {
messageToEvictIndex = 1;
}
// 淘汰这个下标的消息
ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
log.trace("Evicting the following message to comply with the capacity requirement: {}", evictedMessage);
// 移除的是否是AiMessage并且是否包含了工具执行结果
if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
// 继续移除其中的工具执行结果
while (messages.size() > messageToEvictIndex
&& messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
}
}
}
}
// 工具,设置最大的消息配置
public static MessageWindowChatMemory withMaxMessages(int maxMessages) {
return builder().maxMessages(maxMessages).build();
}
}
TokenWindowChatMemory的淘汰机制
private static void ensureCapacity(List<ChatMessage> messages, int maxTokens, Tokenizer tokenizer) {
if (messages.isEmpty()) {
return;
}
// 计算消息的token数
int currentTokenCount = tokenizer.estimateTokenCountInMessages(messages);
while (currentTokenCount > maxTokens) {
int messageToEvictIndex = 0;
if (messages.get(0) instanceof SystemMessage) {
messageToEvictIndex = 1;
}
ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
int tokenCountOfEvictedMessage = tokenizer.estimateTokenCountInMessage(evictedMessage);
log.trace("Evicting the following message ({} tokens) to comply with the capacity requirement: {}",
tokenCountOfEvictedMessage, evictedMessage);
currentTokenCount -= tokenCountOfEvictedMessage;
if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
while (messages.size() > messageToEvictIndex
&& messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
currentTokenCount -= tokenizer.estimateTokenCountInMessage(orphanToolExecutionResultMessage);
}
}
}
}
Tools
首先看构建AiServices的tools方法,正对工具进行解析,并且设置到AiService代理对象中
public AiServices<T> tools(Collection<Object> objectsWithTools) {
context.toolService.tools(objectsWithTools);
return this;
}
public void tools(Collection<Object> objectsWithTools) {
if (toolProvider != null) {
throw new IllegalArgumentException(
"Either the tools or the tool provider can be configured, but not both!");
}
// 初始化toolExecutors和toolSpecifications
initTools();
// 循环每一个tools,外部可以传递多个tools
for (Object objectWithTool : objectsWithTools) {
if (objectWithTool instanceof Class) {
throw illegalConfiguration("Tool '%s' must be an object, not a class", objectWithTool);
}
// 获取每一个tools的方法,拿到所有标注@Tool的方法
for (Method method : objectWithTool.getClass().getDeclaredMethods()) {
if (method.isAnnotationPresent(Tool.class)) {
ToolSpecification toolSpecification = toolSpecificationFrom(method);
// @Tool的name不能重复
if (toolExecutors.containsKey(toolSpecification.name())) {
throw new IllegalConfigurationException(
"Duplicated definition for tool: " + toolSpecification.name());
}
toolExecutors.put(toolSpecification.name(), new DefaultToolExecutor(objectWithTool, method));
toolSpecifications.add(toolSpecification);
}
}
}
}
public static ToolSpecification toolSpecificationFrom(Method method) {
Tool annotation = method.getAnnotation(Tool.class);
// 如果注解没有name则使用方法的名称
String name = isNullOrBlank(annotation.name()) ? method.getName() : annotation.name();
String description = String.join("\n", annotation.value());
if (description.isEmpty()) {
description = null;
}
JsonObjectSchema parameters = parametersFrom(method.getParameters());
return ToolSpecification.builder()
.name(name)
.description(description)
.parameters(parameters)
.build();
}
private static JsonObjectSchema parametersFrom(Parameter[] parameters) {
Map<String, JsonSchemaElement> properties = new LinkedHashMap<>();
List<String> required = new ArrayList<>();
Map<Class<?>, JsonSchemaElementHelper.VisitedClassMetadata> visited = new LinkedHashMap<>();
for (Parameter parameter : parameters) {
// 如果包含了@ToolMemoryId跳过,和@MemoryId类似可以区分不同用户的同一个@Tool
if (parameter.isAnnotationPresent(ToolMemoryId.class)) {
continue;
}
// 判断是否包含@P,包含了给模型针对这个参数的描述
boolean isRequired = Optional.ofNullable(parameter.getAnnotation(P.class))
.map(P::required)
.orElse(true);
properties.put(parameter.getName(), jsonSchemaElementFrom(parameter, visited));
if (isRequired) {
required.add(parameter.getName());
}
}
Map<String, JsonSchemaElement> definitions = new LinkedHashMap<>();
visited.forEach((clazz, visitedClassMetadata) -> {
if (visitedClassMetadata.recursionDetected) {
definitions.put(visitedClassMetadata.reference, visitedClassMetadata.jsonSchemaElement);
}
});
if (properties.isEmpty()) {
return null;
}
return JsonObjectSchema.builder()
.addProperties(properties)
.required(required)
.definitions(definitions.isEmpty() ? null : definitions)
.build();
}
后续代码
// 构建Chat请求参数
ChatRequestParameters parameters = ChatRequestParameters.builder()
// 将调用构建AiService设置的工具赋值给parameters
// 参照 [Tools]
.toolSpecifications(toolExecutionContext.toolSpecifications())
.responseFormat(responseFormat)
.build();
ChatRequest chatRequest = ChatRequest.builder()
.messages(messages)
.parameters(parameters)
.build();
// 向Ai发起chat请求
ChatResponse chatResponse = context.chatModel.chat(chatRequest);
verifyModerationIfNeeded(moderationFuture);
// 通过chatResponse判断模型是否有工具执行请求,不断循环执行,直到模型不在发起工具执行请求
ToolExecutionResult toolExecutionResult = context.toolService.executeInferenceAndToolsLoop(
chatResponse,
parameters,
messages,
context.chatModel,
context.hasChatMemory() ? context.chatMemory(memoryId) : null,
memoryId,
toolExecutionContext.toolExecutors());
chatResponse = toolExecutionResult.chatResponse();
FinishReason finishReason = chatResponse.metadata().finishReason();
Response<AiMessage> response = Response.from(
chatResponse.aiMessage(), toolExecutionResult.tokenUsageAccumulator(), finishReason);
Object parsedResponse = serviceOutputParser.parse(response, returnType);
if (typeHasRawClass(returnType, Result.class)) {
return Result.builder()
.content(parsedResponse)
// tokenUsage是模型token的调用情况,这里只计算了tool的token
// chat模 型的token在对应的大模型调用时计算,最终封装在ChatResponseMetadata内
.tokenUsage(toolExecutionResult.tokenUsageAccumulator())
.sources(augmentationResult == null ? null : augmentationResult.contents())
.finishReason(finishReason)
.toolExecutions(toolExecutionResult.toolExecutions())
.build();
} else {
return parsedResponse;
}
}