题目
时间限制:24000ms
单点时限:2000ms
内存限制:256MB
描述
有一些人在玩一个游戏。游戏的舞台发生在一个 n 个点的树上。
这个游戏分为很多轮,每一轮都有一些玩家参与,每个玩家都会降落在一条给定的边上(不同玩家的边不同)。之后这 n 个点上都会随机出现一个0或者1作为权值。
我们说这一轮游戏是公平的,当且仅当这一轮中,对于每个玩家,如果将她所在的边删除,那么两边对应的两个子树的点权和是相等的。
对于每一轮,我们给出每个玩家的位置,你需要计算出该轮游戏是公平的概率 p。为了保证输出是整数,你只需要输出 p × 2n % (109+7) 就可以了。
输入
树的点从1开始标号。
第一行两个数 n 和 m 分别表示树的点数和游戏的轮数。
接下来 n-1行每行两个数 a 和 b 表示一条边。
接下来 m 行每行表示一轮游戏。
其中的第 i 行由一个数字 ti 开头,表示这轮游戏有 ti 个玩家,
接下来 ti 个数对,其中第 j 个数对 ai,j 和 bi,j 表示第 j 个玩家所在的边的两个端点。
n, m ≤ 100000
所有 ti 的和 ≤ 1000000
输出
输出 m 行,每行一个数表示答案。
样例输入
5 5
1 2
1 3
3 4
3 5
4 1 2 1 3 3 4 3 5
1 3 4
2 3 4 1 2
1 3 5
2 3 4 3 5
样例输出
1
5
2
5
2
分析
- 对于每轮游戏,t个边,把树分为t+1部分,每部分作为一个新的结点,由这t个边连成一棵新的树。
- 如果这棵树不是一条链,只有全0符合要求。
- 如果这棵树是一条链,只考虑两端的结点包含的原图中的结点数,设为x,y,易证结果为Cxx+y。
- 深度遍历原树,得到每个结点的开始时间和结束时间,可以再O(t)时间构造新树,判断树是否为链,可以根据结点的度数判断。
代码
#include <cstdlib>
#include <cassert>
#include <map>
#include <set>
#include <iostream>
#include <algorithm>
#include <string>
#include <sstream>
#include <vector>
#include <queue>
#include <stdint.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <limits.h>
#include <math.h>
#include <time.h>
using namespace std;
typedef uint8_t byte;
typedef int64_t illong;
typedef uint64_t ullong;
typedef uint32_t uint;
#define CLEAN(x) memset(x,0,sizeof(x))
#define TR(i,obj) for(__typeof(obj.begin()) i=obj.begin();i!=obj.end();++i)
const int tsize=1e5+3;
const ullong MOD=1e9+7;
struct Node{
int dst;
Node *next;
Node *set(int dst,Node *next){
this->dst=dst;
this->next=next;
return this;
}
};
int n,m;
Node *base[tsize];
Node nodes[tsize*2];
int size[tsize],lev[tsize];
bool visited[tsize];
int st[tsize],en[tsize];
int deg[tsize];
int nodec=0;
struct SA{
bool st;
int ind;
void set(bool st,int ind){
this->st=st;
this->ind=ind;
}
};
SA sas[tsize*2];
int sac=0;
vector<int> vec,stk,rs;
ullong fact[tsize],fact_inv[tsize];
ullong pow(ullong base,ullong ind){
ullong rs=1;
while(ind){
if(ind&1){
rs=(ullong)rs*base%MOD;
}
base=(ullong)base*base%MOD;
ind>>=1;
}
return rs;
}
void dfs(int ind){
visited[ind]=true;
size[ind]=1;
st[ind]=sac;
sas[sac++].set(true,ind);
for(Node *it=base[ind];it;it=it->next){
if(!visited[it->dst]){
lev[it->dst]=lev[ind]+1;
dfs(it->dst);
size[ind]+=size[it->dst];
}
}
en[ind]=sac;
sas[sac++].set(false,ind);
assert(sac<tsize*2);
}
int c(int n,int k){
return (ullong)fact[n]*fact_inv[n-k]%MOD*fact_inv[k]%MOD;
}
int cal(){
stk.clear();
rs.clear();
for(uint i=0;i<vec.size();++i){
int cur=vec[i];
bool st=sas[cur].st;
int ind=sas[cur].ind;
if(st){
if(stk.size()){
deg[ind]=1;
deg[stk.back()]++;
}else{
deg[ind]=0;
}
stk.push_back(ind);
}else{
assert(ind==stk.back());
assert(deg[ind]);
if(deg[ind]==1){
if(ind){
rs.push_back(size[ind]);
}else{
rs.push_back(n-size[sas[vec[1]].ind]);
}
}else if(deg[ind]>2){
return 1;
}
stk.pop_back();
}
}
assert(stk.empty());
assert(rs.size()==2);
assert(rs[0]+rs[1]<=n);
return c(rs[0]+rs[1],rs[0]);
}
int main() {
scanf("%d%d",&n,&m);
assert(n<tsize);
for(int i=1;i<n;++i){
int u,v;
scanf("%d%d",&u,&v);
--u;--v;
base[u]=nodes[nodec++].set(v,base[u]);
base[v]=nodes[nodec++].set(u,base[v]);
assert(nodec<tsize*2);
}
dfs(0);
fact[0]=1;
fact_inv[0]=1;
for(int i=1;i<=n;++i){
fact[i]=(ullong)fact[i-1]*i%MOD;
fact_inv[i]=pow(fact[i],MOD-2);
}
for(int i=0;i<m;++i){
int t;
scanf("%d",&t);
vec.clear();
vec.push_back(st[0]);
vec.push_back(en[0]);
for(int j=0;j<t;++j){
int u,v;
scanf("%d%d",&u,&v);
--u;--v;
assert(lev[u]!=lev[v]);
if(lev[u]>lev[v]){
vec.push_back(st[u]);
vec.push_back(en[u]);
}else{
vec.push_back(st[v]);
vec.push_back(en[v]);
}
}
assert(vec.size()==(t*2+2));
sort(vec.begin(),vec.end());
cout<<cal()<<endl;
}
return 0;
}