はまやんはまやんはまやん

hamayanhamayan's blog

門松宝くじ [yukicoder 335]

問題

http://yukicoder.me/problems/no/335

宝くじがM枚あり、以下のように当選金額を決める。
何枚目の宝くじの期待値が最も大きいか答える(複数ある場合は最小の番号)

長さNの1~Nが1つずつある数列Eがある。この時、

  • 数列中から2つの数が指定される
  • 3つの数が門松列となるように、もう1つの数を好きに選んでも良い
  • この時の当選金額は3つの数の最大値になる

門松列とは、真ん中の数が最も大きいか、最も小さいかである数列のこと

3 <= N <= 800
2 <= M <= 3

考察

1. 愚直解を考えてみる。全通りチェックして、当選金額とその場合の数を数える
2. ある宝くじについて、指定される全ての場合の数は N^2 = 640000 通りある
3. 全ての場合の数を検証しても間に合う解法がありそう
4. 数が2つ指定されたときに、自分が選ぶ数は N-2 通りあるので、愚直解で作ると O(N^3)

これはだめだけど、ここから計算量を落とすように頑張る

5. 数が2つ指定されれば、もう1つの選択は貪欲に取れる
6. 貪欲に取るには6通りの場合分けをする必要がある

[A](小)[B](大)[C] -> (大)を門松の最大にする -> [C]内での最小が(大)より小さい
                -> (小)を門松の最小にする -> [A]内での最大が(小)より大きい
[A](大)[B](小)[C] -> (大)を門松の最大にする -> [A]内での最小が(大)より小さい
                -> (小)を門松の最小にする -> [C]内での最大が(小)より大きい
[A](??)[B](??)[C] -> [B]を門松の最大にする -> [B]内での最大が(??)より大さい
                  -> [B]を門松の最小にする -> [B]内での最小が(??)より小さい

7. 区画の最大最小はセグメントツリーでやりましょう

実装

http://yukicoder.me/submissions/102717

typedef long long ll;
template<class V, int NV> class SegTreeMax {
public:
    static V const def = -(1LL << 60);
    V comp(V l, V r) { return max(l, r); };

    vector<V> val;
    SegTreeMax() { val = vector<V>(NV * 2, def); }

    V getval(int l, int r) { //[l,r]
        l += NV; r += NV + 1;
        V ret = def;
        while (l < r) {
            if (l & 1) ret = comp(ret, val[l++]);
            if (r & 1) ret = comp(ret, val[--r]);
            l /= 2; r /= 2;
        }
        return ret;
    }
    void update(int i, V v) {
        i += NV;
        val[i] = v;
        while (i>1) i >>= 1, val[i] = comp(val[i * 2], val[i * 2 + 1]);
    }
};
template<class V, int NV> class SegTreeMin {
public:
    static V const def = (1LL << 31) - 1;
    V comp(V l, V r) { return min(l, r); };

    vector<V> val;
    SegTreeMin() { val = vector<V>(NV * 2, def); }

    V getval(int l, int r) { //[l,r]
        l += NV; r += NV + 1;
        V ret = def;
        while (l < r) {
            if (l & 1) ret = comp(ret, val[l++]);
            if (r & 1) ret = comp(ret, val[--r]);
            l /= 2; r /= 2;
        }
        return ret;
    }
    void update(int i, V v) {
        i += NV;
        val[i] = v;
        while (i>1) i >>= 1, val[i] = comp(val[i * 2], val[i * 2 + 1]);
    }
};
//-----------------------------------------------------------------
int N, M;
int E[3][800];
//-----------------------------------------------------------------
double calc(int m) {
    SegTreeMax<ll, 1 << 10> st_max;
    SegTreeMin<ll, 1 << 10> st_min;

    rep(i, 0, N) st_max.update(i, E[m][i]);
    rep(i, 0, N) st_min.update(i, E[m][i]);

    int cnt[805];
    rep(i, 0, N + 1) cnt[i] = 0;
    rep(i, 0, N) rep(j, i + 1, N) {
        int _max = max(E[m][i], E[m][j]);
        int _min = min(E[m][i], E[m][j]);

        int idx = 0;

        if (j != N - 1) {
            if (E[m][i] < E[m][j]) {
                int _min_st = st_min.getval(j + 1, N - 1);
                if (_min_st < E[m][j]) idx = max(idx, E[m][j]);
            }
            else {
                int _max_st = st_max.getval(j + 1, N - 1);
                if (E[m][j] < _max_st) idx = max(idx, max(E[m][i], _max_st));
            }
        }

        if (i != 0) {
            if (E[m][i] < E[m][j]) {
                int _max_st = st_max.getval(0, i - 1);
                if (E[m][i] < _max_st) idx = max(idx, max(E[m][j], _max_st));
            }
            else {
                int _min_st = st_min.getval(0, i - 1);
                if (_min_st < E[m][i]) idx = max(idx, E[m][i]);
            }
        }

        if (j - i != 1) {
            int _max_st = st_max.getval(i + 1, j - 1);
            int _min_st = st_max.getval(i + 1, j - 1);

            if (_max < _max_st) idx = max(idx, _max_st);
            if (_min_st < _min) idx = max(idx, _max);
        }

        cnt[idx]++;
    }

    int sum = 0;
    rep(i, 0, N + 1) sum += cnt[i];

    double ret = 0;
    rep(i, 0, N + 1) ret += (double)i * ((double)cnt[i] / (double)sum);
    return ret;
}
//-----------------------------------------------------------------
int main()
{
    scanf("%d %d", &N, &M);
    rep(i, 0, M) rep(j, 0, N) scanf("%d", &E[i][j]);

    int ans = 0;
    double ans_p = -1;
    rep(i, 0, M) {
        float p = calc(i);
        if (ans_p < p) {
            ans = i;
            ans_p = p;
        }
    }

    printf("%d\n", ans);
}