原题链接:https://www.luogu.com.cn/problem/P3373
题意解读:对于序列a[n],支持三种操作:1.对区间每个数乘上一个数 2.对区间每个数加上一个数 3.求区间和
解题思路:由于支持乘、加两种区间修改操作,是线段树的另一种典型应用:多个懒标记
显然,这里需要两个懒标记,mul表示对子节点区间每个数乘mul,add表示对子节点区间每个数加上add,节点定义如下:
struct Node
{int l, r;LL sum; //区间和LL mul; //懒标记,子节点区间每个数乘上mul,默认值为1LL add; //懒标记,子节点区间每个数加上add,默认值为0
} tr[N * 4];
下面就要考虑sum、mul、add如何修改的问题
对于一个节点u,
如果要对其区间每个数乘mul,则有tr[u].sum = tr[u].sum * mul
如果要对其区间每个数加add,则有tr[u].sum = tr[u].sum + (tr[u].r - tr[u].l + 1) * add
再区间更新时,可以把乘和加统一成一个操作:tr[u].sum = tr[u].sum * mul + (tr[u].r - tr[u].l + 1) * add(加操作时mul设置为1,乘操作时add设置为0)
上面解决了sum修改的问题,接下来,就要看mul、add如何修改,关键在于要考虑mul、add的优先级?
1、先加后乘
假设先执行加法,后执行乘法,那么对于懒标记mul,add,意味着对其区间每一个数x都执行(x + add) * mul,
如果再来一个加add'操作,区间每一个数变成(x + add) * mul + add',不难分析,无法通过将add、mul进行更新得到形如(x + add) * mul的形式,
所以先加后乘不可行。
2、先乘后加
假设先执行乘法,后执行加法,那么对于懒标记mul,add,意味着对其区间每一个数x都执行x * mul + add,
如果再来一个加add'操作,区间每一个数变成x * mul + add + add',显然通过将add += add',即可以通过x * mul + add得到正确的结果;
如果再来一个乘mul'操作,区间每一个数变成(x * mul + add) * mul' = x * mul * mul' + add * mul',显然通过将mul *= mul', add * mul',即可以通过x * mul + add得到正确的结果。
确定了操作优先级,也就确定了懒标记的更新方式,可以将乘和加统一处理:
对于一个节点u,对其区间每个数乘mul,加add,如果只加则mul=1,如果只乘则add=0,懒标记更新方式为:
tr[u].mul = tr[u].mul * mul
void addtag(int u, LL mul, LL add)
{tr[u].sum = (tr[u].sum * mul + (tr[u].r - tr[u].l + 1) * add) % m;tr[u].mul = tr[u].mul * mul % m;tr[u].add = (tr[u].add * mul + add) % m;
}
最后要注意的还是开long long。
100分代码:
#include <bits/stdc++.h>
using namespace std;typedef long long LL;const int N = 100005;struct Node
{int l, r;LL sum; //区间和LL mul; //懒标记,子节点区间每个数乘上mul,默认值为1LL add; //懒标记,子节点区间每个数加上add,默认值为0
} tr[N * 4];
LL a[N];
int n, q, m;void pushup(int u)
{tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % m;
}void build(int u, int l, int r)
{tr[u] = {l, r, 0, 1, 0};if(l == r) tr[u].sum = a[l];else{int mid = l + r >> 1;build(u << 1, l, mid);build(u << 1 | 1, mid + 1, r);pushup(u);}
}void addtag(int u, LL mul, LL add)
{tr[u].sum = (tr[u].sum * mul + (tr[u].r - tr[u].l + 1) * add) % m;tr[u].mul = tr[u].mul * mul % m;tr[u].add = (tr[u].add * mul + add) % m;
}void pushdown(int u)
{addtag(u << 1, tr[u].mul, tr[u].add);addtag(u << 1 | 1, tr[u].mul, tr[u].add);tr[u].mul = 1;tr[u].add = 0;
}LL query(int u, int l, int r)
{if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;else if(tr[u].l > r || tr[u].r < l) return 0;else{pushdown(u);return (query(u << 1, l, r) + query(u << 1 | 1, l, r)) % m;}
}void update(int u, int l, int r, LL mul, LL add)
{if(tr[u].l >= l && tr[u].r <= r) addtag(u, mul, add);else if(tr[u].l > r || tr[u].r < l) return;else {pushdown(u);update(u << 1, l, r, mul, add);update(u << 1 | 1, l, r, mul, add);pushup(u);}
}int main()
{cin >> n >> q >> m;for(int i = 1; i <= n; i++) cin >> a[i];build(1, 1, n);int op, x, y, k;while(q--){cin >> op >> x >> y;if(op == 1) {cin >> k;update(1, x, y, k, 0); //乘k加0}else if(op == 2){cin >> k;update(1, x, y, 1, k); //乘1加k}else cout << query(1, x, y) << endl;}
}