详解Megatron中的数据混合算法(BlendableDataset)

🧑‍💻 本文主要讲解Megatron早期版本中的数据混合算法。

目录

  • 1. 数据混合
  • 2. 源码解析
  • 3. 证明部分&讨论
  • 4. 进一步优化

1. 数据混合

在谈源码之前,我们有必要先了解一下Megatron中的数据混合思想。

给定 n n n 个数据集 D 1 , D 2 , ⋯ , D n \mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n D1,D2,,Dn 和对应的 n n n 个权重 w 1 , w 2 , ⋯ , w n w_1,w_2,\cdots,w_n w1,w2,,wn,我们要按照这些权重去混合 n n n 个数据集,设混合后的数据集为 D \mathcal{D} D

Megatron假定:

  • ∣ D ∣ = ∑ i = 1 n ∣ D i ∣ |\mathcal{D}|=\sum_{i=1}^n|\mathcal{D}_i| D=i=1nDi。即混合后的数据集大小等于混合前的各数据集大小之和。
  • D \mathcal{D} D 中有 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi 个样本来自 D i \mathcal{D}_i Di

那如何确定 D \mathcal{D} D 中到底有多少个样本是来自 D i \mathcal{D}_i Di 的呢?一种最直观的做法是,计算 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi,然后进行取整,但这种操作无法保证所有取整后的 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi 相加起来恰好是 ∣ D ∣ |\mathcal{D}| D 如果总和大于 ∣ D ∣ |\mathcal{D}| D,说明某些数据集被过采样了,应当减少相应数据集的采样数;如果总和小于 ∣ D ∣ |\mathcal{D}| D,说明某些数据集被欠采样了,应当增加相应数据集的采样数。可问题是,如何确定这些被过采样/欠采样的数据集呢?显然我们需要一个更加公平的算法。

我们可以把获取数据集 D \mathcal{D} D 看作是一个采样过程:一开始有 n n n 个数据源 { D i } i = 1 n \{\mathcal{D}_i\}_{i=1}^n {Di}i=1n,每一轮迭代,我们需要先从这 n n n 个数据源中选出一个数据源 D i \mathcal{D}_i Di,然后再从这个数据源中选出一个样本 S \mathcal{S} S 由于每一轮迭代只会选出一个样本,因此 ∣ D ∣ |\mathcal{D}| D 轮迭代结束后,我们便得到了 ∣ D ∣ |\mathcal{D}| D 个样本,这些样本构成了混合后的数据集 D \mathcal{D} D

每一轮迭代都会产生两个信息:要选取的数据源 D i \mathcal{D}_i Di,要从 D i \mathcal{D}_i Di 中选取的样本。我们可以考虑构造两个整数序列 P , S \mathcal{P},\mathcal{S} P,S,它们的长度均为 ∣ D ∣ |\mathcal{D}| D,含义如下:

  • P j \mathcal{P}_j Pj 代表的是第 j j j 轮迭代时,选取的数据源的下标。例如 P 10 = 3 \mathcal{P}_{10}=3 P10=3 意味着第 10 10 10 轮迭代选取的数据源是 D 3 \mathcal{D}_3 D3
  • S j \mathcal{S}_j Sj 代表的是第 j j j 轮迭代时,从数据源 D P j \mathcal{D}_{\mathcal{P}_j} DPj 选取的样本的下标。

由以上定义知, ∀ j \forall j j,都有 1 ≤ P j ≤ n 1\leq \mathcal{P}_j\leq n 1Pjn 1 ≤ S j ≤ ∣ D P j ⁣ ∣ 1\leq \mathcal{S}_j\leq|\mathcal{D}_{\mathcal{P}_j}\!| 1SjDPj(下标均从 1 1 1 开始)。

接下来的问题是,如何确定每一轮的 P j \mathcal{P}_j Pj S j \mathcal{S}_j Sj 呢?

先谈 P j \mathcal{P}_j Pj。因为是一个从 1 1 1 ∣ D ∣ |\mathcal{D}| D 的一个逐步采样过程,在第 j j j 轮迭代时,我们已经抽取了 j − 1 j-1 j1 个样本,接下来要确定第 j j j 个样本。根据Megatron的假定,在确定下来第 j j j 个样本后,这 j j j 个样本中应当有约 j ⋅ w i j\cdot w_i jwi 个样本是来自 D i \mathcal{D}_i Di 的。

考虑构造一个长度为 n n n 的序列 C \mathcal{C} C,该序列随着迭代不断更新。 C i \mathcal{C}_i Ci 代表当前已经从 D i \mathcal{D}_i Di 抽取了多少个样本。显然可知,第一轮迭代开始时,有 C i = 0 , i = 1 , 2 , ⋯ , n \mathcal{C}_i=0,\,i=1,2,\cdots,n Ci=0,i=1,2,,n。最后一轮迭代结束后,有 ∑ i = 1 n C i = ∣ D ∣ \sum_{i=1}^n\mathcal{C}_i=|\mathcal{D}| i=1nCi=D,并且

C i = { ∑ t = 1 j − 1 I ( P t = i ) , P j 确定前 ∑ t = 1 j I ( P t = i ) , P j 确定后 , ∀ i \mathcal{C}_i=\begin{cases} \sum_{t=1}^{j-1} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定前} \\ \sum_{t=1}^{j} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定后} \\ \end{cases},\quad \forall i Ci={t=1j1I(Pt=i),t=1jI(Pt=i),Pj确定前Pj确定后,i

回到对 P j \mathcal{P}_j Pj 的讨论中。假设在确定第 j j j 个样本前已经从 D i \mathcal{D}_i Di 中抽取了 C i \mathcal{C}_i Ci 个样本,在确定第 j j j 个样本后,诸 C i \mathcal{C}_i Ci有且仅有一个的值会增加 1 1 1,不妨记为 C k \mathcal{C}_k Ck,这个过程可以形容为

[ C 1 , ⋯ , C k , ⋯ , C n ] ⏟ 第 j 轮迭代开始时 → [ C 1 , ⋯ , C k + 1 , ⋯ , C n ] ⏟ 第 j 轮迭代结束时 [ j ⋅ w 1 , j ⋅ w 2 , ⋯ , j ⋅ w n ] ⏟ 理论值 \underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_k,\cdots,\mathcal{C}_n]}_{第j轮迭代开始时}\to\underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n]}_{第j轮迭代结束时}\qquad \underbrace{[j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n]}_{理论值} j轮迭代开始时 [C1,,Ck,,Cn]j轮迭代结束时 [C1,,Ck+1,,Cn]理论值 [jw1,jw2,,jwn]

我们期望第 j j j 轮迭代结束时,诸 C i \mathcal{C}_i Ci 应当尽可能地接近理论值(在MSE下)。由于只能让其中一个 C k \mathcal{C}_k Ck 自增 1 1 1,显然有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)

再谈 S j \mathcal{S}_j Sj。在确定了数据源是 D k \mathcal{D}_k Dk 后,为了避免重复,我们应当做到不放回、随机地从中采样。如何做到这两点呢?我们可以在一开始就对 n n n 个数据源进行打乱,然后在采样的时候只需要从前往后进行,就可以做到以上两点。注意到 C i \mathcal{C}_i Ci 的值是从 0 0 0 开始,以步长为 1 1 1 依次递增,所以我们可以用每次更新完的 C i \mathcal{C}_i Ci 赋值给相应的 S j \mathcal{S}_j Sj,即 S j = 第 j 轮迭代结束时的 C i \mathcal{S}_j=第j轮迭代结束时的\mathcal{C}_i Sj=j轮迭代结束时的Ci

由此我们可以得到整个算法的伪代码:

2. 源码解析

Python部分:

class BlendableDataset(torch.utils.data.Dataset):def __init__(self, datasets, weights):self.datasets = datasetsnum_datasets = len(datasets)assert num_datasets == len(weights), "The number of datasets and weights must match."self.size = sum(len(dataset) for dataset in self.datasets)# Normalize weights.weights = np.array(weights, dtype=np.float64)sum_weights = np.sum(weights)assert sum_weights > 0.0, "Sum of weights must be positive."weights /= sum_weights# Build indices.start_time = time.time()assert num_datasets < 255, "Number of datasets must be less than 255."self.dataset_index = np.zeros(self.size, dtype=np.uint8)self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)helpers.build_blending_indices(self.dataset_index,self.dataset_sample_index,weights,num_datasets,self.size,torch.distributed.get_rank() == 0,)print_rank_0(f'> elapsed time for building blendable dataset indices: 'f'{time.time() - start_time:.2f} sec')def __len__(self):return self.sizedef __getitem__(self, idx):dataset_idx = self.dataset_index[idx]sample_idx = self.dataset_sample_index[idx]return {"dataset_idx": dataset_idx,**self.datasets[dataset_idx][sample_idx],}

C++部分:

void build_blending_indices(py::array_t<uint8_t> &dataset_index,py::array_t<int64_t> &dataset_sample_index,const py::array_t<double> &weights,const int32_t num_datasets,const int64_t size,const bool verbose
) {/* Given multiple datasets and a weighting array, build samplessuch that it follows those weights. */if (verbose) {std::cout << "> building indices for blendable datasets ..." << std::endl;}// Get the pointer access without the checks.auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();auto weights_ptr = weights.unchecked<1>();// Initialize buffer for number of samples used for each dataset.int64_t current_samples[num_datasets];for (int64_t i = 0; i < num_datasets; ++i) {current_samples[i] = 0;}// For each sample:for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {// Determine where the max error in sampling is happening.auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);int64_t max_error_index = 0;double max_error = weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]);for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {double error = weights_ptr[dataset_idx] * sample_idx_double - static_cast<double>(current_samples[dataset_idx]);if (error > max_error) {max_error = error;max_error_index = dataset_idx;}}// Populate the indices.dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];// Update the total samples.current_samples[max_error_index] += 1;}// Print infoif (verbose) {std::cout << " > sample ratios:" << std::endl;for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {auto ratio = static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size);std::cout << "   dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;}}
}

具体的算法实现是在C++的函数中,我们先来看Python部分。

self.size 实际上就是 ∣ D ∣ |\mathcal{D}| D,即混合后的数据集大小(从后面的 __len__ 也能看出)。在构造函数中,首先会对 weights 进行归一化,然后声明 P , S \mathcal{P},\mathcal{S} P,S 两个数组。注意 self.dataset_index 实际上就是 P \mathcal{P} Pself.dataset_sample_index 实际上就是 S \mathcal{S} S。由于 P \mathcal{P} P 的数据类型是 uint8,这表明其中元素的范围是 [ 0 , 2 8 − 1 = 255 ] [0,2^8-1=255] [0,281=255],故 P \mathcal{P} P 最多能表示 256 256 256 个数据集,而源码中规定了参与混合的数据集个数必须严格少于 255 255 255博主不是很懂这一点,看懂的小伙伴可以在评论区留言)。

再来看C++部分。前五个形参分别是 P , S , { w i } i , n , ∣ D ∣ \mathcal{P},\mathcal{S},\{w_i\}_i,n,|\mathcal{D}| P,S,{wi}i,n,D

C \mathcal{C} C 数组会在该函数中进行声明并初始化。随后的两个嵌套 for 循环则是整个算法的核心流程,注意到这里的实现中,sample_idx(即 j j j)是从 0 0 0 开始的,而算法伪代码中的 j j j 是从 1 1 1 开始的,所以一开始要执行 j = max ⁡ ( j , 1 ) j=\max(j,1) j=max(j,1) 以确保 j j j 至少是 1 1 1但这样做有一个弊端就是前两轮的循环里, j j j 的值是相同的,和我们期望的每一轮里 j j j 值不同相违背,这是源码中的一个缺陷,实际上应该计算 ( j + 1 ) ⋅ w i − C i (j+1)\cdot w_i-\mathcal{C}_i (j+1)wiCi)。内层循环中的 error 实际上就是 j ⋅ w i − C i j\cdot w_i-\mathcal{C}_i jwiCi。此外,由于 j j j 是从 0 0 0 开始的,所以 C P j \mathcal{C}_{\mathcal{P}_j} CPj 的更新要放到最后执行。

一言以蔽之, j j j 1 1 1 开始,更新顺序为 P → C → S \mathcal{P}\to\mathcal{C}\to\mathcal{S} PCS j j j 0 0 0 开始,更新顺序为 P → S → C \mathcal{P}\to\mathcal{S}\to\mathcal{C} PSC

得到了 P , S \mathcal{P},\mathcal{S} P,S 数组后,我们便可得到混合后的数据集 D \mathcal{D} D

D j = D P j [ S j ] , j = 1 , 2 , ⋯ , ∣ D ∣ \mathcal{D}_j=\mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j],\quad j=1,2,\cdots,|\mathcal{D}| Dj=DPj[Sj],j=1,2,,D

其中 D i [ j ] \mathcal{D}_i[j] Di[j] 代表数据集 D i \mathcal{D}_i Di 中的第 j j j 个样本。

回到Python部分,__getitem__ 中传入的 idx 实际上就是 j j jself.datasets[dataset_idx][sample_idx] 实际上就是上述的 D P j [ S j ] \mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j] DPj[Sj]

3. 证明部分&讨论


Prop 1. \text{Prop} \;1.\, Prop1. 每一轮循环开始时所有误差加和为 1 1 1,即 ∑ i = 1 n e i = 1 \sum_{i=1}^n e_i=1 i=1nei=1,其中 e i ≜ j ⋅ w i − C i e_i\triangleq j\cdot w_i-\mathcal{C}_i eijwiCi

P r o o f . Proof.\; Proof. 注意到第 j j j 轮循环开始时,此时一共只采样了 j − 1 j-1 j1 个样本,所以 ∑ i = 1 n C i = j − 1 \sum_{i=1}^n\mathcal{C}_i=j-1 i=1nCi=j1,从而

∑ i = 1 n e i = ∑ i = 1 n ( j ⋅ w i − C i ) = j ⋅ ∑ i = 1 n w i − ∑ i = 1 n C i = j − ∑ i = 1 n C i = j − ( j − 1 ) = 1 \sum_{i=1}^n e_i=\sum_{i=1}^n (j\cdot w_i-\mathcal{C}_i)=j\cdot\sum_{i=1}^n w_i-\sum_{i=1}^n\mathcal{C}_i=j-\sum_{i=1}^n\mathcal{C}_i=j-(j-1)=1 i=1nei=i=1n(jwiCi)=ji=1nwii=1nCi=ji=1nCi=j(j1)=1

进一步可知,每一轮循环结束时所有误差加和为 0 0 0


Prop 2. \text{Prop} \;2.\, Prop2. 假定下标从 1 1 1 开始,且 n = 2 n=2 n=2(即只有两个数据源)。若 e 1 ≥ 0.5 e_1\geq 0.5 e10.5,则 P j = 1 \mathcal{P}_j=1 Pj=1,否则 P j = 2 \mathcal{P}_j=2 Pj=2

P r o o f . Proof.\; Proof. e 1 > 0.5 e_1>0.5 e1>0.5 的情况显然。当 e 1 = e 2 = 0.5 e_1=e_2=0.5 e1=e2=0.5 时, arg max ⁡ \argmax argmax 会优先挑选下标最小的,故此时 P j \mathcal{P}_j Pj 仍是 1 1 1


Prop 3. \text{Prop} \;3.\, Prop3. 假定下标从 1 1 1 开始。可能存在一组 { D i } i \{\mathcal{D}_i\}_i {Di}i { w i } i \{w_i\}_i {wi}i,使得经由上述算法得到的 P , S \mathcal{P},\mathcal{S} P,S 数组, ∃ j , s.t. S j > ∣ D P j ∣ \exists \,j,\,\text{s.t.}\;\,\mathcal{S}_j>|\mathcal{D}_{\mathcal{P}_j}| j,s.t.Sj>DPj,意味着 __getitem__ 会出现下标越界的错误。

P r o o f . Proof.\; Proof. 构造特殊情形即可。令 n = 2 n=2 n=2 ∣ D 1 ∣ = ∣ D 2 ∣ = 2 |\mathcal{D}_1|=|\mathcal{D}_2|=2 D1=D2=2 w 1 = 0.1 , w 2 = 0.9 w_1=0.1,\,w_2=0.9 w1=0.1,w2=0.9

∣ D 1 ∣ + ∣ D 2 ∣ = 4 |\mathcal{D}_1|+|\mathcal{D}_2|=4 D1+D2=4 可知,总共会有 4 4 4 轮循环。且理应有 1 ≤ P j , S j ≤ 2 , j = 1 , 2 , 3 , 4 1\leq \mathcal{P}_j,\mathcal{S}_j\leq 2,\,j=1,2,3,4 1Pj,Sj2,j=1,2,3,4

利用 Prop 2 \text{Prop} \;2 Prop2 快速计算:

  • 第一轮循环,计算误差 e 1 = 1 ⋅ w 1 − 0 = 0.1 < 0.5 e_1=1\cdot w_1-0=0.1<0.5 e1=1w10=0.1<0.5,故 P 1 = 2 \mathcal{P}_1=2 P1=2 C = { 0 , 1 } \mathcal{C}=\{0,1\} C={0,1} S 1 = C 2 = 1 \mathcal{S}_1=\mathcal{C}_2=1 S1=C2=1

  • 第二轮循环,计算误差 e 1 = 2 ⋅ w 1 − 0 = 0.2 < 0.5 e_1=2\cdot w_1-0=0.2<0.5 e1=2w10=0.2<0.5,故 P 2 = 2 \mathcal{P}_2=2 P2=2 C = { 0 , 2 } \mathcal{C}=\{0,2\} C={0,2} S 2 = C 2 = 2 \mathcal{S}_2=\mathcal{C}_2=2 S2=C2=2

  • 第三轮循环,计算误差 e 1 = 3 ⋅ w 1 − 0 = 0.3 < 0.5 e_1=3\cdot w_1-0=0.3<0.5 e1=3w10=0.3<0.5,故 P 3 = 2 \mathcal{P}_3=2 P3=2 C = { 0 , 3 } \mathcal{C}=\{0,3\} C={0,3} S 3 = C 2 = 3 \mathcal{S}_3=\mathcal{C}_2=3 S3=C2=3

  • 第四轮循环,计算误差 e 1 = 4 ⋅ w 1 − 0 = 0.4 < 0.5 e_1=4\cdot w_1-0=0.4<0.5 e1=4w10=0.4<0.5,故 P 4 = 2 \mathcal{P}_4=2 P4=2 C = { 0 , 4 } \mathcal{C}=\{0,4\} C={0,4} S 4 = C 2 = 4 \mathcal{S}_4=\mathcal{C}_2=4 S4=C2=4

由上可知 j = 3 , 4 j=3,4 j=3,4 满足要求。


Prop 4. \text{Prop} \;4.\, Prop4. 在MSE下,要使 [ C 1 , ⋯ , C k + 1 , ⋯ , C n ] [\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n] [C1,,Ck+1,,Cn] 尽可能接近 [ j ⋅ w 1 , j ⋅ w 2 , ⋯ , j ⋅ w n ] [j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n] [jw1,jw2,,jwn],应当有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)

P r o o f . Proof.\; Proof. 注意到

Δ MSE = MSE a f t e r − MSE b e f o r e = 1 n [ ( j ⋅ w k − C k − 1 ) 2 − ( j ⋅ w k − C k ) 2 ] = 1 n [ 1 − 2 ( j ⋅ w k − C k ) ] \begin{aligned} \Delta \text{MSE}=\text{MSE}_{after}-\text{MSE}_{before}&=\frac1n[(j\cdot w_k-\mathcal{C}_k-1)^2-(j\cdot w_k-\mathcal{C}_k)^2] \\ &=\frac1n[1-2(j\cdot w_k-\mathcal{C}_k)] \end{aligned} ΔMSE=MSEafterMSEbefore=n1[(jwkCk1)2(jwkCk)2]=n1[12(jwkCk)]

由上式可知,要使 Δ MSE \Delta \text{MSE} ΔMSE 越小,应使 j ⋅ w k − C k j\cdot w_k-\mathcal{C}_k jwkCk 越大,故 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)


Prop 5. \text{Prop} \;5.\, Prop5. 假定下标从 1 1 1 开始。若 w 1 = w 2 = ⋯ = w n = 1 / n w_1=w_2=\cdots=w_n=1/n w1=w2==wn=1/n,令 ∣ D ∣ = q ⋅ n + r |\mathcal{D}|=q\cdot n+r D=qn+r,其中 q q q 是商, r r r 是余数,则有

P = [ 1 , 2 , ⋯ , n ] ∗ q + [ 1 , 2 , ⋯ , r ] S = [ 1 , 1 , ⋯ , 1 ] + [ 2 , 2 , ⋯ , 2 ] + ⋯ + [ q , q , ⋯ , q ] ⏟ 每个列表的长度均为 n + [ q + 1 , q + 1 , ⋯ , q + 1 ] ⏟ 长度为 r C = [ q + 1 , q + 1 , ⋯ , q + 1 ⏟ r 个 , q , q , ⋯ , q ⏟ n − r 个 ] \begin{aligned} \mathcal{P}&=[1,2,\cdots,n] * q + [1,2,\cdots,r] \\ \mathcal{S}&=\underbrace{[1,1,\cdots,1] + [2,2,\cdots,2] + \cdots+[q,q,\cdots,q]}_{每个列表的长度均为 n}+\underbrace{[q+1,q+1,\cdots,q+1]}_{长度为r} \\ \mathcal{C}&=[\underbrace{q+1,q+1,\cdots,q+1}_{r个},\underbrace{q,q,\cdots,q}_{n-r个}] \end{aligned} PSC=[1,2,,n]q+[1,2,,r]=每个列表的长度均为n [1,1,,1]+[2,2,,2]++[q,q,,q]+长度为r [q+1,q+1,,q+1]=[r q+1,q+1,,q+1,nr q,q,,q]

上述的 ∗ * + + + 均是列表运算符

P r o o f . Proof.\; Proof. 证明留给读者。


讨论:

Prop 3. \text{Prop} \;3.\, Prop3. 中提到了可能会出现下标越界的错误,为了避免这个错误,我们可以在得到 P , S \mathcal{P},\mathcal{S} P,S 数组后,对 S \mathcal{S} S 进行更新(假定下标从 1 1 1 开始):

S j = S j mod ( ∣ D P j ∣ + 1 ) , j = 1 , 2 , ⋯ , ∣ D ∣ \mathcal{S}_j=\mathcal{S}_j\;\text{mod}\; (|\mathcal{D}_{\mathcal{P}_j}|+1),\quad j=1,2,\cdots,|\mathcal{D}| Sj=Sjmod(DPj+1),j=1,2,,D

例如某个数据集是 [ 1 , 2 , 3 , 4 , 5 ] [1,2,3,4,5] [1,2,3,4,5],如果要从这个数据集采样 8 8 8 个样本,则原先的算法会在采样第 6 6 6 个样本时抛出下标越界错误,修正后的算法的采样结果为 [ 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 ] [1,2,3,4,5,1,2,3] [1,2,3,4,5,1,2,3]

为什么Megatron源码里没有规避这个错误但在使用的过程中却好像并没有遇到bug呢?注意到 self.datasets[dataset_idx] 实际上指向的是 megatron/data/gpt_dataset.py 中的 GPTDataset 类,在混合数据集场景下,Megatron会预先根据权重计算每个数据集所需要的样本数,然后根据这个样本数构建 GPTDataset,而非根据document数去构建。所以,即使对于两个完全相同的数据集,当赋予它们的权重不同时,所得到的 GPTDataset 的长度也不同,这一点可以通过向 BlendableDataset 源码中加入以下代码来验证:

for i, dataset in enumerate(self.datasets):print(f"dataset {i}: {len(dataset)}")

因为 GPTDataset 的长度已经根据权重做出了相应的调整,所以绝大部分时候是不会出现bug的,但我们依然可以构造极端情形来触发bug。

考虑在训练脚本中提供两个完全相同的路径,但却赋予它们不同的权重,如下:

--train-data-path 0.001 /path/to/your/data_text_document 0.999 /path/to/your/data_text_document

然后在 BlendableDataset 源码中的 __getitem__ 方法中固定索引,即:

def __getitem__(self, idx):idx = self.size - 1  # 意味着我们总是取BlendableDataset的最后一个样本dataset_idx = self.dataset_index[idx]sample_idx = self.dataset_sample_index[idx]return {"dataset_idx": dataset_idx,**self.datasets[dataset_idx][sample_idx],}

这样就可以稳定的触发下标越界的bug。

📝 注意到从 GPTDataset 中取出来的是sample,所以Megatron的混合算法实际上是以sample为单位的,而非以document为单位。

4. 进一步优化

根据 Prop 3. \text{Prop} \;3.\, Prop3. Prop 5. \text{Prop} \;5.\, Prop5. 以及其他细节,我们有以下几个优化方向:

  • 修复可能会出现的下标越界错误(可通过取余来实现)。
  • 在等权重情形下加速混合(利用 numpy)。
  • 支持更多数据集进行混合(修改 uint8 为其他类型)。

假设相应的接口名为 make_blendable_dataset,它接收两个形参:datasetsweights。前者是一个二维列表,包含了要进行混合的数据集(每个数据集是一个一维列表),后者是一个一维列表,包含了每个数据集的权重。

使用Python进行实现:

from typing import List, Any, Union
import numpy as np
import random
from tqdm import tqdmdef make_blendable_dataset(datasets: List[List[Any]], weights: List[Union[float, int]]) -> List[Any]:num_datasets = len(datasets)assert num_datasets == len(weights), "The number of datasets must match the number of weights."# Shufflesize = 0for dataset in datasets:size += len(dataset)random.shuffle(dataset)# Normalize weightsweights = np.array(weights, dtype=np.float64)assert np.all(weights > 0), "All weights must be positive."weights /= weights.sum()# Determine if all weights are equalif np.ptp(weights) < 1e-5:q, r = divmod(size, num_datasets)dataset_index = np.concatenate([np.tile(np.arange(num_datasets, dtype=np.int16), q),np.arange(r, dtype=np.int16)])dataset_sample_index = np.concatenate([np.repeat(np.arange(q, dtype=np.int64), num_datasets),np.full(r, q, dtype=np.int64)])current_samples = np.full(num_datasets, q, dtype=np.int64)current_samples[:r] += 1else:dataset_index = np.zeros(size, dtype=np.int16)dataset_sample_index = np.zeros(size, dtype=np.int64)current_samples = np.zeros(num_datasets, dtype=np.int64)for sample_idx in tqdm(range(size), desc="Calculating error"):errors = weights * (sample_idx + 1) - current_samplesmax_error_index = np.argmax(errors)dataset_index[sample_idx] = max_error_indexdataset_sample_index[sample_idx] = current_samples[max_error_index]current_samples[max_error_index] += 1print(f"Ratios:")for i in range(num_datasets):print(f"input: {weights[i]}, achieved: {current_samples[i] / size}")# Blendres = []for i in tqdm(range(size), desc="Blending"):dataset_idx = dataset_index[i]sample_idx = dataset_sample_index[i] % len(datasets[dataset_idx])res.append(datasets[dataset_idx][sample_idx])return res

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/491099.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【统计分析数学模型】聚类分析: 系统聚类法

【统计分析数学模型】聚类分析&#xff1a; 系统聚类法 一、聚类分析1. 基本原理2. 距离的度量&#xff08;1&#xff09;变量的测量尺度&#xff08;2&#xff09;距离&#xff08;3&#xff09;R语言计算距离 三、聚类方法1. 系统聚类法2. K均值法 三、示例1. Q型聚类&#x…

微服务知识02

1、九大高并发解决方案 2、系统架构图​​​​​​​ 3、分布式事务 本地事务、分布式事务 操作不同服务器的数据库&#xff08;垂直分库&#xff09; 4、分布式事务解决方案&#xff08;没有seata之前&#xff09; &#xff08;1&#xff09;XA协议&#xff08;强一致性&a…

12. Springboot集成Dubbo3(三)Dubbo-Admin

目录 1、前言 2、安装 2.1、下载Dubbo-admin 2.2、修改配置 2.3、编译前端 2.4、访问 2.5、加载自己的服务 2.6、服务测试 2.7、其他 3、小结 1、前言 Dubbo Admin是用于管理Dubbo服务的基于Web的管理工具。Dubbo Admin提供了一个用户友好的界面&#xff0c;用于在分…

在IDEA中创建vue hello-world项目

工作中最近在接触vue前端项目,记录一下从0搭建一个vue hello world项目的步骤 1、本地电脑安装配置node、npm D:\Project\vue\hello-world>node -v v14.21.3 D:\Project\vue\hello-world>npm -v 6.14.18 D:\Project\vue\hello-world> 2、设置npm国内淘宝的景象 …

MySQL-基本使用,数据类型,简单操作

1. 数据库概述 1.1 数据库(DatBase) 数据库&#xff0c;就是遵循一定数据格式的数据集合&#xff0c;可以认为他是对文件系统的改进。它解决了不同操作系统之间&#xff0c;数据格式的兼容性问题。也就是说&#xff0c;只要是同一个数据库的数据文件&#xff0c;即使从windows迁…

基于springboot+vue的租房管理系统(前后端分离)

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…

第三方认证中心跳转

一、业务需求 由第三方认证中心将 token 放在 header (请求头) 中跳转系统&#xff0c;前端获取到第三方系统携带 header 中的 token。 二、 业务流程 模拟第三方应用 CUSTOM-USERTOKEN 是第三方的 tokenproxy_pass 是我们的前端地址 server {listen 12345;server_name loca…

前端导出下载文件后提示无法打开文件

问题 项目中的导出文件功能&#xff0c;导出下载后的文件打开提示如下&#xff1a; 原因 对返回的响应数据进行打印&#xff0c;发现响应数据为字符串格式&#xff0c;前期规划的后端返回数据应该 blob 对象的。后经排查后发现是请求头缺少了响应数据格式的配置&#xff0c;应…

python 循环语句 while 循环

while循环 Python 编程中 while 语句用于循环执行程序&#xff0c;即在某条件下&#xff0c;循环执行某段程序&#xff0c;以处理需要重复处理的相同任务。其基本形式为&#xff1a; while 判断条件(condition)&#xff1a; 执行语句(statements)…… 执行语句可以是单个语句…

【GameFramework框架内置模块】3、数据表(Data Table)

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址 大家好&#xff0c;我是佛系工程师☆恬静的小魔龙☆&#xff0c;不定时更新Unity开发技巧&#xff0c;觉得有用记得一键三连哦。 一、前言 【GameFramework框架】系列教程目录&#xff1a; https://blog.csdn.net/q7…

Python爬虫实战:图片爬取与保存

引言&#xff1a; 在本文中&#xff0c;我们将学习如何使用Python创建一个简单的图片爬虫。 我们将利用requests库来发送HTTP请求&#xff0c;BeautifulSoup库来解析HTML页面&#xff0c;以及os和shutil库来下载和保存图片。通过这个教程&#xff0c;你将学会如何爬取网…

16.沙箱里的秩序——外观模式

然后&#xff0c;冯诺伊曼又用三名士兵构建了与非门、或非门、异或门、同或门和三态门&#xff0c;最后只用两名士兵构建了最简单的非门&#xff0c;出总是举与入颜色相反的旗。 冯:诺伊曼对皇帝鞠躬说:“现在&#xff0c;陛下&#xff0c;所有的门部件都已演示完毕&#xff0c…