0%

uoj#310 题解 - 【UNR#2】黎明前的巧克力

题目链接

简要题意:

给一个集合 $S$,求多少个有序集合对 $(A,B)$,使得 $A\cup B\neq \varnothing,A\cap B=\varnothing$ 且 $\bigoplus_{x\in A}x=\bigoplus_{x\in B}x$。

显然,由于 $A\cap B=\varnothing$,所以

于是只需要求出

满足 $\bigoplus x\in A=0$ 的非空集合的 2 的集合大小次方之和

设 $f_{i,j}$ 是考虑前 $i$ 个元素,异或值为 $j$ 的答案。显然有

设 $F_i$ 是 $f_{i,j}$ 的集合幂级数,则

其中 $\times$ 是异或卷积。

那么现在我们考虑如何求 $\text{FWT}(x^{\varnothing}+2x^{a_i})$。

我们都知道 $\text{FWT}$ 是一个线性算子,从而它应该等于 $\text{FWT}(x^{\varnothing})+2\text{FWT}(x^{a_i})$。打个表可以发现,$\text{FWT}(x^{d})_i=(-1)^{|i\cap d|}$。

从而

于是,只要算出 $\sum_{i}\big[|a_i\cap j|\bmod 2=1\big]$ 即可得到 $\text{FWT}(\text{ans})_j$。

设 $b_i$ 是 $a_j=i$ 的 $j$ 个数。于是所求即可转化为

这恰好又变回了 $\text{FWT}$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<bits/stdc++.h>
using namespace std;

const int p=998244353,maxN=1<<20,maxK=20,g=3,ig=332748118,inv2=(p+1)/2;
int calc(int n){
int x=1;while(x<n) x<<=1;
return x;
}
struct Z{
int x;
Z(int x0=0):x(x0){}
};
int inline check(int x){return x>=p?x-p:x;}
Z operator +(const Z a,const Z b){return check(a.x+b.x);}
Z operator -(const Z a,const Z b){return check(a.x-b.x+p);}
Z operator *(const Z a,const Z b){return 1LL*a.x*b.x%p;}
Z operator -(const Z a){return check(p-a.x);}
Z& operator +=(Z &a,const Z b){return a=a+b;}
Z& operator -=(Z &a,const Z b){return a=a-b;}
Z& operator *=(Z &a,const Z b){return a=a*b;}
Z fac[maxN],ifac[maxN],inv[maxN];
Z qpow(Z a,int k){
Z ans=1;
while(k){
if(k&1) ans*=a;
a*=a;
k>>=1;
}
return ans;
}
Z pow3[1<<20];
void init(){
fac[0]=ifac[0]=fac[1]=ifac[1]=inv[1]=1;
for(int i=2;i<maxN;i++)
fac[i]=fac[i-1]*i,
inv[i]=(p-p/i)*inv[p%i],
ifac[i]=ifac[i-1]*inv[i];
pow3[0]=1;
for(int i=1;i<maxN;i++) pow3[i]=pow3[i-1]*3;
}

int M,N,tot;
void FWT(Z d[],bool flg){
for(int n=1;n<N;n*=2)
for(Z *j=d;j<d+N;j+=n*2)
for(Z *k0=j,*k1=j+n,tmp;k0<j+n;k0++,k1++)
tmp=*k1,*k1=*k0-*k1,*k0+=tmp;
if(flg){
Z invN=qpow(N,p-2);
for(Z *k=d;k<d+N;k++) *k*=invN;
}
}

int inline read(){
int num=0,neg=1;char c=getchar();
while(!isdigit(c)){if(c=='-') neg=-1;c=getchar();}
while(isdigit(c)) num=num*10+c-'0',c=getchar();
return num*neg;
}
Z B[1<<20];
Z ANS[1<<20];

int main(){
init();
M=20;N=1<<M;
scanf("%d",&tot);
for(int i=1;i<=tot;i++) B[read()].x++;
FWT(B,0);
for(int i=0;i<N;i++){
int cnt=((B[i].x+tot)%p)/2; //-1 个数
ANS[i]=(tot-cnt)&1?p-pow3[cnt]:pow3[cnt];
}
FWT(ANS,1);
printf("%d\n",(ANS[0].x-1+p)%p);
}