您现在的位置是:首页 >其他 >怎样在最新的sklearn中使用mnist数据集网站首页其他
怎样在最新的sklearn中使用mnist数据集
简介怎样在最新的sklearn中使用mnist数据集
from sklearn.datasets import fetch_openml mnist = fetch_openml('mnist_784')
In [6]:
X, y =mnist['data'], mnist['target'] X.shape
Out[6]:
(70000, 784)
In [10]:
y.shape X.index
Out[10]:
RangeIndex(start=0, stop=70000, step=1)
In [30]:
%matplotlib inline import matplotlib import matplotlib.pyplot as plt some_digit = X.loc[36011].values print(type(some_digit)) some_digit_image = some_digit.reshape(28, 28) plt.imshow(some_digit_image, cmap = matplotlib.cm.binary, interpolation="nearest") plt.axis("off") plt.show()
<class 'numpy.ndarray'>
In [32]:
y[36011]
Out[32]:
'5'
In [45]:
X_train, X_test, y_train, y_test = X[:60000].values, X[60000:].values, y[:60000].values, y[60000:].values
In [46]:
import numpy as np shuffle_index = np.random.permutation(60000) X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
In [53]:
X_train.shape y_train
Out[53]:
['4', '4', '1', '8', '4', ..., '8', '5', '4', '7', '5'] Length: 60000 Categories (10, object): ['0', '1', '2', '3', ..., '6', '7', '8', '9']
In [54]:
y_train_5 = (y_train == '5') y_test_5 = (y_test == '5') y_train_5
Out[54]:
array([False, False, False, ..., False, False, True])
In [56]:
from sklearn.linear_model import SGDClassifier sgd_clf = SGDClassifier(random_state=42) sgd_clf.fit(X_train, y_train_5)
Out[56]:
SGDClassifier
SGDClassifier(random_state=42)
In [61]:
sgd_clf.predict([X.loc[36011], X.loc[36010]])
Out[61]:
array([ True, False])
In [63]:
from sklearn.model_selection import cross_val_score cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
Out[63]:
array([0.96975, 0.96335, 0.9521 ])
In [65]:
from sklearn.base import BaseEstimator class Never5Classifier(BaseEstimator): def fit(self, X, y=None): pass def predict(self, X): return np.zeros((len(X), 1), dtype=bool)
In [67]:
never_5_clf = Never5Classifier() cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")
Out[67]:
array([0.91025, 0.90975, 0.90895])
In [70]:
y_scores = sgd_clf.decision_function([X.loc[36011], X.loc[36010]]) y_scores
Out[70]:
array([ 3211.94412614, -4886.85783489])
In [72]:
from sklearn.model_selection import cross_val_predict y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method="decision_function") y_scores
Out[72]:
array([-10306.99957365, -6898.31615637, -7299.4630608 , ..., -5958.06340458, -18084.1859439 , 6322.98986326])
In [74]:
y_scores.shape
Out[74]:
(60000,)
In [76]:
from sklearn.ensemble import RandomForestClassifier forest_clf = RandomForestClassifier(random_state=42) y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method="predict_proba")
In [78]:
y_probas_forest[:4]
Out[78]:
array([[0.99, 0.01], [1. , 0. ], [1. , 0. ], [0.96, 0.04]])
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。