Python 機械学習 Scikit-learnによるサポートベクターマシン(SVM)による学習・分類の実践

今回はSVM(サポートベクターマシン)によるIrisの花のデータセットの分類を実践します。サンプルプログラムをコピーペーストでまずは動かしてみることを念頭にしています。

※Pythonの導入方法http://costudyinfodatabase.nagoya/2017/01/05/%e3%81%82-2/

※Iris のデータセットとは
Irisのデータセットは「Iris setosa」「 Iris versicolor」「 Iris virginica 」の三種のアヤメの花について、sepal length(がくの長さ), sepal width(がくの幅), petal length (花弁の長さ)そして petal width(花弁の幅) の測定値が特徴量として格納されている150個 のデータセットです。機械学習用の標準データセットとして、下記のサンプルコードで簡単にインストールできます。

◆SVMとは

SVMは、分類器の一つ。n次元の特徴量からなるデータ空間に、トレーニングデータで学習して、分類境界線を引きます。そして、その境界線を境にして、未知のデータがどのクラスに属するかを判定します。

線形SVM
特徴量が2次元、クラスが二つ(下では、青とオレンジで色分け)の場合のイメージを以下に示します。2つのクラスの間に境界面を引き、2つのクラスの内、この境界面に最も近いサンプルと境界との距離(マージン)が最も大きくなるようにな境界面を決定境界とするのが線形SVMです。

 

非線形SVM
上記のように、境界面が直線(平面)であるようなSVMを線形SVMと呼んでいます。それに対し、決定境界に曲面を用いる非線形SVMというものがあります。下のようなクラスの分布を持っている場合はどうでしょうか。青とオレンジの2つのクラスを直線で分離するのが困難であり、下図のような曲面の境界が有効であることがわかります。
(曲面の関数を引いてやる手法はここでは割愛。非線形SVMで検索すると大量にヒットします)

いずれの手法を用いるにしても、設定した特徴量空間でクラスが分類できるような集合を成しているか(集合がごちゃ混ぜで境界もなにも引けないような状態になっていないか)が高い識別率を左右します。

◆サンプルプログラムの大枠の流れ

1.irisのデータセットから3次元データ(3つの特徴量を有するデータ)をインポートし、学習データと検証データに切り分けます。

2.学習データを用いて、scikit-learnにある非線形SVMを適用し、分類器(どんな境界線を引くか)を学習させます。

3.検証用データを学習させた分類器で分類し、正答率を検証します。

 

 

irisのデータセットから3次元データ(3つの特徴量を有するデータ)をインポートし、学習データと検証データに切り分けます。切り分けたデータには標準化を行います。

※標準化:データの平均値=0、分散=1になるようにデータを整形すること。

#=========================================================
## data import
from distutils.version import LooseVersion as Version
from sklearn import __version__ as sklearn_version
from sklearn import datasets
import numpy as np

iris = datasets.load_iris()
data1=1
data2=2
data3=3

X = iris.data[:, [data1,data2,data3]]#sepal width/petal length/petal length
y = iris.target

print('Class labels:', np.unique(y))

if Version(sklearn_version) < '0.18':
from sklearn.cross_validation import train_test_split
else:
from sklearn.model_selection import train_test_split

##Split the data into trainning data and test data at a certain ratio
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=0)

from sklearn.preprocessing import StandardScaler

sc = StandardScaler()
sc.fit(X_train)
X_train_std = sc.transform(X_train)
X_test_std = sc.transform(X_test)
#==========================================================

#==========================================================
#data plot
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure()
ax = Axes3D(fig)
#ax.plot_wireframe(X,Y,Z)
#ax.scatter3D(np.ravel(X1),np.ravel(X2),np.ravel(X3),c='b', marker='x',label='1')
ax.scatter3D(np.ravel(X[y==0,0]),np.ravel(X[y==0,1]),np.ravel(X[y==0,2]),c='b', marker='x',label='1')
ax.scatter3D(np.ravel(X[y==1,0]),np.ravel(X[y==1,1]),np.ravel(X[y==1,2]),c='r', marker='o',label='2')
ax.scatter3D(np.ravel(X[y==2,0]),np.ravel(X[y==2,1]),np.ravel(X[y==2,2]),c='y', marker='o',label='3')
ax.set_xlabel(iris.feature_names[data1])
ax.set_ylabel(iris.feature_names[data2])
ax.set_zlabel(iris.feature_names[data3])
plt.legend(loc='upper left')

plt.show()
#========================================================

実行結果

3次元の特徴量空間に3種類の花が色分けしてプロットされています。

 

次に、scikit-learnライブラリからSVMをインポートします。

SVMとして、非線形SVMを選択します。kernel=’rbf’ : 動径基底関数カーネルを用いた非線形SVM。’linear’を選択すると線形SVMを利用できます。ハイパーパラメータであるγを小さくするとサンプルデータの寄与が大きくなり、決定境界がより学習サンプルに追従するように学習されます。詳しくは割愛。

学習させた分類器にテストデータを分類させ、クラス(花の種類)を予測させます。実際の花の種類と、予測された種類を比較し、正答率を表示します。

次に、全データと、テストデータをその予測された花の種類の色で塗りつぶしたものを大きめの〇で囲ってプロットするようにしています。

from sklearn.svm import SVC

## apply non-linear svm to iris
svm = SVC(kernel='rbf', random_state=0, gamma=0.2, C=1.0)
svm.fit(X_train_std, y_train)
y_pred_svm=svm.predict(X_test_std)

# prediction accuracy
from sklearn.metrics import accuracy_score
print('Misclassified samples: %d' % (y_test != y_pred_svm).sum())
print('Accuracy : %.2f' % accuracy_score(np.rint(y_pred_svm), y_test))

#test data plot
fig1 = plt.figure()
ax1 = Axes3D(fig1)
ax1.scatter3D(np.ravel(X_test[y_pred_svm==0,0]),np.ravel(X_test[y_pred_svm==0,1]),np.ravel(X_test[y_pred_svm==0,2]),c='b', marker='x',label='1')
ax1.scatter3D(np.ravel(X_test[y_pred_svm==1,0]),np.ravel(X_test[y_pred_svm==1,1]),np.ravel(X_test[y_pred_svm==1,2]),c='r', marker='o',label='2')
ax1.scatter3D(np.ravel(X_test[y_pred_svm==2,0]),np.ravel(X_test[y_pred_svm==2,1]),np.ravel(X_test[y_pred_svm==2,2]),c='y', marker='o',label='3')
ax1.set_xlabel(iris.feature_names[data1])
ax1.set_ylabel(iris.feature_names[data2])
ax1.set_zlabel(iris.feature_names[data3])
plt.legend(loc='upper left')

#============================================================
#all data plot
fig = plt.figure()
ax2 = Axes3D(fig)
#train data
ax2.scatter3D(np.ravel(X[y==0,0]),np.ravel(X[y==0,1]),np.ravel(X[y==0,2]),c='b', marker='x',label='1')
ax2.scatter3D(np.ravel(X[y==1,0]),np.ravel(X[y==1,1]),np.ravel(X[y==1,2]),c='r', marker='o',label='2')
ax2.scatter3D(np.ravel(X[y==2,0]),np.ravel(X[y==2,1]),np.ravel(X[y==2,2]),c='y', marker='o',label='3')
#test data
ax2.scatter3D(np.ravel(X_test[y_pred_svm==0,0]),np.ravel(X_test[y_pred_svm==0,1]),np.ravel(X_test[y_pred_svm==0,2]),c='b', marker='x',label='1')
ax2.scatter3D(np.ravel(X_test[y_pred_svm==1,0]),np.ravel(X_test[y_pred_svm==1,1]),np.ravel(X_test[y_pred_svm==1,2]),c='r', marker='o',label='2')
ax2.scatter3D(np.ravel(X_test[y_pred_svm==2,0]),np.ravel(X_test[y_pred_svm==2,1]),np.ravel(X_test[y_pred_svm==2,2]),c='y', marker='o',label='3')
#highlight the test data
ax2.scatter3D(np.ravel(X_test[y_pred_svm==0,0]),np.ravel(X_test[y_pred_svm==0,1]),np.ravel(X_test[y_pred_svm==0,2]),c='', marker='o',label='1', s=60)
ax2.scatter3D(np.ravel(X_test[y_pred_svm==1,0]),np.ravel(X_test[y_pred_svm==1,1]),np.ravel(X_test[y_pred_svm==1,2]),c='', marker='o',label='2', s=60)
ax2.scatter3D(np.ravel(X_test[y_pred_svm==2,0]),np.ravel(X_test[y_pred_svm==2,1]),np.ravel(X_test[y_pred_svm==2,2]),c='', marker='o',label='3', s=60)

ax2.set_xlabel(iris.feature_names[data1])
ax2.set_ylabel(iris.feature_names[data2])
ax2.set_zlabel(iris.feature_names[data3])
#============================================================

実行結果

 

Misclassified samples: 1
Accuracy : 0.98

以上のように、98%の正答率で分類できています。おおむねうまく分類されていますが、クラスの集合の境目付近は誤分類が発生しています。

 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です