从零构建深度学习推理框架-7 计算图的表达式

什么是表达式

表达式就是一个计算过程,类似于如下:

output_mid = input1 + input2
output = output_mid * input3

用图形来表达就是这样的。

但是在PNNX的表达式(Experssion Layer)中不是这个样子,而是以一种抽象得方式,替换掉输入张量改为@1,@2等等

所以上面的计算图也就变成了

add(@0,mul(@1,@2))

我们是希望把这个抽象的表达式变回到一个方便后端执行的计算过程(抽象的语法树来表达,在推理的时候转为逆波兰式)。

其中addmul表示我们上一节中说到的RuntimeOperator, @0@1表示我们上一节课中说道的RuntimeOperand. 这个抽象表达式看起来比较简单,但是实际上情况会非常复杂,我们给出一个复杂的例子:

add(add(mul(@0,@1),mul(@2,add(add(add(@0,@2),@3),@4))),@5)

这就要求我们需要一个鲁棒的表达式解析和语法树构建功能。

词法解析:

词法解析的目的就是将add(@0,mul(@1,@2))拆分为多个token,token依次为add ( @0 , mul等.代码如下:

enum class TokenType {TokenUnknown = -1,TokenInputNumber = 0,TokenComma = 1,TokenAdd = 2,TokenMul = 3,TokenLeftBracket = 4,TokenRightBracket = 5,
};struct Token {TokenType token_type = TokenType::TokenUnknown;int32_t start_pos = 0; //词语开始的位置int32_t end_pos = 0; // 词语结束的位置//比如add就是 start_pos = 0 , end_pos = 2Token(TokenType token_type, int32_t start_pos, int32_t end_pos): token_type(token_type), start_pos(start_pos), end_pos(end_pos) {}
};

我们在TokenType中规定了Token的类型,类型有输入、加法、乘法以及左右括号等.Token类中记录了类型以及Token在字符串的起始和结束位置.

这样就把表达式变成了多个token的一个数组。

如下的代码是具体的解析过程,我们将输入(也就是:add(@0,mul(@1,@2)))存放在statement_中,首先是判断statement_是否为空, 随后删除表达式中的所有空格和制表符。

if (!need_retoken && !this->tokens_.empty()) {return;}CHECK(!statement_.empty()) << "The input statement is empty!";statement_.erase(std::remove_if(statement_.begin(), statement_.end(), [](char c) {return std::isspace(c);}), statement_.end());CHECK(!statement_.empty()) << "The input statement is empty!";

然后对于statement,我们遍历所有的表达式,要开始将这个statement拆成多个token啦!

for (int32_t i = 0; i < statement_.size();) {char c = statement_.at(i);if (c == 'a') {CHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'd')<< "Parse add token failed, illegal character: " << c;CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'd')<< "Parse add token failed, illegal character: " << c;Token token(TokenType::TokenAdd, i, i + 3);tokens_.push_back(token);std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3);token_strs_.push_back(token_operation);i = i + 3;} }

char c就是当前的字符 如果这个等于a的话,那么由于我们的词法规定了以a开头的只有add,所以我们必须判断接下来的两个字符是不是'd','d',如果不是的话就报错,如果是的话就初始化一个新token保存。

同理:

else if (c == 'm') {CHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'u')<< "Parse add token failed, illegal character: " << c;CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'l')<< "Parse add token failed, illegal character: " << c;Token token(TokenType::TokenMul, i, i + 3);tokens_.push_back(token);std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3);token_strs_.push_back(token_operation);i = i + 3;} 

也只有mul这一种可能。

如果是一个操作数的话:

else if (c == '@') {CHECK(i + 1 < statement_.size() && std::isdigit(statement_.at(i + 1)))<< "Parse number token failed, illegal character: " << c;int32_t j = i + 1;for (; j < statement_.size(); ++j) {if (!std::isdigit(statement_.at(j))) {break;}}Token token(TokenType::TokenInputNumber, i, j);CHECK(token.start_pos < token.end_pos);tokens_.push_back(token);std::string token_input_number = std::string(statement_.begin() + i, statement_.begin() + j);token_strs_.push_back(token_input_number);i = j;} 

那就是在@后只要是数字就一直读。读完之后组成新的token。

else if (c == ',') {Token token(TokenType::TokenComma, i, i + 1);tokens_.push_back(token);std::string token_comma = std::string(statement_.begin() + i, statement_.begin() + i + 1);token_strs_.push_back(token_comma);i += 1;} else if (c == '(') {Token token(TokenType::TokenLeftBracket, i, i + 1);tokens_.push_back(token);std::string token_left_bracket = std::string(statement_.begin() + i, statement_.begin() + i + 1);token_strs_.push_back(token_left_bracket);i += 1;} else if (c == ')') {Token token(TokenType::TokenRightBracket, i, i + 1);tokens_.push_back(token);std::string token_right_bracket = std::string(statement_.begin() + i, statement_.begin() + i + 1);token_strs_.push_back(token_right_bracket);i += 1;} else {LOG(FATAL) << "Unknown  illegal character: " << c;}

其他输入符也是如此。要是不再我们所列的项当中就报错。

这样我们就可以得到一个抽象的语法树。

语法解析:

语法解析的过程是递归向下的,定义在Generate_函数中.

 通过这个语法树中序遍历left、right就可以得到具体的一个计算的过程。0 mul 1 add 0 mul 1

我们这里用一个例子来讲解:

add(@0,@1)这个例子.输入到Generate_函数中, 是一个token数组.

Generate_数组首先检查第一个输入是否为add,mul或者是input number中的一种.

CHECK(current_token.token_type == TokenType::TokenInputNumber|| 
current_token.token_type == TokenType::TokenAdd || current_token.token_type == TokenType::TokenMul);

那这里为什么不判断第一个不是left bracket token(左括号)或)(右括号)呢?

因为这个一般只会是以add,mul或者光一个数字@0。

第一个输入add,所以我们需要判断其后是否是left bracket来判断合法性, 如果合法则构建左子树.

else if (current_token.token_type == TokenType::TokenMul || current_token.token_type == TokenType::TokenAdd) {std::shared_ptr<TokenNode> current_node = std::make_shared<TokenNode>();//组枝起来一个节点current_node->num_index = -int(current_token.token_type);index += 1;//到左括号 因为add之后的的token一定到左括号 不对就报错CHECK(index < this->tokens_.size());CHECK(this->tokens_.at(index).token_type == TokenType::TokenLeftBracket);index += 1;//左括号之后一定是一个操作数CHECK(index < this->tokens_.size());const auto left_token = this->tokens_.at(index);//token当前是@0这个tokenif (left_token.token_type == TokenType::TokenInputNumber|| left_token.token_type == TokenType::TokenAdd || left_token.token_type == TokenType::TokenMul) {
//递归调用current_node->left = Generate_(index);}

处理下一个token, 构建左子树.

if (current_token.token_type == TokenType::TokenInputNumber) {uint32_t start_pos = current_token.start_pos + 1;uint32_t end_pos = current_token.end_pos;CHECK(end_pos > start_pos);CHECK(end_pos <= this->statement_.length());const std::string &str_number =std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);return std::make_shared<TokenNode>(std::stoi(str_number), nullptr, nullptr);}

递归进入左子树后,判断是TokenType::TokenInputNumber则返回一个新的TokenNode到add token成为左子树.

检查下一个token是否为逗号,也就是在add(@0,@1)的@0是否为,

CHECK(this->tokens_.at(index).token_type == TokenType::TokenComma);index += 1;CHECK(index < this->tokens_.size());

下一步是构建add token的右子树

index += 1;CHECK(index < this->tokens_.size());const auto right_token = this->tokens_.at(index);if (right_token.token_type == TokenType::TokenInputNumber|| right_token.token_type == TokenType::TokenAdd || right_token.token_type == TokenType::TokenMul) {current_node->right = Generate_(index);} else {LOG(FATAL) << "Unknown token type: " << int(left_token.token_type);}index += 1;CHECK(index < this->tokens_.size());CHECK(this->tokens_.at(index).token_type == TokenType::TokenRightBracket);return current_node;
current_node->right = Generate_(index); /// 构建add(@0,@1)中的右子树

Generate_(index)递归进入后遇到的token是@1 token,因为是Input Number类型所在构造TokenNode后返回.

if (current_token.token_type == TokenType::TokenInputNumber) {uint32_t start_pos = current_token.start_pos + 1;uint32_t end_pos = current_token.end_pos;CHECK(end_pos > start_pos);CHECK(end_pos <= this->statement_.length());const std::string &str_number =std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);return std::make_shared<TokenNode>(std::stoi(str_number), nullptr, nullptr);}

之后检查右括号在不在:

index += 1;CHECK(index < this->tokens_.size());CHECK(this->tokens_.at(index).token_type == TokenType::TokenRightBracket);return current_node;} else {LOG(FATAL) << "Unknown token type: " << int(current_token.token_type);}

至此, add语句的抽象语法树构建完成.

struct TokenNode {int32_t num_index = -1;std::shared_ptr<TokenNode> left = nullptr;std::shared_ptr<TokenNode> right = nullptr;TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left, std::shared_ptr<TokenNode> right);TokenNode() = default;
};

在上述结构中, left存放的是@0表示的节点, right存放的是@1表示的节点

我们再弄一个更复杂一些的例子:

add(mul(@0,@1),@2)

  • add
  • (
  • mul
  • (
  • @0
  • ,
  • @1
  • )
  • ,
  • @2
  • )
  • index = 0, 当前遇到的tokenadd, 调用层为1
  • index = 1, 根据以上的流程,我们期待add token之后的tokenleft bracket, 否则就报错. 调用层为1
  • 开始递归调用,构建add的左子树.从层1进入层2
  • index = 2, 遇到了mul token. 调用层为2.
  • index = 3, 根据以上的流程,我们期待mul token之后的token是第二个left bracket. 调用层为2.
  • 开始递归调用用来构建mul token的左子树.
  • index = 4, 遇到@0,进入递归调用,进入层3, 但是因为操作数都是叶子节点,构建好之后就直接返回了,得到mul token的左子节点.放在mul tokenleft指针上.
  • index = 5, 我们希望遇到一个逗号,否则就报错mul(@0,@1)中中间的逗号.调用层为2.
  • index = 6, 遇到@2,进入递归调用,进入层3, 但是因为操作数是叶子节点, 构建好之后就直接返回到2,得到mul token的右子节点.
  • index = 7, 我们希望遇到一个右括号,就是mul(@1,@2)中的右括号.调用层为2.
  • 到现在为止mul token已经构建完毕,返回形成add token的左子节点,add token的left指针指向构建完毕的mul树. 返回到调用层1.
    ...
  • add token开始构建right token,但是因为@2是一个输入操作数,所以直接递归就返回了,至此得到add的右子树,并用right指针指向.

这个东西最厉害的地方就在于,括号里面一定是一个新的节点!

Experssion Layer的实现(如何实现@0 + @1):

Expression Operator的定义

class ExpressionOp : public Operator {public:explicit ExpressionOp(const std::string &expr);std::vector<std::shared_ptr<TokenNode>> Generate();private:std::unique_ptr<ExpressionParser> parser_;std::vector<std::shared_ptr<TokenNode>> nodes_;std::string expr_;
};

其中expr_表示表达式字符串, nodes_表示经过逆波兰变换之后得到的节点.

Expression Layer的定义

class ExpressionLayer : public Layer {public:explicit ExpressionLayer(const std::shared_ptr<Operator> &op);void Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs,std::vector<std::shared_ptr<Tensor<float>>> &outputs) override;private:std::unique_ptr<ExpressionOp> op_;
};

初始化Expression Layer

ExpressionLayer::ExpressionLayer(const std::shared_ptr<Operator> &op) : Layer("Expression") {CHECK(op != nullptr && op->op_type_ == OpType::kOperatorExpression);ExpressionOp *expression_op = dynamic_cast<ExpressionOp *>(op.get());CHECK(expression_op != nullptr) << "Expression operator is empty";this->op_ = std::make_unique<ExpressionOp>(*expression_op);
}

 

Expression Layer中的输入排布

 

Expression Layer的输入中, 多个输入依次排布. 如果batch_size的大小为4, 则上图中input1中的元素数量为4, input2的元素数量也为4. 换句话说, input1中的数据都来源于操作数1(operand 1), input2中的数据都来源于操作数2(operand 2).

将数据存放到input1input2的实现如下:

int batch_size = 4;for (int i = 0; i < batch_size; ++i) {std::shared_ptr<ftensor> input = std::make_shared<ftensor>(3, 224, 224);input->Fill(1.f);inputs.push_back(input);}for (int i = 0; i < batch_size; ++i) {std::shared_ptr<ftensor> input = std::make_shared<ftensor>(3, 224, 224);input->Fill(2.f);inputs.push_back(input);}

inputs被分为两段, 前半段存放input1, 前半段的长度为4. 后半段存放input2, 后半段的长度为4.

计算的结果存放在outputs, 8个输入数据两两相加, 最后的输出数据大小等于4.

Expression Layer的计算过程

数据排布

 

第一个例子

 

已知有如上的数据存储排布, 在本节中我们将讨论如何根据现有的数据完成add(@0,@1)计算. 可以看到每一次计算的时候, 都以此从input1input2中取得一个数据进行加法操作, 并存放在对应的输出位置.

第二个例子

下图的例子展示了对于三个输入,mul(add(@0,@1),@2)的情况:

 

每次计算的时候依次从input1, input2input3中取出数据, 并作出相应的运算, 并将结果数据存放于对应的output中.

操作数处理的代码实现

ExpressionLayer::Forward函数中, 首先检查输入是否为空, 并初始化outputs数组中的元素.

CHECK(!inputs.empty());const uint32_t batch_size = outputs.size();CHECK(batch_size != 0);for (uint32_t i = 0; i < batch_size; ++i) {CHECK(outputs.at(i) != nullptr && !outputs.at(i)->empty());outputs.at(i)->Fill(0.f);}CHECK(this->op_ != nullptr && this->op_->op_type_ == OpType::kOperatorExpression);std::stack<std::vector<std::shared_ptr<Tensor<float>>>> op_stack;const std::vector<std::shared_ptr<TokenNode>> &token_nodes = this->op_->Generate();

this->op_->Generate(); 获得的是逆波兰表达式.

for (const auto &token_node : token_nodes) {if (token_node->num_index >= 0) {uint32_t start_pos = token_node->num_index * batch_size;std::vector<std::shared_ptr<Tensor<float>>> input_token_nodes;for (uint32_t i = 0; i < batch_size; ++i) {CHECK(i + start_pos < inputs.size());input_token_nodes.push_back(inputs.at(i + start_pos));}op_stack.push(input_token_nodes);}}

依次遍历逆波兰表达式, 如果当前的op遇到的是一个操作数, 例如@0或者@1. 就将他们一个批次的数据(input_token_nodes)全部读取出来, 并临时存放到栈op_stack中.

 

举个例子, 对于input1就将input1中所有的数据读取出来并存放到input_token_nodes中, 再将input_token_nodes这一个批次的数据放入到栈中.

根据输入的逆波兰式@0,@1,add,遇到的第一个节点是操作数是@0, 所以栈op_stack内的内存布局如下:

 

当根据顺序遇到第二个节点(op)的时候, 操作数@1的时候, 再将inputs中的操作数读取出来并存放到input_token_nodes中, 再将input_token_nodes这一个批次的数据放入到栈中.

 

运算符处理的代码实现

const int32_t op = token_node->num_index;CHECK(op_stack.size() >= 2) << "The number of operand is less than two";std::vector<std::shared_ptr<Tensor<float>>> input_node1 = op_stack.top();CHECK(input_node1.size() == batch_size);op_stack.pop();std::vector<std::shared_ptr<Tensor<float>>> input_node2 = op_stack.top();CHECK(input_node2.size() == batch_size);op_stack.pop();

当节点(op)类型为操作符号的时候, 首先弹出栈(op_stack)内的两个批次操作数, 对于如上的情况input_node1分别存放input1...4, input_node2分别存放input5...8.

CHECK(input_node1.size() == input_node2.size());std::vector<std::shared_ptr<Tensor<float>>> output_token_nodes(batch_size);for (uint32_t i = 0; i < batch_size; ++i) {if (op == -int(TokenType::TokenAdd)) {output_token_nodes.at(i) = ftensor::ElementAdd(input_node1.at(i), input_node2.at(i));} else if (op == -int(TokenType::TokenMul)) {output_token_nodes.at(i) = ftensor::ElementMultiply(input_node1.at(i), input_node2.at(i));} else {LOG(FATAL) << "Unknown operator type: " << op;}}op_stack.push(output_token_nodes);

当获取大小长度为batch_sizeinput_node1input_node2后, 流程在for(int i = 0...batch_size)中对两个输入进行两两操作, 操作类型定义于当前的op中. 对于逆波兰式@0,@1,add, 在如上处理完两个输入节点之后,当前的节点类型是add.

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/63426.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Active Directory安全和风险状况管理

风险评估和管理 风险评估和管理是主动安全性和合规性管理不可或缺的一部分。 发现关键基础设施组件中的风险行为和配置对于阻止网络入侵和预防网络攻击至关重要。帐户泄露和配置错误漏洞是用于破坏网络的常见技术。当评估、监控和降低 Active Directory 基础架构的风险时&…

HCIP 链路聚合技术

1、链路聚合概述 为了保证网络的稳定性&#xff0c;仅仅是设备进行备份还不够&#xff0c;我们需要针对我们的链路进行备份&#xff0c;同时也增加了链路的利用率&#xff0c;提高带宽。避免一条链路出现故障&#xff0c;导致网络无法正常通信。这就可以使用链路聚合技术。 以…

图片预览插件vue-photo-preview的使用

移动端项目中需要图片预览的功能&#xff0c;但本身使用mintui&#xff0c;vantui中虽然也有&#xff0c;但是为了一个组件安装这个有点儿多余&#xff0c;就选用了vue-photo-preview插件实现&#xff08;其实偷懒也不想自己写&#xff09;。 1、安装 npm i vue-photo-preview…

计算机竞赛 LSTM的预测算法 - 股票预测 天气预测 房价预测

0 简介 今天学长向大家介绍LSTM基础 基于LSTM的预测算法 - 股票预测 天气预测 房价预测 这是一个较为新颖的竞赛课题方向&#xff0c;学长非常推荐&#xff01; &#x1f9ff; 更多资料, 项目分享&#xff1a; https://gitee.com/dancheng-senior/postgraduate 1 基于 Ke…

怎么做Tik Tok海外娱乐公会呢?新加坡市场怎么样?

一、为什么选择TikTok直播 1. 海外市场潜力巨大 • 自2016年始&#xff0c;多家直播平台陆续拓展至东南亚、中东、俄罗斯、日韩、欧美、拉美等地区。 • 海外市场作为直播发展新蓝海&#xff0c;2021年直播行业整申请cmxyci体规模达百亿美元&#xff0c;并维持高速增长。 &a…

新一代构建工具 maven-mvnd

新一代构建工具 maven-mvnd mvnd的前世今生下载安装 mvndIDEA集成 mvnd的前世今生 maven 作为一代经典的构建工具&#xff0c;流行了很多年&#xff0c;知道现在依然是大部分Java项目的构建工具的首选&#xff1b;但随着项目复杂度提高&#xff0c;代码量及依赖库的增多使得ma…

接口测试自动化:简化测试流程,提升效率

接口测试自动化&#xff1a;简化测试流程&#xff0c;提升效率 什么是接口测试自动化&#xff1f; 接口测试自动化是指使用特定的工具和技术来自动化执行接口测试的过程。通过编写脚本&#xff0c;自动化工具可以模拟用户与软件系统的交互&#xff0c;验证接口的功能和性能。…

树莓派RP2040 用Arduino IDE安装和编译

目录 1 Arduino IDE 1.1 IDE下载 1.2 安装 arduino mbed os rp2040 boards 2 编程-烧录固件 2.1 打开点灯示例程序 2.2 选择Raspberry Pi Pico开发板 2.3 编译程序 2.4 烧录程序 2.4.1 Raspberry Pi Pico开发板首次烧录提示失败 2.4.2 解决首次下载失败问题 2.4.2.1…

AIGC技术揭秘:探索火热背后的原因与案例

文章目录 什么是AIGC技术&#xff1f;为何AIGC技术如此火热&#xff1f;1. 提高效率与创造力的完美结合2. 拓展应用领域&#xff0c;创造商业价值3. 推动技术创新和发展 AIGC技术案例解析1. 艺术创作&#xff1a;生成独特的艺术作品2. 内容创作&#xff1a;实时生成各类内容3. …

备战秋招012(20230808)

文章目录 前言一、今天学习了什么&#xff1f;二、动态规划1.概念2.题目 总结 前言 提示&#xff1a;这里为每天自己的学习内容心情总结&#xff1b; Learn By Doing&#xff0c;Now or Never&#xff0c;Writing is organized thinking. 提示&#xff1a;以下是本篇文章正文…

日常BUG —— Java判空注解

&#x1f61c;作 者&#xff1a;是江迪呀✒️本文关键词&#xff1a;日常BUG、BUG、问题分析☀️每日 一言 &#xff1a;存在错误说明你在进步&#xff01; 一. 问题描述 问题一&#xff1a; 在使用Java自带的注解NotNull、NotEmpty、NotBlank时报错&#xff0c;…

开封Geotrust单域名https证书推荐

Geotrust作为全球领先的数字证书颁发机构之一&#xff0c;拥有多年的数字证书颁发经验&#xff0c;其数字证书被广泛应用于电子商务、在线支付、企业通讯、云计算等领域&#xff0c;为用户提供了安全可靠的保障。而Geotrust旗下的单域名https证书是大多数客户创建网站时的选择之…