关于BGE-M3接入LangChain时遇到的问题与解决方法

news/2024/10/6 4:26:51/文章来源:https://www.cnblogs.com/tarorat/p/18286378

本文基于https://github.com/datawhalechina/self-llm/blob/master/GLM-4/02-GLM-4-9B-chat%20langchain%20%E6%8E%A5%E5%85%A5.md提供的教程。由于使用本地部署的大模型,在继承LangChain中的LLM类时需要重写几个函数。

但是在具体测试的时候出现了以下的错误

/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1659: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.warnings.warn(
Traceback (most recent call last):File "/root/autodl-tmp/glm4LLM.py", line 63, in <module>print(llm.invoke("你是谁"))^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 276, in invokeself.generate_prompt(File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 633, in generate_promptreturn self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 803, in generateoutput = self._generate_helper(^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 670, in _generate_helperraise eFile "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 657, in _generate_helperself._generate(File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 1322, in _generateself._call(prompt, stop=stop, run_manager=run_manager, **kwargs)File "/root/autodl-tmp/glm4LLM.py", line 40, in _callgenerated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_contextreturn func(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 1758, in generateresult = self._sample(^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2397, in _sampleoutputs = self(^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_implreturn self._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_implreturn forward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 1005, in forwardtransformer_outputs = self.transformer(^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_implreturn self._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_implreturn forward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 887, in forwardinputs_embeds = self.embedding(input_ids)^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_implreturn self._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_implreturn forward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 823, in forwardwords_embeddings = self.word_embeddings(input_ids)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_implreturn self._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_implreturn forward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 163, in forwardreturn F.embedding(^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/functional.py", line 2264, in embeddingreturn torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

错误原因主要是因为input_ids(输入数据)与model(模型)所在设备不一致。

经过修改成下面的代码可以成功运行,主要修改了input_ids对应语句。

from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torchclass ChatGLM4_LLM(LLM):# 基于本地 ChatGLM4 自定义 LLM 类tokenizer: AutoTokenizer = Nonemodel: AutoModelForCausalLM = Nonegen_kwargs: dict = Nonedef __init__(self, model_name_or_path: str, gen_kwargs: dict = None):super().__init__()print("正在从本地加载模型...")self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,torch_dtype=torch.bfloat16,trust_remote_code=True,device_map="auto").eval()print("完成本地模型的加载")if gen_kwargs is None:gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}self.gen_kwargs = gen_kwargsdef _call(self, prompt: str, stop: Optional[List[str]] = None,run_manager: Optional[CallbackManagerForLLMRun] = None,**kwargs: Any) -> str:messages = [{"role": "user", "content": prompt}]model_inputs = self.tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", return_dict=True, add_generation_prompt=True)# 将input_ids移动到与模型相同的设备device = next(self.model.parameters()).devicemodel_inputs = {key: value.to(device) for key, value in model_inputs.items()}generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids)]response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]return response@propertydef _identifying_params(self) -> Dict[str, Any]:"""返回用于识别LLM的字典,这对于缓存和跟踪目的至关重要。"""return {"model_name": "glm-4-9b-chat","max_length": self.gen_kwargs.get("max_length"),"do_sample": self.gen_kwargs.get("do_sample"),"top_k": self.gen_kwargs.get("top_k"),}@propertydef _llm_type(self) -> str:return "glm-4-9b-chat"

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/738887.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mybatis PageHelper编译SQL引发的一次性能问题.18286262

起源 最近一直在跟大佬们做公司项目的性能优化,我这种小卡乐咪基本上负责的就是慢接口优化,但实际上只有以下几种情况需要进行接口代码级别的改造:循环查库、RPC 数据库设计不合理 业务流程太长,代码耦合性太高等随着对接口分析的深入,我们越来越发现系统中有很多拖后腿的…

蓝牙音箱App设计总结

前言 最近做了一个关于带Sound bar的智能电视的蓝牙项目,就是将电视Sound bar当作蓝牙音箱,将手机、电脑等设备的声音传输到电视,通过电视Soundbar播放声音。做这个项目的时候遇到了各种大大小小的问题,好在都解决了。本篇文章总结了在设计蓝牙相关的项目时需要了解的小知识…

设计模式学习(二)工厂模式——抽象工厂模式

介绍抽象工厂模式,并说明其优缺点目录背景抽象工厂模式优点与缺点 背景 现在我需要开发一个相机操作模块,它可能在Windows下运行,也可能在Linux下运行。由于在厂家提供的SDK中,Windows下的SDK和Linux下的SDK是有区别的,因此我们要创建两个类去封装这两个不同平台下的API。…

aippt 实现原理 AI生成PPT开源项目

AI生成PPT原理与代码实现通过 AI 生成 PPT 火了好长一段时间了,该类型产品也越来越多,我分析了几个主流的 aippt 产品,其中有一家公司的技术原理让我眼前一亮:文多多 AI 生成 PPT,官网: https://docmee.cn 该产品在 github 上有对应开源项目:https://github.com/veasion…

自定义流程表单开发优势体现在什么地方?

一起来了解自自定义流程表单开发的优势特点。提质、增效、降本,应该是很多职场办公需要实现的发展目标。那么,应用什么样的软件平台可以实现?低代码技术平台、自定义流程表单开发是目前流行于职场行业中的软件产品,可视化操作界面、够灵活、易维护等优势特点明显,在推进企…

Matlab马尔可夫链蒙特卡罗法(MCMC)估计随机波动率(SV,Stochastic Volatility) 模型|附代码数据

全文下载链接:http://tecdat.cn/?p=16708 最近我们被客户要求撰写关于随机波动率的研究报告,包括一些图形和统计输出。 波动率是一个重要的概念,在金融和交易中有许多应用。它是期权定价的基础。波动率还可以让您确定资产配置并计算投资组合的风险价值 (VaR) 甚至波动率本身…

2024.7.5 鲜花

菜就多练空白とカタルシス——TOGENASHI TOGEARI。震惊,K某He 强推竟然是这首歌,三天重复上百遍…… どれだけ手に入れても どれだけ自分のものにしてもしてもしても 追いつけないな 高望みしすぎなんて 腐ったような言葉 誰しも誰よりも優れて欲しくはないんだよ 理由はただ…

泛娱乐出海新风口,视频云技术需要怎样的融合创新?

泛娱乐的音视频技术随着出海在演进,交互和内容的技术是内核,也在融合。 泛娱乐的音视频技术随着出海在演进,交互和内容的技术是内核,也在融合。 面向出海,虽然娱乐社交这个行业由来已久,但近几年的商业模式发生了巨大变化,比如行业刚兴起时,大家要先把DAU做大之后再…

米尔瑞米派集聚5种操作系统,兼顾学习开发和项目产品需要的派

米尔电子发布的瑞萨第一款MPU生态板卡-瑞米派(Remi Pi),采用瑞萨RZ/G2L双核A55芯片,接口丰富,全面兼容树莓派的扩展模块。瑞米派支持五种系统,兼顾学习开发和项目产品需要。软件提供五种软件系统分别为:基于Yocto构建的两种系统,一种是支持通用功能的精简型系统,另一种…

echarts中Label标签与数据项颜色设置为同一种颜色

echarts5中默认标签颜色不会跟数据项颜色保持一致,而是全都是黑色。想要实现label颜色和它的数据项颜色一致,需要手动继承颜色,设置label{ color: inherit}即可解决label标签颜色与数据项颜色一致。 https://echarts.apache.org/examples/zh/editor.html?c=pie-simple 注意…

GaussDB AI新特性:gs_index_advise推荐索引

GaussDB的AI新特性,可以把SQL文本嵌套在函数中,数据库会返回一个创建合适索引的列gs_index_advise(text) 描述:针对单条查询语句推荐索引。 参数: SQL语句字符串 返回值类型: record 一、通常的SQL优化会通过参数调优的方式进行调整,例如如下参数set enable_fast_query_s…

Packing Python to exe(打包Python成EXE文件)

Python文件要执行需要Python环境,如果package成EXE文件则可以随意放在任意主机上去执行。package步骤如下: 1. 安装pythoninstaller (pip install pyinstaller) 2.安装auto-py-to-exe(有UI界面,很方便使用)(pip install auto-py-to-exe) 3.然后直接运行命令auto-py-to-e…