https://atcoder.jp/contests/abc123/tasks/abc123_d
解説1:全部二分探索
https://atcoder.jp/contests/abc123/submissions/4870889
どこから初めていいか分からないと思うが、よくある問題として以下のものがある。
「ある条件でソートしたときのK番目の数を求めよ」
この問題への取り組み方として、答えの数で二分探索するというものがある。
この方針から考えてみることにする。
K番目のケーキの美味しさの合計がわかれば、それより小さい美味しさの合計はK個未満なので、
高速に列挙できそう。
なので、K番目のケーキの美味しさを求めることにしよう。
ここで二分探索する。
check(x) := 美味しさの合計がx以上の個数がK以上か
を使おう。この二分探索で得られたokが、K番目のケーキの美味しさの合計となる。
checkの実装は、2と3の形のキャンドルの全組合せを、あらかじめ、とある配列Bに入れて、ソートしておく。
チェック時はAを全探索して、lower_boundでx-A[i]を探せば、2と3の合計の境界線が分かるので、組合せが得られる。
これの総和がK以上かどうかを見ればいい。
答えを作るときは、ok未満の美味しさの合計を作り、K個に足りない分はokを入れて答える。
okも作ってしまうと、okの個数が多いときに対応できないので、ok未満のみ作る。
int N[3], K; ll A[3][1010]; //--------------------------------------------------------------------------------------------------- vector<ll> B; int BN; int check(ll x) { ll cnt = 0; rep(i, 0, N[0]) { int id = lower_bound(all(B), x - A[0][i]) - B.begin(); cnt += BN - id; } return K <= cnt; } //--------------------------------------------------------------------------------------------------- void _main() { rep(i, 0, 3) cin >> N[i]; cin >> K; rep(i, 0, 3) rep(j, 0, N[i]) cin >> A[i][j]; rep(a, 0, N[1]) rep(b, 0, N[2]) B.push_back(A[1][a] + A[2][b]); sort(all(B)); BN = B.size(); ll ok = 0, ng = infl; while (ok + 1 != ng) { ll md = (ng + ok) / 2; if (check(md)) ok = md; else ng = md; } reverse(all(B)); vector<ll> ans; rep(i, 0, N[0]) { fore(b, B) { if (ok < A[0][i] + b) ans.push_back(A[0][i] + b); else break; } } while (ans.size() < K) ans.push_back(ok); sort(all(ans), greater<ll>()); rep(i, 0, K) printf("%lld\n", ans[i]); }
解説2: 一部、枝刈り全探索
https://atcoder.jp/contests/abc123/submissions/4870985
解説を見て、面白かったので紹介しておく。
check関数は解説1だと二分探索していたが、枝刈り全探索でも間に合う。
具体的な実装は以下の通り。
1. A,B,Cを降順ソート
2. 順番に見るときに総和がx未満となったらbreak(降順ソートなので、その後は絶対x未満)
3. 個数をカウントするときにcntがK個以上になったらreturn
これだけの改善でO(N^3)がO(N^2+K)に落ちる。
なぜこのような計算量になるかというと、
A,B,Cの全探索の過程で、
- 「実装2.」はAのサイズ×Bのサイズ回数だけ起きる
- しかもそれ以外は「実装3.」のチェックが走る
- 「実装3.」のチェックは最高K回しか行われない
となるので、ちょっと枝を刈るだけで計算量改善がなされる。
これはansの構築のときでも使える。
int N[3], K; ll A[3][1010]; //--------------------------------------------------------------------------------------------------- inline ll get(int a, int b, int c) { return A[0][a] + A[1][b] + A[2][c]; } int check(ll x) { int cnt = 0; rep(a, 0, N[0]) rep(b, 0, N[1]) rep(c, 0, N[2]) { if (get(a, b, c) < x) break; cnt++; if (K <= cnt) return 1; } return 0; } //--------------------------------------------------------------------------------------------------- void _main() { rep(i, 0, 3) cin >> N[i]; cin >> K; rep(i, 0, 3) rep(j, 0, N[i]) cin >> A[i][j]; rep(i, 0, 3) sort(A[i], A[i] + N[i], greater<ll>()); ll ok = 0, ng = infl; while (ok + 1 != ng) { ll md = (ng + ok) / 2; if (check(md)) ok = md; else ng = md; } vector<ll> ans; rep(a, 0, N[0]) rep(b, 0, N[1]) rep(c, 0, N[2]) { if (get(a, b, c) < ok) break; ans.push_back(get(a, b, c)); } while (ans.size() < K) ans.push_back(ok); sort(all(ans), greater<ll>()); rep(i, 0, K) printf("%lld\n", ans[i]); }