자료구조 공부/Tree 구조 알고리즘

[백준] 15681 트리와 쿼리 <Java> - 트리에서 DP

kdhoooon 2022. 1. 7. 13:30

문제


간선에 가중치와 방향성이 없는 임의의 루트 있는 트리가 주어졌을 때, 아래의 쿼리에 답해보도록 하자.

  • 정점 U를 루트로 하는 서브트리에 속한 정점의 수를 출력한다.

만약 이 문제를 해결하는 데에 어려움이 있다면, 하단의 힌트에 첨부한 문서를 참고하자.

 

 

 

풀이


풀이는 문제에서 설명하는대로 구현을 하였다 코드를 보면서 설명을 하면,

public static void makeTree(int currentNode, int parentNode) {
	
	for(int node : tree[currentNode]) {
		if(node != parentNode) {
			child[currentNode].add(node);
			parent[node] = currentNode;				
			makeTree(node, currentNode);
		}
	}
}

makeTree 메서드다.

현재 노드와 연결된 모든 노드를 돌면서 부모노드와 같은 노드가 아니라면, 자신을 부모로하는 자식노드로 추가한다.

이를 재귀적으로 다음 노드는 똑같은 방식으로 트리를 만든다.

 

위 방식이 가능한 이유는,

현재 자신과 연결 된 노드들은 부모와 자식뿐이기 때문이다.

이렇게 재귀적으로 하면 자식노드와 부모노드를 나눌 수 있다.

 

public static void countSubtreeNodes(int currentNode) {
	size[currentNode] = 1;
	for(int node : child[currentNode]) {
		countSubtreeNodes(node);
		size[currentNode] += size[node];
	}
}

countSubtreeNodes 메서드의 코드다.

현재 노드를 포함하여 subtree의 개수를 세는 것이기 때문에, size = 1 에서 더하게 된다.

이 것도 재귀적으로 자식의 subtree 개수를 모두 더하면 현재 자신의 subtree 개수가 됨을 이용하였다.

 

<전체코드>

import java.io.*;
import java.util.*;

public class Main {	

	static StringBuilder sb = new StringBuilder();
	static List<Integer>[] tree;
	static List<Integer>[] child;
	static int[] parent, size;
	static int n, r, q;

	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		
		StringTokenizer st = new StringTokenizer(br.readLine());
		n = Integer.parseInt(st.nextToken());
		r = Integer.parseInt(st.nextToken());
		q = Integer.parseInt(st.nextToken());
		
		tree = new List[n + 1];
		child = new List[n + 1];
		parent = new int[n + 1];
		size = new int[n + 1];
		for(int i = 0 ; i <= n ; i++) {
			tree[i] = new ArrayList<>();
			child[i] = new ArrayList<>();
		}
		
		
		for(int i = 0 ; i < n - 1; i++) {
			st = new StringTokenizer(br.readLine());
			
			int u = Integer.parseInt(st.nextToken());
			int v = Integer.parseInt(st.nextToken());
			
			tree[u].add(v);
			tree[v].add(u);
		}
		
		makeTree(r, -1);
		countSubtreeNodes(r);
		for(int i = 0 ; i < q ; i++) {
			sb.append(size[Integer.parseInt(br.readLine())] + "\n");
		}
		
		System.out.println(sb);
	}
	
	public static void countSubtreeNodes(int currentNode) {
	    size[currentNode] = 1;
	    for(int node : child[currentNode]) {
		     countSubtreeNodes(node);
		     size[currentNode] += size[node];
	    }
	}
	
	public static void makeTree(int currentNode, int parentNode) {
		
		for(int node : tree[currentNode]) {
			if(node != parentNode) {
				child[currentNode].add(node);
				parent[node] = currentNode;				
				makeTree(node, currentNode);
			}
		}
	}
}

 

하지만,

문제를 풀고 나서 보니까 makeTree 부분에서 자식노드와 부모노드의 분리 없이 countSubtreeNodes 메서드에서 subtree 개수를 return 해서 풀어도 되겠다 생각하여 두 메서드를 합쳐 하나의 메서드로 만들었다.

 

<2번째 방법 코드>

import java.io.*;
import java.util.*;

public class Main {	

	static StringBuilder sb = new StringBuilder();
	static List<Integer>[] tree;
	static List<Integer>[] child;
	static int[] parent, size;
	static int n, r, q;

	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		
		StringTokenizer st = new StringTokenizer(br.readLine());
		n = Integer.parseInt(st.nextToken());
		r = Integer.parseInt(st.nextToken());
		q = Integer.parseInt(st.nextToken());
		
		tree = new List[n + 1];
		size = new int[n + 1];
		for(int i = 0 ; i <= n ; i++) {
			tree[i] = new ArrayList<>();
		}
		
		
		for(int i = 0 ; i < n - 1; i++) {
			st = new StringTokenizer(br.readLine());
			
			int u = Integer.parseInt(st.nextToken());
			int v = Integer.parseInt(st.nextToken());
			
			tree[u].add(v);
			tree[v].add(u);
		}
		
		countSubtreeNodes(r, -1);
		for(int i = 0 ; i < q ; i++) {
			sb.append(size[Integer.parseInt(br.readLine())] + "\n");
		}
		
		System.out.println(sb);
	}
		
	public static int countSubtreeNodes(int currentNode, int parentNode) {
		int result  = 1;
		
		for(int node : tree[currentNode]) {
			if(node != parentNode) {				
				result += countSubtreeNodes(node, currentNode);;
			}
		}
		
		return size[currentNode] = result;
	}
}