はまやんはまやんはまやん

hamayanhamayan's blog

Multiset Mean [AtCoder Regular Contest 104 D]

https://atcoder.jp/contests/arc104/tasks/arc104_d

解説

https://atcoder.jp/contests/arc104/submissions/17180396

何から始める?

簡潔な問題文だが、何から始めるべきだろうか。
平均の問題への取り組み方はいくつかあるが、平均を求めるには「個数」と「総和」が分かっている必要がある。
これを平均が決まっていれば「総和」のみに減らすテクがある。

「a1, a2, a3, ...の平均がxになる」というのを「a1-x, a2-x, a3-x, ...の総和が0になる」と考える
よって、平均xが固定されていれば、総和だけを考える問題に帰着させることができる

とりあえず1クエリで考えてみる

各クエリでは平均がxの時の組み合わせを答えればいい。
なので、平均がxで固定であるとしよう。
先ほどのテクを見ると、
N=8でx=3とすると、-3をすると、
-2 -1 0 1 2 3 4 5
という数になる。
負の数を扱うのは大変なので、左の負の数側の総和と右の正の数の総和が絶対値で一致する組合せを計算することにする。
よって、以下のDPが作られていればいい。
dp[i][sm] := 1..iの各整数をそれぞれK個以下使って総和がsmとなる組合せ
こうすると答えは

dp[x-1][sm] × dp[N-x][sm] × (K+1)のsm=0...MAXの総和-1

となる。K+1をかけている理由はちょうどxの数は総和に寄与しないので、0個、1個、...、K個のどれをとってもいいのでK+1通りをかけている。
-1は空でない多重集合が要求されているので、空の集合を引いている。

DPを作る

愚直にDPを作ってみよう。
dp[i][sm]からi+1を0個~K個取ったときの遷移を作ればいい。
smの最大値はN2 Kなので、状態はN3 Kとなり、遷移があるので、O(N3 K2)で計算可能。
実は最大値は正確にはN2 K/2で十分なので、これを上限とすると、愚直でも通る。
https://atcoder.jp/contests/arc104/submissions/17179387

※ 最大値はなぜN2 K/2?

クエリを見てみると、左右の個数を見て、DPテーブルを参照している。
片方が0になると意味がないので、左右の個数のうち少ない方のsmの上限が最大となる。
左右の個数のうち少ない方の最大値はN/2なので、smの上限もN2 K/2となる。

DP高速化

DP高速化は普通に配るDPを貰うDPに変換して累積和に帰着させる。

順番を適当に変えて貰うDPにしたら、こんな感じ。
https://atcoder.jp/contests/arc104/submissions/17180237

あとはqueueとmodをうまい事使って先頭K+1個の累積和をする。

int N, K, M;
const int MA = 251010;
int dp[101][MA];
//---------------------------------------------------------------------------------------------------
deque<int> que[101];
int tot[101];
void _main() {
    cin >> N >> K >> mod;
    
    dp[0][0] = 1;
    rep(i, 0, 100) {
        rep(m, 0, i + 1) {
            que[m].clear();
            tot[m] = 0;
        }
        rep(sm, 0, MA) {
            if (que[sm % (i + 1)].size() == K + 1) {
                int tp = que[sm % (i + 1)].front(); que[sm % (i + 1)].pop_front();
                chsub(tot[sm % (i + 1)], tp);
            }
            que[sm % (i + 1)].push_back(dp[i][sm]);
            chadd(tot[sm % (i + 1)], dp[i][sm]);
            chadd(dp[i + 1][sm], tot[sm % (i + 1)]);
        }
    }

    rep(x, 1, N + 1) {
        int lft = x - 1, rht = N - x;
        int ans = 0;
        rep(sm, 0, MA) chadd(ans, mul(dp[lft][sm], dp[rht][sm], K + 1));
        chsub(ans, 1);
        printf("%d\n", ans);
    }
}