快速开始¶
本指南将帮助您快速上手 genrec 框架。
前置要求¶
- Python 3.8 或更高版本
- CUDA 11.0+ (如果使用 GPU)
- 8GB+ GPU 显存(推荐)
安装¶
1. 克隆仓库¶
2. 安装依赖¶
3. 准备数据¶
下载 P5 Amazon 数据集:
第一个实验:训练 RQVAE¶
1. 查看配置文件¶
主要配置参数:
- train.iterations=400000: 训练迭代次数
- train.batch_size=64: 批量大小
- train.learning_rate=0.0005: 学习率
- train.dataset_folder="dataset/amazon": 数据集路径
2. 开始训练¶
训练过程中您会看到: - 自动数据下载和预处理 - 文本特征编码进度 - 训练损失和指标 - 模型检查点保存
3. 监控训练¶
如果启用了 Weights & Biases 日志:
访问 wandb.ai 查看训练进度。
第二个实验:训练 TIGER¶
1. 确保 RQVAE 已训练¶
TIGER 需要预训练的 RQVAE 模型来生成语义 ID:
2. 配置 TIGER¶
编辑 config/tiger/p5_amazon.gin:
3. 开始训练¶
理解框架结构¶
数据处理流水线¶
graph TD
A[原始数据] --> B[数据下载]
B --> C[预处理]
C --> D[文本编码]
D --> E[序列生成]
E --> F[数据集]
模型训练流程¶
graph TD
A[配置文件] --> B[数据集加载]
B --> C[模型初始化]
C --> D[训练循环]
D --> E[评估]
E --> F[检查点保存]
F --> D
自定义配置¶
创建自定义配置¶
# my_config.gin
import genrec.data.p5_amazon
import genrec.models.rqvae
# 自定义参数
train.batch_size=32
train.learning_rate=0.001
train.vae_hidden_dims=[256, 128, 64]
# 使用自定义数据路径
train.dataset_folder="path/to/my/data"
使用自定义配置¶
评估模型¶
RQVAE 评估¶
from genrec.models.rqvae import RqVae
from genrec.data.p5_amazon import P5AmazonItemDataset
# 加载模型
model = RqVae.load_from_checkpoint("path/to/checkpoint.pt")
# 加载测试数据
test_dataset = P5AmazonItemDataset(
root="dataset/amazon",
train_test_split="eval"
)
# 评估重构质量
reconstruction_loss = model.evaluate(test_dataset)
TIGER 评估¶
from genrec.models.tiger import Tiger
from genrec.modules.metrics import TopKAccumulator
# 加载模型
model = Tiger.load_from_checkpoint("path/to/checkpoint.pt")
# 计算 Recall@K
metrics = TopKAccumulator(k=10)
recall = metrics.compute_recall(model, test_dataloader)
常见问题¶
Q: 内存不足怎么办?¶
A: 调整以下参数:
Q: 训练速度慢?¶
A: 优化建议: - 使用更大的批量大小 - 启用混合精度训练 - 使用多 GPU 训练
Q: 如何添加新数据集?¶
A: 参考自定义数据集指南