Loj 10172(三进制状压DP)

Loj 10172
题意:给出一个矩阵,有一行涂上了已经颜色,有1 2 3三种颜色,求涂完能有多少种方案使得没有两个相邻的格子颜色是相同的。
对于$k$行填色的限制,我们不妨把$k$行上下分开做,最后方案数乘一下即可。
这个题目是很经典的网格图状压DP,不多说了。关键说说多进制状压的做法。
这里先考虑三进制状压:
我们相当于把一个整数当做三进制数来看,也就是说$(25)_{10}={221}_3$,我们就维护这个$221$。
我们开一个数组sjc[状态][位]=位上的值 (0, 1, 2),用来找一个状态每一个位上是什么
这个数组可以用十进制转三进制的方法求得,具体看代码。
然后注意本题要特判$k=1,k=n$,并且只有一行的情况。

知识点:
1、三进制状态压缩 DP
2、数据检查要考虑只有一行一列的情况

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<map>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MO = 1000000;

    int n, m, k, klne[10], smax, sjz[300][10], st[300], cnt, dp[10000][300];
    //sjc[状态][位]=位上的值 (0, 1, 2) 
    //st[index]=可行状态, cnt为大小 
    //dp[i][j] 为前 i 行,i 行用 st[j]状态的方案数。 

    bool check(int i, int j) {// st[i] and st[j]
        for (int o = 1; o <= m; ++o) if (sjz[st[i]][o] == sjz[st[j]][o]) return false;
        return true;
    }

    void clean() {
        ms(sjz, 0);
    }
    int solve() {
        scanf("%d%d%d", &n, &m, &k);
        for (int i = 1; i <= m; ++i) scanf("%d", &klne[i]), --klne[i];
        for (int i = 2; i <= m; ++i) if (klne[i] == klne[i - 1]) return printf("0\n"), 0;
        if (n == 1) return printf("1\n"), 0;
        clean();
        smax = 1;
        for (int i = 1; i <= m; ++i) smax *= 3;
        for (int S = 0; S < smax; ++S) {
            int tmp = S, ws = 1;
            while (1) {
                sjz[S][ws] = tmp % 3;
                tmp /= 3, ++ws;
                if (tmp == 0) break ;
            }
        }//求 sij 数组

        //                                            printf("smax=%d\n", smax);
        //                                            for (int S = 0; S < smax; ++S, puts("")) 
        //                                            for (int i = 1; i <= 5; ++i) printf("%d", sjz[S][i]);

        for (int S = 0; S < smax; ++S) {
            int fl = 0;
            for (int i = 2; i <= m; ++i) {
                if (sjz[S][i] == sjz[S][i - 1]) {
                    fl = 1; break;
                }
            }
            if (!fl) st[++cnt] = S;
        }

        //                                            for (int j = 1; j <= cnt; ++j, puts("")) 
        //                                            for (int i = 1; i <= m; ++i) printf("%d", sjz[st[j]][i]);

        LL ans = 0;

        if (k != 1) {
            ms(dp, 0);
            for (int i = 1; i <= cnt; ++i) {
                if (k == 2) {
                    int fl = 1;
                    for (int o = 1; o <= m; ++o) if (sjz[st[i]][o] == klne[o]) {fl = 0; break;}
                    if (!fl) continue ;
                }
                dp[1][i] = 1;
            }

            for (int hi = 1; hi < k; ++hi) {
                for (int i = 1; i <= cnt; ++i) { // i 行状态 
                    if (hi == k - 1) {
                        int fl = 1;
                        for (int o = 1; o <= m; ++o) if (sjz[st[i]][o] == klne[o]) {fl = 0; break;}
                        if (!fl) continue ;
                    }
                    for (int j = 1; j <= cnt; ++j) { // j 行状态 
                        if (!check(i, j)) continue ;
                        if (dp[hi - 1][j] == 0) continue ;
                        dp[hi][i] = (dp[hi][i] + dp[hi - 1][j]) % MO;
                    }
                }
            }
            for (int i = 1; i <= cnt; ++i) ans = (ans + dp[k - 1][i]) % MO;
            //                                        for (int i = 1; i <= cnt; ++i) printf("%d\n", dp[k - 1][i]);
            if (k == n) return printf("%lld\n", ans), 0;
        } else ans = 1ll;

        ms(dp, 0);

        for (int i = 1; i <= cnt; ++i) {
            int fl = 1;
            for (int o = 1; o <= m; ++o) if (sjz[st[i]][o] == klne[o]) {fl = 0; break;}
            if (!fl) continue ;
            dp[k + 1][i] = 1;
        }
        for (int hi = k + 2; hi <= n; ++hi) {
            for (int i = 1; i <= cnt; ++i) { // i 行状态 
                for (int j = 1; j <= cnt; ++j) { // j 行状态 
                    if (!check(i, j)) continue ;
                    if (dp[hi - 1][j] == 0) continue ;
                    dp[hi][i] = (dp[hi][i] + dp[hi - 1][j]) % MO;
                }
            }
        }

        LL tmp = 0;
        for (int i = 1; i <= cnt; ++i) tmp = (tmp + dp[n][i]) % MO;

        printf("%lld\n", (ans * tmp) % MO);

        return 0; 
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
/*
3 3
3
1 2 1

2 2 
2
2 3
*/
------ 本文结束 ------