P4211 LNOI2014 LCA
2024-03-13 12:16:51 # OI # Problem

题面

题目描述

给出一个 nn 个节点的有根树(编号为 00n1n-1,根节点为 00)。

一个点的深度定义为这个节点到根的距离 +1+1

dep[i]dep[i] 表示点 ii 的深度,LCA(i,j)\operatorname{LCA}(i, j) 表示 iijj 的最近公共祖先。

mm 次询问,每次询问给出 l,r,zl, r, z,求 i=lrdep[LCA(i,z)]\sum_{i=l}^r dep[\operatorname{LCA}(i,z)]

输入格式

第一行 22 个整数,n,mn, m

接下来 n1n-1 行,分别表示点 11 到点 n1n-1 的父节点编号。

接下来 mm 行,每行 33 个整数,l,r,zl, r, z

输出格式

输出 qq 行,每行表示一个询问的答案。每个答案对 201314201314 取模输出。

样例 #1

样例输入 #1

5 2
0
0
1
1
1 4 3
1 4 2

样例输出 #1

8
5

提示

对于 20%20\% 的数据,n10000,m10000n\le 10000,m\le 10000

对于 40%40\% 的数据,n20000,m20000n\le 20000,m\le 20000

对于 60%60\% 的数据,n30000,m30000n\le 30000,m\le 30000

对于 80%80\% 的数据,n40000,m40000n\le 40000,m\le 40000

对于 100%100\% 的数据,1n50000,1m500001\le n\le 50000,1\le m\le 50000

思路

看上去很水的一道题。

实际上需要 动动脑子

先来看问题的简化版:求两点的 lca 的深度。

第一种方法可以是求出这两点的 lca,然后得到它的深度,单词复杂度为 O(logn)O(logn),但极端情况下可能是 O(n)O(n),无优化空间。

第二种方法:设两点分别为 (s,t)(s, t),那么先把 ss 到根的路径上的点权全部加一,然后统计 tt 到根路径的点权和即为 lca 的深度。

想一想,为什么?

可以画张图来理解,我曾在SNOIP听课简记中提到,这里不多赘述。

这样一次处理,树剖+线段树区间加的复杂度是 O(nlog2n)O(nlog^2n) 的,处理 qq 次,复杂度 O(qnlog2n)O(qnlog^2n),不理想。

考虑把所有的询问离线,然后发现它们用到的其实都是一个东西:一段路径和的前缀和,每次询问的答案其实就是 sumrsuml1sum_r-sum_{l-1}sumi=j=1idep[lca(j,z)]sum_i=\sum_{j=1}^i{dep[lca(j,z)]},于是我们就可以把所有的询问编号并按照端点排序并处理答案,由于 l,rnl,r\in n,所以复杂度优化为 O(nlog2n)O(nlog^2n)

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 5e4 + 7;
const int MOD = 201314;
#define ls(x) tr[u << 1]
#define rs(x) tr[u << 1 | 1]

namespace SegmentTree
{
    struct Node
    {
        int l, r;
        int sum;
        int add;
        Node()
        {
            sum = add = 0;
        }
    }tr[N << 2];
    void pushup(int u)
    {
        tr[u].sum = ls(u).sum + rs(u).sum; 
    }
    void pushdown(int u)
    {
        Node &root = tr[u];
        if(root.add)
        {
            ls(u).add += root.add;
            rs(u).add += root.add;
            ls(u).sum = (ls(u).sum + (ls(u).r - ls(u).l + 1) * root.add) % MOD;
            rs(u).sum = (rs(u).sum + (rs(u).r - rs(u).l + 1) * root.add) % MOD;
            root.add = 0;
        }
    }
    void build(int u, int l, int r)
    {
        tr[u].l = l, tr[u].r = r;
        if(l == r) return;
        int mid = (l + r) >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
    }
    void update(int u, int l, int r, int k)
    {
        if(l <= tr[u].l && tr[u].r <= r)
        {
            tr[u].sum = (tr[u].sum + tr[u].r - tr[u].l + 1) * k % MOD;
            tr[u].add += k;
            return;
        }
        pushdown(u);
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) update(u << 1, l, r, k);
        if(r > mid) update(u << 1 | 1, l, r, k);
        pushup(u);
    }
    int query(int u, int l, int r)
    {
        if(l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
        pushdown(u);
        int mid = (tr[u].l + tr[u].r) >> 1;
        int res = 0;
        if(l <= mid) res = (res + query(u << 1, l, r)) % MOD;
        if(r > mid) res = (res + query(u << 1 | 1, l, r)) % MOD;
        return res; 
    }
}

int n, m;
int idx, cnt;
int fir[N], fa[N], dep[N], sz[N], son[N], id[N], top[N];
int ans1[N], ans2[N];
struct Edge
{
    int nxt, to;
}e[N << 1];
struct Request
{
    int id, pos, p;
    bool bj; // 0 - ans1, 1 - ans2
    bool operator < (const Request &t) const 
    {
        return pos < t.pos;
    }
}q[N << 1];

void add(int u, int v)
{
    e[++cnt].to = v;
    e[cnt].nxt = fir[u];
    fir[u] = cnt;
}
void dfs1(int u)
{
    sz[u] = 1;
    for(int i = fir[u]; i; i = e[i].nxt)
    {
        int v = e[i].to;
        if(v == fa[u]) continue;
        dep[v] = dep[u] + 1;
        dfs1(v);
        sz[u] += sz[v];
        if(sz[son[u]] < sz[v]) son[u] = v;
    }
}
void dfs2(int u, int t)
{
    top[u] = t;
    id[u] = ++idx;
    if(!son[u]) return;
    dfs2(son[u], t);
    for(int i = fir[u]; i; i = e[i].nxt)
    {
        int v = e[i].to;
        if(v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}
void update(int u)
{
    while(top[u] != 1)
    {
        SegmentTree::update(1, id[top[u]], id[u], 1);
        u = fa[top[u]];
    }
    SegmentTree::update(1, 1, id[u], 1);
}
int query(int u)
{
    int res = 0;
    while(top[u] != 1)
    {
        res = (res + SegmentTree::query(1, id[top[u]], id[u])) % MOD;
        u = fa[top[u]];
    }
    res = (res + SegmentTree::query(1, 1, id[u])) % MOD;
    return res;
}

int main()
{
    scanf("%d%d", &n, &m);
    for(int i = 2; i <= n; i++)
    {
        int f;
        scanf("%d", &f);
        fa[i] = ++f;
        add(f, i);
        add(i, f);
    }
    for(int i = 1, j = 1; i <= m; i++, j += 2)
    {
        int l, r, z;
        scanf("%d%d%d", &l, &r, &z);
        q[j] = {i, l, z + 1, false};
        q[j + 1] = {i, r + 1, z + 1, true};
        // cout << q[2].pos << endl;
    }
    sort(q + 1, q + m * 2 + 1);
    dfs1(1);
    dfs2(1, 1);
    SegmentTree::build(1, 1, n);
    for(int i = 1, j = 1; i <= 2 * m; i++)
    {
        while(j <= q[i].pos) update(j++);
        if(!q[i].bj) ans1[q[i].id] = query(q[i].p);
        else ans2[q[i].id] = query(q[i].p);
    }
    for(int i = 1; i <= m; i++)
    {
        // cout << ans2[i] << endl << ans1[i] << endl;
        printf("%d\n", ((ans2[i] - ans1[i]) % MOD + MOD) % MOD);
    }
    return 0;
}