https://atcoder.jp/contests/abc201/tasks/abc201_e
前提知識
解説
https://atcoder.jp/contests/abc201/submissions/22639728
まず、方針を色々考えていく上でビット毎に処理していくという発想を出すところがまず難しい。
ビット毎に処理?
XOR問題への典型テクとしてビット毎に処理していくというのがある。
例えば下位2ビット目についてだけを考えることにする。
他のビットは考慮しない、全部0であると考えることにする。
こうするとすべての辺の重みは...00100か...00000のどちらかとなる。
全ての頂点の組の間での重みのxorも同様に...00100か...00000のどちらかとなる。
つまり、下位2ビット目だけを考えると、答えとして計上されるのは...00100×(重みのxorが...00100となる組合せ)となる。
これが求まって何になるのかという話であるが、xorはビット毎に独立した計算であり、かつ総和もビット毎にある程度独立に計算することが可能である。
全ての頂点の重みのxorの総和 = (下位1ビット目のみ考慮した場合の重みの総和) + (下位2ビット目のみ考慮した場合の重みの総和) + ... + (下位60ビット目のみ考慮した場合の重みの総和)
以上の等式が成り立つ。この関係性が分からないと答えまで行くのは難しいと思う。良く考えてみてほしい。
次はどうする?
ビット毎に処理することにした場合に何がいいかというと、xorをした結果が全部0かそのビットだけ1であるかの2択に状態が圧縮できる所にある。
bitビット目を処理しているとする。
始点と終点を決めたときにxorした結果がbitビット目だけ1であるような組合せを高速にどうやって求めていこうか。
これを求める方法として、二通りの方法を紹介しよう。
想定解
問題の解法を理解するのに自分が通した解法よりも想定解の方が事前要求知識が少ないのでこちらを先に紹介しておく。
ここで、XORでのもう一つのテクを利用する。
(頂点iから頂点jへのXOR和) = (木の根から頂点iへのXOR和) ^ (木の根から頂点jへのXOR和)
これが成り立つ。
つまり、木の根から全ての頂点に対してdfsでXOR和を計算しておく。
すると、条件を満たす「始点と終点を決めたときにxorした結果がbitビット目だけ1であるような組合せ」というのは、
(始点と終点を決めたときにxorした結果がbitビット目だけ1であるような組合せ) = (木の根から頂点へのXOR和が0である頂点数) × (木の根から頂点へのXOR和が0じゃない頂点数)
となる。
左辺は、dfsをしてxor和を計算しておき、bitビット目が1かどうかを見ることで木の根から頂点へのXOR和が0か0じゃないかが分かるので、個数をそれぞれ計算できる。
よって、この問題は解くことができた。
注意点としてやや計算時間が厳しい部分があげられる。
自分の実装では先にdfsをして全部の頂点の根からのxorを計算しておき、その後、ビット毎に計算しておいたxor和を見ながら個数を集計している。
ビット毎にdfsをしても計算量的には同じO((ビット数)*N)で等しいのだが、ビット毎にdfsをすると定数倍分の計算が重くなるので注意。
特にdfsをすると関数呼び出し分のオーバーヘッドが発生するので、計算量を気にするなら、公式解説にもあるようにbfsする方がいい。
本番通した解法
本番は全方位木DPを使った解法で2992msで通しました。
https://atcoder.jp/contests/abc201/submissions/22626100
全方位木DPについて知らない方はどこかの記事で勉強してきてほしい。
「(頂点iから頂点jへのXOR和) = (木の根から頂点iへのXOR和) ^ (木の根から頂点jへのXOR和)」を使えば簡単な実装で済むのだが、完全に失念していた。
この解法では、ビット毎にやるという方針は一緒なのだが、とある頂点iを始点としてXOR和が0以外となる終点の組み合わせを高速に求めることを目標とする。
以下の木DPを定義する。
dp[cu][xo] := 頂点cuを根とした部分木において、頂点cuからのXORがxo(=0or1)であるような頂点数
これを普通に木DPで作って、全方位木DPで根をずらしながら計算していく。
この時の根が先ほどの「とある頂点i」に対応していて、再構築された木DPを使えばXOR和が0以外となる終点も高速に数え上げることができる。
なお、こっちだと実装量が2倍くらいになる。
int N; vector<pair<int, ll>> E[201010]; ll xo[201010]; //--------------------------------------------------------------------------------------------------- void dfs(int cu, int pa, ll x) { xo[cu] = x; fore(p, E[cu]) if (pa != p.first) dfs(p.first, cu, x ^ p.second); } //--------------------------------------------------------------------------------------------------- void _main() { cin >> N; rep(i, 0, N - 1) { int a, b; ll c; cin >> a >> b >> c; a--; b--; E[a].push_back({ b, c }); E[b].push_back({ a, c }); } dfs(0, -1, 0); mint ans = 0; rep(bit, 0, 60) { ll msk = 1LL << bit; ll cnt[2] = { 0, 0 }; rep(i, 0, N) cnt[(xo[i] & msk) != 0]++; ans += mint(msk) * mint(cnt[0]) * mint(cnt[1]); } cout << ans << endl; }