跳转至

模型序列化与检查点

Genesis提供了强大的模型序列化和检查点功能,用于保存和加载模型状态、优化器状态和训练进度。这对于长时间训练、模型部署和实验可重现性至关重要。

概述

Genesis中的序列化系统处理: - 模型状态字典(参数和缓冲区) - 优化器状态(动量、运行平均值等) - 训练元数据(epoch、损失、指标) - 原子写操作和安全备份

核心函数

save()

Python
import genesis

def save(state_dict, file_path):
    """
    将状态字典保存到文件,使用原子写操作。

    Args:
        state_dict (dict): 包含要保存状态的字典
        file_path (str): 保存文件的路径

    Features:
        - 原子写操作与备份
        - 成功时自动清理
        - 失败时回滚
        - 保存后内存清理
    """

load()

Python
def load(file_path):
    """
    从文件加载状态字典。

    Args:
        file_path (str): 保存文件的路径

    Returns:
        dict: 加载的状态字典

    Raises:
        FileNotFoundError: 如果文件不存在
        pickle.UnpicklingError: 如果文件损坏
    """

save_checkpoint()

Python
def save_checkpoint(model_state_dict, optimizer_state_dict, file_path):
    """
    保存模型和优化器检查点。

    Args:
        model_state_dict (dict): 模型状态字典
        optimizer_state_dict (dict): 优化器状态字典
        file_path (str): 保存检查点的路径

    创建包含以下内容的检查点:
        - model_state_dict: 模型参数和缓冲区
        - optimizer_state_dict: 优化器状态
    """

load_checkpoint()

Python
def load_checkpoint(file_path):
    """
    加载模型和优化器检查点。

    Args:
        file_path (str): 检查点文件路径

    Returns:
        tuple: (model_state_dict, optimizer_state_dict)

    Example:
        >>> model_state, optimizer_state = genesis.load_checkpoint('checkpoint.pth')
        >>> model.load_state_dict(model_state)
        >>> optimizer.load_state_dict(optimizer_state)
    """

基本用法

保存简单模型

Python
import genesis
import genesis.nn as nn

# 创建和训练模型
model = nn.Linear(784, 10)

# 保存模型状态
genesis.save(model.state_dict(), 'model.pth')

# 加载模型状态
state_dict = genesis.load('model.pth')
model.load_state_dict(state_dict)

训练检查点

Python
import genesis
import genesis.nn as nn
import genesis.optim as optim

# 设置模型和优化器
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

optimizer = optim.Adam(model.parameters(), lr=0.001)

# 带检查点的训练循环
for epoch in range(100):
    # 训练代码...
    train_loss = train_one_epoch(model, train_loader, optimizer)

    # 每10个epoch保存检查点
    if epoch % 10 == 0:
        genesis.save_checkpoint(
            model.state_dict(),
            optimizer.state_dict(), 
            f'checkpoint_epoch_{epoch}.pth'
        )
        print(f"检查点已保存在epoch {epoch}")

# 加载检查点恢复训练
model_state, optimizer_state = genesis.load_checkpoint('checkpoint_epoch_90.pth')
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)

高级检查点

完整训练状态

Python
import genesis

def save_training_checkpoint(model, optimizer, scheduler, epoch, loss, metrics, file_path):
    """保存完整训练状态。"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'loss': loss,
        'metrics': metrics,
        'model_config': {
            'input_size': model.input_size,
            'hidden_size': model.hidden_size,
            'num_classes': model.num_classes
        }
    }
    genesis.save(checkpoint, file_path)

def load_training_checkpoint(file_path):
    """加载完整训练状态。"""
    return genesis.load(file_path)

# 用法
save_training_checkpoint(
    model, optimizer, scheduler, 
    epoch=50, loss=0.234, 
    metrics={'accuracy': 0.94, 'f1': 0.91},
    file_path='complete_checkpoint.pth'
)

# 恢复训练
checkpoint = load_training_checkpoint('complete_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if checkpoint['scheduler_state_dict']:
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

start_epoch = checkpoint['epoch'] + 1
print(f"从epoch {start_epoch}恢复训练,损失: {checkpoint['loss']}")

最佳模型跟踪

Python
import genesis

class ModelCheckpointer:
    def __init__(self, save_dir='checkpoints'):
        self.save_dir = save_dir
        self.best_loss = float('inf')
        self.best_accuracy = 0.0

    def save_checkpoint(self, model, optimizer, epoch, loss, accuracy, is_best=False):
        """保存检查点并跟踪最佳模型。"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'accuracy': accuracy
        }

        # 保存常规检查点
        checkpoint_path = f'{self.save_dir}/checkpoint_epoch_{epoch}.pth'
        genesis.save(checkpoint, checkpoint_path)

        # 基于损失保存最佳模型
        if loss < self.best_loss:
            self.best_loss = loss
            best_loss_path = f'{self.save_dir}/best_loss_model.pth'
            genesis.save(checkpoint, best_loss_path)
            print(f"新的最佳损失: {loss:.4f}")

        # 基于准确率保存最佳模型
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            best_acc_path = f'{self.save_dir}/best_accuracy_model.pth'
            genesis.save(checkpoint, best_acc_path)
            print(f"新的最佳准确率: {accuracy:.4f}")

    def load_best_model(self, model, metric='loss'):
        """加载最佳模型。"""
        if metric == 'loss':
            path = f'{self.save_dir}/best_loss_model.pth'
        elif metric == 'accuracy':
            path = f'{self.save_dir}/best_accuracy_model.pth'
        else:
            raise ValueError("metric必须是'loss'或'accuracy'")

        checkpoint = genesis.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        return checkpoint

# 用法
checkpointer = ModelCheckpointer()

for epoch in range(100):
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_accuracy = validate(model, val_loader)

    checkpointer.save_checkpoint(
        model, optimizer, epoch, val_loss, val_accuracy
    )

模型部署

推理模型保存

Python
import genesis

def save_for_inference(model, file_path, model_config=None):
    """保存优化的推理模型。"""
    model.eval()  # 设置为评估模式

    inference_state = {
        'model_state_dict': model.state_dict(),
        'model_config': model_config,
        'genesis_version': genesis.__version__,
        'inference_only': True
    }

    genesis.save(inference_state, file_path)

def load_for_inference(file_path, model_class):
    """加载推理模型。"""
    checkpoint = genesis.load(file_path)

    # 创建模型实例
    if 'model_config' in checkpoint and checkpoint['model_config']:
        model = model_class(**checkpoint['model_config'])
    else:
        model = model_class()

    # 加载状态
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    return model

# 保存训练好的模型用于部署
model_config = {
    'input_size': 784,
    'hidden_size': 256, 
    'num_classes': 10
}

save_for_inference(model, 'deployed_model.pth', model_config)

# 在生产环境中加载
deployed_model = load_for_inference('deployed_model.pth', MyModelClass)

模型版本管理

Python
import genesis
import time
from datetime import datetime

class VersionedCheckpoint:
    def __init__(self, base_path='models'):
        self.base_path = base_path

    def save_version(self, model, optimizer, epoch, metrics, version_name=None):
        """保存带有版本信息的模型。"""
        if version_name is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            version_name = f"v_{timestamp}"

        checkpoint = {
            'version': version_name,
            'timestamp': time.time(),
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics,
            'genesis_version': genesis.__version__
        }

        file_path = f'{self.base_path}/{version_name}.pth'
        genesis.save(checkpoint, file_path)

        # 更新最新版本链接
        latest_path = f'{self.base_path}/latest.pth'
        genesis.save(checkpoint, latest_path)

        return version_name

    def load_version(self, version_name='latest'):
        """加载特定版本的模型。"""
        file_path = f'{self.base_path}/{version_name}.pth'
        return genesis.load(file_path)

    def list_versions(self):
        """列出可用的模型版本。"""
        import os
        versions = []
        for file in os.listdir(self.base_path):
            if file.endswith('.pth') and file != 'latest.pth':
                versions.append(file[:-4])  # 移除.pth扩展名
        return sorted(versions)

# 用法
versioner = VersionedCheckpoint()

# 保存新版本
version = versioner.save_version(
    model, optimizer, epoch=100, 
    metrics={'accuracy': 0.95, 'loss': 0.15},
    version_name='model_v1.2'
)

# 加载特定版本
checkpoint = versioner.load_version('model_v1.2')

# 加载最新版本
latest = versioner.load_version('latest')

错误处理和安全性

稳健的检查点加载

Python
import genesis
import os

def safe_load_checkpoint(file_path, model, optimizer=None):
    """安全加载检查点,带错误处理。"""
    try:
        if not os.path.exists(file_path):
            print(f"警告: 检查点 {file_path} 未找到")
            return False

        checkpoint = genesis.load(file_path)

        # 验证检查点结构
        required_keys = ['model_state_dict']
        for key in required_keys:
            if key not in checkpoint:
                print(f"错误: 检查点中缺少键 '{key}'")
                return False

        # 加载模型状态
        try:
            model.load_state_dict(checkpoint['model_state_dict'])
            print("模型状态加载成功")
        except Exception as e:
            print(f"加载模型状态时出错: {e}")
            return False

        # 如果提供了优化器,则加载优化器状态
        if optimizer and 'optimizer_state_dict' in checkpoint:
            try:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                print("优化器状态加载成功")
            except Exception as e:
                print(f"警告: 无法加载优化器状态: {e}")

        # 返回附加信息
        epoch = checkpoint.get('epoch', 0)
        loss = checkpoint.get('loss', 'unknown')
        print(f"从epoch {epoch}加载检查点,损失: {loss}")

        return True

    except Exception as e:
        print(f"加载检查点时出错: {e}")
        return False

# 用法
success = safe_load_checkpoint('checkpoint.pth', model, optimizer)
if success:
    print("检查点加载成功")
else:
    print("加载检查点失败,从头开始")

检查点验证

Python
import genesis

def validate_checkpoint(file_path):
    """验证检查点文件完整性。"""
    try:
        checkpoint = genesis.load(file_path)

        # 基本结构验证
        if not isinstance(checkpoint, dict):
            return False, "检查点不是字典类型"

        if 'model_state_dict' not in checkpoint:
            return False, "缺少model_state_dict"

        # 检查模型状态结构
        model_state = checkpoint['model_state_dict']
        if not isinstance(model_state, dict):
            return False, "model_state_dict不是字典类型"

        # 检查空状态
        if len(model_state) == 0:
            return False, "model_state_dict为空"

        # 验证张量形状(基本检查)
        for key, tensor in model_state.items():
            if not hasattr(tensor, 'shape'):
                return False, f"键 '{key}' 的张量无效"

        return True, "检查点有效"

    except Exception as e:
        return False, f"验证检查点时出错: {e}"

# 用法
is_valid, message = validate_checkpoint('checkpoint.pth')
print(f"检查点验证: {message}")

最佳实践

1. 检查点策略

  • 定期保存检查点(每N个epoch)
  • 保留多个最近的检查点
  • 单独保存最佳模型
  • 包含训练元数据

2. 文件组织

Python
# 推荐的目录结构
checkpoints/
├── latest.pth                    # 最新检查点
├── best_model.pth               # 最佳性能模型
├── epoch_000010.pth            # 常规检查点
├── epoch_000020.pth
└── deployed/
    └── production_model.pth     # 生产就绪模型

3. 内存管理

Python
import genesis
import gc

def efficient_checkpoint_save(model, optimizer, file_path):
    """带内存优化的检查点保存。"""
    # 创建检查点字典
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }

    # 保存检查点
    genesis.save(checkpoint, file_path)

    # 从内存中清除检查点字典
    del checkpoint
    gc.collect()

    print(f"检查点已保存到 {file_path}")

4. 跨设备兼容性

Python
def save_device_agnostic(model, file_path):
    """保存可在任何设备上加载的模型。"""
    # 保存前移动到CPU
    model.cpu()
    genesis.save(model.state_dict(), file_path)

def load_to_device(file_path, model, device):
    """将检查点加载到指定设备。"""
    # 加载检查点
    state_dict = genesis.load(file_path)

    # 加载到模型
    model.load_state_dict(state_dict)

    # 移动到目标设备
    model.to(device)

迁移和兼容性

从PyTorch迁移

Python
import genesis
import torch

def convert_pytorch_checkpoint(pytorch_file, genesis_file):
    """将PyTorch检查点转换为Genesis格式。"""
    # 加载PyTorch检查点
    torch_checkpoint = torch.load(pytorch_file, map_location='cpu')

    # 转换为Genesis格式(如需要)
    genesis_checkpoint = {
        'model_state_dict': torch_checkpoint['model_state_dict'],
        'optimizer_state_dict': torch_checkpoint.get('optimizer_state_dict', {}),
        'epoch': torch_checkpoint.get('epoch', 0),
        'converted_from_pytorch': True
    }

    # 以Genesis格式保存
    genesis.save(genesis_checkpoint, genesis_file)
    print(f"已转换 {pytorch_file} -> {genesis_file}")

Genesis序列化系统为生产级深度学习工作流程提供了强大、高效和安全的模型检查点功能。