알고리즘 공부/최단거리 알고리즘

백준 1761 (정점들의 거리) - LCA 풀이

kdhoooon 2021. 3. 29. 17:02

문제


N(2 ≤ N ≤ 40,000)개의 정점으로 이루어진 트리가 주어지고 M(1 ≤ M ≤ 10,000)개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력하라.

 

 

 

 

 

풀이


선형 시간의 LCA 알고리즘을 사용하였다.

 

해당 LCA까지의 간선 비용을 모두 합한 것이 두 정점 사이의 거리가 된다.

 

배열은 세개를 선언하였다

parents[]  : 해당 정점의 바로 위 조상을 저장 하는 배열

 

parent_len[]  : 해당 정점의 바로 위 조상과의 거리를 저장하는 배열

 

depth[]  : 1번 정점을 제일 큰 조상을 두고 그로 부터의 깊이를 저장하는 배열

 

 

거리를 구하는 코드는 아래와 같다.

public static int find_Dist(int a, int b) {
   	if(depth[a] < depth[b]) {
  		int temp = a;
   		a = b;
   		b = temp;
   	}
   	
  	int answer = 0;
   	
   	while(depth[a] > depth[b]) {
   		answer += parent_len[a];
  		a = parents[a];
   	}
   	
   	while(a != b) {
  		answer += parent_len[a] + parent_len[b];
   		a = parents[a];
   		b = parents[b];
   	}
    	
   	return answer;
    	
}

 

if(depth[a] < depth[b]) 는 a 가 b 보다 항상 깊이가 작게 두고 문제를 풀기 위함이다.

 

while(depth[a] > depth[b]) 반복문을 통해 깊이를 같게 만들어 준다.( 이 때의 거리도 더해 주어야 한다)

 

while( a != b) 반복문을 통해 서로 같은 조상이 나올 때까지 거리를 더해준다.

<전체 코드>

import java.io.*;
import java.util.*;


public class Main {
	
	public static class node{
		
		public int num, dist;
		
		public node(int num, int dist) {
			this.num = num;
			this.dist = dist;
		}
	}

	static StringBuilder sb = new StringBuilder();
	static int[] parent_len;
	static int[] depth;
	static int[] parents;
	static List<node>[] tree;
	
	
    public static void main(String[] args) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in));
        
        int n = Integer.parseInt(bufferedReader.readLine());
        
        tree = new List[n + 1];
        
        for(int i = 0 ; i <= n ; i++) {
        	tree[i] = new LinkedList<node>();
        }
        
        for(int i = 0 ; i < n - 1; i++) {
        	StringTokenizer st = new StringTokenizer(bufferedReader.readLine());
        	
        	int a = Integer.parseInt(st.nextToken());
        	int b = Integer.parseInt(st.nextToken());
        	int dist = Integer.parseInt(st.nextToken());
        	
        	tree[a].add(new node(b, dist));
        	tree[b].add(new node(a, dist));
        }
        
        depth = new int[n + 1];
        parent_len = new int[n + 1];
        parents = new int[n + 1];
        
        set_Tree(1, 1, 1, 0);
       
        int m = Integer.parseInt(bufferedReader.readLine());
        
        for(int i = 0 ; i < m ; i++) {
        	StringTokenizer st = new StringTokenizer(bufferedReader.readLine());
        	
        	int a = Integer.parseInt(st.nextToken());
        	int b = Integer.parseInt(st.nextToken());
        	
        	sb.append(find_Dist(a, b) + "\n");
        }
        
        System.out.println(sb);
    }
    
    
    public static void set_Tree(int node, int pnode, int level, int len) {
    	parents[node] = pnode;
    	parent_len[node] = len;
    	depth[node] = level;
    	
    	for(node next : tree[node]) {
    		int childnode = next.num;
    		int child_len = next.dist;
    		
    		if(childnode == pnode) 
    			continue;
    		
    		set_Tree(childnode, node, level + 1, child_len);
    	}
    }
    
    public static int find_Dist(int a, int b) {
    	if(depth[a] < depth[b]) {
    		int temp = a;
    		a = b;
    		b = temp;
    	}
    	
    	int answer = 0;
    	
    	while(depth[a] > depth[b]) {
    		answer += parent_len[a];
    		a = parents[a];
    	}
    	
    	while(a != b) {
    		answer += parent_len[a] + parent_len[b];
    		a = parents[a];
    		b = parents[b];
    	}
    	
    	return answer;
    	
    }
}