P3295 萌萌哒
题面
题目描述
一个长度为 的大数,用 表示,其中 表示数的第 位, 是数的最高位。告诉你一些限制条件,每个条件表示为四个数,,即两个长度相同的区间,表示子串 与 完全相同。
比如 时,某限制条件 ,那么 , 均满足条件,但是 , 不满足条件,前者数的长度不为 ,后者第二位与第五位不同。问满足以上所有条件的数有多少个。
输入格式
第一行两个数 和 ,分别表示大数的长度,以及限制条件的个数。
接下来 行,对于第 行,有 个数 ,分别表示该限制条件对应的两个区间。
,, ;并且保证 $ r_{i_1}-l_{i_1}=r_{i_2}-l_{i_2}$ 。
输出格式
一个数,表示满足所有条件且长度为 的大数的个数,答案可能很大,因此输出答案模 $ 10^9+7 $ 的结果即可。
样例 #1
样例输入 #1
4 2
1 2 3 4
3 3 3 3
样例输出 #1
90
思路
对于两段连续的区间相等,可以转化为对应的点相等,然后相等的点可以划分为同一类,这个可以用并查集维护。
然后并查集合并完以后,数一下总共有几类数(一开始每个数自成一类),每一类数都有十种可能性(除了第一个数不能是 0),所以答案就是 。
但是会超时,只能得到 30 分的好成绩:
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e5 + 7;
const int MOD = 1e9 + 7;
int p[N];
int find(int x)
{
if(x != p[x]) p[x] = find(p[x]);
return p[x];
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
p[i] = i;
}
for (int i = 1; i <= m; i++)
{
int l1, r1, l2, r2;
scanf("%d%d%d%d", &l1, &r1, &l2, &r2);
for (int j = l1; j <= r1; j++)
{
int a = find(j), b = find(j + l2 - l1);
if (a != b)
{
p[a] = b;
}
}
}
int cnt = 0;
for (int i = 1; i <= n; i++)
if (p[i] == i) cnt++;
int res = 1;
for (int i = 1; i <= cnt - 1; i++)
res = 1ll * res * 10 % MOD;
res = 1ll * res * 9 % MOD;
cout << res << endl;
return 0;
}
怎么优化呢?
一开始我们是点对点地合并并查集,为了提高效率我们可以区间对应区间地合并并查集,为了控制区间长度的一致性,都以 长度进行划分,利用类似 ST 表的思路即可。
然后我们最后还是要查询每个点的父亲是不是它自己,但是我们的信息维护仅停留在区间这个维度,所以我们要下放信息。
即 。
我们记 代表以 为左端点,长度为 的区间的父亲的左端点(初始化为 )。
在这里 可以理解为一个属性,因为只有区间长度相同的才可以合并(两段相等的区间),在 函数和 函数里也别忘了传进这个属性。
然后枚举每个区间下放信息,统计一下有多少种并查集处理答案即可。
代码
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e5 + 7;
const int MOD = 1e9 + 7;
int p[N];
int f[N][18];
int Log[N];
int find(int x, int y)
{
if(x != f[x][y]) f[x][y] = find(f[x][y], y);
return f[x][y];
}
void merge(int x, int y, int len)
{
int a = find(f[x][len], len), b = find(f[y][len], len);
f[f[a][len]][len] = b;
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
for (int j = 0; j <= 22; j++)
f[i][j] = i;
}
Log[0] = -1;
for (int i = 1; i <= n; i++)
{
Log[i] = Log[i / 2] + 1;
}
for (int i = 1; i <= m; i++)
{
int l1, r1, l2, r2;
scanf("%d%d%d%d", &l1, &r1, &l2, &r2);
int len = Log[r1 - l1 + 1];
merge(l1, l2, len); // 合并四个区间
l1 = r1 - (1 << len) + 1, l2 = r2 - (1 << len) + 1;
merge(l1, l2, len);
}
int cnt = 0;
for (int l = 22; l >= 1; l--)
for (int i = 1; i + (1 << l) - 1 <= n; i++)
{
int fa = find(i, l);
if (fa != i)
{
merge(i, fa, l - 1);
merge(i + (1 << (l - 1)),
fa + (1 << (l - 1)),
l - 1); // 分裂
}
}
for (int i = 1; i <= n; i++)
if (f[i][0] == i) cnt++;
int res = 1;
for (int i = 1; i <= cnt - 1; i++)
res = 1ll * res * 10 % MOD;
res = 1ll * res * 9 % MOD;
cout << res << endl;
return 0;
}