알고리즘/PST 세그먼트트리

Persistent Segment Tree(PST) - 영속성 세그먼트 트리

kyjdummy 2025. 4. 6. 16:07

[ 글을 쓰게 된 계기 ]

- PST(Persistent Segment Tree) 알고리즘을 공부하면서 느낀 점은, 개념에 대한 대략적인 설명은 많지만, 원리나 코드 적용 방법, 구현 과정에서의 인사이트를 초보자가 쉽게 이해할 수 있도록 정리한 자료는 거의 없다는 것이었습니다. 그래서 제가 문제를 풀며 정리한 내용을, 같은 고민을 하는 분들과 공유하고자 이 글을 작성하게 되었습니다. 이 글은 기본적인 세그먼트 트리 구현 방법을 알고 있는 분을 대상으로 작성되었습니다. 세그먼트 트리의 작동 원리를 모른다면, 본 내용을 이해하기는 조금 어려울 수 있습니다.

 

[ 영속성(Persistence)이란? ]

- 영속성이란, 이전 상태를 보존하면서 새로운 상태를 만들어내는 것을 의미합니다.

- 일반적인 세그먼트 트리처럼, 기존의 값을 덮어쓰지 않고, 변경 이력을 '버전별'로 관리합니다.

 

 

[ 일반적인 세그먼트 트리와의 차이 ]

구분 일반 세그먼트 트리 PST(영속성) 세그먼트 트리
업데이트 덮어쓰기 버전별로 새로운 트리 생성
과거 데이터 조회 불가능 가능
메모리 효율적 업데이트마다 노드가 늘어남

 

 

[ Persistent Segment Tree의 시각적 자료 ]

- 회색으로 칠해진 원래 트리에서, 오른쪽 가장 아래 값을 수정한다고 했을 때, 수정되는 파랑, 빨강 노드들만 새로 생성하고, 나머지는 주솟값을 연동하는 방식으로 구현합니다.

 

 

 

[ 구현하는 방법 ]

영속성 세그먼트 트리를 구현하기 위해 알아야 할 몇 가지가 있습니다. 

먼저 일반 세그먼트 트리는 아래 코드와 같이 1차원 배열로 간단하게 만들 수 있습니다.

 

int tree[] = new int[H];

 

 

 

하지만 영속성 세그먼트 트리를 구현할 때는 위와 같이 1차원 배열로 선언하는 것은 좋지 않습니다.(노드 업데이트가 적다면 상관없음) 영속성 세그먼트 트리는 값을 업데이트할 때마다 노드가 지속적으로 늘어나기 때문에, 노드를 많이 생성할 경우 언젠가 배열을 초과할 것입니다. 그래서 안전하게 List로 세그먼트 트리를 표현합니다.

 

그리고 필요한 노드만 생성하고, 나머지 노드들은 기존 노드의 주솟값(정확히는 list에서의 노드의 위치)을 갖고 있어야 하기 때문에 별도의 Node 자료구조를 생성합니다. 아래 코드는 Node 객체와 노드를 담을 리스트를 선언한 것입니다. 아래 코드에서 sum은 누적합 세그먼트 트리를 구현한다고 가정했을 때 누적합을 담을 변수입니다.

 

class Node{
	int left, right, sum;
	Node(int left, int right, int sum){
		this.left = left;
		this.right = right;
		this.sum = sum;
	}
}
List<Node> nodes = new ArrayList<>();

 

 

또한, 위 그림에서 보듯이, 오른쪽 아래 값을 수정할 때마다 루트 노드가 계속해서 추가로 늘어나며 생성됨을 알 수 있습니다. 이렇게 업데이트가 될 때마다 root 노드가 계속 생성되며 늘어나기 때문에 이 늘어나는 root 노드를 계속 저장해 줄 배열이나 리스트를 따로 두어야 합니다. 이렇게 저장해 놓으면, root 노드마다 어떻게 업데이트 되었는지 알 수 있는 것입니다. 이렇게 관리하기 때문에 버전별로 관리한다고 표현합니다. root 노드의 값을 저장할 때 저는 int형 배열로 생성했습니다. 이 배열안에 root 노드의 값이 들어가있습니다. 이 root 노드의 값은, 세그먼트 트리를 표현하는 list에서의 root노드 위치입니다.

 

 

int[] roots = new int[업데이트를 하는 횟수 + 1];
// +1을 하는 이유는, 최초 세그먼트트리 init시 루트노드를 담아야 하기 때문

 

 

위와 같이 기본적인 개념을 알고, 누적합을 구하는 세그먼트 트리를 만들되, 업데이트 시점마다 버전(새 루트 노드)을 저장하도록 하는 영속성 세그먼트 트리 코드를 보여드리겠습니다.

 

코드의 전체 흐름은, 최초 세그먼트 트리를 만들고(init), 값을 2번 업데이트하고, 각 버전 마다의 누적합을 출력하는 것입니다. 배열의 길이는 5이고, 업데이트는 2번 한다고 가정했습니다.

 

import java.util.ArrayList;
import java.util.List;

class Node{
	int left, right, sum;
	Node(int left, int right, int sum){
		this.left = left;
		this.right = right;
		this.sum = sum;
	}
}

class Main{
	
	static final int LEN = 5;
	static List<Node>	nodes = new ArrayList<>();
	static int[] roots = new int[3];
	
	public static void main(String[] args)throws Exception{
		
		roots[0] = init(1, LEN);	// 최초 세그먼트 트리 생성
		
		roots[1] = update(roots[0], 1, LEN, 5, 100);// 최초 root값 = root[0]
		
		roots[2] = update(roots[1], 1, LEN, 1, 100);// 첫번 째 갱신 후 root 값 = root[1]
		// roots[2]에는 두번 째 갱신 후 루트의 위치가 들어가 있습니다.
        
		int version1 = sum(roots[1], 1, LEN, 1, LEN);// 첫번 째 갱신 후의 sum
		
		int version2 = sum(roots[2], 1, LEN, 1, LEN);// 두번 째 갱신 후의 sum
		
		System.out.println(version1);// 출력 결과 : 100
		System.out.println(version2);// 출력 결과 : 200
	}
	public static int sum(int nowNode, int s, int e, int left, int right) {
		if(e < left || right < s)// 유효하지 않은 범위일 때 0 반환
			return 0;
		
		Node now = nodes.get(nowNode);
		
		if(left<=s && e<= right)
			return now.sum;	// 유효한 범위일 때 값 반환
		
		int mid			= (s + e) >> 1;
		int leftNode	= now.left;
		int rightNode	= now.right;
		
		return sum(leftNode, s, mid, left, right)
				+ sum(rightNode, mid + 1, e, left, right);
	}
	public static int update(int nowNode, int s, int e, int targetIdx, int value) {
		Node now = nodes.get(nowNode);
		if(s == e)
		{
			nodes.add(new Node(-1, -1, now.sum + value));// 기존 노드의 sum에서 value를 추가한 노드를 새롭게 생성
		}
		else
		{
			int mid		= (s + e) >> 1;
			int left	= now.left;
			int right	= now.right;
			// 업데이트할 위치(targetIdx)가 mid보다 왼쪽에 있다면 왼쪽으로 내려감
			if(targetIdx <= mid)
			{
				left = update(now.left, s, mid, targetIdx, value);
			}
			// 업데이트할 위치(targetIdx)가 mid보다 오른쪽에 있다면 오른쪽으로 내려감
			else
			{
				right = update(now.right, mid + 1, e, targetIdx, value);
			}
			// 새 노드 생성
			nodes.add(new Node(left, right, now.sum + value));
		}
		// 새롭게 만든 노드의 위치 반환
		return nodes.size() - 1;
	}
	public static int init(int s, int e) {
		if(s == e)// 루트 노드일 경우 left, right를 -1로 하여 유효하지 않은 값으로 세팅, 최초 sum값은 0으로 세팅
		{
			nodes.add(new Node(-1,-1, 0));
		}
		else
		{
			// 루트노드가 아닌 경우, 왼쪽과 오른쪽을 각각 탐색하여 그 노드의 위치를 가져옴
			int mid = (s + e) >> 1;
			int left = init(s, mid);
			int right = init(mid + 1, e);
			
			nodes.add(new Node(left, right, 0));
		}
		// 현재까지 노드에 들어간 위치를 반환함으로 써 nodes 리스트에서 현재 노드의 위치를 알 수 있게 한다.
		return nodes.size() - 1;
	}
}

 

 

[ 변수 설명 ]

구현의 편의성을 위해 세그먼트 트리의 node를 담을 리스트(nodes)와, 루트 노드의 정보를 담을 배열(roots)을 전역변수(static)으로 선언했습니다. 

 

[ init 함수 설명 ]

세그먼트 트리 생성은 후위 순회 식으로 값을 추가합니다. 루트 노드일 때와, 아닐 때로 나누어 각각 노드를 생성해 넣습니다. 그리고 반환 값은 새롭게 생성한 노드의 위치(nodes.size() - 1)를 반환해야 합니다. init 함수에서 무조건 하나의 node를 생성하기 때문에 항상 새롭게 생성한 노드의 위치를 반환해 주어야 합니다. 

 

[ update 함수 설명 ]

update 함수도 init 함수와 크게 다르지 않습니다. 무조건 함수안에서 하나의 노드를 새롭게 생성합니다. 다만 누적합(sum)을 기존 노드의 값 + 업데이트하려는 값(value)을 해주는 것이 다릅니다. 추가로 타겟팅하는 위치의 인덱스(targetIdx)를 잘 찾아가도록 mid 값으로 분기를 해주면 됩니다.

 

[ sum 함수 설명 ]

sum 함수는 일반적인 세그먼트 트리의 누적합을 구하는 것과 크게 다르지 않아 생략합니다. 

 

[ roots 저장 방식 ]

코드를 보면 update할 때마다 roots 배열에 값을 저장하는 것을 볼 수 있습니다. node가 후위 순회 식으로 저장되기 때문에 roots는 결국 그 당시 nodes 리스트의 size() - 1 값이 됩니다.