题目:http://www.lydsy.com/JudgeOnline/problem.php?id=2599
裸的点分治的题,看到k这么小其实还有其他方法,我偷懒直接就SBT O(n log^2 n)水过了。
代码:
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std ;
#define AddEdge( s , t , d ) Add( s , t , d ) , Add( t , s , d )
#define MAXN 200100
#define L( t ) left[ t ]
#define R( t ) right[ t ]
#define K( t ) key[ t ]
#define S( t ) size[ t ]
#define inf 0x7fffffff
struct node {
int k , v ;
bool operator < ( const node &a ) const {
return k < a.k ;
}
bool operator == ( const node &a ) const {
return k == a.k ;
}
bool operator > ( const node &a ) const {
return k > a.k ;
}
};
node make( int _k , int _v ) {
node u ;
u.k = _k , u.v = _v ;
return u ;
}
struct SBT {
int left[ MAXN ] , right[ MAXN ] , size[ MAXN ] , V , roof ;
node key[ MAXN ] ;
void Init( ) {
V = roof = 0 ;
L( 0 ) = R( 0 ) = S( 0 ) = 0 ;
}
void update( int t ) {
S( t ) = S( L( t ) ) + S( R( t ) ) + 1 ;
}
void Left( int &t ) {
int k = R( t ) ;
R( t ) = L( k ) ; update( t ) ;
L( k ) = t ; update( k ) ;
t = k ;
}
void Right( int &t ) {
int k = L( t ) ;
L( t ) = R( k ) ; update( t ) ;
R( k ) = t ; update( k ) ;
t = k ;
}
void maintain( int &t ) {
if ( S( L( L( t ) ) ) > S( R( t ) ) ) {
Right( t ) ;
maintain( R( t ) ) ; maintain( t ) ;
return ;
}
if ( S( R( L( t ) ) ) > S( R( t ) ) ) {
Left( L( t ) ) ; Right( t ) ;
maintain( L( t ) ) , maintain( R( t ) ) ; maintain( t ) ;
return ;
}
if ( S( R( R( t ) ) ) > S( L( t ) ) ) {
Left( t ) ;
maintain( L( t ) ) ; maintain( t ) ;
return ;
}
if ( S( L( R( t ) ) ) > S( L( t ) ) ) {
Right( R( t ) ) ; Left( t ) ;
maintain( L( t ) ) , maintain( R( t ) ) ; maintain( t ) ;
return ;
}
}
int search( node k , int t ) {
if ( ! t ) return 0 ;
if ( k == K( t ) ) return t ;
return search( k , k < K( t ) ? L( t ) : R( t ) ) ;
}
void Insert( node k , int &t ) {
if ( ! t ) {
t = ++ V ;
L( t ) = R( t ) = 0 , S( t ) = 1 , K( t ) = k ;
return ;
}
Insert( k , k < K( t ) ? L( t ) : R( t ) ) ;
update( t ) ; maintain( t ) ;
}
int query( int k ) {
int t = search( make( k , 0 ) , roof ) ;
if ( ! t ) return inf ;
return K( t ).v ;
}
void Push( node k ) {
int t = search( k , roof ) ;
if ( ! t ) Insert( k , roof ) ; else {
if ( k.v < K( t ).v ) K( t ).v = k.v ;
}
}
} sbt ;
struct edge {
edge *next ;
int t , d ;
} *head[ MAXN ] ;
void Add( int s , int t , int d ) {
edge *p = new( edge ) ;
p -> t = t , p -> d = d , p -> next = head[ s ] ;
head[ s ] = p ;
}
int n , len , ans = inf ;
int size[ MAXN ] , h[ MAXN ] , dep[ MAXN ] , rt , roof , b[ MAXN ] , bn ;
bool f[ MAXN ] ;
void dfs0( int v , int u ) {
size[ v ] = 1 ;
for ( edge *p = head[ v ] ; p ; p = p -> next ) if ( f[ p -> t ] && p -> t != u ) {
dfs0( p -> t , v ) ;
size[ v ] += size[ p -> t ] ;
}
}
void dfs1( int v , int u ) {
if ( roof ) return ;
bool flag = true ;
if ( size[ rt ] - size[ v ] > size[ rt ] / 2 ) flag = false ;
for ( edge *p = head[ v ] ; p ; p = p -> next ) if ( f[ p -> t ] && p -> t != u ) {
dfs1( p -> t , v ) ;
if ( size[ p -> t ] > size[ rt ] / 2 ) flag = false ;
}
if ( flag ) roof = v ;
}
void dfs2( int v , int u ) {
for ( edge *p = head[ v ] ; p ; p = p -> next ) if ( f[ p -> t ] && p -> t != u ) {
h[ p -> t ] = h[ v ] + 1 , dep[ p -> t ] = dep[ v ] + p -> d ;
dfs2( p -> t , v ) ;
}
}
void dfs3( int v , int u ) {
b[ ++ bn ] = v ;
for ( edge *p = head[ v ] ; p ; p = p -> next ) if ( f[ p -> t ] && p -> t != u ) {
dfs3( p -> t , v ) ;
}
}
void Solve( int v ) {
dfs0( v , 0 ) ;
roof = 0 , rt = v ;
dfs1( v , 0 ) ;
h[ roof ] = dep[ roof ] = 0 ;
dfs2( roof , 0 ) ;
sbt.Init( ) ;
sbt.Push( make( 0 , 0 ) ) ;
for ( edge *p = head[ roof ] ; p ; p = p -> next ) if ( f[ p -> t ] ) {
bn = 0 ;
dfs3( p -> t , roof ) ;
for ( int i = 0 ; i ++ < bn ; ) {
int temp = sbt.query( len - dep[ b[ i ] ] ) ;
if ( temp < inf ) ans = min( ans , h[ b[ i ] ] + temp ) ;
}
for ( int i = 0 ; i ++ < bn ; ) {
sbt.Push( make( dep[ b[ i ] ] , h[ b[ i ] ] ) ) ;
}
}
f[ roof ] = false ;
for ( edge *p = head[ roof ] ; p ; p = p -> next ) if ( f[ p -> t ] ) Solve( p -> t ) ;
}
int main( ) {
scanf( "%d%d" , &n , &len ) ;
memset( head , 0 , sizeof( head ) ) ;
for ( int i = 1 ; i < n ; ++ i ) {
int s , t , d ; scanf( "%d%d%d" , &s , &t , &d ) ;
AddEdge( s + 1 , t + 1 , d ) ;
}
memset( f , true , sizeof( f ) ) ;
Solve( 1 ) ;
printf( "%d\n" , ans < inf ? ans : - 1 ) ;
return 0 ;
}