はまやんはまやんはまやん

hamayanhamayan's blog

Count Descendants [エイシングプログラミングコンテスト2021(AtCoder Beginner Contest 202) E]

https://atcoder.jp/contests/abc202/tasks/abc202_e

前提知識

解説

https://atcoder.jp/contests/abc202/submissions/22837009

最後までの理解は難しいかもしれないが、問題の言い換え部分までは参考になるかもしれない。

問題の言い換え

少し扱いやすいように問題を言い換える。以下のようにクエリ問題を言い換えてみよう。

頂点U[i]を根とした部分木上に本来の根からの距離がD[i]である頂点がいくつあるか

このように考えると、今回の問題は部分木についてのアルゴリズムを適用することができる。

部分木

部分木といえば、オイラーツアーというものが使える。
これを使うことで、ある部分木に対する操作をとある区間の操作に言い換えることができる。
なのでクエリ的には部分木という条件を区間への条件にマッピングすることができるので、

頂点U[i]を根とした部分木上に本来の根からの距離がD[i]である頂点がいくつあるか

という条件を

ある区間[L,R)について、本来の根からの距離がD[i]である頂点がいくつあるか

と言い換えることができる。更に言うと、そのある区間を本来の根からの距離を保持する配列として考えると、

ある区間[L,R)について、値がD[i]である要素がいくつあるか

という感じに帰着させることができる。これでだいぶ解きやすくなった。

帰着問題をいかに解くか

この問題は実は頻出問題であり、いくつか解法がある。
今回はBITと水平走査を使って、これを解くことにしよう。

クエリを先読みしておき、深さ順に処理していくことにする。
深さについて0からN-1まで順に処理をしていくことにする。
処理していく過程で配列bit(実装はBITで実装されていて、区間和を取れるようにしておく)

bit[x] := オイラーツアーによって頂点cuが要素xにマッピングされているとするとき、頂点cuを訪問済み(カウント済み)であればbit[x]=1となる

ここで、以下のような処理を行う。

深さdについて処理するとする
1. D[i]=dである全てのクエリqについて、ans[q]からU[q]に対応する区間[L,R)のbitの総和を引く
2. 深さがdである頂点について対応するbitの要素に+1をする
3. D[i]=dである全てのクエリqについて、ans[q]へU[q]に対応する区間[L,R)のbitの総和を足す

こうすると手順1ではbitの状態がd-1までがカウントされている状態になっていて、手順3ではbitの状態がdまでがカウントされている状態になっている。
なので、

ans[q] = (U[q]を根とする部分木で深さがd以下の頂点数) - (U[q]を根とする部分木で深さがd-1以下の頂点数)

のように計算していることになるので、ちょうど深さがdの頂点数を求めることができている。

これは考え方的には水平走査的な考え方で、根付き木の頂点を深さが小さい順に評価して、適切に情報を保持していくことで2ベクトルの情報を加味した計算を行っている。

int N;
vector<int> E[201010];
int Q;
int U[201010], D[201010];
int ans[201010];
//---------------------------------------------------------------------------------------------------
BIT<int> bit(401010);
int L[401010], R[401010];
int idx = 0;
void euler(int cu, int pa = -1) { // [L[v],R[v])
    L[cu] = idx; idx++;
    for (int to : E[cu]) if (to != pa) euler(to, cu);
    R[cu] = idx;
}
//---------------------------------------------------------------------------------------------------
int dep[201010];
void dfs(int cu, int d = 0, int pa = -1) {
    dep[cu] = d;
    fore(to, E[cu]) if (to != pa) dfs(to, d + 1, cu);
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N;
    rep(i, 1, N) {
        int P; cin >> P; P--;
        E[i].push_back(P);
        E[P].push_back(i);
    }
    euler(0);
    cin >> Q;
    rep(i, 0, Q) cin >> U[i] >> D[i], U[i]--;
    dfs(0);

    map<int, vector<int>> mapping;
    rep(i, 0, N) mapping[dep[i]].push_back(i);

    map<int, vector<int>> qs;
    rep(i, 0, Q) qs[D[i]].push_back(i);

    rep(d, 0, N) {
        fore(q, qs[d]) ans[q] -= bit.get(L[U[q]], R[U[q]]);
        fore(cu, mapping[d]) bit.add(L[cu], 1);
        fore(q, qs[d]) ans[q] += bit.get(L[U[q]], R[U[q]]);
    }

    rep(i, 0, Q) printf("%d\n", ans[i]);
}