這個題目去年就過了,用得是狀態(tài)壓縮dp,不過沒用dfs預(yù)處理,當(dāng)時做得不是很明白,還是參考網(wǎng)上的一個代碼做的。
現(xiàn)在重新做了一下這個題目,請教了icecream,學(xué)會了一個很簡練的做法,而且比較好理解還好寫。
首先還是狀態(tài)的表示,用0表示沒有放木塊,用1表示放了木塊。此外,對于一個橫放的木塊,對應(yīng)的兩位都用1表示;對于一個豎放的木塊,第一行用1表示,第二行用0表示。這只是一種設(shè)狀態(tài)的方式,當(dāng)然還有別的設(shè)法,但是這種方法在接下來就會發(fā)現(xiàn)優(yōu)點。
狀態(tài)表示完就要處理轉(zhuǎn)移了,如何判斷一個轉(zhuǎn)移是否合法比較難辦,用一個dfs卻可以簡潔的解決這個問題。
對于上一行到下一行的轉(zhuǎn)移,規(guī)定上一行一定填滿,這樣有三種方式:
dfs(col + 1, (s1 << 1) | 1, s2 << 1, n);
dfs(col + 1, s1 << 1, (s2 << 1) | 1, n);
dfs(col + 2, (s1 << 2) | 3, (s2 << 2) | 3, n);
第一種上面是1,那么下面一定是0,表示是一個豎放的木塊。
第二種上面是0,就是說這個位置一定是一個豎放木塊的下半截,那么下一行肯定是要另起一行了,放一個豎放或者橫放的木塊都必須是1。
第三種相當(dāng)于上下兩行各放一個橫木塊。
實現(xiàn)的時候我用了一個vector記錄每個狀態(tài)所有可行的轉(zhuǎn)移,這樣在dp的時候可以加快一些效率。
還有一個問題需要考慮,那就是初值和最終的結(jié)果。如果尋找合法狀態(tài),依然比較麻煩,假設(shè)共有n行,可以分別在這n行上下新加兩行。下面一行都是1,由于第n行肯定要填滿,這樣新加的全1的行就相當(dāng)于頂住了第n行使其沒有凸出(有凸出那么第n+1行有0),而通過第n行到第n+1行轉(zhuǎn)移保留了所有合法狀態(tài);同理最上面加的那行保證第一行沒有凸出。最后第n+1行對應(yīng)全1的狀態(tài)就是最終的結(jié)果了。通過新加行巧妙地解決了初值和終值。
實現(xiàn)的時候也需要注意一下,在TSP問題中,外層循環(huán)是狀態(tài),內(nèi)層是點,之所以這樣寫因為在枚舉點的時候,可能會從比當(dāng)前編號大的點轉(zhuǎn)移,但是由于無論怎樣轉(zhuǎn)移過來的狀態(tài)肯定比當(dāng)前狀態(tài)小(去掉了1),所以先從小到大枚舉狀態(tài)就保證轉(zhuǎn)移過來的狀態(tài)一定是算過的。而這個題目里面正好反過來,因為狀態(tài)可能從比當(dāng)前狀態(tài)大的狀態(tài)轉(zhuǎn)移過來,而行數(shù)肯定是從編號小的行轉(zhuǎn)移,因此先枚舉行就能保證轉(zhuǎn)移過來的狀態(tài)一定是更新過的。
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 11;
vector<int> g[1<<N];
long long dp[N+2][1<<N];
void dfs(int col, int s1, int s2, int n)
{
if (col >= n)
{
if (s1 < (1 << n) && s2 < (1 << n))
g[s2].push_back(s1);
return;
}
dfs(col + 1, (s1 << 1) | 1, s2 << 1, n);
dfs(col + 1, s1 << 1, (s2 << 1) | 1, n);
dfs(col + 2, (s1 << 2) | 3, (s2 << 2) | 3, n);
}
long long calc(int m, int n)
{
if (m < n) swap(m, n);
dfs(0, 0, 0, n);
int state = 1 << n;
dp[0][0] = 1;
for (int i = 1; i <= m + 1; i++)
for (int s = 0; s < state; s++)
for (int j = 0; j < g[s].size(); j++)
dp[i][s] += dp[i-1][g[s][j]];
return dp[m+1][state-1];
}
int main()
{
int m, n;
while (scanf("%d %d", &m, &n) == 2)
{
if (m == 0 && n == 0)
break;
for (int i = 0; i < (1 << N); i++)
g[i].clear();
memset(dp, 0, sizeof(dp));
printf("%I64d\n", calc(m, n));
}
return 0;
}