최근접 점의 쌍 찾기(Closest Pair)
2차원 평면 상의 n개의 점이 입력으로 주어질 때, 거리가 가장 가까운 한 쌍의 점을 찾는 문제
모든 점에 대하여 각각의 두 점 사이의 거리를 계산하는 방법의 시간 복잡도는 O(n2)이다.
분할 정복을 이용하면 시간 복잡도는 O(n(logn)2)이다.
책에서는 O(n(logn)2)이라고 하는데, 아직 y좌표에 대해 반드시 정렬해야 하는건지 잘 모르겠고, 찾아보니 O(nlogn)이라는 정보도 있다.
https://en.wikipedia.org/wiki/Closest_pair_of_points_problem
분할 정복
알고리즘
2개나 3개의 점이 남을 때까지 분할한다.
최근접 쌍을 찾고 분할한 부분문제들을 합칠 때, 합치는 경계를 기준으로 좌우의 점들에 대해 다시 검증을 해야 한다.
중간 영역에 속하는 점들을 판별하는 방법은 다음과 같다.
각 부분문제의 최근접 점 사이의 거리 중 작은 값을 택하고 이를 d라고 하자.
왼쪽 부분문제의 가장 오른쪽 점의 x좌표에서 d만큼 뺀 x좌표부터, 오른쪽 부분문제의 가장 왼쪽 점의 x좌표에서 d만큼 더한 x좌표 사이에 포함되는 점들이 바로 중간 영역에 속하는 점들이라고 할 수 있다.
arr는 x좌표 기준 오름차순으로 정렬된 좌표(x, y)들의 배열이다.
Pseudo Code
ClosestPair(arr) { if (arr.size() <= 3) return (2개 또는 3개의 점들 사이의 최근접 쌍); arr를 반으로 나눠서 arr_left와 arr_right로 분할한다. closest_pair_left = ClosestPair(arr_left); closest_pair_right = ClosestPair(arr_right); d = min(closest_pair_left, closest_pair_right); 를 통해 중간 영역에 속하는 점들을 판별한다. 중간 영역에 속하는 점들 중에서 최근접 점의 쌍을 closest_pair_center로 저장한다. return (closest_pair_left, closest_pair_right, closest_pair_center 중에 최근접 점의 쌍); }
시간 복잡도
처음 배열을 정렬할 때의 시간 복잡도는 O(nlogn)이다.
분할해서 계산하는건 모두 O(1)이라고 생각되고, 층 수가 O(logn)인 것 같다.
따라서 시간 복잡도는 O(nlogn)인데.. 처음 정렬하는게 제일 영향이 크다는게 의아하다.
코드
#include <iostream>
#include <vector>
#include <cmath>
#include <limits>
class vec2;
float GetDistance(const vec2& v1, const vec2& v2);
class vec2
{
public:
int x_;
int y_;
vec2()
{
x_ = rand() % 100;
y_ = rand() % 100;
}
vec2(int x_in, int y_in)
: x_(x_in), y_(y_in)
{}
vec2 operator - (const vec2& v) const
{
return (vec2(x_ - v.x_, y_ - v.y_));
}
friend std::ostream& operator << (std::ostream& out, const vec2& v)
{
out << '(' << v.x_ << ", " << v.y_ << ')';
return out;
}
};
class ClosestPair
{
public:
vec2 v1_;
vec2 v2_;
float dist_;
ClosestPair()
{
dist_ = std::numeric_limits<float>::max();
}
ClosestPair(const vec2& v1, const vec2& v2)
{
v1_ = v1;
v2_ = v2;
dist_ = GetDistance(v1, v2);
}
friend std::ostream& operator << (std::ostream& out, const ClosestPair& cp)
{
out << "Point1 : " << cp.v1_ << "\tPoint2 : " << cp.v2_ << "\tDistance : " << cp.dist_ << std::endl;
return out;
}
};
float GetDistance(const vec2& v1, const vec2& v2)
{
return std::sqrt(std::pow((v1 - v2).x_, 2) + std::pow((v1 - v2).y_, 2));
}
void QuickSort(std::vector<vec2>& arr, int left, int right)
{
if (left >= right)
return ;
int pivot = (left + right) / 2;
int high = left + 1;
int low = right;
std::swap(arr[pivot], arr[left]);
while (high <= low)
{
while ((high <= right) && (arr[high].x_ <= arr[left].x_))
high++;
while ((low > left) && (arr[low].x_ >= arr[left].x_))
low--;
if ((high > right) || (low == left) || (high >= low))
break;
std::swap(arr[high], arr[low]);
}
std::swap(arr[left], arr[low]);
QuickSort(arr, left, low - 1);
QuickSort(arr, low + 1, right);
}
ClosestPair GetUnderThreeCase(const std::vector<vec2>& v, int left, int right)
{
if (right == left + 2)
{
float dist1 = GetDistance(v[left], v[left + 1]);
float dist2 = GetDistance(v[left], v[right]);
float dist3 = GetDistance(v[left + 1], v[right]);
if (dist1 < dist2 && dist1 < dist3)
return ClosestPair(v[left], v[left + 1]);
else if (dist2 < dist1 && dist2 < dist3)
return ClosestPair(v[left], v[right]);
else
return ClosestPair(v[left + 1], v[right]);
}
else
return ClosestPair(v[left], v[right]);
}
ClosestPair CheckCenter(const std::vector<vec2>& v, int left, int right, float d)
{
float x_leftside = v[(left + right) / 2].x_ - d;
float x_rightside = v[(left + right) / 2 + 1].x_ + d;
ClosestPair closest_pair_center;
for (int i = 0; i <= (left + right) / 2; ++i)
{
if (v[i].x_ < x_leftside)
continue ;
for (int j = (left + right) / 2 + 1; j <= right; ++j)
{
if (v[j].x_ > x_rightside)
continue ;
if (GetDistance(v[i], v[j]) < closest_pair_center.dist_)
closest_pair_center = ClosestPair(v[i], v[j]);
}
}
return closest_pair_center;
}
ClosestPair GetClosestPair(const std::vector<vec2>& v, int left, int right)
{
if (right <= left + 2)
return GetUnderThreeCase(v, left, right);
ClosestPair closest_pair_left = GetClosestPair(v, left, (left + right) / 2);
ClosestPair closest_pair_right = GetClosestPair(v, (left + right) / 2 + 1, right);
float d = std::min(closest_pair_left.dist_, closest_pair_right.dist_);
ClosestPair closest_pair_center = CheckCenter(v, left, right, d);
if (closest_pair_left.dist_ < closest_pair_right.dist_ && \
closest_pair_left.dist_ < closest_pair_center.dist_)
return closest_pair_left;
else if (closest_pair_right.dist_ < closest_pair_left.dist_ && \
closest_pair_right.dist_ < closest_pair_center.dist_)
return closest_pair_right;
else
return closest_pair_center;
}
int main()
{
using namespace std;
srand(time(NULL));
vector<vec2> v(10);
QuickSort(v, 0, v.size() - 1);
for (const auto &e : v)
cout << e << '\n';
ClosestPair closest_pair;
closest_pair = GetClosestPair(v, 0, v.size() - 1);
cout << closest_pair << endl;
}
코드를 완벽히 하려면
sqrt
를 제외하고 비교하는 등의 최적화를 하고, 예외처리도 해서 좀 더 정리해야 한다.그래픽 요소 추가하면 좋을 듯
책에서는 부분문제를 합칠 때 y좌표를 기준으로 정렬한 뒤 d와 비교하라고 하는데, 굳이 정렬해야 하는 이유를 찾지 못해서 그냥 비교했다.