位运算进阶:位掩码、子集与状态压缩
medium位运算位掩码状态压缩DP子集枚举
位掩码:用整数表示集合
一个 n 元素集合的所有子集可以用 0 到 2^n - 1 的整数表示。第 i 位为 1 表示选了第 i 个元素。
集合 {A, B, C} (n=3)
掩码 000 = 空集
001 = {A}
010 = {B}
011 = {A,B}
100 = {C}
101 = {A,C}
110 = {B,C}
111 = {A,B,C}
枚举所有子集
void enumerate(int[] nums) {
int n = nums.length;
for (int mask = 0; mask < (1 << n); mask++) {
// mask 代表一个子集
for (int i = 0; i < n; i++) {
if ((mask >> i & 1) == 1) {
// 第 i 个元素被选中
System.out.print(nums[i] + " ");
}
}
System.out.println();
}
}
枚举某个掩码的所有子集(子集的子集)
// 枚举 mask 的所有子集(包括空集)
for (int sub = mask; sub > 0; sub = (sub - 1) & mask) {
// sub 是 mask 的一个非空子集
}
// 注意:sub=0(空集)不在循环内,需单独处理
为什么 (sub - 1) & mask 能枚举子集:
sub - 1会把 sub 最低为的 1 变 0,并把下面所有位变 1& mask再限制只保留 mask 包含的位- 这样每次迭代都是 mask 的下一个更小的子集
状态压缩 DP
当 DP 的状态是一个"集合"时,用位掩码表示状态,这就是状态压缩 DP。核心是把 O(2^n) 的集合遍历编码成整数运算。
旅行商问题(TSP):最短路经过所有城市
dp[mask][i]:已经访问了 mask 中所有城市,当前在城市 i,的最短距离。
int tsp(int[][] dist) {
int n = dist.length;
int FULL = (1 << n) - 1;
int[][] dp = new int[1 << n][n];
for (int[] row : dp) Arrays.fill(row, Integer.MAX_VALUE / 2);
dp[1][0] = 0; // 从城市0出发,只访问了城市0(mask=1)
for (int mask = 1; mask <= FULL; mask++) {
for (int u = 0; u < n; u++) {
if ((mask >> u & 1) == 0) continue; // u 不在当前集合中
if (dp[mask][u] == Integer.MAX_VALUE / 2) continue;
for (int v = 0; v < n; v++) {
if ((mask >> v & 1) == 1) continue; // v 已访问
int nextMask = mask | (1 << v);
dp[nextMask][v] = Math.min(dp[nextMask][v], dp[mask][u] + dist[u][v]);
}
}
}
int ans = Integer.MAX_VALUE;
for (int u = 1; u < n; u++)
ans = Math.min(ans, dp[FULL][u] + dist[u][0]); // 回到起点
return ans;
}
状态数:2^n × n,转移:枚举下一个城市 O(n),总时间 O(2^n × n²)。
划分为 K 个等和子集(LeetCode 698)
将数组划分为 k 个子集使每个子集的和相等。
boolean canPartitionKSubsets(int[] nums, int k) {
int sum = Arrays.stream(nums).sum();
if (sum % k != 0) return false;
int target = sum / k;
int n = nums.length;
int[] dp = new int[1 << n]; // dp[mask] = 当前剩余目标量(在下一个桶中)
Arrays.fill(dp, -1);
dp[0] = 0;
Arrays.sort(nums); // 剪枝:从小到大排列
for (int mask = 0; mask < (1 << n); mask++) {
if (dp[mask] == -1) continue;
for (int i = 0; i < n; i++) {
if ((mask >> i & 1) == 1) continue; // 已使用
if (dp[mask] + nums[i] > target) break; // 超过目标,剪枝(数组已排序)
int nextMask = mask | (1 << i);
dp[nextMask] = (dp[mask] + nums[i]) % target; // 余数进入下一个桶
}
}
return dp[(1 << n) - 1] == 0; // 所有数都用了且每桶恰好满
}
最小的必要团队(LeetCode 1125)
给定所需技能列表和员工的技能列表,求人数最少的、覆盖所有所需技能的团队。
将技能用位掩码表示,dp[mask] = 覆盖技能集 mask 所需的最小人数:
int[] smallestSufficientTeam(String[] req_skills, List<List<String>> people) {
int n = req_skills.length;
Map<String, Integer> skillIdx = new HashMap<>();
for (int i = 0; i < n; i++) skillIdx.put(req_skills[i], i);
int FULL = (1 << n) - 1;
int[] dp = new int[1 << n]; // dp[mask] = 达到 mask 的最少人数
int[] parent = new int[1 << n]; // 记录路径(从哪个 mask 转移来的)
int[] personAdded = new int[1 << n]; // 本步加入的人的索引
Arrays.fill(dp, Integer.MAX_VALUE);
dp[0] = 0;
for (int mask = 0; mask <= FULL; mask++) {
if (dp[mask] == Integer.MAX_VALUE) continue;
for (int p = 0; p < people.size(); p++) {
int personSkill = 0;
for (String s : people.get(p))
if (skillIdx.containsKey(s))
personSkill |= (1 << skillIdx.get(s));
int nextMask = mask | personSkill;
if (dp[mask] + 1 < dp[nextMask]) {
dp[nextMask] = dp[mask] + 1;
parent[nextMask] = mask;
personAdded[nextMask] = p;
}
}
}
// 回溯找到团队成员
List<Integer> team = new ArrayList<>();
int cur = FULL;
while (cur != 0) {
team.add(personAdded[cur]);
cur = parent[cur];
}
return team.stream().mapToInt(i -> i).toArray();
}
子集异或和(LeetCode 1863)
所有子集的 XOR 值的总和:
// 方法1:枚举所有子集,O(n * 2^n)
int subsetXORSum(int[] nums) {
int n = nums.length, res = 0;
for (int mask = 1; mask < (1 << n); mask++) {
int xor = 0;
for (int i = 0; i < n; i++)
if ((mask >> i & 1) == 1) xor ^= nums[i];
res += xor;
}
return res;
}
// 方法2:数学规律,O(n)
// 结论:答案 = (所有元素 OR 的结果) << (n-1)
int subsetXORSum2(int[] nums) {
int or = 0;
for (int n : nums) or |= n;
return or << (nums.length - 1);
}
状态压缩 DP 模板
// 通用框架
int n = ...; // 元素/节点数量(通常 n <= 20)
int FULL = (1 << n) - 1;
int[] dp = new int[1 << n];
// 从小状态推大状态
for (int mask = 0; mask <= FULL; mask++) {
// 对每个状态,枚举可以加入的元素(或城市、任务等)
for (int i = 0; i < n; i++) {
if ((mask >> i & 1) == 0) continue; // i 已在 mask 中,跳过
// 从 mask 转移到 mask | (1 << i),或从 mask 去掉 i 等
}
}
适用场景:
- 任务分配(n 个任务/人,n ≤ 20)
- 图上路径(TSP,n ≤ 20 城市)
- 集合覆盖(n ≤ 20 个属性/技能)
- 排列计数(集合中所有元素的某种排列)
小结
| 操作 | 位掩码写法 |
|---|---|
| 选入第 i 个元素 | mask | (1 << i) |
| 移除第 i 个元素 | mask & ~(1 << i) |
| 判断第 i 个是否被选 | (mask >> i) & 1 |
| 全集(n 个元素) | (1 << n) - 1 |
| 枚举子集的子集 | for (sub=mask; sub>0; sub=(sub-1)&mask) |
| 统计集合大小 | Integer.bitCount(mask) |