logo头像
Snippet 博客主题

Catboost-实战案例

Catboost有个关键特点:可以直接使用类别特征而不需要预处理


训练模型

数据处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import pandas as pd

transaction = pd.read_csv('./train_transaction.csv')
identity = pd.read_csv('./train_identity.csv')

df = pd.merge(transaction, identity, on='TransactionID')

df_fraud = df[df['isFraud'] == 1]
df_normal = df[df['isFraud'] == 0].sample(100000)

float_features = ['C1', 'C2']
category_features = ['ProductCD', 'DeviceType', 'DeviceInfo']

df = df.fillna({'ProductCD': 'NA', 'DeviceType': 'NA', 'DeviceInfo': 'NA', 'C1': 0, 'C2': 0})

训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from sklearn.model_selection import train_test_split

feature_names = float_features[:]
feature_names.extend(category_features)

Y = df['isFraud']
df = df[feature_names]
X = df[feature_names]

from catboost import CatBoostClassifier

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)

model = CatBoostClassifier(iterations=50,
depth=5,
learning_rate=0.1,
cat_features = category_features,
loss_function='Logloss',
verbose=True)
model.fit(X_train, Y_train)

指标

1
2
3
4
5
6
7
8
9
10
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

predictions = model.predict(X_test)
total_accuracy = accuracy_score(Y_test, predictions)
print("Total Accuracy: %.2f%%" % (total_accuracy * 100.0))

print(classification_report(Y_test, predictions))
print(confusion_matrix(Y_test, predictions))
1
2
3
4
5
6
7
8
9
10
11
12
Total Accuracy: 94.44%
precision recall f1-score support

0 0.95 0.99 0.97 26595
1 0.79 0.39 0.52 2252

accuracy 0.94 28847
macro avg 0.87 0.69 0.75 28847
weighted avg 0.94 0.94 0.94 28847

[[26363 232]
[ 1371 881]]

特征重要性

model.get_feature_importance(prettified=True)


保存模型

model.save_model(‘model.cbm’)


加载模型

1
2
3
from catboost import CatBoostClassifier
model_v1 = CatBoostClassifier()
model_v1.load_model('model.cbm')

数据集说明

预测在线交易是否作弊,数据分为两部分: identity(实体信息)和transaction(交易信息),使用TransactionID可以进行关联,类别特征说明:z

Categorical Features - Transaction

  • ProductCD

  • card1 - card6

  • addr1, addr2

  • P_emaildomain

  • R_emaildomain

  • M1 - M9

  • Categorical Features - Identity

  • DeviceType

  • DeviceInfo

  • id_12 - id_38

评论系统未开启,无法评论!