参考
# main
data, _ = [_.cuda() for _ in batch] # 遍历batch
data_support, data_query = data[:p_support], data[p_support:] # [150,3,84,84]
labels_support = torch.arange(way).repeat(shot)
emb_support = model(data_support) # [150,1600]
proto_support = emb_support.reshape(shot, way, -1).mean(0)
labels_proto = torch.arange(way) # 构造假标签[01234,01234,...]
tsne(emb_support, proto_support, labels_support, labels_proto)
# tsne
def tsne(training_feature, proto_feature, train_label, proto_label):
"""
:param training_feature:[shot*way,1600]
:param test_feature:
:param train_label:
"""
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
import random
size_train = training_feature.size()[0]
size_proto = proto_feature.size()[0]
size_sum = size_train + size_proto
training_feature = training_feature.cpu().detach().numpy()
proto_feature = proto_feature.cpu().detach().numpy()
train_label = train_label.numpy()
proto_label = proto_label.numpy()
# t-SNE
# tsne_2D = TSNE(n_components=2, perplexity=50, n_iter=1000, learning_rate=200,
# n_iter_without_progress=10).fit_transform(training_feature)
# training_feature_tsne_2D = tsne_2D
tsne_2D = TSNE(n_components=2, perplexity=50, n_iter=1000, learning_rate=200,
n_iter_without_progress=10).fit_transform(np.concatenate((training_feature, proto_feature)))
training_feature_tsne_2D = tsne_2D[0:size_train, :]
proto_feature_tsne_2D = tsne_2D[size_train:size_sum, :]
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(1, 1, 1)
colors = ['r', 'g', 'b']
markers = ['+', '+', '+']
classes = np.sort(np.unique(train_label))
labels = ['0', '1', '2']
for class_ix, marker, color, label in zip(classes, markers, colors, labels):
ax.scatter(training_feature_tsne_2D[np.where(train_label == class_ix), 0],
training_feature_tsne_2D[np.where(train_label == class_ix), 1],
marker=marker, color=color,
linewidth='1', alpha=0.9, label=label, )
# ax.legend(loc='best')
markers = ['o', 'o', 'o']
# markers = ['o', 'P', 'v']
for class_ix, marker, color, label in zip(classes, markers, colors, labels):
ax.scatter(proto_feature_tsne_2D[np.where(proto_label == class_ix), 0],
proto_feature_tsne_2D[np.where(proto_label == class_ix), 1],
marker=marker, color=color,
linewidth='5', alpha=0.9, label=label)
title = 'title'
plt.title(title)
# plt.show()
save_path = './home/...'
item = random.randint(0, 200)
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.savefig(save_path + '/' + str(item) + '.png')
plt.close()