알고리즘/Mo's 알고리즘

BOJ 백준 13518 트리와 쿼리 9

kyjdummy 2025. 5. 16. 23:08

 

문제 : https://www.acmicpc.net/problem/13518

 

[ 문제 요약 ]

  • 트리 간선 정보와, 각 노드당 가중치가 주어집니다.
  • 쿼리는 두 개의 노드로 주어지며 두 노드를 포함한 두 노드 사이에 있는 유일한 경로의 모든 노드들의 가중치 값의 종류의 개수를 출력합니다.

 

[ 테스트 케이스 설명 ]

8// 노드 수 2<=100,000
105 2 9 3 8 5 7 7 // 가중치 1,000,000
1 2	// 노드수 -1 개줄에 연결 간선 주어짐
1 3
1 4
3 5
3 6
3 7
4 8
2	// 쿼리 수 1<=100,000
2 5	// 두 노드가 주어짐
7 8
// 이하 답(두 노드 사이 서로다른 정점의 가중치의 개수)
4
4

 

 

[ 알고리즘 분류 ]

  • 트리
  • 오프라인 쿼리
  • 최소 공통 조상
  • 오일러 경로 테크닉
  • Mo’s

 

[ 문제 해설 ]

 

이 문제는 완전 탐색 시 시간 초과이기 때문에, 각 쿼리마다 구한 값을, 다음 값에도 활용할 수 있도록 해야 합니다. 즉, mo’s 알고리즘으로 탐색해야 합니다.

 

트리를 지나면서 경로에 대해 가중치 합, 곱 등을 빠르게 구하기 위해서는 트리를 세그먼트 트리에 대응되도록 분할해야 해서 HLD 알고리즘으로 트리를 분할합니다.

 

이와 비슷하게 트리를 탐색하는데 mo's 알고리즘을 이용하려면 트리를 1차원 배열에 평탄화 시켜야 합니다.

 

트리를 1차원 배열에 평탄화할 때는, 각 노드마다의 진입(in) 시점, 진출(out) 시점을 기록하고, 또한 [진입+진출]을 한꺼번에 1차원에 담을 배열이 필요합니다. 이를 구할 때 오일러 투어 테크닉과 비슷하게 구현합니다.

 

진입(in), 진출(out)을 기록할 배열의 크기는 노드의 수 N 만큼 선언하고, [진입+진출]을 한꺼번에 1차원에 담을 배열은 노드 수의 2배인 N * 2만큼 배열을 선언합니다.

 

그리고 mo’s 알고리즘으로 정렬하고 투 포인터로 탐색 시 투포인터는 [진입+진출]을 한꺼번에 담은 1차원 배열에서 할 것이기 때문에 주어지는 쿼리의 두 노드는 [진입+진출] 배열의 위치로 변경한 후 담아야 합니다.

 

근데 이렇게 트리를 1차원 배열로 평탄화한 것으로 어떻게 답을 구할 수 있을까요?

 

트리를 평탄화하게 되면 규칙이 생깁니다. 손으로 해보면서 따라와야 이해가 쉽습니다.

 

그림판으로 그려 양해 부탁드립니다

 

위와 같이 노드가 1번부터 7번까지 있다고 하면, DFS로 1번부터 차례로 방문한다고 할 때 방문 순서는 아래와 같다고 가정합니다.

 

1 - 2 - 3 - 4 - 5 - 7 - 6

 

그럼 그에 따른 진입, 진출 값을 트리에 적어보면 아래와 같습니다.

그림판으로 그려 양해 부탁드립니다.

 

각 노드의 왼쪽 위에 쓰여있는 게 진입 시점이고, 오른쪽에 쓰여있는 게 진출 시점입니다. mo’s 투포인터 탐색 시 바로 저 왼쪽 위, 오른쪽 위 값들로 탐색을 합니다. 저 왼쪽 위, 오른쪽 위 값들은 [진입+진출]을 1차원으로 평탄화 한 배열의 인덱스로 쓰이고, value는 해당 노드 번호들이 들어있게 됩니다.

 

[진입+진출 배열 저장 현황]

[1] : 1번 노드

[2] : 2번 노드

[3] : 2번 노드

[4] : 3번 노드

[5] : 4번 노드

[6] : 4번 노드

[7] : 5번 노드

[8] : 7번 노드

[9] : 7번 노드

[10] : 5번 노드

[11] : 6번 노드

[12] : 7번 노드

[13] : 3번 노드

[14] : 1번 노드

 

쿼리에서 만약 1번 노드부터 5번 노드까지 탐색한다 치면, 1번 노드의 진입 시간은 1이고, 5번 노드의 진입시간은 7입니다.

그럼 1부터 7까지 순서대로 [진입+진출] 배열에서 탐색합니다. 그러면 탐색하다가 두번 만나게 되는 노드가 있는데, 두번 만나는 노드들은 진입과 진출을 그 안에서 한 것이므로 없는 것과 마찬가지입니다. 한 번만 등장하는 노드들을 확인해 보면 그 노드들이 쿼리에서 주어진 두 노드 사이에 실제 존재하는 노드임을 알 수 있습니다. 이렇게 트리를 1차원으로 평탄화 시키면 그 배열로 두 노드 사이에 존재하는 노드가 뭔지 알 수 있습니다. 

 

쿼리의 두 노드의 LCA가 두 노드 중 하나라면 위와 같은 방법을 답을 구하면 됩니다. 방금 설명한 규칙은, 주어진 두 노드 중 하나가 서로의 LCA 일 때만 해당됩니다.

 

만약 4번 노드부터 6번 노드까지 탐색한다고 하면 어떻게 확인할까요?

 

4번 노드의 진입은 시간은 5이고, 6번 노드의 진입은 11입니다. 그럼 5부터 11까지 탐색할 텐데 그러면 제대로 된 결과가 나오지 않습니다.

 

일단 4번 노드의 진입 시간이 5인데, 진출 시간이 6이므로, 5부터 11까지 탐색하다 보면 자기 자신이 두 번 등장하여 없는 것이나 마찬가지가 됩니다. 또한 LCA 노드인 3번 노드의 진입번호 (4)는 숫자에 포함되지도 않습니다. 그래서 LCA가 따로 있는 쿼리 질의라면, 즉, 쿼리 질의상 주어진 두 노드의 LCA가 두 노드 중 하나가 아니라면, 진입 시간이 아니라 하나만 진출 시간으로 탐색을 시작하고, LCA 노드의 가중치는 따로 계산해 주어야 합니다.

 

그럼 4번 노드의 진출 시간인 6과, 6번 노드의 진입시간인 11을 탐색하면, 그 사이에 한 번만 등장하는 노드가 실제 쓰이는 노드가 되고, LCA를 별도로 확인해 주면 답이 됩니다.

 

정답 코드는 다음과 같은 흐름으로 되어 있습니다.

 

1 ) 노드의 가중 치 입력, 인접 리스트 생성

2 ) HLD를 위한 인접리스트 안에 자식노드들의 위치 재배열

3 ) 트리 1차원 평탄화 + 트리 분할

4 ) 쿼리를 입력받아 시작, 종료 노드를 평탄화된 트리 1차원 배열에 대응되도록 치환 및 LCA 값이 두 노드 중 하나인지 아닌지 저장

5 ) Mo’s 알고리즘 정렬 후 탐색 진행 후 결과 출력

 

풀이는 2가지입니다. 사실상 똑같지만, LCA를 HLD로 구하느냐, 아니면 dp로 구하느냐 차이만 있습니다.

 

보통 LCA만 구할 때는 dp를 많이 씁니다만, 개인적으로 HLD가 활용도가 높아 HLD를 선호하는 편이라 코드를 먼저 적습니다.

코드 지적 환영합니다.

 

[ HLD 정답 코드 ]

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.StringTokenizer;
class Main{
	
	static final BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	static StringTokenizer st;
	static int N, Q;
	static int sqrt;		// mo's알고리즘으로 정렬시 사용할 제곱근, N * 2의 제곱근이 된다.
	static int time;		// 오일러투어 탐색시 진입 진출시간을 표시할 변수
	static int weight[];		// 각노드의 가중치를 담을 배열
	static int weightCount[];	// 가중치가 등장한 횟수 (idx : 가중치 값) (value : 가중치 등장 횟수)
	static int nodeCount[];		// 노드가 몇번 등장했는지 카운팅 하기 위한 배열
	static int in[];		// 진입 시간을 담을 배열 (index : 노드번호) (value : 진입시간)
	static int out[];		// 진출 시간을 담을 배열 (index : 노드번호) (value : 진입시간)
	static int ett[];		// 진입+진출 시간을 1차원으로 평탄화 한것(index : 진입시간) (value : 노드번호)
	static int chainLevel[];	// LCA를 구할 때만 사용, HLD분할시 체인의 레벨 저장
	static int chainHeader[];	// LCA를 구할 때만 사용, HLD분할시 체인의 가장 첫번째 값
	static int chainParent[];	// LCA를 구할 때만 사용, HLD분할시 이전 체인으로 바로 점프할 수 있도록 이전 노드저장
	static int result[];		// 쿼리 질의 최종 결과를 담을 배열
	static List<Integer>[] adList;	// 트리 정보를 담을 인접 리스트
	static List<Query> query;	// 주어지는 쿼리 정보를 담을 리스트
	
	
	public static void main(String[] args)throws Exception{
		init();	// 해당 함수에서 배열들 초기화 및 간선의 가중치를 입력 받음 + 간선 정보도 입력 받아 인접리스트 생성
		setHLD(1, 0, new int[N + 1]);// HLD로 트리를 분할하기 위해 전처리 과정, 단순 adList의 위치만 바꾸는 용도
		dfs(1, 1);// 해당 함수에서 트리를 평탄화 시키며 동시에 HLD로 트리를 분할 한다.
		inputQuery();// 해당 함수에서 쿼리를 입력 받고, 입력된 노드들을 1차원 배열에 대응토록 변환 및 LCA구하고 mo's 정렬까지함
		solve();// 투포인터로 정답을 구한 후 최종 출력 까지 진행
	}
	static void init()throws Exception {
		N = Integer.parseInt(br.readLine());
		sqrt = (int)Math.sqrt((N * 2));
		weight = new int[N + 1];
		weightCount = new int[1_000_001];
		nodeCount = new int[N + 1];
		in = new int[N + 1];
		out = new int[N + 1];
		ett = new int[(N * 2) + 1];
		chainLevel = new int[N + 1];
		chainHeader = new int[N + 1];
		chainParent = new int[N + 1];
		adList = new ArrayList[N + 1];
		chainHeader[1] = 1;	// 체인 헤더 기본 값 세팅
		
		
		st = new StringTokenizer(br.readLine());
		for(int i=1; i<=N; i++)
		{
			weight[i] = Integer.parseInt(st.nextToken());
			adList[i] = new ArrayList<>();
		}
		for(int i=1; i<N; i++)
		{
			st = new StringTokenizer(br.readLine());
			int u = Integer.parseInt(st.nextToken());
			int v = Integer.parseInt(st.nextToken());
			adList[u].add(v);
			adList[v].add(u);
		}
	}
	static void setHLD(int nowNode, int parentNode, int[] size) {
		int heavyIdx = 0;
		int heavySize = 0;
		size[nowNode] = 1;
		for(int i=0; i<adList[nowNode].size(); i++)
		{
			int nextNode = adList[nowNode].get(i);
			if(nextNode == parentNode) // 이미 방문한 노드는 스킵
				continue;
			
			setHLD(nextNode, nowNode, size);
			
			size[nowNode] += size[nextNode]; // nowNode의 자식 노드 사이즈를 저장
			
			// 가장 무거운 자식 노드의 인덱스를 저장
			if(heavySize < size[nextNode])
			{
				heavySize = size[nextNode];
				heavyIdx = i;
			}
		}
		// 가장 무거운 노드를 리스트의 가장 앞으로 옮겨 추후 연산을 간편하게 한다.
		if(adList[nowNode].size() > 0)
			Collections.swap(adList[nowNode], 0, heavyIdx);
	}
	static void dfs(int nowNode, int level) {
		in[nowNode] = ++time;// 진입 시간 마킹
		ett[time] = nowNode;// 1차원화
		chainLevel[nowNode] = level;	// HLD를 위한 레벨 입력
		
		for(int i=0; i<adList[nowNode].size(); i++)
		{
			int nextNode = adList[nowNode].get(i);
			
			if(chainHeader[nextNode] != 0)// chainHeader가 있다면 이미 방문한 것
				continue;
			
			if(i == 0)// 무거운 자식일 때(setHLD 함수에서 무거운 자식노드를 앞으로 옮겼으므로)
			{
				chainHeader[nextNode] = chainHeader[nowNode];// 무거우면 헤더값 유지
				chainParent[nextNode] = chainParent[nowNode];// 무거우면 점프할 노드 유지
				dfs(nextNode, level);// 무거운 자식 탐색시 level 유지
				continue;
			}
			
			// 가벼운 자식 노드들은 새로운 체인을 시작하므로, header 값과 parent, level모두 변경
			chainHeader[nextNode] = nextNode;
			chainParent[nextNode] = nowNode;
			dfs(nextNode, level + 1);
		}
		
		out[nowNode] = ++time;// 진출 시간 마킹
		ett[time] = nowNode;// 1차원화
	}
	static void inputQuery()throws Exception {
		Q = Integer.parseInt(br.readLine());
		result = new int[Q + 1];
		query = new ArrayList<>();
		
		for(int i=1; i<=Q; i++)
		{
			st = new StringTokenizer(br.readLine());
			int s = Integer.parseInt(st.nextToken());
			int e = Integer.parseInt(st.nextToken());
			int lca = getLCA(s,e);
			
			if(in[s] > in[e])	// s가 더 자식이라면 s를 부모로 만듬
			{
				int tmp = s;
				s = e;
				e = tmp;
			}
			
			if(lca == s)// 두 노드 중 하나가 lca라면
			{
				query.add(new Query(in[s], in[e], i, 0, in[s] / sqrt));
				continue;
			}
			// 두 노드 모두 lca가 아니라면
			query.add(new Query(out[s], in[e], i, lca, in[s] / sqrt));
		}
		Collections.sort(query);
	}
	static void solve() {
		int s = 1;
		int e = 0;
		int cnt = 0;
		
		for(Query q : query)
		{
			while(e < q.e) cnt += plus(ett[++e]);
			while(q.s < s) cnt += plus(ett[--s]);
			while(q.e < e) cnt += minus(ett[e--]);
			while(s < q.s) cnt += minus(ett[s++]);
			
			if(q.lca != 0) cnt += plus(q.lca);	// lca값은 노드값 그대로 저장했으므로 그냥 넘겨줌
			
			result[q.idx] = cnt;
			
			if(q.lca != 0) cnt += minus(q.lca); // lca값은 노드값 그대로 저장했으므로 그냥 넘겨줌
		}
		
		// 최종적인 값 출력
		StringBuilder sb = new StringBuilder();
		
		for(int i=1; i<=Q; i++)
			sb.append(result[i]).append('\n');
		
		System.out.print(sb);
	}
	static int minus(int node) {
		int w = weight[node];
		int cnt = 0;
		--nodeCount[node];
		
		if(nodeCount[node] == 1)
		{
			if(++weightCount[w] == 1)
				++cnt;
		}
		else if(nodeCount[node] == 0)
		{
			if(--weightCount[w] == 0)
				--cnt;
		}
		return cnt;
	}
	static int plus(int node) {
		int w = weight[node];
		
		int cnt = 0;
		
		++nodeCount[node];
		
		if(nodeCount[node] == 1)
		{
			if(++weightCount[w] == 1)
				++cnt;
		}
		else if(nodeCount[node] == 2) {
			if(--weightCount[w] == 0)
				--cnt;
		}
		return cnt;
	}
	static int getLCA(int node1, int node2) {
		if(in[node1] > in[node2]) // node1이 더 자식이라면 부모로 올림
		{
			int tmp = node1;
			node1 = node2;
			node2 = tmp;
		}
		// 같은 체인이 될 때 까지 높은 레벨에 있는 체인을 낮은 레벨로 올림
		while(chainHeader[node1] != chainHeader[node2])
		{
			// node1의 레벨이 더 크다면 더 작게 올려야하기 때문에 node1을 위로 올림
			if(chainLevel[node1] > chainLevel[node2])
			{
				node1 = chainParent[node1];
				continue;
			}
			// node2가 레벨이 더크면 node2를 위로 올려 level을 작게 만듬
			node2 = chainParent[node2];
		}
		// 진입 시간이 더 작은(더 상위노드) 노드 번호를 반환하면 그게 LCA가 됨
		return in[node1] > in[node2] ? node2 : node1;
	}
	static class Query implements Comparable<Query>{
		int s, e, idx, lca, fac;
		Query(int s, int e, int idx, int lca, int fac){
			this.s=s;
			this.e=e;
			this.idx=idx;
			this.lca=lca;
			this.fac=fac;
		}
		@Override
		public int compareTo(Query o) {
			if(fac != o.fac)
				return fac - o.fac;
			// 구간이 짝수면 e기준 오름차순 정렬
			if((fac&1) == 0)
				return e - o.e;
			// 구간이 홀수면 e기준 내림차순 정렬
			return o.e - e;
		}
	}
}

 

[ DP 정답 코드 ]

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.StringTokenizer;

class Main{
	
	static final BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	static StringTokenizer st;
	static int N, Q;
	static int idx;		// ETT진행시 순차 증가할 인덱스
	static int sqrt;	// Mo's알고리즘을 위한 변수
	static int in[];	// ETT진행시 진입 인덱스	 (idx : 노드번호) (value : all배열의 인덱스)
	static int out[];	// ETT진행시 나가는 인덱스(idx : 노드번호) (value : all배열의 인덱스)
	static int all[];	// 트리를 1차원으로 표현, ETT in과 out을 모두 담음( idx : ETT변환 인덱스) (value : 노드번호)
	static int arr[];	// 각 노드마다의 가중치를 담음
	static int depth[];	// LCA구할 때 사용
	static int dp[][];	// LCA를 빠르게 구하기 위한 dp배열
	static int cnt[];	// (idx : 가중치) (value : 가중치의 당장횟수)
	static int nodeCnt[]; // 해당 노드의 등장 횟수
	static int ans;		// 최종적으로 구한 서로 다른 가중치의 개수
	static int result[];
	static ArrayList<Integer> adList[];// 인접리스트
	static ArrayList<Query> query;
	
	public static void main(String[] args)throws Exception{
		init();// 기본 입력 받기
		dfs(1, 1);// 트리 1차원화 + LCA를 위한 기본 세팅
		buildLcaTable();// LCA를 위한 dp테이블 세팅
		inputQuery();// 쿼리 입력 및 1차원 배열 인덱스로 치환
		solve();// 투포인터로 해결
	}
	static void init()throws Exception{
		N = Integer.parseInt(br.readLine());
		sqrt = (int)Math.sqrt(N<<1);
		in = new int[N + 1];
		out = new int[N + 1];
		all = new int[(N << 1) + 1];
		arr = new int[N + 1];
		adList = new ArrayList[N + 1];
		depth = new int[N + 1];
		dp = new int[N + 1][18];
		query = new ArrayList<>();
		cnt = new int[1_000_001];
		nodeCnt = new int[N + 1];
		
		st = new StringTokenizer(br.readLine());
		for(int i=1; i<=N; i++)
		{
			arr[i] = Integer.parseInt(st.nextToken());// 각 노드당 가중치를 입력받음
			adList[i] = new ArrayList<>();// 인접리스트도 같이 초기화
		}
		
		for(int i=1; i<N; i++)
		{
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			adList[a].add(b);
			adList[b].add(a);
		}
	}
	static void dfs(int now, int dep) {
		depth[now] = dep;	// LCA를 위한 깊이 저장
		in[now] = ++idx;	// 트리의 진입 정보를 담음
		all[idx] = now;	// 트리 순회를 1차원 배열로 저장
		for(int next : adList[now])
		{
			if(depth[next] != 0)// 방문한 적이 있다면 스킵
				continue;
			
			dp[next][0] = now;	// LCA를 위해 부모노드 저장
			
			dfs(next, dep + 1);
		}
		out[now] = ++idx;
		all[idx] = now;		
	}
	static void buildLcaTable()
	{
		for(int j=1; j<18; j++)
			for(int i=1; i<=N; i++)
				dp[i][j] = dp[dp[i][j-1]][j-1];
	}
	static void inputQuery()throws Exception {
		Q = Integer.parseInt(br.readLine());// 쿼리 수
		result = new int[Q + 1];
		for(int i=1; i<=Q; i++)
		{
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			int lca = getLCA(a,b);
			
			if(in[a] > in[b])
			{
				int tmp = a;
				a = b;
				b = tmp;
			}
			
			if(lca == a)
			{
				query.add(new Query(in[a], in[b], 0, in[a] / sqrt, i));
				continue;
			}
			
			query.add(new Query(out[a], in[b], lca, out[a] / sqrt, i));
		}
		
		Collections.sort(query);
	}
	static void solve() {
		int L = 1;
		int R = 0;
		for(Query q : query)
		{
			while(R < q.right) plus(all[++R]);
			while(q.left < L) plus(all[--L]);
			while(q.right < R) minus(all[R--]);
			while(L < q.left) minus(all[L++]);
			
			if(q.lca != 0) plus(q.lca);
			
			result[q.idx] = ans; 
			
			if(q.lca != 0) minus(q.lca);
		}
		
		StringBuilder sb = new StringBuilder();
		
		for(int i=1; i<=Q; i++)
			sb.append(result[i]).append('\n');
		
		System.out.print(sb);
	}
	static void plus(int node) {
		++nodeCnt[node];
		if(nodeCnt[node] == 1) {
			if(++cnt[arr[node]] == 1)
				++ans;
		}
		else if(--cnt[arr[node]] == 0)
				--ans;
	}
	static void minus(int node) {
		--nodeCnt[node];
		if(nodeCnt[node] == 0) {
			if(--cnt[arr[node]] == 0)
				--ans;
		}
		else if(++cnt[arr[node]] == 1)
			++ans;
	}

    static int getLCA(int u, int v) {
        if (depth[u] < depth[v]) {
            int temp = u;
            u = v;
            v = temp;
        }
        int diff = depth[u] - depth[v];
        for (int i = 0; diff > 0; i++) {
            if ((diff & 1) == 1) u = dp[u][i];
            diff >>= 1;
        }
        if (u != v) {
            for (int i = 17; i >= 0; i--) {
                if (dp[u][i] != dp[v][i]) {
                    u = dp[u][i];
                    v = dp[v][i];
                }
            }
            u = dp[u][0];
        }
        return u;
    }
	static class Query implements Comparable<Query>{
		int left, right, lca, fac, idx;
		Query(int left, int right, int lca, int fac, int idx){
			this.left = left;
			this.right = right;
			this.lca = lca;
			this.fac = fac;
			this.idx = idx;
		}
		@Override
		public int compareTo(Query o) {
			if(fac != o.fac)
				return fac - o.fac;
			
			if((fac&1) == 0)
				return right - o.right;
			
			return o.right - right;
		}
	}
}