跳转至

示例代码

本页面包含使用 genrec 的实用示例。

基础使用示例

从零开始训练 RQVAE

import torch
from genrec.models.rqvae import RqVae, QuantizeForwardMode
from genrec.data.p5_amazon import P5AmazonItemDataset
from torch.utils.data import DataLoader

# 创建数据集
dataset = P5AmazonItemDataset(
    root="dataset/amazon",
    split="beauty",
    train_test_split="train"
)

# 创建模型
model = RqVae(
    input_dim=768,
    embed_dim=32,
    hidden_dims=[512, 256, 128],
    codebook_size=256,
    n_layers=3,
    commitment_weight=0.25,
    codebook_mode=QuantizeForwardMode.ROTATION_TRICK
)

# 训练循环
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

for epoch in range(100):
    for batch in dataloader:
        optimizer.zero_grad()

        outputs = model(torch.tensor(batch))
        loss = outputs.loss

        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

使用数据集工厂

from genrec.data.dataset_factory import DatasetFactory

# 创建物品数据集
item_dataset = DatasetFactory.create_item_dataset(
    "p5_amazon",
    "dataset/amazon",
    split="train"
)

# 创建序列数据集
sequence_dataset = DatasetFactory.create_sequence_dataset(
    "p5_amazon", 
    "dataset/amazon",
    split="train",
    pretrained_rqvae_path="./checkpoints/rqvae.pt"
)

自定义配置

from genrec.data.configs import P5AmazonConfig, TextEncodingConfig

# 自定义文本编码配置
text_config = TextEncodingConfig(
    encoder_model="sentence-transformers/all-MiniLM-L6-v2",
    template="产品: {title} | 品牌: {brand} | 类别: {categories}",
    batch_size=32
)

# 自定义数据集配置
dataset_config = P5AmazonConfig(
    root_dir="my_data",
    split="electronics",
    text_config=text_config
)

高级示例

多 GPU 训练

from accelerate import Accelerator

def train_with_accelerate():
    accelerator = Accelerator()

    # 模型、优化器、数据加载器
    model = RqVae(...)
    optimizer = torch.optim.AdamW(model.parameters())
    dataloader = DataLoader(...)

    # 准备分布式训练
    model, optimizer, dataloader = accelerator.prepare(
        model, optimizer, dataloader
    )

    for epoch in range(epochs):
        for batch in dataloader:
            optimizer.zero_grad()

            with accelerator.autocast():
                outputs = model(batch)
                loss = outputs.loss

            accelerator.backward(loss)
            optimizer.step()

自定义数据集实现

from genrec.data.base_dataset import BaseRecommenderDataset

class MyCustomDataset(BaseRecommenderDataset):
    def download(self):
        # 实现数据下载逻辑
        pass

    def load_raw_data(self):
        # 加载原始数据文件
        return {"items": items_df, "interactions": interactions_df}

    def preprocess_data(self, raw_data):
        # 自定义预处理
        return processed_data

    def extract_items(self, processed_data):
        return processed_data["items"]

    def extract_interactions(self, processed_data):
        return processed_data["interactions"]

集成示例

Weights & Biases 集成

import wandb

# 初始化 wandb
wandb.init(
    project="my-recommendation-project",
    config={
        "learning_rate": 0.0005,
        "batch_size": 64,
        "model_type": "rqvae"
    }
)

# 训练过程中记录指标
for epoch in range(epochs):
    # ... 训练代码 ...

    wandb.log({
        "epoch": epoch,
        "loss": loss.item(),
        "reconstruction_loss": recon_loss.item(),
        "quantization_loss": quant_loss.item()
    })

超参数调优

import optuna

def objective(trial):
    # 建议超参数
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
    embed_dim = trial.suggest_categorical("embed_dim", [16, 32, 64])

    # 使用建议的参数训练模型
    model = RqVae(embed_dim=embed_dim, ...)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # 训练循环
    val_loss = train_and_evaluate(model, optimizer, batch_size)

    return val_loss

# 运行优化
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=100)

评估示例

模型评估

def evaluate_model(model, test_dataloader, device):
    model.eval()
    total_loss = 0
    total_samples = 0

    with torch.no_grad():
        for batch in test_dataloader:
            batch = batch.to(device)
            outputs = model(batch)

            total_loss += outputs.loss.item() * len(batch)
            total_samples += len(batch)

    return total_loss / total_samples

# 评估 RQVAE
test_loss = evaluate_model(rqvae_model, test_dataloader, device)
print(f"测试重构损失: {test_loss:.4f}")

推荐生成

def generate_recommendations(tiger_model, user_sequence, top_k=10):
    """为用户序列生成 Top-K 推荐"""
    tiger_model.eval()

    with torch.no_grad():
        # 编码用户序列
        logits = tiger_model.generate(user_sequence, max_length=top_k)

        # 获取 Top-K 物品
        top_items = torch.topk(logits, top_k).indices

    return top_items.tolist()

# 生成推荐
user_seq = [1, 5, 23, 45]  # 用户交互历史
recommendations = generate_recommendations(tiger_model, user_seq, top_k=10)
print(f"推荐物品: {recommendations}")

实用工具

数据分析

from genrec.data.processors.sequence_processor import SequenceStatistics

# 分析序列统计信息
stats = SequenceStatistics.compute_sequence_stats(sequence_data)
print(f"平均序列长度: {stats['avg_seq_length']:.2f}")
print(f"唯一物品数量: {stats['num_unique_items']}")

模型检查

def inspect_codebook_usage(rqvae_model, dataloader):
    """分析码本利用率"""
    used_codes = set()

    with torch.no_grad():
        for batch in dataloader:
            outputs = rqvae_model(batch)
            semantic_ids = outputs.sem_ids
            used_codes.update(semantic_ids.flatten().tolist())

    usage_rate = len(used_codes) / rqvae_model.codebook_size
    print(f"码本利用率: {usage_rate:.2%}")

    return used_codes

used_codes = inspect_codebook_usage(model, dataloader)

技巧和最佳实践

内存优化

# 为大型模型启用梯度检查点
model.gradient_checkpointing_enable()

# 使用混合精度训练
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()

    with autocast():
        outputs = model(batch)
        loss = outputs.loss

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

调试

# 启用详细日志
import logging

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# 记录模型统计信息
def log_model_stats(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    logger.info(f"总参数量: {total_params:,}")
    logger.info(f"可训练参数量: {trainable_params:,}")