Python机器学习StandardScaler类
StandardScaler
是 Scikit-learn 库中用于标准化数据的一种预处理工具。标准化(Standardization)是一种数据预处理技术,通过移除均值并缩放到单位方差,使数据符合标准正态分布。StandardScaler
类实现了这一功能,通常用于机器学习管道中,以确保特征具有相似的尺度,从而使模型训练过程更加稳定和有效。
基本用法
导入 StandardScaler
类:
1 | from sklearn.preprocessing import StandardScaler |
创建 StandardScaler
实例:
1 | scaler = StandardScaler() |
拟合和转换数据:
- 对于训练数据,使用
fit_transform()
方法,这将计算均值和标准差并立即对数据进行标准化。 - 对于测试数据或新数据,使用
transform()
方法,这将使用在训练数据上计算的均值和标准差进行标准化。
1 | # 假设有训练数据 X_train 和测试数据 X_test |
属性和方法
属性:
mean_
:每个特征的均值。var_
:每个特征的方差。scale_
:每个特征的标准差。
1 | scaler = StandardScaler() |
方法:
fit(X)
:计算数据的均值和标准差。transform(X)
:使用之前计算的均值和标准差对数据进行标准化。fit_transform(X)
:组合fit
和transform
,先计算均值和标准差再进行标准化。inverse_transform(X)
:将标准化的数据还原为原始数据。
1 | scaler = StandardScaler() |
标准化公式
标准化公式如下:
\(X_{\text{scaled}} = \frac{X - \mu}{\sigma}\)
其中:
- \(X\) 是原始数据。
- \(\mu\) 是均值(
mean_
)。 - \(\sigma\) 是标准差(
scale_
)。
使用场景
- 线性模型:例如线性回归、逻辑回归和支持向量机等模型对特征的尺度敏感,标准化有助于提高模型性能和训练速度。
- 距离度量:例如 K 近邻、K 均值聚类,这些算法依赖于欧氏距离,标准化可以防止某些特征对距离计算的影响过大。
- 神经网络:标准化有助于加快神经网络的收敛速度。
transform()与fit_transform()的区别
scaler.transform()
和 scaler.fit_transform()
的区别:
scaler.fit_transform():
fit_transform()
方法是fit()
和transform()
方法的组合。- 首先,它会根据训练数据计算均值和标准差(
fit()
)。 - 然后,它使用计算出的均值和标准差对数据进行标准化(
transform()
)。 - 这个方法适用于训练集数据的处理,因为它在标准化数据之前需要先计算训练集的统计信息。
1
2scaler = StandardScaler()
scaled_data = scaler.fit_transform(training_data)scaler.transform():
transform()
方法仅对数据进行标准化,而不会计算新的均值和标准差。- 它使用之前通过
fit()
方法计算出的均值和标准差对数据进行转换。 - 这个方法通常用于测试集或新数据的处理,因为测试集或新数据应该使用与训练集相同的统计信息进行标准化。
1
2
3scaler = StandardScaler()
scaler.fit(training_data) # 计算均值和标准差
scaled_data = scaler.transform(new_data) # 使用计算出的均值和标准差对新数据进行标准化
举例说明
假设你有一个训练数据集 X_train
和一个测试数据集 X_test
:
1 | from sklearn.preprocessing import StandardScaler |
总结
fit_transform()
:适用于训练数据,先计算均值和标准差再进行标准化。transform()
:适用于测试数据或新数据,使用已经计算出的均值和标准差进行标准化。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Vincent's Blog!