Vamos encontrar o número de sequências ruins: Sequências de tamanho que não passam por nenhuma aresta preta. Então, a resposta será o número total de sequências de tamanho menos a quantidade de sequências ruins.
Ideia Principal: Para contar tais sequências, podemos remover todas as arestas pretas da árvore!
Após removê-las, a árvore se separa em diversas componentes conexas. A resposta de cada componente não depende de outras componentes, já que não podemos escolher um caminho que saia de uma componente e vai para outra, pois ele, necessariamente, passa por uma aresta preta. Para cada uma dessas componentes, devemos contar qual o número de vértices que estão presentes nela. Para isso, basta rodar uma DFS para cada vértice ainda não visitado por ela e incrementar uma variável em para cada novo nó desta componente. Então, se é o número final de vértices de uma componente, a resposta desta componente, ou seja, o número se sequências ruins, é . Seja a soma de todas as respostas de todas as componentes, módulo .
Concluindo, a resposta final é - .
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <bits/stdc++.h> | |
using namespace std; | |
const int MAXN = 1e5 + 100; | |
const long long mod = 1e9+7; | |
vector<int> grafo[MAXN]; | |
bool mark[MAXN]; | |
long long p; | |
int dfs(int u){ | |
mark[u]=1; | |
for(int i=0; i<grafo[u].size(); i++){ | |
int v = grafo[u][i]; | |
if(!mark[v]){ | |
p++; | |
dfs(v); | |
} | |
} | |
} | |
long long fast_exponentiation(long long b, long long exp){ | |
long long ans = 1; | |
for(; exp > 0; exp /= 2){ | |
if(exp & 1) ans = ans * b % mod; | |
b = b * b % mod; | |
} | |
return ans; | |
} | |
int main(){ | |
ios_base::sync_with_stdio(false);cin.tie(NULL); | |
long long n, k; | |
cin >> n >> k; | |
for(int i=1; i<n; i++){ | |
int x, y, z; | |
cin >> x >> y >> z; | |
if(z==0){ | |
grafo[x].push_back(y); | |
grafo[y].push_back(x); | |
} | |
} | |
long long ans = fast_exponentiation(n, k); | |
for(int i=1; i<=n; i++){ | |
if(!mark[i]){ | |
p=1; | |
dfs(i); | |
ans = (mod+ans-fast_exponentiation(p, k))%mod; | |
} | |
} | |
cout << (mod+ans)%mod << endl; | |
return 0; | |
} |