当前位置:首页 > Python > 正文

Python实现决策树详解(手把手教你用Scikit-learn构建机器学习决策树模型)

机器学习入门的过程中,决策树算法因其直观、易于理解而成为初学者的首选模型之一。本文将带你从零开始,使用Python决策树库——Scikit-learn,一步步构建一个完整的决策树分类器。即使你是编程小白,也能轻松上手!

Python实现决策树详解(手把手教你用Scikit-learn构建机器学习决策树模型) Python决策树  决策树算法 机器学习入门 Scikit-learn决策树 第1张

什么是决策树?

决策树是一种树形结构的监督学习算法,用于分类和回归任务。它通过一系列“是/否”问题对数据进行分割,最终将样本归入某一类别或预测一个数值。例如:判断一个水果是否为苹果,可以通过颜色、形状、大小等特征逐步判断。

准备工作:安装必要库

要使用Scikit-learn决策树,你需要先安装以下Python库:

pip install scikit-learn pandas matplotlib

步骤一:导入所需模块

import pandas as pdfrom sklearn.datasets import load_irisfrom sklearn.model_selection import train_test_splitfrom sklearn.tree import DecisionTreeClassifier, plot_treeimport matplotlib.pyplot as plt

步骤二:加载数据集

我们使用经典的鸢尾花(Iris)数据集,它包含150个样本,每个样本有4个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度),目标是预测花的种类(共3类)。

# 加载数据iris = load_iris()X = iris.data  # 特征y = iris.target  # 标签# 转换为DataFrame便于查看df = pd.DataFrame(X, columns=iris.feature_names)df['target'] = yprint(df.head())

步骤三:划分训练集与测试集

X_train, X_test, y_train, y_test = train_test_split(    X, y, test_size=0.3, random_state=42)

步骤四:创建并训练决策树模型

# 创建决策树分类器clf = DecisionTreeClassifier(    criterion='gini',        # 分割标准:基尼不纯度    max_depth=3,            # 最大深度,防止过拟合    random_state=42)# 训练模型clf.fit(X_train, y_train)

步骤五:评估模型性能

# 在测试集上预测y_pred = clf.predict(X_test)# 计算准确率accuracy = clf.score(X_test, y_test)print(f"模型准确率: {accuracy:.2f}")

步骤六:可视化决策树

plt.figure(figsize=(12, 8))plot_tree(    clf,    feature_names=iris.feature_names,    class_names=iris.target_names,    filled=True)plt.title("决策树可视化")plt.show()

小贴士:避免过拟合

决策树容易过拟合,可通过以下参数控制:

  • max_depth:限制树的最大深度
  • min_samples_split:内部节点再划分所需最小样本数
  • min_samples_leaf:叶子节点最少样本数

总结

通过本教程,你已经掌握了如何使用Python决策树进行分类任务。无论是机器学习入门还是实际项目开发,Scikit-learn决策树都是一个强大而简洁的工具。记住,理解决策树算法的核心逻辑比代码本身更重要——它模拟了人类做决策的过程。

动手试试吧!修改参数、更换数据集,你会发现更多乐趣。