はまやんはまやんはまやん

hamayanhamayan's blog

Find Path Union [CSAcademy #56 D]

https://csacademy.com/contest/round-56/task/find-path-union/

無限に続く完全二分木がある。
N個の数がある。
この数全てと頂点1を結ぶパスを着色する。
着色される辺は何本か。
 
ただし、TLとMLが少し厳しい

前提知識

解法1 本番通したLCA木解

全ての頂点とそのLCAを集めたLCA木を作って、その木上の距離を全て集めると答え。
LCA木が作れないとまず厳しいので、必要な関数を用意する。

count(x) := 数xが書いてある頂点の深さ(根1を0とする)
sort() := 頂点をpre-order順でソートする
lca(a,b) := aとbのlcaを求める

sort関数の比較関数内でcountを呼んでしまうとO(NlogNlog64)となってしまうので、事前計算するとO(N(logN + log64))とできる。
lca関数はダブリングを使ういつものアレ。
あとは、LCA木の隣接頂点間の距離を測って総和を答える。

typedef long long ll;
int N;
//---------------------------------------------------------------------------------------------------
int count(ll x) {
    int ok = 0, ng = 64;
    while (ok + 1 != ng) {
        int y = (ok + ng) / 2;
        if ((1LL << y) <= x) ok = y;
        else ng = y;
    }
    return ok;
}
ll lca(ll a, ll b) {
    ll x = a, y = b;

    int xx = count(x);
    int yy = count(y);

    if (xx < yy) y >>= yy - xx;
    if (xx > yy) x >>= xx - yy;

    int ng = -1, ok = 64;
    while (ng + 1 != ok) {
        int z = (ng + ok) / 2;
        if ((x >> z) == (y >> z)) ok = z;
        else ng = z;
    }
    return x >> ok;
}
vector<ll> v;
void sort() {
    vector<int> vv, v2;
    rep(i, 0, v.size()) {
        vv.push_back(i);
        v2.push_back(count(v[i]));
    }

    sort(vv.begin(), vv.end(), [&](int _a, int _b) {
        ll a = v[_a], b = v[_b];

        int aa = v2[_a];
        int bb = v2[_b];

        if (aa == bb) return a < b;
        if (aa < bb) {
            ll bbb = b >> (bb - aa);
            if (a != bbb) return a < bbb;
            else return true;
        }
        if (aa > bb) {
            ll aaa = a >> (aa - bb);
            if (aaa != b) return aaa < b;
            else return false;
        }
    });

    vector<ll> v3;
    rep(i, 0, v.size()) v3.push_back(v[vv[i]]);
    swap(v, v3);
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N;

    
    int ok = 0;
    rep(i, 0, N) {
        ll x; cin >> x;
        if (x == 1) ok = 1;
        v.push_back(x);
    }
    if (!ok) {
        v.push_back(1);
        N++;
    }

    sort();
    rep(i, 0, N - 1) v.push_back(lca(v[i], v[i + 1]));
    sort();

    N = v.size();
    int ans = 0;
    rep(i, 0, N - 1) {
        ll x = v[i + 1];
        ll y = lca(v[i], v[i + 1]);
        ans += count(x) - count(y);
    }
    cout << ans << endl;

}

解法2 マージソートっぽくやる

多分以下の解法はマージソートのマージ部分を知っていないと理解できない。
先にそちらを理解しておく事をおすすめする(そんなに難しくないので)。
 
愚直にsetに突っ込んでいくと間に合わないが、2つのソート済み集合のソートとすれば線形でできるようになる。
最初に深さ毎に別々に格納しておき、深さ毎にソートする。
答えは頂点1以外の経由する頂点数と考えることができるため、各深さ毎に現れる頂点の数を足していけばいい。
 
深さが深い順に考えていく。
深さiをカウントしたら、全ての要素を切り捨てで半分にした集合preをまず作る。
これは既にソートされている配列に対して行うため、順番にやれば自然とpreもソート済み配列となる。
半分にすると、全て深さがi-1の数となるため、次の深さi-1の頂点を数える前に、集合preとv[i-1]をマージしてやる必要がある。
これは2つともソート済みであるため、マージソートのマージっぽく、尺取りっぽくやっていくと線形でマージできる。
この時についでに重複している数は省いておこう。
 
これを繰り返すことで答えが得られる。
古い数は消していかないとMLEするが、clear関数を呼ぶだけではfreeされないので、shrink_to_fit関数も使う。

typedef long long ll;
int N;
//---------------------------------------------------------------------------------------------------
int count(ll x) {
    int ok = 0, ng = 64;
    while (ok + 1 != ng) {
        int y = (ok + ng) / 2;
        if ((1LL << y) <= x) ok = y;
        else ng = y;
    }
    return ok;
}
//---------------------------------------------------------------------------------------------------
vector<ll> merge(vector<ll> &a, vector<ll> &b) {
    vector<ll> res;

    int aa = 0, bb = 0;
    int na = a.size(), nb = b.size();
    while (aa < na or bb < nb) {
        if (bb == nb) {
            if (res.empty()) res.push_back(a[aa]);
            else if (res.back() != a[aa]) res.push_back(a[aa]);
            aa++;
        } else if (aa == na) {
            if (res.empty()) res.push_back(b[bb]);
            else if (res.back() != b[bb]) res.push_back(b[bb]);
            bb++;
        } else {
            if (a[aa] < b[bb]) {
                if (res.empty()) res.push_back(a[aa]);
                else if (res.back() != a[aa]) res.push_back(a[aa]);
                aa++;
            } else {
                if (res.empty()) res.push_back(b[bb]);
                else if (res.back() != b[bb]) res.push_back(b[bb]);
                bb++;
            }
        }
    }

    return res;
}
//---------------------------------------------------------------------------------------------------
vector<ll> v[64];
void _main() {
    cin >> N;
    rep(i, 0, N) {
        ll x; cin >> x;
        v[count(x)].push_back(x);
    }
    
    rep(i, 0, 64) sort(v[i].begin(), v[i].end());

    int ans = 0;
    vector<ll> pre;
    rrep(i, 63, 1) {
        vector<ll> vv(v[i].begin(), v[i].end());
        v[i] = merge(pre, vv);
        ans += v[i].size();

        pre.clear();
        fore(x, v[i]) pre.push_back(x / 2);
        
        v[i].clear();
        v[i].shrink_to_fit();
    }
    cout << ans << endl;
}