Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integration lightgbm #746

Merged
merged 2 commits into from
Dec 2, 2024
Merged

feat: integration lightgbm #746

merged 2 commits into from
Dec 2, 2024

Conversation

Zeyi-Lin
Copy link
Member

@Zeyi-Lin Zeyi-Lin commented Dec 2, 2024

Description

测试代码:

import lightgbm as lgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import swanlab
from swanlab.integration.lightgbm import SwanLabCallback

# Step 1: Initialize swanlab
swanlab.init(project="lightgbm-example", name="breast-cancer-classification")

# Step 2: Load the dataset
data = load_breast_cancer()
X = data.data
y = data.target

# Step 3: Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Step 4: Create LightGBM datasets
train_data = lgb.Dataset(X_train, label=y_train)
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

# Step 5: Set parameters
params = {
    'objective': 'binary',
    'metric': 'binary_logloss',
    'boosting_type': 'gbdt',
    'num_leaves': 31,
    'learning_rate': 0.05,
    'feature_fraction': 0.9
}

# Step 6: Train the model with swanlab callback
num_round = 100
gbm = lgb.train(
    params,
    train_data,
    num_round,
    valid_sets=[test_data],
    callbacks=[SwanLabCallback()]
)

# Step 8: Make predictions
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
y_pred_binary = [1 if p >= 0.5 else 0 for p in y_pred]

# Step 9: Evaluate the model
accuracy = accuracy_score(y_test, y_pred_binary)
print(f"模型准确率: {accuracy:.4f}")
swanlab.log({"accuracy": accuracy})

# Step 10: Save the model locally
gbm.save_model('lightgbm_model.txt')

# Step 11: Load the model and predict again
bst_loaded = lgb.Booster(model_file='lightgbm_model.txt')
y_pred_loaded = bst_loaded.predict(X_test)
y_pred_binary_loaded = [1 if p >= 0.5 else 0 for p in y_pred_loaded]

# Step 12: Evaluate the loaded model
accuracy_loaded = accuracy_score(y_test, y_pred_binary_loaded)
print(f"加载模型后的准确率: {accuracy_loaded:.4f}")
swanlab.log({"accuracy_loaded": accuracy_loaded})

# Step 13: Finish the swanlab run
swanlab.finish()

Closes: #744

@Zeyi-Lin Zeyi-Lin requested a review from SAKURA-CAT December 2, 2024 10:50
@Zeyi-Lin Zeyi-Lin self-assigned this Dec 2, 2024
@Zeyi-Lin Zeyi-Lin added the 💪 enhancement New feature or request label Dec 2, 2024
@SAKURA-CAT SAKURA-CAT merged commit c02471a into main Dec 2, 2024
5 checks passed
@SAKURA-CAT SAKURA-CAT deleted the feat-integration-lightgbm branch December 2, 2024 10:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
💪 enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[REQUEST] 集成LightGBM和XGBoost
2 participants