langchain4j源码阅读

38

这篇博客主要针对langchain4j源码的阅读,框架的使用请移步上一篇博客

以下是我阅读源码的版本

<!-- <langchain4j.version>1.0.0-beta1</langchain4j.version> -->
<dependency>
    <groupId>dev.langchain4j</groupId>
    <artifactId>langchain4j-cohere</artifactId>
    <version>${langchain4j.version}</version>
</dependency>

<dependency>
    <groupId>dev.langchain4j</groupId>
    <artifactId>langchain4j-cohere</artifactId>
    <version>${langchain4j.version}</version>
</dependency>

当我们使用langchain4j就是在构建一个AiServices,最后内部会返回一个自定义的基于jdk代理接口

通过建造者模式创建一个AI服务实例,具体包含以下几个关键配置

  • chatLanguageModel 配置底层的语言模型,负责实际的AI对话生成

  • chatMemory 设置默认对话记忆功能

  • chatMemoryProvider 自定义对话记忆功能

  • retrievalAugmentor 配置RAG的功能,可以从外部知识库检索相关信息

  • build 最终构建并返回一个实现了AiCustomer接口的代理对象

interface AiCustomer {
    // String call(String query);
	String chat(@MemoryId String memoryId, @UserMessage String userMessage);
}

String userMessage = "余额提现什么时候到账?";
AiCustomer aiCustomer = AiServices.builder(AiCustomer.class)
        .chatLanguageModel(chatModel)
		// memoryId固定为default
        // .chatMemory(MessageWindowChatMemory.withMaxMessages(10))
        .retrievalAugmentor(retrievalAugmentor)
		// 动态的memoryId
		.chatMemoryProvider(memoryId -> MessageWindowChatMemory.builder()
	                        .id(memoryId)
	                        .maxMessages(100)
	                        .chatMemoryStore(store)
	                        .build())
        .build();
String result = aiCustomer.call(userMessage);

构建

进入build方法,可以看到内部是如何实现的

我这里没有使用到SPI机制,所以执行的是默认的实现,也就是DefaultAiServices

public static <T> AiServices<T> builder(Class<T> aiService) {
    AiServiceContext context = new AiServiceContext(aiService);
    // 使用jdk自带的SPI机制获取resource中配置的AiServicesFactory
    for (AiServicesFactory factory : loadFactories(AiServicesFactory.class)) {
        return factory.create(context);
    }
    // 否则使用默认的AiServices
    return new DefaultAiServices<>(context);
}

进行构建,build方法很长,所以后续代码块中的代码会被截断成多份

验证

build方法会进行验证,在创建代理对象之前进行各种配置的合法性检查,如果不合法则抛出异常,并且结束执行

  • 检查是否配置了大语言模型

  • 如果方法标注了@Moderate注解,必须配置相应的审核模型moderationModel,否则抛出异常

  • 对于返回Result、List或Set类型的方法,验证泛型参数是否正确配置

  • 方法参数使用了@MemoryId注解但没有配置chatMemoryProvider,抛出异常

public T build() {
    // 进行验证,是否配置了Model或StreamModel
    performBasicValidation();
​
    for (Method method : context.aiServiceClass.getMethods()) {
        // 循环判断方法上是否添加了@Moderate,如果设置了@Moderate则必须配置识别敏感词的模型
        if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) {
            throw illegalConfiguration(
                    "The @Moderate annotation is present, but the moderationModel is not set up. Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
        }
        // 判断返回值类型
        if (method.getReturnType() == Result.class
                || method.getReturnType() == List.class
                || method.getReturnType() == Set.class) {
            TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
        }
      
        // 如果未配置了chatMemoryProvider,但配置了@MemoryId抛出异常,AiServices的builder方法可以构建一个chatMemoryProvider
        if (context.chatMemoryProvider == null) {
            for (Parameter parameter : method.getParameters()) {
                if (parameter.isAnnotationPresent(MemoryId.class)) {
                    throw illegalConfiguration(
                            "In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.",
                            context.aiServiceClass.getName());
                }
            }
        }
    }
// 截断后续build方法代码...

protected void performBasicValidation() {
    if (context.chatModel == null && context.streamingChatModel == null) {
        throw illegalConfiguration("Please specify either chatLanguageModel or streamingChatLanguageModel");
    }
}

// TypeUtils
public static void validateReturnTypesAreProperlyParametrized(String methodName, Type type) {
    TypeUtils.validateReturnTypesAreProperlyParametrized(methodName, type, new ArrayList<>());
}
private static void validateReturnTypesAreProperlyParametrized(String methodName, Type type, List<Type> typeChain) {
    if (type instanceof ParameterizedType parameterizedType) {
        // 递归调用,解析范型中的范型
        for (Type actualTypeArgument : parameterizedType.getActualTypeArguments()) {
            typeChain.add(parameterizedType);
            validateReturnTypesAreProperlyParametrized(methodName, actualTypeArgument, typeChain);
        }
    } else if (type instanceof WildcardType) {
        // 禁止范型为<?>
        typeChain.add(type);
        throw genericNotProperlySpecifiedException(methodName, typeChain);
    } else if (type instanceof TypeVariable) {
        // 禁止范型为<T>
        typeChain.add(type);
        throw genericNotProperlySpecifiedException(methodName, typeChain);
    } else if (type instanceof Class<?> clazz && clazz.getTypeParameters().length > 0) {
        //  禁止类本身有泛型参数,但是方法返回没有指定范型
        typeChain.add(type);
        throw genericNotProperlySpecifiedException(methodName, typeChain);
    }
}

创建代理对象

这段代码是langchain4j框架的核心部分,使用JDK动态代理技术创建AiService的实例对象

  • newProxyInstance创建代理对象,指定要代理的接口,定义InvocationHandler

    • 这个线程池用于处理后续异步操作

    • 校验接口参数、获取memoryId、处理系统和用户提示词

// 创建代理对象
	Object proxyInstance = Proxy.newProxyInstance(
        context.aiServiceClass.getClassLoader(),
        new Class<?>[] {context.aiServiceClass},
        new InvocationHandler() {
            private final ExecutorService executor = Executors.newCachedThreadPool();
            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
                // 如果是调用的方法是Object类,例如Object的hashCode等方法则直接调用不进行代理
                if (method.getDeclaringClass() == Object.class) {
                    return method.invoke(this, args);
                }
                // 校验接口的参数
                validateParameters(method);
                // 获取@MemoryId的值,没有标注@MemoryId则使用默认的
                Object memoryId = findMemoryId(method, args).orElse(DEFAULT);
                // 解析系统提示词
                Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
                // 解析用户提示词
                UserMessage userMessage = prepareUserMessage(method, args);
                // 截断后续newProxyInstance方法代码...
            }
        });


private static UserMessage prepareUserMessage(Method method, Object[] args) {
    // 判断方法上或者方法参数上是否添加了@UserMessage
    String template = getUserMessageTemplate(method, args);
    // 获取参数列表
    Map<String, Object> variables = findTemplateVariables(template, method, args);
    Prompt prompt = PromptTemplate.from(template).apply(variables);
    // 判断是否有参数使用了@UserName(对话中 区分不同用户发言)
    Optional<String> maybeUserName = findUserName(method.getParameters(), args);
    // 将占位符填充到用户提示词
    return maybeUserName
            .map(userName -> UserMessage.from(userName, prompt.text()))
            .orElseGet(prompt::toUserMessage);
}

校验参数

如果方法参数大于1个,并且@UserMessage、@V、@MemoryId、@UserName注解都为null抛出异常

  • @UserMessage 标记某个参数是用户输入的消息

  • @V 标记某个参数是 提示词Prompt模板变量

  • @MemoryId 对话上下文的唯一标识

  • @UserName 用户名身份信息

static void validateParameters(Method method) {
    // 获取方法的参数列表
    Parameter[] parameters = method.getParameters();
    // 如果参数小于2不进行校验
    if (parameters == null || parameters.length < 2) {
        return;

    for (Parameter parameter : parameters) {
        V v = parameter.getAnnotation(V.class);
        dev.langchain4j.service.UserMessage userMessage =
                parameter.getAnnotation(dev.langchain4j.service.UserMessage.class);
        MemoryId memoryId = parameter.getAnnotation(MemoryId.class);
        UserName userName = parameter.getAnnotation(UserName.class);
        // 如果这4个注解都为空则抛出异常
        if (v == null && userMessage == null && memoryId == null && userName == null) {
            throw illegalConfiguration(
                    "Parameter '%s' of method '%s' should be annotated with @V or @UserMessage or @UserName or @MemoryId",
                    parameter.getName(), method.getName());
        }
    }
}

系统提示词

private Optional<SystemMessage> prepareSystemMessage(Object memoryId, Method method, Object[] args) {
    // 获取系统提示词并将@V注解填充到占位符中
    return findSystemMessageTemplate(memoryId, method)
            .map(systemMessageTemplate -> PromptTemplate.from(systemMessageTemplate)
                    .apply(findTemplateVariables(systemMessageTemplate, method, args))
                    .toSystemMessage());
}
​
private Optional<String> findSystemMessageTemplate(Object memoryId, Method method) {
    // 获取方法上的@SystemMessage注解
    dev.langchain4j.service.SystemMessage annotation =
            method.getAnnotation(dev.langchain4j.service.SystemMessage.class);
    if (annotation != null) {
        // 获取@SystemMessage的值,系统提示词可以传递多个,所以内部会将系统提示词value拼接,并且通过delimiter分隔,得到最终的系统提示词.系统提示词也可以以文件的信息通过fromResource指定路径
        return Optional.of(getTemplate(
                method, "System", annotation.fromResource(), annotation.value(), annotation.delimiter()));
    }
    return context.systemMessageProvider.apply(memoryId);
}
​
private static Map<String, Object> findTemplateVariables(String template, Method method, Object[] args) {
    Parameter[] parameters = method.getParameters();
    // 获取方法上的参数,key是@V的value,value是对应的参数值
    Map<String, Object> variables = new HashMap<>();
    for (int i = 0; i < parameters.length; i++) {
        String variableName = getVariableName(parameters[i]);
        Object variableValue = args[i];
        variables.put(variableName, variableValue);
    }
    // 如果包含it占位符,就将方法参数第1个进行填充
    if (template.contains("{{it}}") && !variables.containsKey("it")) {
        String itValue = getValueOfVariableIt(parameters, args);
        variables.put("it", itValue);
    }
    return variables;
}

用户提示词

和系统提示词的构建类似

private static UserMessage prepareUserMessage(Method method, Object[] args) {
    // 判断方法上或者方法参数上是否添加了@UserMessage
    String template = getUserMessageTemplate(method, args);
    // 获取参数列表
    Map<String, Object> variables = findTemplateVariables(template, method, args);
    Prompt prompt = PromptTemplate.from(template).apply(variables);
    // 判断是否有参数使用了@UserName(对话中 区分不同用户发言)
    Optional<String> maybeUserName = findUserName(method.getParameters(), args);
    // 将占位符填充到用户提示词
    return maybeUserName
            .map(userName -> UserMessage.from(userName, prompt.text()))
            .orElseGet(prompt::toUserMessage);
}

检索增强

RAG(Retrieval-Augmented Generation)是现代AI应用的核心技术之一,它允许模型从外部知识库中检索信息,从而提供更准确、更相关、更丰富的回答

检索增强器内可以包含一下组件,具体使用在RAG那篇博客中已经详细写了,我这里就不重复了

  • ContentRetriever

  • QueryRouter

  • QueryTransfomer

  • ContentAggregator

  • Contentinjector

// 创建一个检索增强器,内部挺简单就是将以下所有组件都封装到一个类中
RetrievalAugmentor retrievalAugmentor =  DefaultRetrievalAugmentor.builder()
        .queryRouter(queryRouter)           // 查询路由
        .queryTransformer(queryTransformer) // 内容转换
        .contentRetriever(contentRetriever) // 内容检索器
        .contentInjector(contentInjector) // 提示词注入器
        .build();
// 注入AiServices代理对象中
AiCustomer aiCustomer = AiServices.builder(AiCustomer.class)
        .chatLanguageModel(chatModel)
        .chatMemory(MessageWindowChatMemory.withMaxMessages(10))
        .retrievalAugmentor(retrievalAugmentor)
        .build();
  • 检索增强器检查

    • 首先检查是否配置了retrievalAugmentor

  • 上下文信息收集

    • 如果配置了聊天记忆功能,获取当前用户的历史对话消息

    • 检索提供上下文信息,创建包含用户消息、memoryId和聊天历史的元数据对象

  • 检索增强请求

    • 创建包含原始用户消息和相关元数据的AugmentationRequest对象

    • 调用检索增强器的augment()方法进行处理

  • 消息增强

    • 检索增强器会根据用户消息从知识库中检索相关信息

    • 将检索到的信息与原始用户消息结合,生成增强后的消息

    • 用增强后的消息替换原始用户消息,提供更丰富的上下文信息

// newProxyInstance的后续代码

	AugmentationResult augmentationResult = null;
	
	// 是否添加了检索增强器
	if (context.retrievalAugmentor != null) {
		// 这部分移步 [消息淘汰]
	    List<ChatMessage> chatMemory = context.hasChatMemory()
	            ? context.chatMemory(memoryId).messages()
	            : null;
	    Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
	    AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
	    // 检索增强
	    augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
	    // 获取到最终增强的用户消息
	    userMessage = (UserMessage) augmentationResult.chatMessage();

		// 截断后续newProxyInstance方法代码...
	}

执行augment方法

用户发起一个提问,RAG会将用户已有的聊天内容进行封装到augmentationRequest中并将用户的消息

  • 通过queryTransformer转换为多个问题,通过合适的转换可以在知识库或搜索引擎中匹配度更高

    • queryTransformer分为以下几种,内部原理很简单就是通过PromptTemplate(内置的提示词)向LLM发起提问,然后将结果返回并封装到Content

      • ExpandingQueryTransformer 扩展转换 将一个问题转换为多个

      • CompressingQueryTransformer 压缩转换 将一个问题进行压缩

  • 通过process将每一个问题进行路由,对每个扩展后的查询执行路由和检索,拿到相关内容

    • 通过queryRouter.route(query)根据查询内容选择合适的检索器

      • 如果只有一个检索器,直接调用检索并返回结果

      • 如果有多个检索器使用retrieveFromAll方法并发检索

    • 检索器也分为2种

      • EmbeddingStoreContentRetriever 基于向量数据库进行检索

      • WebSearchContentRetriever 基于互联网进行检索

  • 通过contentAggregator 把不同查询返回的内容进行聚合排序

    • contentAggregator分为2种

      • DefaultContentAggregator 内部使用倒数融合算法将多个List<Content>融合到单个List<Content>中

      • ReRankingContentAggregator 使用scoringModel(评分模型)进行打分

  • 最后通过contentInjector将返回的contents和用户消息进行提示词增强

@Override
public AugmentationResult augment(AugmentationRequest augmentationRequest) {
    ChatMessage chatMessage = augmentationRequest.chatMessage();
    Metadata metadata = augmentationRequest.metadata();
  	// 封装一个Query对象
    Query originalQuery = Query.from(chatMessage.text(), metadata);
  	// 调用transform,将一个问题进行转换, 参照 [消息转换]
    Collection<Query> queries = queryTransformer.transform(originalQuery);
    logQueries(originalQuery, queries);
  	// 循环遍历每个问题并且进行路由, 参照 [内容检索器]
    Map<Query, Collection<List<Content>>> queryToContents = process(queries);
  	// 通过提示词增强器,为每一个问题设置优先级,参照 [内容聚合器]
    List<Content> contents = contentAggregator.aggregate(queryToContents);
    log(queryToContents, contents);
  	// 使用提示词注入器,将最终生成的contents和用户消息拼接成一个增强之后的用户提问
    ChatMessage augmentedChatMessage = contentInjector.inject(contents, chatMessage);
    log(augmentedChatMessage);
	// 将最终的结果通过AugmentationResult返回
    return AugmentationResult.builder()
        .chatMessage(augmentedChatMessage)
        .contents(contents)
        .build();
}

消息转换

此处以ExpandingQueryTransformer 扩展转换 为例

提示词模版的翻译为:生成{{n}}提供的用户查询的不同版本。 使用同义词或替代句子结构,每个版本都应以不同的方式措辞,但它们都应保留原始含义。 这些版本将用于检索相关文档。 在单独的行上提供每个查询版本非常重要,而无需枚举,连字符或任何其他格式!用户查询:{{query}}

// 默认的提示词模版
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
        """
                Generate {{n}} different versions of a provided user query. \
                Each version should be worded differently, using synonyms or alternative sentence structures, \
                but they should all retain the original meaning. \
                These versions will be used to retrieve relevant documents. \
                It is very important to provide each query version on a separate line, \
                without enumerations, hyphens, or any additional formatting!
                User query: {{query}}"""
);

@Override
public Collection<Query> transform(Query query) {
	// 将提示词中query和n填充到提示词模版中
    Prompt prompt = createPrompt(query);
	// 与语言模型交互
    String response = chatLanguageModel.generate(prompt.text());
	// 通过换行分割并返回
    List<String> queries = parse(response);
    return queries.stream()
            .map(queryText -> query.metadata() == null
                    ? Query.from(queryText)
                    : Query.from(queryText, query.metadata()))
            .collect(toList());
}

protected Prompt createPrompt(Query query) {
    Map<String, Object> variables = new HashMap<>();
    variables.put("query", query.text());
    variables.put("n", n);
    return promptTemplate.apply(variables);
}

protected List<String> parse(String queries) {
    return stream(queries.split("\n"))
            .filter(Utils::isNotNullOrBlank)
            .collect(toList());
}

查询路由

此处以LanguageModelQueryRouter为例

默认的提示词翻译:基于用户查询,确定最合适的数据源以从以下选项中检索相关信息:{{options}}}您的答案由单个数字或多个由逗号分隔的数字非常重要!用户查询:{{query}}

public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
        """
                Based on the user query, determine the most suitable data source(s) \
                to retrieve relevant information from the following options:
                {{options}}
                It is very important that your answer consists of either a single number \
                or multiple numbers separated by commas and nothing else!
                User query: {{query}}"""
);

@Override
public Collection<ContentRetriever> route(Query query) {
    Prompt prompt = createPrompt(query);
    try {
        // 向大模型发起请求,并返回结果
        String response = chatLanguageModel.generate(prompt.text());
        return parse(response);
    } catch (Exception e) {
        log.warn("Failed to route query '{}'", query.text(), e);
        return fallback(query, e);
    }
}
​
// 通过提示模版,并将query和options替换提示模版中的占位符
protected Prompt createPrompt(Query query) {
    Map<String, Object> variables = new HashMap<>();
    variables.put("query", query.text());
    variables.put("options", options);
    return promptTemplate.apply(variables);
}

内容检索器

把一个或多个Query分配给合适的 ContentRetriever,因为问题会有多个所以涉及并并发执行检索

最后返回查询和内容的Map

private Map<Query, Collection<List<Content>>> process(Collection<Query> queries) {
    // 如果只有1个问题,获取第一个问题进行路由
    if (queries.size() == 1) {
        Query query = queries.iterator().next();
        // 通过[查询路由],获取到符合条件的内容检索器
        Collection<ContentRetriever> retrievers = queryRouter.route(query);
        if (retrievers.size() == 1) {
            // 只有一个返回内容
            ContentRetriever contentRetriever = retrievers.iterator().next();
            List<Content> contents = contentRetriever.retrieve(query);
            return singletonMap(query, singletonList(contents));
        } else if (retrievers.size() > 1) {
            // 多个检索器,异步获取所有内容
            Collection<List<Content>> contents = retrieveFromAll(retrievers, query).join();
            return singletonMap(query, contents);
        } else {
            return emptyMap();
        }
      // 多个问题
    } else if (queries.size() > 1) {
        Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
        queries.forEach(query -> {
            CompletableFuture<Collection<List<Content>>> futureContents =
                supplyAsync(() -> {
                        Collection<ContentRetriever> retrievers = queryRouter.route(query);
                        log(query, retrievers);
                        return retrievers;
                    },
                    executor
                // 先路由,再并行检索
                ).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
            queryToFutureContents.put(query, futureContents);
        });
        // 循环之后统一join
        return join(queryToFutureContents);
    } else {
        return emptyMap();
    }
}
​
private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers, Query query) {
    List<CompletableFuture<List<Content>>> futureContents = retrievers.stream()
        .map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
        .toList();
    return allOf(futureContents.toArray(new CompletableFuture[0]))
        .thenApply(ignored ->
            futureContents.stream()
                .map(CompletableFuture::join)
                .toList());
}
​
private static List<Content> retrieve(ContentRetriever retriever, Query query) {
    // 此处可以看到内容检索器ContentRetriever有2个实现
    // 一个是基于embeddingModel的EmbeddingStoreContentRetriever,通过向量模型匹配符合的内容
    // 另一个是基于webSearchEngine的WebSearchContentRetriever,在互联网上执行检索
    List<Content> contents = retriever.retrieve(query);
    log(query, retriever, contents);
    return contents;
}

内容聚合器

此处以ReRankingContentAggregator,内部使用了scoringModel,因为不是大语言模型所以此处没有提示词

我问了一下ai,为什么不直接交给 scoringModel,以下是他的回答,我摘略了一下

最终排序确实是 scoringModel (精排)在做,但中间加一层 fuse(粗排) 有两个好处

  • 降噪

    • scoringModel 是逐条打分的,如果输入全是嘈杂内容,性能和结果都会受影响

    • fuse 可以先做一次初步「多路检索融合」,过滤掉明显靠后的低质量结果

  • 公平性

    • scoringModel 的目标是「精排」,但它需要有一个 候选集

    • fuse 先把不同 Query/不同 Retriever 的结果整合 → 再交给 scoringModel 精排

  • 效率

    • scoringModel(通常是一个深度模型/BERT/Reranker)开销比 fuse 大很多

    • fuse 能帮忙缩小候选集,减少 scoringModel 的计算量

如果不使用scoringModel,使用的是DefaultContentAggregator,则只使用了粗排

private final Function<Map<Query, Collection<List<Content>>>, Query> querySelector;

@Override
public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {
    if (queryToContents.isEmpty()) {
        return emptyList();
    }
    // 先通过外部传递的lambda手动获取一个query
	// 所以用querySelector来选一个最优的Query(比如内容最多/上下文最相关的)
    Query query = querySelector.apply(queryToContents);
    // 将queryToContents数据结构转换为Map<Query, List<Content>>
    Map<Query, List<Content>> queryToFusedContents = fuse(queryToContents);
    // 使用倒数融合算法将所有结果融合为List
    List<Content> fusedContents = ReciprocalRankFuser.fuse(queryToFusedContents.values());
    if (fusedContents.isEmpty()) {
        return fusedContents;
    }
    // 使用scoringModel重新进行排序
    return reRankAndFilter(fusedContents, query);
}

protected List<Content> reRankAndFilter(List<Content> contents, Query query) {
    List<TextSegment> segments = contents.stream()
            .map(Content::textSegment)
            .collect(Collectors.toList());
	// 使用scoringModel
    List<Double> scores = scoringModel.scoreAll(segments, query.text()).content();
    Map<TextSegment, Double> segmentToScore = new HashMap<>();
    for (int i = 0; i < segments.size(); i++) {
        segmentToScore.put(segments.get(i), scores.get(i));
    }
	
    return segmentToScore.entrySet().stream()
			// 过滤过滤掉低分,按分数从高到低排序
            .filter(entry -> minScore == null || entry.getValue() >= minScore)
            .sorted(Map.Entry.<TextSegment, Double>comparingByValue().reversed())
			// 转成带有打分信息封装到Content
            .map(entry ->  Content.from(entry.getKey(), Map.of(RERANKED_SCORE, entry.getValue())))
			// 限制最大数量
            .limit(maxResults)
            .collect(Collectors.toList());
}

提示词注入器

实现很简单就是叫原本的提示词外封装了一层提示词

public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
        """
                {{userMessage}}
                
                Answer using the following information:
                {{contents}}"""
);

@Override
public ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {
    if (contents.isEmpty()) {
        return chatMessage;
    }
    Prompt prompt = createPrompt(chatMessage, contents);
    if (chatMessage instanceof UserMessage message && isNotNullOrBlank(message.name())) {
        return prompt.toUserMessage(message.name());
    }
    return prompt.toUserMessage();
}

消息淘汰

这段代码实现了智能的聊天记忆管理机制,确保模型能够维持上下文连贯的对话

  • 记忆存储策略

    • 通过memoryId区分不同用户的对话上下文,确保用户间的对话不会相互干扰

    • 将系统消息和用户消息添加到对应用户的聊天记忆中

    • 当消息数量超过限制时会自动淘汰旧消息

  • 消息列表构建

    • 如果配置了聊天记忆,直接从记忆中获取完整的消息历史

    • 如果没有配置记忆功能,创建临时消息列表,只包含当前对话的系统消息和用户消息

  • 上下文连续性

    • 通过维护消息历史,AI能够理解对话的上下文,提供更加连贯和相关的回复

    • 系统消息通常包含AI的角色定义和行为指导,确保AI始终按照预期的方式回应

// newProxyInstance的后续代码
	
	// 通过不同的MemoryId获取不同用户的消息
	if (context.hasChatMemory()) {
	    ChatMemory chatMemory = context.chatMemory(memoryId);
	    systemMessage.ifPresent(chatMemory::add);
	    // 消息淘汰
	    chatMemory.add(userMessage);
	}
	// 构建消息
	List<ChatMessage> messages;
	// 通过memoryId获取所有消息
	if (context.hasChatMemory()) {
	    messages = context.chatMemory(memoryId).messages();
	} else {
	    messages = new ArrayList<>();
	    // 将系统提示词添加到messages中
	    systemMessage.ifPresent(messages::add);
	    // 将用户提示词也添加到messages中
	    messages.add(userMessage);
	}


// AiServiceContext
// 指定了store就算memoryId存储在磁盘中,langchain4j依然会将数据存储在内存中
public ChatMemory chatMemory(Object memoryId) {
	// 执行构建AiServices的ChatMemoryProvider
    return chatMemories.computeIfAbsent(memoryId, ignored -> chatMemoryProvider.get(memoryId));
}

实现ChatMemoryStore可以自定义增删改查消息,实现getMessages、updateMessages、deleteMessages

@Bean
public AssistantUnique assistantUniqueStore(@Qualifier("ollamaChatModel") ChatLanguageModel languageModel,
                                            @Qualifier("ollamaStreamingChatModel") StreamingChatLanguageModel streamingChatLanguageModel,
											// 从ioc容器中获取
                                            PersistentChatMemoryStore store,
                                            ToolService toolService) {
    return AiServices.builder(AssistantUnique.class)
            .chatLanguageModel(languageModel)
            .streamingChatLanguageModel(streamingChatLanguageModel)
            // .tools(toolService)
            .chatMemoryProvider(memoryId -> MessageWindowChatMemory.builder()
                    .id(memoryId)
                    .maxMessages(100)
                    .chatMemoryStore(store)
                    .build())
            .build();
}

// langchain4j的提供的一个默认实现
// key是memoryId,value就是存储的消息
public class InMemoryChatMemoryStore implements ChatMemoryStore {
    // 内部使用十分简单,将所有消息都存储在ConcurrentHashMap中
    private final Map<Object, List<ChatMessage>> messagesByMemoryId = new ConcurrentHashMap<>();
​
    public InMemoryChatMemoryStore() {}
​
    @Override
    public List<ChatMessage> getMessages(Object memoryId) {
        return messagesByMemoryId.computeIfAbsent(memoryId, ignored -> new ArrayList<>());
    }
​
    @Override
    // 此处就是消息淘汰后调用,存储剩余的消息
    public void updateMessages(Object memoryId, List<ChatMessage> messages) {
        messagesByMemoryId.put(memoryId, messages);
    }
​
    @Override
    public void deleteMessages(Object memoryId) {
        messagesByMemoryId.remove(memoryId);
    }
}

构建不同的ChatMemory,langchain4j执行不同的淘汰方式

  • MessageWindowChatMemory 通过消息的数量进行淘汰

  • TokenWindowChatMemory 通过token进行淘汰

public class MessageWindowChatMemory implements ChatMemory {
​
    private static final Logger log = LoggerFactory.getLogger(MessageWindowChatMemory.class);
​
    private final Object id;
    private final Integer maxMessages;
    private final ChatMemoryStore store; // 默认使用了InMemoryChatMemoryStore
​
    private MessageWindowChatMemory(Builder builder) {
        this.id = ensureNotNull(builder.id, "id");
        this.maxMessages = ensureGreaterThanZero(builder.maxMessages, "maxMessages");
        this.store = ensureNotNull(builder.store, "store");
    }
​
    @Override
    public Object id() {
        return id;
    }
​
    @Override
    // 添加一条消息
    public void add(ChatMessage message) {
        // 获取所有的历史记录
        List<ChatMessage> messages = messages();
        if (message instanceof SystemMessage) {
            // 过滤系统提示词
            Optional<SystemMessage> systemMessage = findSystemMessage(messages);
            if (systemMessage.isPresent()) {
                if (systemMessage.get().equals(message)) {
                    return;
                } else {
                    messages.remove(systemMessage.get());
                }
            }
        }
        // 加到list最后
        messages.add(message);
        // 消息淘汰
        ensureCapacity(messages, maxMessages);
        // 将淘汰后剩余消息放入对应的store中
        store.updateMessages(id, messages);
    }
​
    private static Optional<SystemMessage> findSystemMessage(List<ChatMessage> messages) {
        return messages.stream()
                .filter(message -> message instanceof SystemMessage)
                .map(message -> (SystemMessage) message)
                .findAny();
    }
​
    @Override
    public List<ChatMessage> messages() {
        List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
        ensureCapacity(messages, maxMessages);
        return messages;
    }
​
    // 消息淘汰
    private static void ensureCapacity(List<ChatMessage> messages, int maxMessages) {
        // 消息总量若大于maxMessages则进行淘汰
        while (messages.size() > maxMessages) {
            // 从0开始判断,如果第0个是系统提示此则从第1个开始 
            int messageToEvictIndex = 0;
            if (messages.get(0) instanceof SystemMessage) {
                messageToEvictIndex = 1;
            }
            // 淘汰这个下标的消息
            ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
            log.trace("Evicting the following message to comply with the capacity requirement: {}", evictedMessage);
            // 移除的是否是AiMessage并且是否包含了工具执行结果
            if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
                // 继续移除其中的工具执行结果
                while (messages.size() > messageToEvictIndex
                        && messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
                    ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
                    log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
                }
            }
        }
    }
    // 工具,设置最大的消息配置
    public static MessageWindowChatMemory withMaxMessages(int maxMessages) {
        return builder().maxMessages(maxMessages).build();
    }
}

TokenWindowChatMemory的淘汰机制

private static void ensureCapacity(List<ChatMessage> messages, int maxTokens, Tokenizer tokenizer) {
    if (messages.isEmpty()) {
        return;
    }
    // 计算消息的token数
    int currentTokenCount = tokenizer.estimateTokenCountInMessages(messages);
    while (currentTokenCount > maxTokens) {
        int messageToEvictIndex = 0;
        if (messages.get(0) instanceof SystemMessage) {
            messageToEvictIndex = 1;
        }
        ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
        int tokenCountOfEvictedMessage = tokenizer.estimateTokenCountInMessage(evictedMessage);
        log.trace("Evicting the following message ({} tokens) to comply with the capacity requirement: {}",
                tokenCountOfEvictedMessage, evictedMessage);
        currentTokenCount -= tokenCountOfEvictedMessage;
        if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
            while (messages.size() > messageToEvictIndex
                    && messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
                ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
                log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
                currentTokenCount -= tokenizer.estimateTokenCountInMessage(orphanToolExecutionResultMessage);
            }
        }
    }
}

Tools

首先看构建AiServices的tools方法,正对工具进行解析,并且设置到AiService代理对象中

public AiServices<T> tools(Collection<Object> objectsWithTools) {
    context.toolService.tools(objectsWithTools);
    return this;
}
​
public void tools(Collection<Object> objectsWithTools) {
    if (toolProvider != null) {
        throw new IllegalArgumentException(
                "Either the tools or the tool provider can be configured, but not both!");
    }
    // 初始化toolExecutors和toolSpecifications
    initTools();
    // 循环每一个tools,外部可以传递多个tools
    for (Object objectWithTool : objectsWithTools) {
        if (objectWithTool instanceof Class) {
            throw illegalConfiguration("Tool '%s' must be an object, not a class", objectWithTool);
        }
        // 获取每一个tools的方法,拿到所有标注@Tool的方法
        for (Method method : objectWithTool.getClass().getDeclaredMethods()) {
            if (method.isAnnotationPresent(Tool.class)) {
                ToolSpecification toolSpecification = toolSpecificationFrom(method);
                // @Tool的name不能重复
                if (toolExecutors.containsKey(toolSpecification.name())) {
                    throw new IllegalConfigurationException(
                            "Duplicated definition for tool: " + toolSpecification.name());
                }
                toolExecutors.put(toolSpecification.name(), new DefaultToolExecutor(objectWithTool, method));
                toolSpecifications.add(toolSpecification);
            }
        }
    }
}
​
public static ToolSpecification toolSpecificationFrom(Method method) {
    Tool annotation = method.getAnnotation(Tool.class);
    // 如果注解没有name则使用方法的名称
    String name = isNullOrBlank(annotation.name()) ? method.getName() : annotation.name();
    String description = String.join("\n", annotation.value());
    if (description.isEmpty()) {
        description = null;
    }
    JsonObjectSchema parameters = parametersFrom(method.getParameters());
    return ToolSpecification.builder()
            .name(name)
            .description(description)
            .parameters(parameters)
            .build();
}
​
private static JsonObjectSchema parametersFrom(Parameter[] parameters) {
    Map<String, JsonSchemaElement> properties = new LinkedHashMap<>();
    List<String> required = new ArrayList<>();
    Map<Class<?>, JsonSchemaElementHelper.VisitedClassMetadata> visited = new LinkedHashMap<>();
    for (Parameter parameter : parameters) {
        // 如果包含了@ToolMemoryId跳过,和@MemoryId类似可以区分不同用户的同一个@Tool
        if (parameter.isAnnotationPresent(ToolMemoryId.class)) {
            continue;
        }
        // 判断是否包含@P,包含了给模型针对这个参数的描述
        boolean isRequired = Optional.ofNullable(parameter.getAnnotation(P.class))
                .map(P::required)
                .orElse(true);
        properties.put(parameter.getName(), jsonSchemaElementFrom(parameter, visited));
        if (isRequired) {
            required.add(parameter.getName());
        }
    }
    Map<String, JsonSchemaElement> definitions = new LinkedHashMap<>();
    visited.forEach((clazz, visitedClassMetadata) -> {
        if (visitedClassMetadata.recursionDetected) {
            definitions.put(visitedClassMetadata.reference, visitedClassMetadata.jsonSchemaElement);
        }
    });
    if (properties.isEmpty()) {
        return null;
    }
    return JsonObjectSchema.builder()
            .addProperties(properties)
            .required(required)
            .definitions(definitions.isEmpty() ? null : definitions)
            .build();
}

后续代码

                    // 构建Chat请求参数
                    ChatRequestParameters parameters = ChatRequestParameters.builder()
                      			// 将调用构建AiService设置的工具赋值给parameters
                      			// 参照 [Tools]
                            .toolSpecifications(toolExecutionContext.toolSpecifications())
                            .responseFormat(responseFormat)
                            .build();
                    ChatRequest chatRequest = ChatRequest.builder()
                            .messages(messages)
                            .parameters(parameters)
                            .build();
                  	// 向Ai发起chat请求
                    ChatResponse chatResponse = context.chatModel.chat(chatRequest);
                    verifyModerationIfNeeded(moderationFuture);
                  	// 通过chatResponse判断模型是否有工具执行请求,不断循环执行,直到模型不在发起工具执行请求
                    ToolExecutionResult toolExecutionResult = context.toolService.executeInferenceAndToolsLoop(
                            chatResponse,
                            parameters,
                            messages,
                            context.chatModel,
                            context.hasChatMemory() ? context.chatMemory(memoryId) : null,
                            memoryId,
                            toolExecutionContext.toolExecutors());
                    chatResponse = toolExecutionResult.chatResponse();
                    FinishReason finishReason = chatResponse.metadata().finishReason();
                    Response<AiMessage> response = Response.from(
                            chatResponse.aiMessage(), toolExecutionResult.tokenUsageAccumulator(), finishReason);
                    Object parsedResponse = serviceOutputParser.parse(response, returnType);
                    if (typeHasRawClass(returnType, Result.class)) {
                        return Result.builder()
                                .content(parsedResponse)
                          			// tokenUsage是模型token的调用情况,这里只计算了tool的token
                          			// chat模 型的token在对应的大模型调用时计算,最终封装在ChatResponseMetadata内
                                .tokenUsage(toolExecutionResult.tokenUsageAccumulator())
                                .sources(augmentationResult == null ? null : augmentationResult.contents())
                                .finishReason(finishReason)
                                .toolExecutions(toolExecutionResult.toolExecutions())
                                .build();
                    } else {
                        return parsedResponse;
                    }
                }