Q3D-Calibration/GaussianMixture.py

37 lines
891 B
Python

import os
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.mixture import GaussianMixture
path = 'result/bind'
file_list = os.listdir(path)
pbar = tqdm(total=len(file_list))
for file in file_list:
name = file.split('.')[0]
file = os.path.join(path, file)
data = np.loadtxt(file, dtype=np.uint16)
x = data[:, 0]
y = data[:, 1]
model = GaussianMixture(n_components=2)
model.fit(data)
nx = np.array([])
ny = model.predict(data)
fig = plt.figure(figsize=(8, 8))
for cluster in np.unique(ny):
idx = np.where(ny == cluster)[0]
nx = idx if len(idx) > len(nx) else nx
plt.scatter(data[idx, 0], data[idx, 1], s=0.1)
fig.savefig('result/GMM/' + name + '.png')
plt.close()
np.savetxt('result/bind-GMM/' + name + '.txt', data[nx], fmt='%d')
pbar.update(1)
pbar.close()