서론
K-means clustering 포스트에서 말했지만, k-means clustering은 초기 중심점을 랜덤으로 잡았었다.
K-means는 초기 중심점이 어디에 잡히는지에 따라 clustering 결과 차이가 많이 난다.
물론 최적의 결과가 나올 수도 있지만, 최악이 나올 수도 있다.
즉, 돌릴때마다 결과가 다르게 나올 가능성이 높다.
그리고 위키피디아를 보면, 초기 중심점이 이상하게 잡히면 알고리즘 실행 시간이 길 수 있다는 것도 알 수 있다.
K-means++의 아이디어
아이디어는 이렇다.
중심점들을 최대한 퍼뜨려 놓으면.. 잘된다!
왜 잘될까?
위키피디아의 suboptimal clustering 부분을 보면, 어느정도 이해가 된다.
초기 중심점들이 모여있으면, 의도하지 않게 데이터들이 묶일 가능성이 높아진다.
따라서 초기 중심점들 서로가 최대한 멀리 있게 구성하면
1. 알고리즘 실행 시간이 크게 단축된다.
2. 잘못 clustering될 가능성이 크게 줄어든다.
라고 한다.
자세한 내용은 논문을 참고하자.
진행 과정
데이터들중에서 무작위로 하나 뽑는 것으로 시작한다.
그게 맨 처음 중심점이 되고, K개의 중심점을 찾을 때까지 다음 과정을 따른다.
N($N < K$)개의 중심점을 찾았다고 가정하자.
1. 모든 데이터들과 N개의 중심점들간의 거리를 모두 구한다.
2. 각 데이터에 대해 가장 가까운 중심점과의 거리를 찾는다.
3. 2번 과정에서 찾은 거리들중에 가장 큰 값을 갖는 데이터가 새로운 중심점이 된다.
4. 1~3번 과정을 $N==K$ 가 될때까지 진행한다.
이를 시각화 한 유튜브 영상이 있어 소개한다.
이제 이걸 구현해보자.
Numpy의 broadcasting을 적극 활용했다.
import numpy as np
import matplotlib.pyplot as plt
def get_kmeans_plus_centroids(datas:np.ndarray, centroid_count:int):
"""
[Params]
datas: 데이터셋, (data_count, 2)
centroid_count: 중심점 개수
[Return]
centroids: 초기 중심점들, (centroid_count, 2)
"""
# 무작위로 하나의 데이터를 중심점으로 정한다.
initial_centroid = datas[int(np.random.uniform(0, len(datas)))]
centroids = [initial_centroid]
for _ in range(centroid_count-1):
centroids_ = np.array(centroids, np.float32)
centroids_ = centroids_[np.newaxis,...] # (1, len(centroids), 2)
datas_ = datas[:,np.newaxis,...] # (data_count, 1, 2)
distances = np.sum((centroids_ - datas_)**2, axis=-1)**0.5 # (data_count, len(centroids))
closest_distance_by_centroids = np.min(distances, axis=-1) # (data_count,)
indices = np.array(list(range(len(datas))), np.float32) # 0~(data_count-1), (data_count,)
distances_with_indices = np.stack([closest_distance_by_centroids, indices], axis=-1) # (data_count, 2), (거리, 인덱스)
distances_with_indices = sorted(distances_with_indices, key= lambda x : -x[0]) # 거리에 대해 내림차순 정렬
new_centroid = datas[int(distances_with_indices[0][1])] # 가장 거리값이 큰 데이터 고르기
centroids.append(new_centroid)
centroids = np.stack(centroids, axis=0)
return centroids
연산 이후, 데이터의 shape을 주석으로 적었다.
이전 포스팅에 사용된 데이터셋 생성 코드를 사용하여 데이터셋을 생성하고, 중심점들을 시각화 하면 다음과 같다.
kmeans_plus_centroids = get_kmeans_plus_centroids(dataset_flatten, 5)
plt.scatter(dataset_flatten[:,0], dataset_flatten[:,1])
plt.scatter(kmeans_plus_centroids[:,0], kmeans_plus_centroids[:,1])
plt.show()
랜덤으로 생성한 것과 비교하면, 확실히 중심점들이 퍼져있는 것을 확인할 수 있다.
Clustering
Clustering 결과는 이전 포스팅의 코드를 재활용했다.
centroids = get_kmeans_plus_centroids(dataset_flatten, 5) # 초기 중심점 얻기
difference = np.sum(centroids) # 이전 중심점과의 차이값을 저장하는 변수, 초기값은 초기 중심점들 값의 합으로 했다
datas = dataset_flatten.copy() # 사용할 데이터 복사
# 초기 중심점들 위치 plot
plt.scatter(datas[:,0], datas[:,1])
plt.scatter(centroids[:,0], centroids[:,1])
plt.show()
# 알고리즘 진행
while difference != 0:
clustered_datas = clustering(centroids, datas)
new_centroid = update_centroids(clustered_datas)
difference = np.sum(centroids - new_centroid)
centroids = new_centroid
# 최종 결과 plot
for clustered_data in clustered_datas:
plt.scatter(clustered_data[:,0], clustered_data[:,1])
plt.scatter(centroids[:,0], centroids[:,1])
plt.show()
이전 포스팅 코드와 다른점은 맨 첫줄밖에 없다.
위 코드를 여러번 돌려봐도, 랜덤한 방식에 비해 위와같이 clustering 되는 경우가 압도적으로 많다는 것을 알 수 있다.
생각보다 퍼져있지는 않은데요?
중심값 생성 시, 맨 처음 랜덤으로 뽑은 데이터가 어디에 위치하는 지가 최종적인 결과에 영향을 많이 주기 때문이다.
이는 연속적으로 다음 중심점, 그 다음 중심점의 위치까지 결정하기 때문에 영향이 클 수 밖에 없다.
그러나 이것을 고려해도, 최종적으로 골라진 중심점들은 최대한 멀리 떨어져있음을 알 수 있다.
'ML' 카테고리의 다른 글
[Clustering] K-means clustering 구현해보기 (4) | 2024.05.12 |
---|