题目:
感觉没紫题(上位蓝题到下位紫,考虑到考场上难度自动上升半段,给紫题也合理
首先我们考虑什么情况下会出错:
很显然,对于某个数i,如果w=2,按照贪心策略如果选中一定不会出错(在/2的情况下仍然排在前面,说明原价一定比较高)
如果w=1,选中该数可能会导致后续只能选择另一个w=1的数(这个数可能很小)而导致无法选择一个w=2的数(这个数可能大于所选的两个w=1的数的和)
所以我们考虑正难则反
也就是找出所有非法情况
我们令y=本来应该选择的数,x=贪心策略选择的数(大)z=贪心策略选择的数(小)
把a从小到大排序:
最优解为
........y.......
当前选择为
....z...x....(y)(没选)...
显然的,x>a[y]/2,所以它被选择了;
我们尝试将区间分段,考虑每个区间的取数;
Ⅰ:随便取任何数(1/2),因为z是选的最后一个数,所以该段区间的w赋值无影响,为答案提供2^z种可能性
Ⅱ:已知y没有被取,因为在给到w=2时a[y]/2<a[x]/1;那么对于无论/2还是/1都更小的Ⅱ区间内的数更不会被取,他们的性价比无论如何都低于y
Ⅲ:w=2时,他们的性价比一定比y低,不考虑,w=1时,在已经选择x的情况下,选择该数一定是最优解,而我们当前考虑的是错解,所以不考虑;
Ⅳ:w=2时性价比小于y,不选,w=1时选择,cost=0/1
Ⅴ:w=1/2都选,cost=1/2
我们考虑枚举x,y;z的范围可以根据xy的范围得出,因为x,z必选,且w都等于1,所以留给剩下选数的cost=m-2;
观察上面的图,发现了吗,只有Ⅳ,Ⅴ区间内的数才会被选择,其中Ⅴ内的数必被选中,我们可以将cost统一减去Ⅴ范围内数的个数,这样Ⅳ/Ⅴ区间内的数w就都变成了0/1;
对于每一组x,y,我们需要在(n-x-1)个数中选择(cost-(n-y))=(m-2-(n-y))个数,组合数O1搞定,总时间复杂度O2;
os:洛谷卡signed main.......这我是真没想到
code:
#include<bits/stdc++.h>
//#define int long long
#define inf 0x3f3f3f3f3f3f3f
#define GG ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define cnot cout<<"NO"<<"\n"
#define cyes cout<<"YES"<<"\n"
#define cans cout<<ans<<"\n"
#define pb push_back
#define x0 first
#define y0 second
#define lc p<<1
#define rc p<<1|1
#define mem(a,b) memset(a,b,sizeof(a))
#define sp(x) fixed<<setprecision(x)
#define all(v) v.begin(),v.end()
#define fr(i,st,ed) for(int i=st;i<=ed;i++)
#define ffr(i,st,ed,dt) for(int i=st;i<=ed;i+=dt)
#define all1(a) a.begin()+1,a.end()
using namespace std;
typedef pair<int,string>Pis;
typedef pair<int,int>Pii;
const int N=10005,mod=998244353,M=1e6+10;
int lowbit(int x){
return x&(-x);}
//vector<int>inv2(N);
int inv2[N];
//vector<vector<int> >C(N,vector<int>(N));
int C[N][N];
int a[N];
void P(){
inv2[0]=1;
for(int i=1;i<10001;i++){
inv2[i]=(long long)2*inv2[i-1]%mod;
}
C[0][0]=1;
for(int i=1;i<10001;i++){
C[i][0]=1;
for(int j=1;j<=i;j++){
C[i][j]=(C[i-1][j-1]+C[i-1][j])%mod;
}
}
}
void solve(){
int n,m;
cin>>n>>m;
//vector<int>a(n+1);
fr(i,1,n){
cin>>a[i];
}
//sort(all1(a));
sort(a+1,a+1+n);
int ans=0;
for(int x=1;x<=n;x++){
int pos=0;
for(int y=x+1;y<=n;y++){
if(a[x]==a[y]){
continue;
}
if((m-2-(n-y))<0){
continue;
}
if(2*a[x]<=a[y]){
break;
}
while(pos<n&&a[pos+1]+a[x]<a[y]){
pos++;
}
ans=(ans+(long long)1*C[n-x-1][m-2-(n-y)]*inv2[pos])%mod;
}
}
ans=(inv2[n]-ans+mod)%mod;
cans;
}
int main(){
GG;
int _t=1;
int __;
P();
cin>>__>>_t;
while(_t--){
solve();
}
}