ABC379 G - Count Grid 3-coloring (2nd)

(G题之间亦有差距)

相关算法 : dp,枚举,复杂度估计

Part 1.题目概述

一个 \(n \times m\) 的矩阵,里面只有 \(1,2,3,?\),其中 \(?\) 为不确定的数。问有多少种可能使矩阵相邻元素不相同。其中 \(n\cdot m <= 200\)

Part 2.分析

我们考虑从上到下从左到右逐格dp来确定每格的不同状态可能的情况。

dp设计

\(dp[i][val]\) 为填到第 \(i\) 个格子值为 \(val\) 时的状态数。 首先我们思考:如果按照上面所述的顺序填入数字,那么每格的状态和哪些格子直接相关呢?(这关乎我们状态的设计和转移的方式,dp无后效性的重要性质也在这里体现)
显然,答案是左邻格和上邻格(如果有的话)。那么我们得知这两个一起的状态时,就可以转移到目前格的状态。不过重要的是"一起的状态",也就是一个键值。这里是不能分开的(因为我们dp的状态含义本身就是填到第多少格为止的状态,分开考虑显然会将前面的部分重叠)。
但是,如何得到每个格子对应的键值来连续地去转移来dp呢?我们需要在dp时去记录一些有必要的信息。 我们是连续填入的,所以每个格子需要的dp信息要在上一个里面,而每个格子需要的最早的信息是上邻格的状态(与左邻格状态的组合)。再稍加归纳,我们需要也只需要记录一整行的状态即可。转移起来则是舍去上邻格状态,加入本格状态,转移到下一格。

dp优化

我们再考虑dp的优化。
首先,第一维是可以滚动掉的,所以我们只需要一个数组记录每行的合法状态,最多 \(9\cdot 2^m\) 个。每次转移起来则需要\(O(m)\)
1. 对于状态数,\(2^m\)可能达到一个很大的值,但是\(\min(n,m)\)是小于\(\sqrt {200}\)的。因为我们是以一行的状态进行转移,考虑到对称性,我们可以选择数值较小的一维作为行(即 \(n < m\)时进行一个转置操作),把我们每行的状态数维持在一个可控的状态。
2. 对于每次转移的开销,\(m\)虽然不大,但是\(O(m)\)的转移仍有tle的风险。虽然我们每次修改头尾,可以使用队列一类的数据结构,但是实际上我们有更好的选择——状态压缩。
由于每个格点只有\(0\)(我们保持dp状态都为一整行的状态,所以最开始全部设置为0)\(,1,2,3\)四种状态,用两位进制数即可表示,所以我们可以把一行最多14个数的状态压缩成一个unsigned int。转移和check的时候使用位运算即可很方便地实现\(O(1)\)的转移。
3. 最后,我们还需要来记录所有的状态,实现滚动。map<uint,Z>是一个很好的选择,毕竟键值是离散的。但是实际上还可以优化。
考虑我们在转移时枚举的状态。不同前状态转移后的状态相等,当且仅当这两个前状态仅有上邻格状态是不同的。如果我们外层枚举前状态,内层枚举该格的值,那么两个相同的后状态是可能由两个相同的前状态得来的。但是如果外层枚举该格的值,内层枚举前状态,那么两个相同的后状态,一定由两个相邻的前状态枚举出来! 这里比较抽象,可以类比想象对字符串做类似的操作。对一个有序的字符串集合\(S\),初始只有一个有长度的零串。如果我们有序遍历\(S\),删除前缀并按字典序添加后缀,如此循环往复几轮\(S\)也是混乱无序的。因为有序添加的后缀是在连续变动的,是内层的序,后续状态也会是不连续的。但是如果我们有序选择添加的后缀,遍历集合进行操作,那么连续操作后,\(S\)仍然是按照一个倒的字典序排列的,因为我们所期望的序在外层循环,是不连续变动的,而连续变动的是我们即将删掉的部分,不会影响。
总的来说,就是将即将删除的位置作为内层循环,离删除最远的,新的位置作为外层循环,以实现枚举状态的连续性。
因此,我们可以使用后一种枚举方式,而将map改成vector,对于得到的合法状态,我们只需判断是否等于最后一个元素,进行相应的操作即可。这样我们就把\(O(\Omega \log{\Omega})\)的枚举优化成了\(O(\Omega)\).

Part 3 实现

实际上不需要第三条优化,即可通过本题。但是第三条优化将1400+ms优化为42ms,或许值得仔细思考这种遍历顺序选取的技巧。 code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
void solve()
{
int n, m;
cin >> n >> m;
vector<string>a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}

if (n < m) {
vector<string>na(m, string(n, 0));
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
na[i][j] = a[j][i];
}
}
a = move(na);
swap(n, m);
}

vector<pair<uint, Z>>dp;
//map<uint, Z>dp;
int r = (m - 1) * 2;
uint inf = (uint(1) << (m * 2)) - 1;
//dp[uint(0)] = 1;
dp.emplace_back(0, 1);
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
vector<pair<uint, Z>>ndp;
//map<uint, Z>ndp;
if (a[i][j] == '?') {
for (int val = 1; val <= 3; val++) {
for (auto [v, c] : dp) {
if ((i == 0 || ((v >> r) & 3) != val) && (j == 0 || (v & 3) != val)) {
uint nv = ((v << 2) & inf) | val;
if (!ndp.empty() && ndp.back().first == nv)
ndp.back().second += c;
else
ndp.emplace_back(nv, c);
//ndp[nv] += c;
}
}
}
}
else {
int val = a[i][j] - '0';
for (auto [v, c] : dp) {
if ((i == 0 || ((v >> r) & 3) != val) && (j == 0 || (v & 3) != val)) {
uint nv = ((v << 2) & inf) | val;
if (!ndp.empty() && ndp.back().first == nv)
ndp.back().second += c;
else
ndp.emplace_back(nv, c);
//ndp[nv] += c;
}
}
}
dp = move(ndp);
}
}

Z ans = 0;
for (auto [v, c] : dp)
ans += c;
cout << ans << endl;
return;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#define ll long long
#define ull unsigned long long
#define mod 998244353
using i64 = long long;
using uint = unsigned int;

using namespace std;

template<class T>
constexpr T power(T a, i64 b) {
T res{ 1 };
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}

constexpr i64 mul(i64 a, i64 b, i64 p) {
i64 res = a * b - i64(1.L * a * b / p) * p;
res %= p;
if (res < 0) {
res += p;
}
return res;
}

template<i64 P>
struct MInt {
i64 x;
constexpr MInt() : x{ 0 } {}
constexpr MInt(i64 x) : x{ norm(x % getMod()) } {}

static i64 Mod;
constexpr static i64 getMod() {
if (P > 0) {
return P;
}
else {
return Mod;
}
}
constexpr static void setMod(i64 Mod_) {
Mod = Mod_;
}
constexpr i64 norm(i64 x) const {
if (x < 0) {
x += getMod();
}
if (x >= getMod()) {
x -= getMod();
}
return x;
}
constexpr i64 val() const {
return x;
}
constexpr MInt operator-() const {
MInt res;
res.x = norm(getMod() - x);
return res;
}
constexpr MInt inv() const {
return power(*this, getMod() - 2);
}
constexpr MInt& operator*=(MInt rhs)& {
if (getMod() < (1ULL << 31)) {
x = x * rhs.x % int(getMod());
}
else {
x = mul(x, rhs.x, getMod());
}
return *this;
}
constexpr MInt& operator+=(MInt rhs)& {
x = norm(x + rhs.x);
return *this;
}
constexpr MInt& operator-=(MInt rhs)& {
x = norm(x - rhs.x);
return *this;
}
constexpr MInt& operator/=(MInt rhs)& {
return *this *= rhs.inv();
}
friend constexpr MInt operator*(MInt lhs, MInt rhs) {
MInt res = lhs;
res *= rhs;
return res;
}
friend constexpr MInt operator+(MInt lhs, MInt rhs) {
MInt res = lhs;
res += rhs;
return res;
}
friend constexpr MInt operator-(MInt lhs, MInt rhs) {
MInt res = lhs;
res -= rhs;
return res;
}
friend constexpr MInt operator/(MInt lhs, MInt rhs) {
MInt res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream& operator>>(std::istream& is, MInt& a) {
i64 v;
is >> v;
a = MInt(v);
return is;
}
friend constexpr std::ostream& operator<<(std::ostream& os, const MInt& a) {
return os << a.val();
}
friend constexpr bool operator==(MInt lhs, MInt rhs) {
return lhs.val() == rhs.val();
}
friend constexpr bool operator!=(MInt lhs, MInt rhs) {
return lhs.val() != rhs.val();
}
friend constexpr bool operator<(MInt lhs, MInt rhs) {
return lhs.val() < rhs.val();
}
};

template<>
i64 MInt<0>::Mod = 998244353;

using Z = MInt<0>;