奇怪的字符串:需要一点观察的 SAM 小清新题。
观察
我们首先观察什么样的字符串才是奇怪的,可以发现,首先类似 AAAAAAA
之类全部相等的字符串是奇怪的。
继续观察,如果字符种类变为两种或者三种能不能是奇怪的。显然,类似 AAABBBBCCDDDEEEEEE
之类有三种及以上的且每个字母都只在一个连续段出现的串都不是奇怪的,因为我们一定可以选不相邻的两个连续段出来,是他不是原串的子串。因此,类似 AAAAABBBBBB
的就可以了。
那么一个字母有多个连续段的是不是奇怪的呢?考虑像 AAABBBAA
样子的串,我们显然可以把这些连续段的字母全部选出,例如选出 AAAAA
,这样一定不是原来的子串。
因此,一个字符串奇怪,当且仅当它满足类似 AAAAAAAA
或 AAAAAABBBBBBB
的形态。
实现
暴力枚举做法
枚举第二种形态的两个字符,线性扫一遍统计即可。
时间 \(O(n|\sum|^2)\),能过。
SAM 做法
考虑建出 SAM,枚举字符 \(c\),先求出从根节点到每个节点是否有只存在字符 \(c\) 的路径,这个可以通过正向拓扑一遍实现。然后再反向拓扑一遍,记录下每个节点后面最多能接多少个 \(c\),答案统计的时候先统计第一种形态的答案,再统计第二种,把后缀最多能接的字符数加上即可。
具体看代码吧,时间复杂度 \(O(n|\sum|)\),但我实现得很烂,还没暴力枚举跑得快。
代码
#include <bits/stdc++.h>
#define fi first
#define se second
#define lc (p<<1)
#define rc ((p<<1)|1)
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi=pair<int,int>;
char s[200005];
ll ans;
int n,np=1,tot=1,ch[400005][26],fa[400005],len[400005],rd[400005];
ll pre[400005][26],suf[400005][26],sm[400005];
vector<pi>g1[400005],g2[400005];
void extend(int c)
{int p=np;np=++tot;len[np]=len[p]+1;for(;p&&ch[p][c]==0;p=fa[p])ch[p][c]=np;if(p==0)fa[np]=1;else{int q=ch[p][c];if(len[q]==len[p]+1)fa[np]=q;else{int nq=++tot;len[nq]=len[p]+1;fa[nq]=fa[q],fa[q]=nq,fa[np]=nq;for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;memcpy(ch[nq],ch[q],sizeof(ch[q]));}}
}
void topo1(vector<pi>*g,ll dp[400005][26])
{memset(rd,0,sizeof(rd));for(int i=1;i<=tot;i++){for(auto ed:g[i]){int v=ed.fi;rd[v]++;}}queue<int>q;for(int i=1;i<=tot;i++){if(rd[i]==0)q.push(i);}while(!q.empty()){int u=q.front();q.pop();for(auto ed:g[u]){int v=ed.fi,c=ed.se;rd[v]--;if(rd[v]==0)q.push(v);if(dp[u][c]>=len[u])dp[v][c]+=dp[u][c]+1;}}
}
void topo2(vector<pi>*g,ll dp[400005][26])
{memset(rd,0,sizeof(rd));for(int i=1;i<=tot;i++){for(auto ed:g[i]){int v=ed.fi;rd[v]++;}}queue<int>q;for(int i=1;i<=tot;i++){if(rd[i]==0)q.push(i);}while(!q.empty()){int u=q.front();q.pop();for(auto ed:g[u]){int v=ed.fi,c=ed.se;rd[v]--;if(rd[v]==0)q.push(v);dp[v][c]+=dp[u][c]+1;}}
}
void cal()
{for(int i=1;i<=tot;i++){for(int j=0;j<26;j++){pre[i][j]=pre[i][j];ans+=(pre[i][j]>0);sm[i]+=pre[i][j];}}for(int i=1;i<=tot;i++){for(int j=0;j<26;j++){ans+=suf[i][j]*((sm[i]-pre[i][j])>0);}}
}
int main()
{ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);cin>>s+1;n=strlen(s+1);for(int i=1;i<=n;i++)extend(s[i]-'a');for(int i=1;i<=tot;i++){for(int j=0;j<26;j++){int v=ch[i][j];if(v){g1[i].push_back({v,j});g2[v].push_back({i,j});}}}topo1(g1,pre);topo2(g2,suf);cal();cout<<ans;return 0;
}