跳转至

训练器 API 参考

训练工具和脚本的详细文档。

基础训练器

BaseTrainer

所有训练器的基础类。

class BaseTrainer:
    """基础训练器类"""

    def __init__(
        self,
        model: LightningModule,
        config: TrainingConfig,
        logger: Optional[Logger] = None
    ):
        self.model = model
        self.config = config
        self.logger = logger or self._create_default_logger()
        self.trainer = None

    def _create_default_logger(self) -> Logger:
        """创建默认日志记录器"""
        from pytorch_lightning.loggers import TensorBoardLogger

        return TensorBoardLogger(
            save_dir=self.config.log_dir,
            name=self.config.experiment_name,
            version=self.config.version
        )

方法:

setup_trainer()

设置 PyTorch Lightning 训练器。

def setup_trainer(self) -> None:
    """设置训练器"""
    from pytorch_lightning import Trainer
    from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

    # 回调函数
    callbacks = []

    # 检查点回调
    checkpoint_callback = ModelCheckpoint(
        dirpath=self.config.checkpoint_dir,
        filename='{epoch}-{val_loss:.2f}',
        monitor='val_loss',
        mode='min',
        save_top_k=self.config.save_top_k,
        save_last=True
    )
    callbacks.append(checkpoint_callback)

    # 早停回调
    if self.config.early_stopping_patience > 0:
        early_stop_callback = EarlyStopping(
            monitor='val_loss',
            patience=self.config.early_stopping_patience,
            mode='min'
        )
        callbacks.append(early_stop_callback)

    # 创建训练器
    self.trainer = Trainer(
        max_epochs=self.config.max_epochs,
        gpus=self.config.gpus,
        precision=self.config.precision,
        gradient_clip_val=self.config.gradient_clip_val,
        accumulate_grad_batches=self.config.accumulate_grad_batches,
        val_check_interval=self.config.val_check_interval,
        callbacks=callbacks,
        logger=self.logger,
        deterministic=self.config.deterministic,
        enable_progress_bar=self.config.progress_bar
    )

train(train_dataloader, val_dataloader)

执行训练。

def train(
    self,
    train_dataloader: DataLoader,
    val_dataloader: Optional[DataLoader] = None
) -> None:
    """
    执行模型训练

    Args:
        train_dataloader: 训练数据加载器
        val_dataloader: 验证数据加载器
    """
    if self.trainer is None:
        self.setup_trainer()

    self.trainer.fit(self.model, train_dataloader, val_dataloader)

test(test_dataloader)

执行测试。

def test(self, test_dataloader: DataLoader) -> Dict[str, float]:
    """
    执行模型测试

    Args:
        test_dataloader: 测试数据加载器

    Returns:
        测试结果字典
    """
    if self.trainer is None:
        self.setup_trainer()

    results = self.trainer.test(self.model, test_dataloader)
    return results[0] if results else {}

RQVAE 训练器

RQVAETrainer

专门用于训练 RQVAE 模型的训练器。

class RQVAETrainer(BaseTrainer):
    """RQVAE 训练器"""

    def __init__(
        self,
        model: RqVae,
        config: RQVAETrainingConfig,
        dataset: ItemDataset,
        logger: Optional[Logger] = None
    ):
        super().__init__(model, config, logger)
        self.dataset = dataset

    def create_dataloaders(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """
        创建数据加载器

        Returns:
            (train_loader, val_loader, test_loader): 数据加载器元组
        """
        # 分割数据集
        train_size = int(0.8 * len(self.dataset))
        val_size = int(0.1 * len(self.dataset))
        test_size = len(self.dataset) - train_size - val_size

        train_dataset, val_dataset, test_dataset = random_split(
            self.dataset, 
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(self.config.random_seed)
        )

        # 创建数据加载器
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=True
        )

        return train_loader, val_loader, test_loader

train_model()

训练 RQVAE 模型。

def train_model(self) -> RqVae:
    """
    训练 RQVAE 模型

    Returns:
        训练好的模型
    """
    # 创建数据加载器
    train_loader, val_loader, test_loader = self.create_dataloaders()

    # 执行训练
    self.train(train_loader, val_loader)

    # 测试模型
    if self.config.run_test:
        test_results = self.test(test_loader)
        print(f"Test results: {test_results}")

    return self.model

evaluate_reconstruction(test_dataloader)

评估重构质量。

def evaluate_reconstruction(self, test_dataloader: DataLoader) -> Dict[str, float]:
    """
    评估重构质量

    Args:
        test_dataloader: 测试数据加载器

    Returns:
        评估指标字典
    """
    self.model.eval()
    device = next(self.model.parameters()).device

    total_mse = 0
    total_cosine_sim = 0
    total_samples = 0

    with torch.no_grad():
        for batch in test_dataloader:
            if isinstance(batch, dict):
                features = batch['features'].to(device)
            else:
                features = batch.to(device)

            # 前向传播
            reconstructed, _, _, _ = self.model(features)

            # 计算指标
            mse = F.mse_loss(reconstructed, features, reduction='sum')
            cosine_sim = F.cosine_similarity(reconstructed, features, dim=1).sum()

            total_mse += mse.item()
            total_cosine_sim += cosine_sim.item()
            total_samples += features.size(0)

    return {
        'mse': total_mse / total_samples,
        'rmse': (total_mse / total_samples) ** 0.5,
        'cosine_similarity': total_cosine_sim / total_samples
    }

generate_semantic_ids(dataloader)

为数据集生成语义 ID。

def generate_semantic_ids(self, dataloader: DataLoader) -> torch.Tensor:
    """
    为数据集生成语义 ID

    Args:
        dataloader: 数据加载器

    Returns:
        语义 ID 张量 (num_samples,)
    """
    self.model.eval()
    device = next(self.model.parameters()).device

    all_semantic_ids = []

    with torch.no_grad():
        for batch in dataloader:
            if isinstance(batch, dict):
                features = batch['features'].to(device)
            else:
                features = batch.to(device)

            semantic_ids = self.model.generate_semantic_ids(features)
            all_semantic_ids.append(semantic_ids.cpu())

    return torch.cat(all_semantic_ids, dim=0)

TIGER 训练器

TIGERTrainer

专门用于训练 TIGER 模型的训练器。

class TIGERTrainer(BaseTrainer):
    """TIGER 训练器"""

    def __init__(
        self,
        model: Tiger,
        config: TIGERTrainingConfig,
        dataset: SequenceDataset,
        logger: Optional[Logger] = None
    ):
        super().__init__(model, config, logger)
        self.dataset = dataset
        self.collate_fn = self._create_collate_fn()

    def _create_collate_fn(self) -> Callable:
        """创建数据整理函数"""
        def collate_fn(batch):
            # 提取序列数据
            input_ids = [item['input_ids'] for item in batch]
            labels = [item['labels'] for item in batch]

            # 填充序列
            input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
            labels = pad_sequence(labels, batch_first=True, padding_value=-100)

            # 创建注意力掩码
            attention_mask = (input_ids != 0).float()

            return {
                'input_ids': input_ids,
                'labels': labels,
                'attention_mask': attention_mask
            }

        return collate_fn

create_dataloaders()

创建 TIGER 数据加载器。

def create_dataloaders(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    创建数据加载器

    Returns:
        (train_loader, val_loader, test_loader): 数据加载器元组
    """
    # 分割数据集
    train_size = int(0.8 * len(self.dataset))
    val_size = int(0.1 * len(self.dataset))
    test_size = len(self.dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(
        self.dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(self.config.random_seed)
    )

    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=self.config.batch_size,
        shuffle=True,
        num_workers=self.config.num_workers,
        collate_fn=self.collate_fn,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=self.config.batch_size,
        shuffle=False,
        num_workers=self.config.num_workers,
        collate_fn=self.collate_fn,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=self.config.batch_size,
        shuffle=False,
        num_workers=self.config.num_workers,
        collate_fn=self.collate_fn,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

evaluate_generation(test_dataloader, k_values)

评估生成质量。

def evaluate_generation(
    self,
    test_dataloader: DataLoader,
    k_values: List[int] = [5, 10, 20]
) -> Dict[str, float]:
    """
    评估生成质量

    Args:
        test_dataloader: 测试数据加载器
        k_values: Top-K 值列表

    Returns:
        评估指标字典
    """
    self.model.eval()
    device = next(self.model.parameters()).device

    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            # 生成推荐
            generated = self.model.generate(
                input_ids,
                max_length=self.config.max_generation_length,
                temperature=self.config.generation_temperature,
                top_k=self.config.generation_top_k
            )

            # 提取目标序列
            targets = []
            for label_seq in labels:
                target = label_seq[label_seq != -100].cpu().tolist()
                targets.append(target)

            all_predictions.extend(generated.cpu().tolist())
            all_targets.extend(targets)

    # 计算指标
    metrics = {}
    for k in k_values:
        recall_k = self._compute_recall_at_k(all_predictions, all_targets, k)
        ndcg_k = self._compute_ndcg_at_k(all_predictions, all_targets, k)

        metrics[f'recall@{k}'] = recall_k
        metrics[f'ndcg@{k}'] = ndcg_k

    return metrics

_compute_recall_at_k(predictions, targets, k)

计算 Recall@K。

def _compute_recall_at_k(
    self,
    predictions: List[List[int]],
    targets: List[List[int]],
    k: int
) -> float:
    """计算 Recall@K"""
    recall_scores = []

    for pred, target in zip(predictions, targets):
        if len(target) == 0:
            continue

        top_k_pred = set(pred[:k])
        target_set = set(target)

        recall = len(top_k_pred & target_set) / len(target_set)
        recall_scores.append(recall)

    return np.mean(recall_scores) if recall_scores else 0.0

_compute_ndcg_at_k(predictions, targets, k)

计算 NDCG@K。

def _compute_ndcg_at_k(
    self,
    predictions: List[List[int]],
    targets: List[List[int]],
    k: int
) -> float:
    """计算 NDCG@K"""
    ndcg_scores = []

    for pred, target in zip(predictions, targets):
        if len(target) == 0:
            continue

        # 计算 DCG
        dcg = 0
        for i, item in enumerate(pred[:k]):
            if item in target:
                dcg += 1 / np.log2(i + 2)

        # 计算 IDCG
        idcg = sum(1 / np.log2(i + 2) for i in range(min(len(target), k)))

        # 计算 NDCG
        ndcg = dcg / idcg if idcg > 0 else 0
        ndcg_scores.append(ndcg)

    return np.mean(ndcg_scores) if ndcg_scores else 0.0

训练配置

TrainingConfig

基础训练配置。

@dataclass
class TrainingConfig:
    # 基础设置
    max_epochs: int = 100
    batch_size: int = 32
    learning_rate: float = 1e-3

    # 硬件设置
    gpus: int = 1 if torch.cuda.is_available() else 0
    precision: int = 32
    num_workers: int = 4

    # 训练策略
    gradient_clip_val: float = 1.0
    accumulate_grad_batches: int = 1
    val_check_interval: float = 1.0

    # 检查点和日志
    checkpoint_dir: str = "checkpoints"
    log_dir: str = "logs"
    experiment_name: str = "experiment"
    version: Optional[str] = None
    save_top_k: int = 3

    # 早停
    early_stopping_patience: int = 10

    # 其他
    deterministic: bool = True
    random_seed: int = 42
    progress_bar: bool = True
    run_test: bool = True

RQVAETrainingConfig

RQVAE 训练配置。

@dataclass
class RQVAETrainingConfig(TrainingConfig):
    # 模型参数
    input_dim: int = 768
    hidden_dim: int = 512
    latent_dim: int = 256
    num_embeddings: int = 1024
    commitment_cost: float = 0.25

    # 训练参数
    learning_rate: float = 1e-3
    batch_size: int = 64
    max_epochs: int = 100

    # 数据集参数
    dataset_name: str = "p5_amazon"
    dataset_split: str = "beauty"

    # 评估参数
    eval_reconstruction: bool = True
    eval_quantization: bool = True

TIGERTrainingConfig

TIGER 训练配置。

@dataclass
class TIGERTrainingConfig(TrainingConfig):
    # 模型参数
    vocab_size: int = 1024
    embedding_dim: int = 512
    num_heads: int = 8
    num_layers: int = 6
    attn_dim: int = 2048
    dropout: float = 0.1
    max_seq_length: int = 100

    # 训练参数
    learning_rate: float = 1e-4
    batch_size: int = 16
    max_epochs: int = 50

    # 生成参数
    max_generation_length: int = 50
    generation_temperature: float = 1.0
    generation_top_k: int = 50
    generation_top_p: float = 0.9

    # 数据集参数
    dataset_name: str = "p5_amazon"
    dataset_split: str = "beauty"
    pretrained_rqvae_path: str = "checkpoints/rqvae.ckpt"

    # 评估参数
    eval_generation: bool = True
    eval_k_values: List[int] = field(default_factory=lambda: [5, 10, 20])

训练脚本

训练 RQVAE

#!/usr/bin/env python3
"""训练 RQVAE 模型的脚本"""

import argparse
from pathlib import Path

from genrec.models.rqvae import RqVae
from genrec.data.dataset_factory import DatasetFactory
from genrec.trainers import RQVAETrainer, RQVAETrainingConfig

def main():
    parser = argparse.ArgumentParser(description="Train RQVAE model")
    parser.add_argument("--dataset", default="p5_amazon", help="Dataset name")
    parser.add_argument("--split", default="beauty", help="Dataset split")
    parser.add_argument("--root", default="dataset", help="Dataset root directory")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--max_epochs", type=int, default=100, help="Max epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--checkpoint_dir", default="checkpoints", help="Checkpoint directory")

    args = parser.parse_args()

    # 创建配置
    config = RQVAETrainingConfig(
        batch_size=args.batch_size,
        max_epochs=args.max_epochs,
        learning_rate=args.learning_rate,
        checkpoint_dir=args.checkpoint_dir,
        dataset_name=args.dataset,
        dataset_split=args.split
    )

    # 创建数据集
    dataset = DatasetFactory.create_item_dataset(
        args.dataset,
        root=args.root,
        split=args.split,
        train_test_split="all"
    )

    # 创建模型
    model = RqVae(
        input_dim=config.input_dim,
        hidden_dim=config.hidden_dim,
        latent_dim=config.latent_dim,
        num_embeddings=config.num_embeddings,
        commitment_cost=config.commitment_cost,
        learning_rate=config.learning_rate
    )

    # 创建训练器
    trainer = RQVAETrainer(model, config, dataset)

    # 训练模型
    trained_model = trainer.train_model()

    print(f"Training completed. Model saved to {config.checkpoint_dir}")

if __name__ == "__main__":
    main()

训练 TIGER

#!/usr/bin/env python3
"""训练 TIGER 模型的脚本"""

import argparse
from pathlib import Path

from genrec.models.tiger import Tiger
from genrec.data.dataset_factory import DatasetFactory
from genrec.trainers import TIGERTrainer, TIGERTrainingConfig

def main():
    parser = argparse.ArgumentParser(description="Train TIGER model")
    parser.add_argument("--dataset", default="p5_amazon", help="Dataset name")
    parser.add_argument("--split", default="beauty", help="Dataset split")
    parser.add_argument("--root", default="dataset", help="Dataset root directory")
    parser.add_argument("--rqvae_path", required=True, help="Pretrained RQVAE checkpoint path")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
    parser.add_argument("--max_epochs", type=int, default=50, help="Max epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--checkpoint_dir", default="checkpoints", help="Checkpoint directory")

    args = parser.parse_args()

    # 创建配置
    config = TIGERTrainingConfig(
        batch_size=args.batch_size,
        max_epochs=args.max_epochs,
        learning_rate=args.learning_rate,
        checkpoint_dir=args.checkpoint_dir,
        dataset_name=args.dataset,
        dataset_split=args.split,
        pretrained_rqvae_path=args.rqvae_path
    )

    # 创建数据集
    dataset = DatasetFactory.create_sequence_dataset(
        args.dataset,
        root=args.root,
        split=args.split,
        train_test_split="train",
        pretrained_rqvae_path=args.rqvae_path
    )

    # 创建模型
    model = Tiger(
        vocab_size=config.vocab_size,
        embedding_dim=config.embedding_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        learning_rate=config.learning_rate
    )

    # 创建训练器
    trainer = TIGERTrainer(model, config, dataset)

    # 训练模型
    trained_model = trainer.train_model()

    print(f"Training completed. Model saved to {config.checkpoint_dir}")

if __name__ == "__main__":
    main()

使用示例

基本训练

from genrec.trainers import RQVAETrainer, RQVAETrainingConfig
from genrec.models.rqvae import RqVae
from genrec.data.dataset_factory import DatasetFactory

# 创建配置
config = RQVAETrainingConfig(
    batch_size=64,
    max_epochs=100,
    learning_rate=1e-3
)

# 创建数据集
dataset = DatasetFactory.create_item_dataset(
    "p5_amazon",
    root="dataset/amazon",
    split="beauty"
)

# 创建模型
model = RqVae(
    input_dim=768,
    hidden_dim=512,
    num_embeddings=1024
)

# 创建训练器并训练
trainer = RQVAETrainer(model, config, dataset)
trained_model = trainer.train_model()