はまやんはまやんはまやん

hamayanhamayan's blog

Red and Blue Tree [エクサウィザーズプログラミングコンテスト2021(AtCoder Beginner Contest 222) E]

https://atcoder.jp/contests/abc222/tasks/abc222_e

解説

https://atcoder.jp/contests/abc222/submissions/26482282

恐らく初見だとかなり難しい問題に見えると思う。

問題を簡単化する

辺の塗り方を求める問題であるが、操作によってとある辺について何回通過するかという回数については
どうなっても変化しないことが分かる。
これはとても便利な性質で、通る回数が分かれば、もともとの操作・木構造はそれほど重要でない。
よって、各辺によって、操作によって通る回数を先に求めておけば、それを赤に塗ればRがその分増えるし、
青に塗ればBが増えるという単純な事象に落とし込める。

なので、木上で通った回数分RBが増減するのではなく、単に辺について着色をした場合にRBが増減すると
もう少し問題を簡単化することにしよう。
このためには、辺を通る回数を数える必要がある。

より簡単な問題にするために辺を通る回数を数える

色々工面するかと思うが、実はDFSで比較的簡単に書ける。
DFSとだけ言われてもあまりピンとこないかもしれないが、AからBへの最短経路を求める場合にDFSを使った場合、
DFSの探索時は関係ない所への再帰で移動していくが、最終的には目的のBへ到達することができる。
到達した後に遷移元を順にたどっていくと、それがちょうど最短経路になっている。
これを利用して辺を数え上げることにする。
多分これだけ聞いてもちょっとわかりにくい。実装を見てもらうといいと思う。

エッセンスだけ再掲しておくと、AからBへの最短経路を求めたいときにDFSを使う場合、
行きに注目すると最短経路は求められないが、Bに到達した後に再帰から戻ってくる帰りの部分は丁度最短経路のみ通ってくることになる。
このようなある種のバックトラックを利用することで比較的少ない実装量で最短経路上に対して計算を行うことができる。
今回は計算量に余裕があるので、この方法でカウントしてしまおう。

残った問題は?

これで辺について何回通るかのcnt配列が計算できたことになる。
各辺について赤か青に適切に塗ることでR-B=Kを満たすような組合せを求める問題に帰着した。
mod 998244353ということもあるが、DPで解いてみよう。

dp[i][tot] := i番目の辺まで着色が終わっていてR-B=totであるような組合せ

これを解ければ、dp[N-1][K]が答えになる。
これも…実は難しい…

DPを解く

難しい点は2点ある。

  1. dp[1000][100000]で配列を取るとMLEするかも
  2. tot部分が負の数になる可能性がある

1. dp[1000][100000]で配列を取るとMLEするかも

メモリの圧縮テクを使う。今回はdp[i+1][?]の更新はdp[i][?]だけが必要でそれ以前は必要ないので、
1つ前だけを保持してDPしていくようなメモリ削減を行ったこれで、第一添え字部分が1000ではなく2で十分になるので
大幅にメモリ削減ができてMLEを回避できる。
なお、同じような実装であればメモリ使用量が少ない方が高速になる傾向がある(要出典)。
自分の実装を見てもらえば何となく何をしているかは分かると思うがもう少し細かい説明が必要な場合は
https://qiita.com/drken/items/68b8503ad4ffb469624c#%E6%B3%A8%E6%84%8F%E7%82%B92-%E3%82%88%E3%82%8A%E6%B1%8E%E7%94%A8%E7%9A%84%E3%81%AB%E4%BD%BF%E3%81%88%E3%82%8B%E3%83%A1%E3%83%A2%E3%83%AA%E7%AF%80%E7%B4%84%E3%83%86%E3%82%AF%E3%83%8B%E3%83%83%E3%82%AF
とやってることは同じなので、見てみてるといいと思う(ここに書かれている他のテクは汎用テクで必見です)

2. tot部分が負の数になる可能性がある

C言語の配列は添え字に負の数が来ると(動くけど)めちゃくちゃになっちゃうので中心をずらして全部非負にする必要がある。
自分はいつもBASE変数として中央の数を指定して入れている。
0はdp[BASE]だし、-5はdp[-5 + BASE]だし、100はdp[100 + BASE]だしといった感じ。
なので、初期値もdp[0][BASE] = 1で、答えもdp[N-1][K + BASE]である。

int N, M, K, A[101];
vector<pair<int, int>> E[1010];
int cnt[1010];
mint dp[2][201010];
const int BASE = 100005;
//---------------------------------------------------------------------------------------------------
bool dfs(int cu, int pa, int goal) {
    if (cu == goal) return true;

    fore(p, E[cu]) if (p.first != pa) {
        bool res = dfs(p.first, cu, goal);
        if (res) {
            cnt[p.second]++;
            return true;
        }
    }

    return false;
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N >> M >> K;
    rep(i, 0, M) cin >> A[i];
    rep(i, 0, N - 1) {
        int u, v; cin >> u >> v;
        E[u].push_back({ v, i });
        E[v].push_back({ u, i });
    }

    rep(i, 0, M - 1) dfs(A[i], -1, A[i + 1]);

    dp[0][BASE] = 1;
    rep(i, 0, N - 1) {
        int cu = i % 2;
        int nxt = 1 - cu;
        rep(tot, 0, 201010) dp[nxt][tot] = 0;
        rep(tot, 0, 201010) if (0 < dp[cu][tot].get()) {
            dp[nxt][tot - cnt[i]] += dp[cu][tot];
            dp[nxt][tot + cnt[i]] += dp[cu][tot];
        }
    }
    cout << dp[(N - 1) % 2][K + BASE] << endl;
}