P3295 萌萌哒
2023-11-02 22:53:24 # OI # Problem

题面

题目描述

一个长度为 nn 的大数,用 S1S2S3SnS_1S_2S_3 \cdots S_n表示,其中 SiS_i 表示数的第 ii 位, S1S_1 是数的最高位。告诉你一些限制条件,每个条件表示为四个数,l1,r1,l2,r2l_1,r_1,l_2,r_2,即两个长度相同的区间,表示子串 Sl1Sl1+1Sl1+2Sr1S_{l_1}S_{l_1+1}S_{l_1+2} \cdots S_{r_1}Sl2Sl2+1Sl2+2Sr2S_{l_2}S_{l_2+1}S_{l_2+2} \cdots S_{r_2} 完全相同。

比如 n=6n=6 时,某限制条件 l1=1,r1=3,l2=4,r2=6l_1=1,r_1=3,l_2=4,r_2=6 ,那么 123123123123351351351351 均满足条件,但是 1201212012131141131141 不满足条件,前者数的长度不为 66 ,后者第二位与第五位不同。问满足以上所有条件的数有多少个。

输入格式

第一行两个数 nnmm,分别表示大数的长度,以及限制条件的个数。

接下来 mm 行,对于第 ii 行,有 44 个数 li1,ri1,li2,ri2l_{i_1},r_{i_1},{l_{i_2}},r_{i_2},分别表示该限制条件对应的两个区间。

1n1051\le n\le 10^51m1051\le m\le 10^51li1,ri1,li2,ri2n1\le l_{i_1},r_{i_1},{l_{i_2}},r_{i_2}\le n ;并且保证 $ r_{i_1}-l_{i_1}=r_{i_2}-l_{i_2}$ 。

输出格式

一个数,表示满足所有条件且长度为 nn 的大数的个数,答案可能很大,因此输出答案模 $ 10^9+7 $ 的结果即可。

样例 #1

样例输入 #1

4 2
1 2 3 4
3 3 3 3

样例输出 #1

90

思路

对于两段连续的区间相等,可以转化为对应的点相等,然后相等的点可以划分为同一类,这个可以用并查集维护。

然后并查集合并完以后,数一下总共有几类数(一开始每个数自成一类),每一类数都有十种可能性(除了第一个数不能是 0),所以答案就是 9×10cnt19\times10^{cnt-1}

但是会超时,只能得到 30 分的好成绩:

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e5 + 7;
const int MOD = 1e9 + 7;

int p[N];

int find(int x)
{
	if(x != p[x]) p[x] = find(p[x]);
	return p[x];
}

int main()
{
	int n, m;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++)
	{
		p[i] = i;
	}
	for (int i = 1; i <= m; i++)
	{
		int l1, r1, l2, r2;
		scanf("%d%d%d%d", &l1, &r1, &l2, &r2);
		for (int j = l1; j <= r1; j++)
		{
			int a = find(j), b = find(j + l2 - l1);
			if (a != b)
			{
				p[a] = b;
			}
		}
	}
	int cnt = 0;
	for (int i = 1; i <= n; i++)
		if (p[i] == i) cnt++;
	int res = 1;
	for (int i = 1; i <= cnt - 1; i++)
		res = 1ll * res * 10 % MOD;
	res = 1ll * res * 9 % MOD;
	cout << res << endl;
	return 0;
}

怎么优化呢?

一开始我们是点对点地合并并查集,为了提高效率我们可以区间对应区间地合并并查集,为了控制区间长度的一致性,都以 2k2^k 长度进行划分,利用类似 ST 表的思路即可。

然后我们最后还是要查询每个的父亲是不是它自己,但是我们的信息维护仅停留在区间这个维度,所以我们要下放信息

2k=2k1+2k12^k=2^{k-1}+2^{k-1}

我们记 fi,jf_{i,j} 代表以 ii 为左端点,长度为 2j2^j 的区间的父亲的左端点(初始化为 ii)。

在这里 jj 可以理解为一个属性,因为只有区间长度相同的才可以合并(两段相等的区间),在 findfind 函数和 mergemerge 函数里也别忘了传进这个属性。

然后枚举每个区间下放信息,统计一下有多少种并查集处理答案即可。

代码

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e5 + 7;
const int MOD = 1e9 + 7;

int p[N];
int f[N][18];
int Log[N];

int find(int x, int y)
{
	if(x != f[x][y]) f[x][y] = find(f[x][y], y);
	return f[x][y];
}
void merge(int x, int y, int len)
{
	int a = find(f[x][len], len), b = find(f[y][len], len);
	f[f[a][len]][len] = b;
}

int main()
{
	int n, m;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++)
	{
		for (int j = 0; j <= 22; j++)
			f[i][j] = i;
	}
	Log[0] = -1;
	for (int i = 1; i <= n; i++)
	{
		Log[i] = Log[i / 2] + 1;	
	}
	for (int i = 1; i <= m; i++)
	{
		int l1, r1, l2, r2;
		scanf("%d%d%d%d", &l1, &r1, &l2, &r2);
		int len = Log[r1 - l1 + 1];
		merge(l1, l2, len); // 合并四个区间 
		l1 = r1 - (1 << len) + 1, l2 = r2 - (1 << len) + 1;
		merge(l1, l2, len);
	}
	int cnt = 0;
	for (int l = 22; l >= 1; l--)
		for (int i = 1; i + (1 << l) - 1 <= n; i++)
		{
			int fa = find(i, l);
			if (fa != i)
			{
				merge(i, fa, l - 1);
				merge(i + (1 << (l - 1)),
					  fa + (1 << (l - 1)), 
					  l - 1); // 分裂	
			} 
		}
	for (int i = 1; i <= n; i++)
		if (f[i][0] == i) cnt++;
	int res = 1;
	for (int i = 1; i <= cnt - 1; i++)
		res = 1ll * res * 10 % MOD;
	res = 1ll * res * 9 % MOD;
	cout << res << endl;
	return 0;
}