はまやんはまやんはまやん

hamayanhamayan's blog

K-th element [第八回 アルゴリズム実技検定 G]

https://atcoder.jp/contests/past202109-open/tasks/past202109_g

前提知識

解説

https://atcoder.jp/contests/past202109-open/submissions/26603092

この問題は正直典型問題として知っていないと解くのは難しいように思う。
二分探索を利用することで解くことができる。
全てそうかは分からないが、以下から類題を探せるかもしれない。
https://drken1215.hatenablog.com/archive/category/k%E7%95%AA%E7%9B%AE%E3%82%92%E6%B1%82%E3%82%81%E3%82%8B

「K番目の要素」

このK番目の要素という題名がヒントになっていて、やや典型テクとしてK番目の要素を求めるなら二分探索みたいな
テクというか、よく出る手法というかがある。
あまりピンとこないかもしれない。
二分探索の比較関数を定義しよう。

比較関数

check(lim) := 数列の要素の中でlim以下のものの個数を集計したときに、K以上個あればtrue、さもなければfalse

ちょっとわかりにくいかもしれない。
少し例を使って説明してみよう。

check(0)は、数列の要素の中で0以下のものを集計していて、そんな要素はないので、かならずfalseになる。
check(∞)は、数列の要素の中で∞以下のものを集計していて、すべての要素が含まれるため、かならずtrueになる。
これはよくって…

例えば答えがansであった場合を考えてみよう。
check(ans)は、数列の要素の中でans以下のものを集計している。
K番目がansという定義で進めているので、数列の要素の中でans以下のものはK個はあるはずである。
よって、K以上個あるのでtrueになる。

check(ans - 1)を考えてみると、K番目がansであるはずなので、数列の要素の中でans-1以下の個数は
K個未満のはずである。
そのため、check(ans - 1)はfalseとなる。

二分探索的にはちょうど境界線上に答えが浮かび上がってくることになる!
なので、比較関数を高速にさばくことができれば、二分探索で答えを導くことができる。

lim以下の要素の個数

全ての数列それぞれに対して、lim以下の要素の個数を求めてその総和を取ればいい。
なので、とある数列に対してlim以下の要素が何個あるかについて高速に計算したい。
これはmin(1LL * A[i], 1LL + (lim - B[i]) / C[i])で求めることができる。

minを使っているのは個数の最大値はA[i]個なので、その上限を決めるために使っている。
なお、1LLを使っているのはlong longに型変換させるために使っている。明示キャストでも問題ない。

実際の個数計算は1LL + (lim - B[i]) / C[i]である。
例えば3 5 7 9 11という数列があって、10以下の個数を求める場合を考えてみる。
最初の3は含まれるとカウントして、全部の数から初項を引いてみよう。
0 2 4 6 8という数列で7以下の個数を求める問題になる。
これは丁度倍数の個数を求めているような感じになるので、7から公差である2を割ったときの商を使えばいい。
7÷2の商は3なので、初項の1つと合わせて1+3の4個が答え。
これを数式に落とし込むと、1LL + (lim - B[i]) / C[i]になる。

int N; ll K;
int A[101010], B[101010], C[101010];
//---------------------------------------------------------------------------------------------------
bool check(ll lim) {
    ll cnt = 0;
    rep(i, 0, N) if (B[i] <= lim) cnt += min(1LL * A[i], 1LL + (lim - B[i]) / C[i]);
    return K <= cnt;
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N >> K;
    rep(i, 0, N) cin >> A[i] >> B[i] >> C[i];

    ll ng = 0, ok = infl;
    while (ng + 1 != ok) {
        ll md = (ng + ok) / 2;
        if (check(md)) ok = md;
        else ng = md;
    }

    cout << ok << endl;
}