数位dp入门详解
- 开源代码
- 2025-09-08 05:36:01

1. 介绍
数位 d p dp dp一般出现在来求一个范围 [ a , b ] [a, b] [a,b]内满足条件的数有多少。数位 d p dp dp的解决比较公式化,考虑每一位对最终答案的影响。
2. 案例Luogu P2602: 求给定范围 [ a , b ] [a,b] [a,b]各个数位 k k k出现了多少次。
考虑 n n n的 10 10 10进制表示 n = ∑ i = 0 m a i 1 0 i n=\sum_{i=0}^{m}a_i 10^{i}\\ n=i=0∑mai10i 我们令 n n n的最高位数的数字为 a m a_m am,最高幂次为 m m m m = ⌊ lg n ⌋ a m = ⌊ n / ( 1 0 m ) ⌋ m = \lfloor \lg n \rfloor \\ a_m =\lfloor n /( 10^{m})\rfloor m=⌊lgn⌋am=⌊n/(10m)⌋ 一个数 n n n中 k k k出现的次数 d i g i t C n t ( n , k ) = ∑ i = 0 m ϕ ( a i ) ϕ ( x ) = { 1 , a i = k 0 , a i ≠ k digitCnt(n, k) =\sum_{i=0}^m \phi(a_i)\\ \phi(x) = \begin{cases} 1, \quad a_i=k\\ 0, \quad a_i \ne k \end{cases} \\ digitCnt(n,k)=i=0∑mϕ(ai)ϕ(x)={1,ai=k0,ai=k 最终我们需要求 a n s ( a , b ) = ∑ x = a b d i g i t C n t ( x ) \\ ans(a,b) = \sum_{x=a}^{b} digitCnt(x) ans(a,b)=x=a∑bdigitCnt(x) 如果我们直接进行遍历,时间复杂度为 O ( n log n ) O(n \log n) O(nlogn),对于稍稍大一点的情况是处理不了的,这时候我们的数位dp出场了。 我们定义 d p [ n ] [ k ] dp[n][k] dp[n][k]为 n n n以内的 k k k位出现的次数,那么我们可以把 n n n以内的数分为两个部分 d p [ n ] = a n s ( 1 , a m 1 0 m − 1 ) + a n s ( a m 1 0 m , n ) dp[n] =ans(1,a_{m}10^{m}-1) +\\ans(a_{m}10^{m}, n) dp[n]=ans(1,am10m−1)+ans(am10m,n) 考虑最高位对最终增加的数为 M s b C n t ( n , k ) = { 1 0 m , a m > k 0 , k = 0 ∨ a m < k n − a m 1 0 m + 1 , a m = k MsbCnt(n,k) = \begin{cases} 10^{m} \quad , a_{m} > k \\ 0\quad, k=0 \vee a_{m} <k\\ n - a_{m}10^{m}+1 \quad,a_{m} =k \end{cases} MsbCnt(n,k)=⎩ ⎨ ⎧10m,am>k0,k=0∨am<kn−am10m+1,am=k 我们先忽略前导 0 0 0, 除去最高位的贡献后可以得到 d p [ n ] = M s b C n t ( n ) + a m T [ 1 0 m − 1 ] + T [ n − a m 1 0 m ] dp[n]=MsbCnt(n)+a_m T[10^{m} -1]\\+T[n-a_m10^{m}] dp[n]=MsbCnt(n)+amT[10m−1]+T[n−am10m] 这里的 T [ n ] T[n] T[n]定义为 T [ n ] = d i g i t C n t ( S { n } ) S { n } : = { s 0 s 1 ⋯ s m − 1 , 0 ≤ s i ≤ 9 , S ≤ n − a m 1 0 m } T[n]= digitCnt(S\{n\})\\ \quad S\{n\}:= \{s_0s_1\cdots s_{m-1} \quad ,0 \le s_i \le 9\\ ,S \le n- a_m 10^{m}\} T[n]=digitCnt(S{n})S{n}:={s0s1⋯sm−1,0≤si≤9,S≤n−am10m} 通俗一点就是比如 n = 123 n=123 n=123
那么 S { 123 } = { 000 , 001 , ⋯ , 123 } S\{123\} = \{000,001,\cdots,123\} S{123}={000,001,⋯,123}。
也就是把 n n n之前的数列出,并加上前导 0 0 0使之与 n n n的位数对齐。
前导 0 0 0的添加并不会影响非 0 0 0的其他数位置的计数,因此可以得到
d p [ n ] [ k ] = T [ n ] [ k ] , k ≠ 0 dp[n][k] =T[n][k], k\ne0 dp[n][k]=T[n][k],k=0
我们考虑 T [ 1 0 d − 1 ] [ k ] T[10^{d} -1][k] T[10d−1][k]的值, S { 1 0 d − 1 } S\{10^{d}-1\} S{10d−1}集合中总共有 1 0 d 10^{d} 10d个数,每个数有 d d d位,而每个数字在排列中等可能出现,因此 T [ 1 0 d − 1 ] [ 1 ] = ⋯ = T [ 1 0 d − 1 ] [ 9 ] = d × 1 0 d − 1 T[10^{d}-1][1]=\cdots=T[10^{d}-1][9] \\=d \times10^{d-1} T[10d−1][1]=⋯=T[10d−1][9]=d×10d−1 我们再考虑减去前导 0 0 0, 容易得到 1 0 d − 1 10^{d}-1 10d−1内的数在 S { 1 0 d − 1 } S\{10^{d}-1\} S{10d−1}集合中的前导 0 0 0的个数为
F r o n t Z e r o ( S { 1 0 d − 1 } ) = ∑ i = 1 d − 1 1 0 i FrontZero(S\{10^{d}-1\}) =\sum_{i=1}^{d-1}10^{i} FrontZero(S{10d−1})=i=1∑d−110i
因此 T [ 1 0 d − 1 ] [ 0 ] = ∑ i = 1 d − 1 ( d − 1 ) 1 0 i + d T[10^d-1][0] =\sum_{i=1}^{d-1} (d-1)10^{i} +d T[10d−1][0]=i=1∑d−1(d−1)10i+d
对于一个数 n n n而言, 在它之前有 k ≠ 0 k \ne 0 k=0的个数为 d p [ n ] [ k ] = M o s t C n t ( n , k ) + m a m 1 0 m − 1 + d p [ n − a m 1 0 m ] [ k ] dp[n][k] = MostCnt(n,k)+ma_{m}10^{m-1}+\\ dp[n-a_m10^{m}][k ] dp[n][k]=MostCnt(n,k)+mam10m−1+dp[n−am10m][k] 如果 k = 0 k =0 k=0, 还需要减去前导 0 0 0 d p [ n ] [ 0 ] = T [ n ] [ 0 ] − ∑ i = 1 m 1 0 m dp[n][0] =T[n][0]-\sum_{i=1}^{m}10^{m} dp[n][0]=T[n][0]−i=1∑m10m
代码一 #include <iostream> #include <vector> #include <functional> #include <unordered_set> constexpr static int BASE = 10; constexpr static int MAX_POW = 12; unsigned long long typeVal; using int_type = decltype(typeVal); int MaxPowNotGreater(int_type BASE, int_type v) { int ans = 1; auto tb = BASE; for ( ;tb <= v; tb *= BASE, ans++) { } return ans - 1; } int_type getDigitCntUntil(int_type val, int_type k) { int_type v = val; int digitCnt = MaxPowNotGreater(BASE, v); int_type mod = val % BASE; int_type ans = ((mod >= k))? 1 : 0; v /= BASE; int_type cpow = BASE; for (int d = 1; d < digitCnt + 1; d++, v/= BASE, cpow *= BASE) { int_type m = v % BASE; if ( m > k) ans += cpow; else if ( m == k) ans += mod + 1; else { } ans += m * (cpow / BASE) * d; mod += cpow * m; } if ( k == 0) { cpow /= BASE; while (cpow >= 1) { ans -= cpow; cpow /= BASE; } } return ans; } int main() { int_type a = 1; int_type b = 99; std::cin >> a >> b; std::vector<int_type> cal(BASE, 0); for (int i = 0;i < 10;i++) { cal[i] = getDigitCntUntil( b,i) - getDigitCntUntil(a - 1, i); } for (auto num:cal) { std::cout << num << " "; } std::cout << std::endl; return 0; }我们可以将整次幂的数位置个数存起来
T C n t [ d ] [ k ] = T [ 1 0 d − 1 ] [ k ] TCnt[d][k] = T[10^{d}-1][k] TCnt[d][k]=T[10d−1][k] 容易得到得到递推关系式 T C n t [ d ] [ k ] = { 1 0 d − 1 + T C n t [ d − 1 ] [ k ] , k ≠ 0 9 T C n t [ d − 1 ] [ 1 ] + T C n t [ d − 1 ] [ 0 ] , k = 0 TCnt[d][k] = \\ \begin{cases} 10^{d-1} +TCnt[d-1][k],\quad k \ne 0 \\ 9TCnt[d-1][1] + TCnt[d-1][0],\quad k =0 \end{cases} TCnt[d][k]={10d−1+TCnt[d−1][k],k=09TCnt[d−1][1]+TCnt[d−1][0],k=0
代码2 #include <iostream> #include <vector> #include <functional> #include <unordered_set> constexpr static int BASE = 10; constexpr static int MAX_POW = 12; unsigned long long typeVal; using int_type = decltype(typeVal); auto fpow = [](int_type base, int_type cnt) { int_type ans = 1; while (cnt) { if (cnt & 1) ans *= base; base *= base; cnt = cnt >> 1; } return ans; }; auto LogFloor = [](int_type base, int_type v) { int_type m = base; int_type kdigits = 1; while ( m <= v) { m *= base; kdigits++; } kdigits--; return kdigits; }; template<typename T> void TEST_EQ( T a, T b) { bool ret = (a == b); if (!ret) { std::cout << a << " NOT EQUAL " << b << '\n'; } else { std::cout << a << " EQUAL " << b << '\n'; } } void testEqual(int_type a, int_type b, const std::vector<int_type>& cal) { std::vector<int> tmpCnt(BASE, 0); for (int_type i = a; i <= b; i++) { auto ti = i; while (ti) { tmpCnt[ti % BASE]++; ti /= BASE; } } bool ok = true; for (int i = 0; i < BASE; i++) { if (cal[i] != tmpCnt[i]) { std::cout << i << "failed: Real=" << tmpCnt[i] << " ;Cal= " << cal[i] <<'\n'; ok = false; } } if (ok) { std::cout << "ok fine result!!!\n"; } } std::vector<std::vector<int_type>> FCnt( MAX_POW + 1, std::vector<int_type>(BASE, 0)); int_type getDigitCntUntil(int_type val, int_type k) { int_type v = val; int digitCnt = LogFloor(BASE, v); int_type mod = val % BASE; int_type ans = ((mod >= k) && k)? 1 : 0; v /= BASE; for (int d = 1; d < digitCnt + 1;d++, v /= BASE) { int_type lsb = v % BASE ; if (d != digitCnt) { ans += lsb * fpow(BASE, d - 1) * d; if ( lsb > k) ans += fpow(BASE, d); else if (lsb == k) ans += mod + 1; } else { ans += FCnt[digitCnt][k]; ans += (lsb - 1) * fpow(BASE, d - 1) * d; if (0 != k) { if ( lsb > k) ans += fpow(BASE, d); else if (k == lsb) ans += mod + 1; else ans += 0; } } mod += lsb * fpow(BASE, d); } return ans; } int main() { for (int i = 0;i < BASE;i++) { FCnt[1][i] = 1; } for (int i = 2;i <= MAX_POW;i++) { for (int d = 0; d < BASE;d++) { if (d == 0) FCnt[i][d] = 9 * FCnt[i - 1][1] + FCnt[i - 1][0]; else { FCnt[i][d] = fpow(10, i - 1) + 10 * FCnt[i - 1][d]; } } } int_type a = 1; int_type b = 99; std::cin >> a >> b; std::vector<int_type> cal(BASE, 0); for (int i = 0;i < 10;i++) { cal[i] = getDigitCntUntil( b,i) - getDigitCntUntil(a - 1, i); } for (auto num:cal) { std::cout << num << " "; } std::cout << std::endl; return 0; }