本文基于https://github.com/datawhalechina/self-llm/blob/master/GLM-4/02-GLM-4-9B-chat%20langchain%20%E6%8E%A5%E5%85%A5.md提供的教程。由于使用本地部署的大模型,在继承LangChain中的LLM类时需要重写几个函数。
但是在具体测试的时候出现了以下的错误
/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1659: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.warnings.warn(
Traceback (most recent call last):File "/root/autodl-tmp/glm4LLM.py", line 63, in <module>print(llm.invoke("你是谁"))^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 276, in invokeself.generate_prompt(File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 633, in generate_promptreturn self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 803, in generateoutput = self._generate_helper(^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 670, in _generate_helperraise eFile "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 657, in _generate_helperself._generate(File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 1322, in _generateself._call(prompt, stop=stop, run_manager=run_manager, **kwargs)File "/root/autodl-tmp/glm4LLM.py", line 40, in _callgenerated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_contextreturn func(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 1758, in generateresult = self._sample(^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2397, in _sampleoutputs = self(^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_implreturn self._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_implreturn forward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 1005, in forwardtransformer_outputs = self.transformer(^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_implreturn self._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_implreturn forward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 887, in forwardinputs_embeds = self.embedding(input_ids)^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_implreturn self._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_implreturn forward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 823, in forwardwords_embeddings = self.word_embeddings(input_ids)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_implreturn self._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_implreturn forward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 163, in forwardreturn F.embedding(^^^^^^^^^^^^File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/functional.py", line 2264, in embeddingreturn torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
错误原因主要是因为input_ids(输入数据)与model(模型)所在设备不一致。
经过修改成下面的代码可以成功运行,主要修改了input_ids对应语句。
from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torchclass ChatGLM4_LLM(LLM):# 基于本地 ChatGLM4 自定义 LLM 类tokenizer: AutoTokenizer = Nonemodel: AutoModelForCausalLM = Nonegen_kwargs: dict = Nonedef __init__(self, model_name_or_path: str, gen_kwargs: dict = None):super().__init__()print("正在从本地加载模型...")self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,torch_dtype=torch.bfloat16,trust_remote_code=True,device_map="auto").eval()print("完成本地模型的加载")if gen_kwargs is None:gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}self.gen_kwargs = gen_kwargsdef _call(self, prompt: str, stop: Optional[List[str]] = None,run_manager: Optional[CallbackManagerForLLMRun] = None,**kwargs: Any) -> str:messages = [{"role": "user", "content": prompt}]model_inputs = self.tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", return_dict=True, add_generation_prompt=True)# 将input_ids移动到与模型相同的设备device = next(self.model.parameters()).devicemodel_inputs = {key: value.to(device) for key, value in model_inputs.items()}generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids)]response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]return response@propertydef _identifying_params(self) -> Dict[str, Any]:"""返回用于识别LLM的字典,这对于缓存和跟踪目的至关重要。"""return {"model_name": "glm-4-9b-chat","max_length": self.gen_kwargs.get("max_length"),"do_sample": self.gen_kwargs.get("do_sample"),"top_k": self.gen_kwargs.get("top_k"),}@propertydef _llm_type(self) -> str:return "glm-4-9b-chat"