从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

怎么部署自定义模型请看:从零开发短视频电商 在AWS上用SageMaker部署自定义模型

  • 都是huaggingface上的模型或者fine-tune后的。

为了适配jumpstart上部署的模型的http输入输出,我在自定义模型中自定义了适配的输入输出,可以做到兼容适配

code/inference.py

  • 容器的原始代码入口:https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/handler_service.py
  • 默认支持的decode和encode:https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
  • 可以用这个在sagemaker上使用jupyterlab:https://github.com/huggingface/notebooks/blob/main/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb

我们自定义的逻辑如下

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import json
import logging
// --------- 这块
logger = logging.getLogger()
logger.setLevel(logging.INFO)
// 自定义http输入,可以适配不同的content_type ,打印输入的日志
// 源码参见下面的 preprocess
def input_fn(input_data, content_type):logger.info(f"laker input_data {input_data} and content_type {content_type}")if content_type == "application/json":request = json.loads(input_data)elif content_type == "application/x-text":request = {"inputs": input_data.decode('utf-8')}else:request = {"inputs": input_data} logger.info(f"laker input_fn request {request} ")return request
// 自定义输出
def output_fn(prediction, accept):return encode_json(prediction)  // 来自https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py#L102C1-L113C6class _JSONEncoder(json.JSONEncoder):def default(self, obj):if isinstance(obj, np.integer):return int(obj)elif isinstance(obj, np.floating):return float(obj)elif hasattr(obj, "tolist"):return obj.tolist()elif isinstance(obj, datetime.datetime):return obj.__str__()elif isinstance(obj, Image.Image):with BytesIO() as out:obj.save(out, format="PNG")png_string = out.getvalue()return base64.b64encode(png_string).decode("utf-8")else:return super(_JSONEncoder, self).default(obj)def encode_json(content):"""encodes json with custom `JSONEncoder`"""return json.dumps(content,ensure_ascii=False,allow_nan=False,indent=None,cls=_JSONEncoder,separators=(",", ":"),)
// --------- 这块  end ---# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):token_embeddings = model_output[0] #First element of model_output contains all token embeddingsinput_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)def model_fn(model_dir):# Load model from HuggingFace Hubtokenizer = AutoTokenizer.from_pretrained(model_dir)model = AutoModel.from_pretrained(model_dir)return model, tokenizerdef predict_fn(data, model_and_tokenizer):# destruct model and tokenizermodel, tokenizer = model_and_tokenizer# Tokenize sentencessentences = data.pop("inputs", data)encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')# Compute token embeddingswith torch.no_grad():model_output = model(**encoded_input)# Perform poolingsentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])# Normalize embeddingssentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)# return dictonary, which will be json serializablereturn {"embedding": sentence_embeddings[0].tolist()}
import logging
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoderlogger = logging.getLogger(__name__)def preprocess(self, input_data, content_type, context=None):"""The preprocess handler is responsible for deserializing the input data intoan object for prediction, can handle JSON.The preprocess handler can be overridden for data or feature transformation.Args:input_data: the request payload serialized in the content_type format.content_type: the request content_type.context (obj): metadata on the incoming request data (default: None).Returns:decoded_input_data (dict): deserialized input_data into a Python dictonary."""# raises en error when using zero-shot-classification or table-question-answering, not possible due to nested propertiesif (os.environ.get("HF_TASK", None) == "zero-shot-classification"or os.environ.get("HF_TASK", None) == "table-question-answering") and content_type == content_types.CSV:raise PredictionException(f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type",400,)decoded_input_data = decoder_encoder.decode(input_data, content_type)return decoded_input_datalogger.info(f"param1 {batch_size} and param2 {sequence_length}")def predict(self, data, model, context=None):"""The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.The predict handler can be overridden to implement the model inference.Args:data (dict): deserialized decoded_input_data returned by the input_fnmodel : Model returned by the `load` method or if it is a custom module `model_fn`.context (obj): metadata on the incoming request data (default: None).Returns:obj (dict): prediction result."""# pop inputs for pipelineinputs = data.pop("inputs", data)parameters = data.pop("parameters", None)# pass inputs with all kwargs in dataif parameters is not None:prediction = model(inputs, **parameters)else:prediction = model(inputs)return predictiondef postprocess(self, prediction, accept, context=None):"""The postprocess handler is responsible for serializing the prediction result tothe desired accept type, can handle JSON.The postprocess handler can be overridden for inference response transformation.Args:prediction (dict): a prediction result from predict.accept (str): type which the output data needs to be serialized.context (obj): metadata on the incoming request data (default: None).Returns: output data serialized"""return decoder_encoder.encode(prediction, accept)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/293917.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为 WATCH GT 4 跨越想象的边界,打造智慧生活新体验

颜值新高度,健康更全面!华为 WATCH GT 4 颜值超能打,表盘随心定义,健康管理再升级身体状况更有数,超长续航给足安全感。跨越想象的边界,打造智慧生活新体验!

ModuleNotFoundError: No module named ‘tensorflow‘

直接运行pip install tensorflow安装成功之后,发现版本是tensorflow2.15.0 python的版本是3.9版本 导入包:import tensorflow 打包xxx.exe,调用之后提示错误 ModuleNotFoundError: No module named tensorflow 最后发现特定的python的版本对应特定的t…

生物系统学中的进化树构建和分析R工具包V.PhyloMaker2的介绍和详细使用

V.PhyloMaker2是一个R语言的工具包,专门用于构建和分析生物系统学中的进化树(也称为系统发育树或phylogenetic tree)。以下是对V.PhyloMaker2的一些基本介绍和使用说明: 论文介绍:V.PhyloMaker2: An updated and enla…

Ubuntu及Docker 安装rabbitmq

安装ubuntu 前 先暴露端口: 5672 用于与mq服务器通信用 15672 管理界面使用的端口 docker命令:docker run -itd --name ubuntu -p 5672:5672 -p 15672:15672 ubuntu 进入docker : docker exec -it ubuntu /bin/bash 步骤: 1. 更新安装源…

大数据Doris(三十九):Duplicate 模型中的 ROLLUP

文章目录 Duplicate 模型中的 ROLLUP 一、前缀索引

Chrome浏览器http自动跳https问题

现象: Chrome浏览器访问http页面时有时会自动跳转https,导致一些问题。比如: 开发阶段访问dev环境网址跳https,后端还是http,导致接口跨域。 复现: 先访问http网址,再改成https访问&#xf…

算法练习Day19 (Leetcode/Python-二叉树)

108. Convert Sorted Array to Binary Search Tree Given an integer array nums where the elements are sorted in ascending order, convert it to a height-balanced binary search tree. 思路: 一个高度平衡二叉树是指一个二叉树每个节点 的左右两个子树的…

conda环境下更改虚拟环境安装路径

1 引言 在Anaconda中如果没有指定路径,虚拟环境会默认安装在anaconda所安装的目录下,但如果默认环境的磁盘空间不足,无法满足大量安装虚拟环境的需求,此时我们需要更改虚拟环境的安装路径,有以下两种方案: 方案1: 每次…

【Amazon 实验②】使用Amazon WAF做基础 Web Service 防护之自定义规则

文章目录 1. 自定义规则1.1 介绍 2. 实验步骤2.1 测试2.2 输出 上一篇章介绍了使用Amazon WAF做基础 Web Service 防护中的Web ACLs 配置 & AWS 托管规则的介绍和演示操作 【Amazon 实验①】使用Amazon WAF做基础 Web Service 防护,本篇章将继续介绍关于自定义…

微软的word文档中内置背景音乐步骤(打开自动播放)

目录 一、前言 二、操作步骤 一、前言 有时候需要在word文档里面打开的时候就自动播放音乐或者音频,那么可以用微软的word来按照操作步骤去这样完成。 如果没有微软office的,可以下载这个是2021专业版的。因为office只能免费使用一段时间&#xff0c…

使用 Elasticsearch 检测抄袭 (一)

作者:Priscilla Parodi 抄袭可以是直接的,涉及复制部分或全部内容,也可以是释义的,即通过更改一些单词或短语来重新表述作者的作品。 灵感和释义之间是有区别的。 即使你得出类似的结论,也可以阅读内容,获得…

【机器学习】【线性回归】梯度下降

文章目录 [toc]数据集实际值估计值估计误差代价函数学习率参数更新Python实现线性拟合结果代价结果 数据集 ( x ( i ) , y ( i ) ) , i 1 , 2 , ⋯ , m \left(x^{(i)} , y^{(i)}\right) , i 1 , 2 , \cdots , m (x(i),y(i)),i1,2,⋯,m 实际值 y ( i ) y^{(i)} y(i) 估计值 h …