题目来源
Say you have an array for which the ith element is the price of a given stock on day i.
Design an algorithm to find the maximum profit. You may complete at most k transactions.
Note:
You may not engage in multiple transactions at the same time (ie, you must sell the stock before you buy again).
这道题呢,大家肯定一看也知道应该用DP,但是具体应该怎么用呢,肯定是个二维的DP。假设dp[k][i]
表示前i天里交易k次最大盈利,那么dp[k][i]可以有两种可能,一种是在第i天的时候没有交易则dp[k][i] = dp[k][i-1]
,另一种是在第i天的时候有交易则dp[k][i] = dp[k-1][j-1] + price[i] - price[j]
,j表示上一次买入的时候的那一天。
那么dp[k][i] = max(dp[k][i-1], dp[k-1][j-1] + price[i] - price[j])
。
dp[0][i] = 0
因为0次交易的话盈利肯定是0。
dp[k][0] = 0
因为第0天肯定也没有盈利。
代码如下所示:
class Solution {
public:
int maxProfit(int k, vector<int>& prices) {
int n = prices.size();
vector<vector<int>> dp(k+1, vector<int>(n+1, 0));
for (int i=1; i<=k; i++) {
for (int j=1; j<=n; j++) {
int tmp = 0;
for (int jj=1; jj<j; jj++)
tmp = max(tmp, dp[i-1][jj-1] + prices[j-1] - prices[jj-1]);
dp[i][j] = max(dp[i][j-1], tmp);
}
}
return dp[k][n];
}
};
然后结果怎么样,没错,又是TLE!又超时了!时间复杂度有点高,O(n^3)。
然后参考了下大神的,修改了一下代码,改成O(n^2)的。
class Solution {
public:
int maxProfit(int k, vector<int>& prices) {
int n = prices.size();
if (n < 2)
return 0;
int maxPro = 0;
vector<vector<int>> dp(k+1, vector<int>(n+1, 0));
for (int i=1; i<=k; i++) {
int tmpMax = dp[i-1][1] - prices[0];
for (int j=1; j<=n; j++) {
dp[i][j] = max(dp[i][j-1], tmpMax + prices[j-1]);
tmpMax = max(tmpMax, dp[i-1][j] - prices[j-1]);
maxPro = max(maxPro, dp[i][j]);
}
}
return maxPro;
}
};
巧妙的引入了一个tmpMax来存储以前dp[i-1][jj-1] + prices[j-1] - prices[jj-1]
中的dp[i-1][jj-1] - prices[jj-1]
。
但是还是A不了,当k或者prices非常大的时候,内存溢出了。
继续看看怎么改进,可以采取一些小技巧来防止TLE或者溢出。
比如说假如交易次数达到天数的一半,那么最大收益可以直接算出来,不用管交易次数的限制。
class Solution {
public:
int maxProfit(int k, vector<int>& prices) {
int n = prices.size();
if (n < 2)
return 0;
if (k > n / 2)
return quickSolver(prices);
int maxPro = 0;
vector<vector<int>> dp(k+1, vector<int>(n+1, 0));
for (int i=1; i<=k; i++) {
int tmpMax = dp[i-1][1] - prices[0];
for (int j=1; j<=n; j++) {
dp[i][j] = max(dp[i][j-1], tmpMax + prices[j-1]);
tmpMax = max(tmpMax, dp[i-1][j] - prices[j-1]);
maxPro = max(maxPro, dp[i][j]);
}
}
return maxPro;
}
int quickSolver(vector<int> prices)
{
int n = prices.size();
int res = 0;
for (int i=1; i<n; i++) {
if (prices[i] > prices[i-1])
res += prices[i] - prices[i-1];
}
return res;
}
};