Today, I worked on a backtracking algorithm problem. The problem is as follows:
Given a collection of candidate numbers
candidatesand a target numbertarget, find all unique combinations incandidateswhere the chosen numbers sum totarget. Each number incandidatesmay only be used once in the combination. Note: The solution set must not contain duplicate combinations.
Thus, I decided to write an article based on this problem to discuss the backtracking algorithm.
First, let's look at the most basic framework of backtracking:
1class Solution: 2 def combinationSum2(): 3 res = [] 4 5 def tranceback(): 6 for i in xxx: 7 8 tranceback() 9 return res
This code essentially traverses an array.
However, the problem we often encounter is the need to reuse a certain element (perhaps the result does not allow duplicates, but we still need to traverse it more than once). For example, the most basic exhaustive approach to summation:
Such methods are very inefficient, as they repeatedly match combinations, such as [1, 2] and [2, 1].
In fact, backtracking is also a form of enumeration, but we have greater optimization potential. The core idea of the backtracking algorithm is to continuously try and error; at each step of the search, if the current decision is found to be infeasible, we backtrack to the previous step and make a new decision.
So we hope to "remember" the parts that have already been matched.
We will continue to add content to the code above:
1class Solution: 2 def combinationSum2(self, candidates: List[int], target: int): 3 res = [] 4 5 def tranceback(index, path, remain): 6 7 for i in range(index, len(candidates)): 8 9 tranceback(0, [], target) 10 return res
The backtracking function accepts three parameters: the first is the pointer, the second is the combination, and the third is how far the current combination is from the target.
The loop inside the function starts from the pointer's position, avoiding repeated checks.
Each time we iterate, we check if the requirements are met.
1for i in range(index, len(candidates)): 2 candi = candidates[i] 3 4 # If the target is reached, no need to continue looping 5 if candi == remain: 6 res.append(path + [candi]) 7 return 8 # Not enough, continue 9 if candi < remain: 10 # Continue reading
If it's not enough, we can continue looping from the pointer:
1for i in range(index, len(candidates)): 2 candi = candidates[i] 3 4 # If the target is reached, no need to continue looping 5 if candi == remain: 6 res.append(path + [candi]) 7 return 8 # Not enough, continue 9 if candi < remain: 10 for i in range(i + 1, len(candidates)): 11 remain -= candi 12 # Repeat the above operation
Every time we traverse a number, the logic executed is similar: check if remain is sufficient; if not, proceed...
So we might as well reuse the function itself, simply moving the pointer one position forward.
1if candi < remain: 2 tranceback(i + 1, path + [candi], remain - candi)
Imagine: if the target is infinitely large, we would keep calling the function, and the pointer would keep moving forward.
The outer loop allows each node to have the right to become the "start of an infinite loop."
This is the core idea of backtracking: starting from each node, then traversing the other nodes.
This code can still be optimized. Because in some cases, it is completely unnecessary to continue the loop, such as:
- The target has been reached
- The number pointed to by the current pointer exceeds the remaining sum
So each time we backtrack, we can check this, which is called "pruning."
1class Solution: 2 def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]: 3 res = [] 4 5 def tranceback(index, path, remain): 6 7 for i in range(index, len(candidates)): 8 candi = candidates[i] 9 10 if candi == remain: 11 res.append(path + [candi]) 12 return 13 if candi < remain: 14 tranceback(i + 1, path + [candi], remain - candi) 15 if candi > remain: 16 return 17 18 tranceback(0, [], target) 19 return res
This is the complete backtracking solution. However, we can still improve it.
First, if the given array is sorted in ascending order, we can get results faster. Because the density of answers is higher among smaller numbers.
Second, if we encounter the same number, we can skip it directly. Because the answers are not allowed to be duplicated.
The optimized code:
1class Solution: 2 def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]: 3 candidates.sort() 4 res = [] 5 6 def tranceback(index, path, remain): 7 8 for i in range(index, len(candidates)): 9 10 candi = candidates[i] 11 # Skip duplicates 12 if i > index and candidates[i - 1] == candidates[i]: 13 continue 14 if candi == remain: 15 res.append(path + [candi]) 16 return 17 if candi < remain: 18 tranceback(i + 1, path + [candi], remain - candi) 19 if candi > remain: 20 return 21 22 tranceback(0, [], target) 23 return res
Practice: The Eight Queens Problem
The "Eight Queens Problem" is a classic backtracking algorithm problem that requires placing 8 queens on an 8x8 chessboard so that no two queens threaten each other, meaning no two queens can be in the same row, column, or diagonal.
In the following source code, yield can be understood as adding elements to the result array.
1# state is a one-dimensional array that records the x-coordinates of the queens in each row. For example, [1, 4, 6, 3, 0, 7, 5, 2] 2 3def queens(num=8, state=()): 4 for pos in range(num): 5 # Pruning: if there is a conflict, do not loop, equivalent to the above return 6 if not conflict(state, pos): 7 # When reaching the second to last row, there is no need to nest further; return all possible coordinates (non-compliant x positions have been filtered out) 8 if len(state) == num - 1: 9 yield (pos,) 10 else: 11 # Use the current state as the starting point and call the function itself again. 12 for result in queens(num, state + (pos,)): 13 yield (pos,) + result
The same classic pattern: using the state of each loop as a starting point to search the remaining parts.
Thank you for reading, and feel free to follow my GitHub for more technical content.