はまやんはまやんはまやん

hamayanhamayan's blog

ファイティング・タカハシ [COLOCON -Colopl programming contest 2018- Final A]

https://beta.atcoder.jp/contests/colopl2018-final-open/tasks/colopl2018_final_a

解法

https://beta.atcoder.jp/contests/colopl2018-final-open/submissions/1996716

考え方は簡単だが、実装が難しい。
まずf(x)を定義しておこう。
f(x) := x回連続で攻撃した場合のダメージ
これは「f(x) = 1 + 2 + 3 + ... + x = x * (x + 1) / 2」である。
M=S.length()と定義しておく。
 
N回Sを結合して攻撃する場合に問題となるのが、Sの最初と最後に攻撃がある場合である。
最初と最後に攻撃があれば、連続の攻撃回数が増えるのでダメージ量が増える。
実際の計算に入る前に攻撃の連続回数を取り出しておこう。
これを配列vとする。
 
まずは各種コーナーケースを処理しておこう。
全て攻撃の場合は攻撃がMN回連続するのでf(MN)が答え。
N=1の場合はそのまま攻撃をシミュレートすればいい。つまりf(v[0])+f(v[1])+...が答え
Sの最初と最後のどちらかが攻撃ではない場合は、合わせてもダメージ量への変化が無いのでN*(f(v[0])+f(v[1])+...)が答え

最後に結合の場合を考えよう。
配列vのサイズをnとする。
int n = v.size();
最初の1つは最後の攻撃以外をまず答えに足す。
rep(i, 0, n - 1) ans += f(v[i]);
それ以降の最初の攻撃は前のSの最後の攻撃の一部と考えるために攻撃回数を増やす
v[0] += v[n - 1];
S一回分を計算して、これがN-2回繰り返す(最初と最後以外の回数)
ll sm = 0;
rep(i, 0, n - 1) sm += f(v[i]);
ans += sm * (N - 2);
最後の1回を計算して足し合わせて答え
rep(i, 0, n) ans += f(v[i]);

ll N; string S; int M;
//---------------------------------------------------------------------------------------------------
ll f(ll x) { return x * (x + 1) / 2; }
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N >> S;
    M = S.length();
 
    if (S.find('B') == string::npos) {
        // 全てA
        ll ans = f((ll)M * N);
        cout << ans << endl;
        return;
    }
 
    vector<ll> v;
    S += "B";
    int pre = -1;
    rep(i, 0, M + 1) if (S[i] == 'B') {
        if (0 < i - pre - 1) v.push_back(i - pre - 1);
        pre = i;
    }
 
    //fore(i, v) printf("%d\n", i);
 
    ll ans = 0;
    if (S[0] == 'A' and S[M - 1] == 'A') {
        if (N == 1) {
            fore(i, v) ans += f(i);
        } else {
            int n = v.size();
            rep(i, 0, n - 1) ans += f(v[i]);
            v[0] += v[n - 1];
            ll sm = 0;
            rep(i, 0, n - 1) sm += f(v[i]);
            ans += sm * (N - 2);
            rep(i, 0, n) ans += f(v[i]);
        }
    }
    else {
        ll sm = 0;
        fore(i, v) sm += f(i);
        ans = sm * N;
    }
 
    cout << ans << endl;
}