您现在的位置是:首页 >其他 >怎样在最新的sklearn中使用mnist数据集网站首页其他

怎样在最新的sklearn中使用mnist数据集

xingxiliang 2023-06-03 20:00:03
简介怎样在最新的sklearn中使用mnist数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')

In [6]:

X, y =mnist['data'], mnist['target']
X.shape

Out[6]:

(70000, 784)

In [10]:

y.shape

X.index

Out[10]:

RangeIndex(start=0, stop=70000, step=1)

In [30]:

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
some_digit = X.loc[36011].values
print(type(some_digit))
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary, interpolation="nearest")
plt.axis("off")
plt.show()
<class 'numpy.ndarray'>

In [32]:

y[36011]

Out[32]:

'5'

In [45]:

X_train, X_test, y_train, y_test = X[:60000].values, X[60000:].values, y[:60000].values, y[60000:].values

In [46]:

import numpy as np

shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

In [53]:

X_train.shape

y_train

Out[53]:

['4', '4', '1', '8', '4', ..., '8', '5', '4', '7', '5']
Length: 60000
Categories (10, object): ['0', '1', '2', '3', ..., '6', '7', '8', '9']

In [54]:

y_train_5 = (y_train == '5')
y_test_5 = (y_test == '5')

y_train_5

Out[54]:

array([False, False, False, ..., False, False,  True])

In [56]:

from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

Out[56]:

SGDClassifier

SGDClassifier(random_state=42)

In [61]:

sgd_clf.predict([X.loc[36011], X.loc[36010]])

Out[61]:

array([ True, False])

In [63]:

from sklearn.model_selection import cross_val_score

cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")

Out[63]:

array([0.96975, 0.96335, 0.9521 ])

In [65]:

from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
    def fit(self, X, y=None):
        pass
    def predict(self, X):
        return np.zeros((len(X), 1), dtype=bool)

In [67]:

never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")

Out[67]:

array([0.91025, 0.90975, 0.90895])

In [70]:

y_scores = sgd_clf.decision_function([X.loc[36011], X.loc[36010]])
y_scores

Out[70]:

array([ 3211.94412614, -4886.85783489])

In [72]:

from sklearn.model_selection import cross_val_predict

y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, 
                            method="decision_function")
y_scores

Out[72]:

array([-10306.99957365,  -6898.31615637,  -7299.4630608 , ...,
        -5958.06340458, -18084.1859439 ,   6322.98986326])

In [74]:

y_scores.shape

Out[74]:

(60000,)

In [76]:

from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,
                                    method="predict_proba")

In [78]:

y_probas_forest[:4]

Out[78]:

array([[0.99, 0.01],
       [1.  , 0.  ],
       [1.  , 0.  ],
       [0.96, 0.04]])

 

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。