メモ化DPするときは必ず値を更新しようというお話

考えてみれば当然だけど、何回かやらかしているので自戒を込めて 本記事には以下の問題のネタバレが含まれます。

競プロ典型90問 45日目 Simple Grouping(★6)

TL; DR

  • メモ化再帰で初期値から更新されない値がある場合、メモ化配列を探索済みとして更新する処理を入れる。

例題

競プロ典型90問 45日目 Simple Grouping(★6)

ユークリッド平面上に、N個の点 (X1, Y1), (X2, Y2), ⋯, (XN ,YN)があります。 これらを、次の条件を満たすようにK個のグループに分けることを考えます。
・複数のグループに入る点があってはならない。
・どのグループにも属さない点があってはならない。
・ひとつも点が属さないグループがあってはならない。
このとき、同一グループ内での2点間距離の最大値を最小化してください。 ただし、ジャッジにはこのときの同一グループ内での2点間距離の最大値の2乗を出力してください。
なお、このときの最大値を2乗した値は必ず整数になることが証明できます。

解法

方針

点の集合としてN個の点のある部分集合をとり、グループ数Kを1減らした小問題を考えます。 このような全ての小問題について、その解と「選ばなかった点」を1グループとしたときのグループ内の2点間距離の最大値を考えると、元の問題を解くことができます。 また、K=1については、全ての点同士の距離を求めることで解くことができるので、上の操作を再帰的に繰り返すことで元の問題を解くことができます。

計算量

bit DPをメモ化再帰で実装すると、状態の数が2^ N K 、状態間の遷移は各部分集合の部分集合の数で O(2^ N) なので、ぱっと見 O(4^ N K) \simeq 10^ {10} くらいで間に合わなそうです。 ただ、実際には部分集合のサイズが小さくなればそれに応じてその部分集合の数も少なくなるので、さぼって 2^ N \times 2^ Nとしている部分を真面目に計算すると、


\sum_{k=0}^ N {}_N C_k 2^ k = 3^N

となり、 O(3^ NK) \simeq 2 * 10^ 8でなんとか間に合います。 上記の導出はO(3N)で部分集合の列挙をする実装を参照しました。

実装

というわけで、とりあえず実装します。

実装例 (TLE)

#include <bits/stdc++.h>
using namespace std;

// テンプレート、マクロ
struct fast_io { fast_io() { cin.tie(nullptr); ios::sync_with_stdio(false); cout << fixed << setprecision(18); } } fastIOtmp;
using ll = long long; ll INFL = 3300300300300300491LL;
#define rep(i, n) for(int i = 0, i##_len = int(n); i < i##_len; ++i) // 0 から n-1 まで昇順
template <class T> inline bool chmax(T& M, const T& x) { if (M < x) { M = x; return true; } return false; } // 最大値を更新(更新されたら true を返す)
template <class T> inline bool chmin(T& m, const T& x) { if (m > x) { m = x; return true; } return false; } // 最小値を更新(更新されたら true を返す)

template<class T> T euclid_square(pair<T, T> s, pair<T, T> t) {
    auto square = [](T a) {return a*a;};
    return square(s.first - t.first) + square(s.second - t.second);
};

int main() {
    // 入力
    int n,k;
    cin >> n >> k;
    vector<pair<ll, ll>> p(n);
    rep(i, n) cin >> p[i].first >> p[i].second;

    // 各部分集合について、含まれる2点間の距離の最大値を前計算
    vector<ll> dist(1 << n, 0);
    rep(b, 1 << n) {
        rep(i, n) rep(j, n) {
            if (((b >> i) & 1) && ((b >> j) & 1)) chmax(dist[b], euclid_square(p[i], p[j]));
        }
    }

    // bit DP (メモ化再帰)
    vector dp(1 << n, vector<ll>(k + 1, INFL));  // メモ化配列
    auto dfs = [&] (auto self, int b, int k_) {
        if (k_ == 1) dp[b][k_] = dist[b];  // k = 1のときは計算して返す
        if (dp[b][k_] != INFL) return dp[b][k_];  // 値が更新されていれば返す
        for (ll b_from = b; b_from != 0; b_from = (b_from - 1) & b) {  // 部分集合全体に渡って繰り返し
            ll cur = max(dist[b_from], self(self, b - b_from, k_-1));
            chmin(dp[b][k_], cur);
        }
        return dp[b][k_];
    };
    cout << dfs(dfs, (1 << n) - 1, k) << endl;
}

これで、メモ化しているので、計算量は 3^ N K...と思って提出するとTLEします。

何が問題なのか

このDPでは、部分集合を管理するbit  bに含まれる1の個数がkより小さいときには、dp[b][k]の値は最大値(INFL)から更新されません。 これは、元の問題に置き換えると「ひとつも点が属さないグループがあってはならない」という条件を満たすようなグループ分けが存在しないためです。

メモ化でないDPをしているときには、特にこれは気にする必要はないのですが、メモ化DPをしている場合、このような点をきちんと探索済みとしてマークしておかなければ複数回同じ点を探索されてしまいます。

AC解答

同じ点を複数回探索することを避けるために、元のコードで探索した結果値が更新されなかった場合にも、値を更新するようにしておきます。 これは、十分に大きな値であって、INFLとは異なる値にすればよいだけなので、例えば(INFL - 1)などとしておけば良いです。

ということで、再帰に以下の1行を追加します。

実装例 (AC、再帰関数部分のみ)

    // bit DP (メモ化再帰)
    vector dp(1 << n, vector<ll>(k + 1, INFL));  // メモ化配列
    auto dfs = [&] (auto self, int b, int k_) {
        if (k_ == 1) dp[b][k_] = dist[b];  // k = 1のときは計算して返す
        if (dp[b][k_] != INFL) return dp[b][k_];  // 値が更新されていれば返す
        for (ll b_from = b; b_from != 0; b_from = (b_from - 1) & b) {  // 部分集合全体に渡って繰り返し
            ll cur = max(dist[b_from], self(self, b - b_from, k_-1));
            chmin(dp[b][k_], cur);
        }
        /* 以下の1行を追加 */
        if (dp[b][k_] == INFL) dp[b][k_]--;  // 値が更新されていなければ、-1して探索済みをマークしておく
        return dp[b][k_];
    };

上記でACできます。

まとめ

  • メモ化再帰で初期値から更新されない値がある場合、メモ化配列を探索済みとして更新する処理を入れる。
  • これは、0で初期化するような場合も同じ(初期化を-1として、探索済みを0とするなど)

遷移考えるのが面倒なので、と甘えて再帰で書かずに、ちゃんとループで書こうね、という話もあったりなかったり