CS지식

[알고리즘] MST 최소 신장 (스패닝) 트리

지예환 2025. 2. 22. 00:29
728x90

 

/* 최소 신장 트리(MST) */

최소 신장 트리(Minimum Spanning Tree, MST)는 가중치 그래프에서 모든 정점을 연결하는 최소 비용의 부분 그래프입니다. MST는 사이클이 없고, (V - 1)개의 간선을 가지며, 그래프의 최소 연결 비용을 찾는 데 사용됩니다.

MST의 특징

  1. 연결 그래프: 모든 정점이 연결되어 있어야 합니다.
  2. 최소 비용: 간선의 가중치 합이 최소가 되어야 합니다.
  3. 사이클이 없어야 함: 트리의 성질을 만족해야 합니다.
  4. (V - 1)개의 간선: 정점이 V개라면, MST의 간선 수는 항상 (V - 1)개입니다.

/* 첫 번째 알고리즘, 크루스칼 알고리즘*/

크루스칼 알고리즘간선을 가중치 기준으로 정렬한 후, 최소 비용의 간선을 하나씩 선택하여 MST를 구성하는 방식입니다.

시간 복잡도: O(E log E)

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class 크루스칼알고리즘 {
    static int[] parent;
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int V = Integer.parseInt(st.nextToken());
        parent = new int[V+1];
        int E = Integer.parseInt(st.nextToken());
        PriorityQueue<Edge> pq = new PriorityQueue<>();

        for (int i = 0; i < E; i++) {
            st = new StringTokenizer(br.readLine());
            int v1 = Integer.parseInt(st.nextToken());
            int v2 = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());
            Edge edge = new Edge(v1, v2, weight);
            pq.offer(edge);
        }

        for (int i=1; i<=V; i++) {
            parent[i] = i;
        }

        int answer = 0;

        while(!pq.isEmpty()) {
            Edge edge = pq.poll();
            int from = edge.from;
            int to = edge.to;
            int rootA = find(from);
            int rootB = find(to);
            if (rootA != rootB) {
                union(rootA, rootB);
                answer += edge.cost;
            }
        }


        System.out.println(answer);
    }
    static void union(int a, int b) {
        int rootA = find(a);
        int rootB = find(b);
        if (rootA > rootB) {
            parent[rootA] = rootB;
        }
        else {
            parent[rootB] = rootA;
        }
    }
    static int find(int n) {
        if (parent[n] == n) {
            return n;
        }
        return parent[n] = find(parent[n]);
    }

    static class Edge implements Comparable<Edge> {
        public int from;
        public int to;
        public int cost;

        public Edge(int from, int to, int cost) {
            this.from = from;
            this.to = to;
            this.cost = cost;
        }

        @Override
        public int compareTo(Edge other) {
            return Integer.compare(this.cost, other.cost);
        }
    }
}

 

/* 두 번째 알고리즘, 프림 알고리즘*/

프림 알고리즘정점 중심으로 최소 비용의 간선을 하나씩 추가하는 방식으로 동작합니다.

시간 복잡도: O(E log V)

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.*;

public class 프림알고리즘 {
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int V = Integer.parseInt(st.nextToken());
        int E = Integer.parseInt(st.nextToken());
        List<List<Pair>> list = new ArrayList<>();
        for (int i = 0; i <= V; i++) {
            list.add(new ArrayList<>());
        }

        for (int i = 0; i < E; i++) {
            st = new StringTokenizer(br.readLine());
            int v1 = Integer.parseInt(st.nextToken());
            int v2 = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());
            list.get(v1).add(new Pair(v2, weight));
            list.get(v2).add(new Pair(v1, weight));
        }

        PriorityQueue<Pair> pq = new PriorityQueue<>();
        boolean[] visited = new boolean[V + 1];
        pq.offer(new Pair(1, 0)); // Start from node 1
        int answer = 0;
        int count = 0; // Count of edges in MST

        while (!pq.isEmpty()) {
            Pair start = pq.poll();
            int to = start.to;
            int cost = start.cost;

            if (visited[to]) continue; // 이미 방문한 정점이면 스킵

            visited[to] = true;
            answer += cost;
            count++;

            if (count == V) break; // 모든 정점을 방문하면 종료 (V-1개의 간선 선택됨)

            for (Pair next : list.get(to)) {
                if (!visited[next.to]) { // 방문하지 않은 정점만 추가
                    pq.offer(next);
                }
            }
        }

        System.out.println(answer);
    }

    static class Pair implements Comparable<Pair> {
        public int to;
        public int cost;

        public Pair(int to, int cost) {
            this.to = to;
            this.cost = cost;
        }

        @Override
        public int compareTo(Pair other) {
            return Integer.compare(this.cost, other.cost);
        }
    }
}
728x90