はまやんはまやんはまやん

hamayanhamayan's blog

Equal Queries [第六回 アルゴリズム実技検定 M]

https://atcoder.jp/contests/past202104-open/tasks/past202104_m

解説

https://atcoder.jp/contests/past202104-open/submissions/22660709

最初に書いておくと実装がとてもしんどいのでデバッグ用出力を駆使しながら実装していく必要がある。
さて、頑張っていこう。

どこから考え始めるか

似たような問題を解いたことがあるので、なんとか考え始めることができたが、そうでないと難しいだろう。
この問題のポイントは区間を1つの値に置き換えるという部分であり、1つの値にそれぞれ置き換えるのではなく、
「その値の区間」として保持しておくことで情報を圧縮して処理を進めていく。
これで処理済みの部分は区間として情報が圧縮されることで数が減っていくし、かつ、追加される区間の数はクエリの個数が上限なので、
区間を処理していく個数は慣らしでN+Q回くらいで済むことになる。
この辺の考察をひねりだすのが難しい。

区間を管理する

この区間を管理するというのは競技プログラミングでは良く出てくるテーマである。
良く出てくるし、毎回、実装で大爆発する。
自分はsetで実装することが多いし、他の人のコードもsetを使っているやつを見たことがある。
自分の今回の実装ではsetに{ {区間の左端, 区間の右端}, 区間の値 }を入れている。
これを入れておけばset内部では自動でソートされて、lower_boundで目的の区間を高速に取り出すことができる。
lower_boundをうまく使い、クエリの[L,R]を含む区間を抽出してきて、処理を行い、新たな区間を入れなおす。

クエリが無い場合

さて、発想のベースが整った所で、本題に入っていこう。
クエリが無い場合の整数の組を計算していこう。
cnt[x] := A[i]=xであるiの個数
これを更新していきながら、整数の組を計算していく。
このアルゴリズムが構築できないと、解ききるのは難しいかもしれない。

クエリ計算

差分計算していく。
作業としては、以下の通り。
1. クエリの区間[L,R]に含まれる区間を全部抽出してくる(この時、両端の半端な区間は分割して、setに戻しておく)
2. 抽出してきた区間をsetから消して、「その区間が関連する整数の組を引く」
3. クエリで追加したい区間をsetに追加して、「その区間が関連する整数の組を足す」

その区間が関連する整数の組を計算する方法について考える。
その区間の数がxで、区間の長さをlenとする。
「その区間以外に数xがある個数と区間の長さ(=区間にある数xの個数)の積」と「区間内での整数の組」がそれになる。
「その区間以外に数xがある個数と区間の長さ(=区間にある数xの個数)の積」は(cnt[x]-len)lenであり、
区間内での整数の組」はlen
(len-1)/2である。
これを足したり、引いたりして差分計算する。

デバッグが大変なので、デバッグ用の出力関数を用意しておくといい。

最後に座標圧縮

数の個数を数えるにあたって入力の最大が109なのは具合が悪い。
数の大小は特に問題を解くうえでは必要ないので、クエリを先読みして、Aとxをすべて座標圧縮しておこう。

あとは頑張って…

int N, A[201010], Q, l[201010], r[201010], x[201010];
int cnt[401010];
//---------------------------------------------------------------------------------------------------
void debug(set<pair<pair<int, int>, int>>& se) {
    printf("======================\n");
    fore(p, se) {
        printf("[%d %d) -> %d\n", p.first.first, p.first.second, p.second);
    }
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N;
    rep(i, 0, N) cin >> A[i];
    cin >> Q;
    rep(i, 0, Q) cin >> l[i] >> r[i] >> x[i], l[i]--;

    vector<int> dic;
    rep(i, 0, N) dic.push_back(A[i]);
    rep(i, 0, Q) dic.push_back(x[i]);
    sort(all(dic));
    dic.erase(unique(all(dic)), dic.end());
    rep(i, 0, N) A[i] = lower_bound(all(dic), A[i]) - dic.begin();
    rep(i, 0, Q) x[i] = lower_bound(all(dic), x[i]) - dic.begin();
    
    set<pair<pair<int,int>, int>> ranges;
    rep(i, 0, N) ranges.insert({ {i, i + 1}, A[i] });
    ranges.insert({ {N, inf}, inf });

    ll ans = 0;
    rep(i, 0, N) {
        ans += cnt[A[i]];
        cnt[A[i]]++;
    }

    //debug(ranges);

    rep(i, 0, Q) {
        vector< pair<pair<int, int>, int> > nxtInsert;

        while (1) {
            auto ite = ranges.lower_bound({ {l[i], inf}, inf });
            if (r[i] <= ite->first.first) break;

            int a = ite->first.first;
            int b = ite->first.second;

            if (r[i] <= b) {
                nxtInsert.push_back({ {r[i], b}, ite->second });
                b = r[i];
            }

            int cn = b - a;
            cnt[ite->second] -= cn;
            ans -= 1LL * cnt[ite->second] * cn;
            ans -= 1LL * cn * (cn - 1) / 2;

            ranges.erase(ite);
        }

        {
            auto ite = ranges.lower_bound({ {l[i], inf}, inf });
            ite--;

            int a = ite->first.first;
            int b = ite->first.second;

            if (a < l[i]) {
                nxtInsert.push_back({ {a, l[i]}, ite->second });
                a = l[i];
            }

            if (r[i] <= b) {
                nxtInsert.push_back({ {r[i], b}, ite->second });
                b = r[i];
            }

            int cn = b - a;
            cnt[ite->second] -= cn;
            ans -= 1LL * cnt[ite->second] * cn;
            ans -= 1LL * cn * (cn - 1) / 2;

            ranges.erase(ite);
        }

        fore(p, nxtInsert) if (p.first.first < p.first.second) ranges.insert(p);
        ranges.insert({ {l[i], r[i]}, x[i] });

        
        ans += 1LL * cnt[x[i]] * (r[i] - l[i]);
        ans += 1LL * (r[i] - l[i]) * (r[i] - l[i] - 1) / 2;
        cnt[x[i]] += r[i] - l[i];

        printf("%lld\n", ans);
        //debug(ranges);
    }
}