はまやんはまやんはまやん

hamayanhamayan's blog

Tree Patrolling [AtCoder Beginner Contest 207 F]

https://atcoder.jp/contests/abc207/tasks/abc207_f

前提知識

解説

https://atcoder.jp/contests/abc207/submissions/23806639

二乗の木DPという手法を用いる。
知らないと解けない気もするが、計算量的なセンスがあれば自ら生み出せるかもしれない。
(念のため補足しておくと、生み出せたならあなたは天才です)

木DP

二乗の木DPについての前に木DPが今回は思いつかないと先には進めない。
木DPが分からない場合はどこかで勉強してきてほしい。
上のリンクにも練習問題なら用意している。

木上での数え上げといえば木DPなので、とりあえず木DPでやれないか検討してみる。
親から見た子について必要な情報を考えてみると、以下のようなDPが立つ。

dp[cu][cnt][placed][colored] :=
頂点cuを親とした部分木において、
警備されている頂点数がcntで
頂点cuに高橋君がいるかの情報がplaced(0/1)で、
頂点cuが警備されているかの情報がcolored(0/1)のときの組合せ

これを計算して、木全体の根を0とすると、警備された頂点数がちょうどK個の時の組み合わせはdp[0][K][any][any]の総和となる。
まずはこのDPテーブルの定義に納得することが第一段階。
とある頂点cuとその子供の頂点toについて、toに高橋君がもしいれば頂点cuは警備されるので、子供について高橋君がいるかどうかの情報は必要だし、
cuに高橋君を配置すれば頂点toが警備されていないなら警備されている状態になるので、子供について子供の頂点が警備されているかの情報は必要である。
だが、この警備関係は直接の子供についてのみ作用するので、木DPで根について以上の情報を持っておけば、他のことは抽象化しても問題ない。
よって、こんな感じの木DPかなという感じ。

遷移について

定義は多次元DPに慣れていればそれほど難しくないが、遷移が難しい。
今回の木DPは子供の情報を単に足し合わせるだけでなく、根の情報をうまく組合せながらまとめ上げていく必要がある。
こういう場合は、DPをマージしていくような感じで遷移を進めていく。

初期状態について

全く子供を持っていない葉の状態では、単に高橋君を置くかおかないかという感じになるので、

dp[cu][0][0][0] = 1 (おかない)
dp[cu][1][1][1] = 1 (おく)

という感じになる。

子を親にマージするとどうなるか

実装を見てもらうといいが、本当にマージしてる感じがでている。
親も上の初期状態DPから始まって子のDPをマージして新しく親のDPを再構築するような形で計算を進めていく。

rep(c0, 0, size0) rep(placed0, 0, 2) rep(colored0, 0, 2) rep(c1, 0, size1) rep(placed1, 0, 2) rep(colored1, 0, 2) {  
    int c2 = c0 + c1;  
    int placed2 = placed0;  
    int colored2 = colored0 | placed1;  
    if (colored0 == 0 && placed1 == 1) c2++;  
    if (colored1 == 0 && placed0 == 1) c2++;  
  
    if (c2 < size0 + size1 - 1 && 0 < (dp[id0][c0][placed0][colored0] * dp[id1][c1][placed1][colored1]).get()) {  
        nxt[c2][placed2][colored2] += dp[id0][c0][placed0][colored0] * dp[id1][c1][placed1][colored1];  
    }  
}  

重要な遷移部分を以上に再掲した。
親dp[cu]と子dp[to]をdp'cuとして再構築している感じになる。

基本的にはnxt[cnt0+cnt1] += dp[cu][cnt0] * dp[to][cnt1]をしていく単純な遷移なのだが、
親と子のplacedとcoloredの関係により、警備されている個数が増加したりするのでその辺をまとめるのに6重ループになってしまっている。
説明するより実装を見てもらう方がDP構造が分かっている場合は早いと思う。
c2(=dp'のcnt)をインクリメントしている部分は、さっき話した

とある頂点cuとその子供の頂点toについて、toに高橋君がもしいれば頂点cuは警備されるので、子供について高橋君がいるかどうかの情報は必要だし、
cuに高橋君を配置すれば頂点toが警備されていないなら警備されている状態になるので、子供について子供の頂点が警備されているかの情報は必要である。

の部分であり、更新前にif文を用意しているのは更新時にDPテーブルをはみ出してしまう(はみ出ている部分はありえない場面であるため無視していい)ためである。
これでマージしていけば最終的に親のDPが答えになっている。

二乗の木DPはどこへ?

ここまで言われていることが分かっていれば問題は解けている。
理解が難しい場合は簡単な二乗の木DPを解いて戻ってくれば恐らく理解できるようになるだろう。

そういえば二乗の木DPの説明をしていなかったが、これは枝刈り木DPのことである。
通常のDPではdp[cu][cnt]みたいにして、どちらも103が上限だとすると、dp[cu][cnt]の全ての要素を使って計算が行われる。
だが、一部の木DP、今回のようなDPの場合にはcnt≦|cuを根とした部分木の要素数|が成り立つ場合がある。
この場合にはdp[103][103]のように定義したとしても、部分木の要素数以上は使用されないので大部分は計算に用いられないことになる。
この用いられない部分に対して明確に計算しないように計算を進める、つまり、適切に枝刈りを行って計算を進めることで、
通常はO(N3)になるはず(雰囲気的にはパッと見てO(N3)っぽい)計算量をO(N2)に抑えることができる。
これが二乗の木DPである。

今回は二乗の木DPにできそうなcnt≦|cuを根とした部分木の要素数|があるので、通常の木DPでは見られにくいmergeを行うことで適切に枝刈りをして計算量を落としている。

実装

自分の実装では部分木の個数で要素数が変わるので毎回vectorでDPを作り直して、merge関数でゴリゴリ更新を行っている。
vector使うと生成コストが気になってしまうとも思うので、他の人の実装も参考にしてみるといいかもしれない。
自分はまだ間に合わなかったことがないので、とりあえず馴染みのこんな実装で通した。

int N;
vector<int> E[2010];
vector<vector<vector<mint>>> dp[2010];
//---------------------------------------------------------------------------------------------------
void merge(int id0, int id1) {
    int size0 = dp[id0].size();
    int size1 = dp[id1].size();
    vector<vector<vector<mint>>> nxt(size0 + size1 - 1, vector<vector<mint>>(2, vector<mint>(2, 0)));

    rep(c0, 0, size0) rep(placed0, 0, 2) rep(colored0, 0, 2) rep(c1, 0, size1) rep(placed1, 0, 2) rep(colored1, 0, 2) {
        int c2 = c0 + c1;
        int placed2 = placed0;
        int colored2 = colored0 | placed1;
        if (colored0 == 0 && placed1 == 1) c2++;
        if (colored1 == 0 && placed0 == 1) c2++;

        if (c2 < size0 + size1 - 1 && 0 < (dp[id0][c0][placed0][colored0] * dp[id1][c1][placed1][colored1]).get()) {
            nxt[c2][placed2][colored2] += dp[id0][c0][placed0][colored0] * dp[id1][c1][placed1][colored1];
        }
    }

    swap(dp[id0], nxt);
}
//---------------------------------------------------------------------------------------------------
void dfs(int cu, int pa = -1) {
    dp[cu] = vector<vector<vector<mint>>>(2, vector<vector<mint>>(2, vector<mint>(2, 0)));
    dp[cu][0][0][0] = 1;
    dp[cu][1][1][1] = 1;

    fore(to, E[cu]) if (to != pa) {
        dfs(to, cu);
        merge(cu, to);
    }
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N;
    rep(i, 0, N - 1) {
        int a, b; cin >> a >> b;
        a--; b--;
        E[a].push_back(b);
        E[b].push_back(a);
    }

    dfs(0);
    rep(k, 0, N + 1) printf("%d\n", (dp[0][k][0][0] + dp[0][k][0][1] + dp[0][k][1][1]).get());
}