跳转至

配置管理 API 参考

配置管理类的详细文档,用于管理数据处理和模型训练参数。

基础配置类

DatasetConfig

数据集基础配置类。

@dataclass
class DatasetConfig:
    root_dir: str
    split: str = "default"
    force_reload: bool = False
    text_config: Optional[TextEncodingConfig] = None
    sequence_config: Optional[SequenceConfig] = None
    processing_config: Optional[DataProcessingConfig] = None

    def __post_init__(self):
        """初始化后处理"""
        if self.text_config is None:
            self.text_config = TextEncodingConfig()
        if self.sequence_config is None:
            self.sequence_config = SequenceConfig()
        if self.processing_config is None:
            self.processing_config = DataProcessingConfig()

参数: - root_dir: 数据集根目录 - split: 数据分割标识 - force_reload: 是否强制重新加载 - text_config: 文本编码配置 - sequence_config: 序列处理配置 - processing_config: 数据处理配置

文本编码配置

TextEncodingConfig

文本编码相关配置。

@dataclass
class TextEncodingConfig:
    encoder_model: str = "sentence-transformers/all-MiniLM-L6-v2"
    template: str = "Title: {title}; Brand: {brand}; Category: {category}; Price: {price}"
    batch_size: int = 32
    max_length: int = 512
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    cache_dir: str = "cache/text_embeddings"
    normalize_embeddings: bool = True

    def __post_init__(self):
        """验证配置参数"""
        if self.batch_size <= 0:
            raise ValueError("batch_size must be positive")
        if self.max_length <= 0:
            raise ValueError("max_length must be positive")

参数: - encoder_model: 文本编码器模型名称 - template: 文本模板格式 - batch_size: 批处理大小 - max_length: 最大文本长度 - device: 计算设备 - cache_dir: 缓存目录 - normalize_embeddings: 是否标准化嵌入

方法:

get_cache_key(split, model_name)

生成缓存键。

def get_cache_key(self, split: str, model_name: str = None) -> str:
    """
    生成缓存键

    Args:
        split: 数据分割
        model_name: 模型名称

    Returns:
        缓存键字符串
    """
    if model_name is None:
        model_name = self.encoder_model
    return f"{model_name}_{split}_{hash(self.template)}"

format_text(item_data)

格式化物品文本。

def format_text(self, item_data: Dict[str, Any]) -> str:
    """
    使用模板格式化物品文本

    Args:
        item_data: 物品数据字典

    Returns:
        格式化后的文本
    """
    try:
        return self.template.format(**item_data)
    except KeyError as e:
        # 处理缺失字段
        available_fields = set(item_data.keys())
        template_fields = set(re.findall(r'\{(\w+)\}', self.template))
        missing_fields = template_fields - available_fields

        # 用默认值填充缺失字段
        filled_data = item_data.copy()
        for field in missing_fields:
            filled_data[field] = "Unknown"

        return self.template.format(**filled_data)

序列处理配置

SequenceConfig

序列处理相关配置。

@dataclass
class SequenceConfig:
    max_seq_length: int = 50
    min_seq_length: int = 3
    padding_token: int = 0
    truncate_strategy: str = "recent"  # "recent", "random", "oldest"
    sequence_stride: int = 1
    target_offset: int = 1
    include_timestamps: bool = False
    time_encoding_dim: int = 32

    def __post_init__(self):
        """验证配置参数"""
        if self.max_seq_length <= self.min_seq_length:
            raise ValueError("max_seq_length must be greater than min_seq_length")
        if self.truncate_strategy not in ["recent", "random", "oldest"]:
            raise ValueError("Invalid truncate_strategy")
        if self.target_offset <= 0:
            raise ValueError("target_offset must be positive")

参数: - max_seq_length: 最大序列长度 - min_seq_length: 最小序列长度 - padding_token: 填充标记 - truncate_strategy: 截断策略 - sequence_stride: 序列步长 - target_offset: 目标偏移 - include_timestamps: 是否包含时间戳 - time_encoding_dim: 时间编码维度

方法:

truncate_sequence(sequence, strategy)

截断序列。

def truncate_sequence(
    self, 
    sequence: List[Any], 
    strategy: str = None
) -> List[Any]:
    """
    根据策略截断序列

    Args:
        sequence: 输入序列
        strategy: 截断策略,如为 None 则使用配置中的策略

    Returns:
        截断后的序列
    """
    if len(sequence) <= self.max_seq_length:
        return sequence

    strategy = strategy or self.truncate_strategy

    if strategy == "recent":
        return sequence[-self.max_seq_length:]
    elif strategy == "oldest":
        return sequence[:self.max_seq_length]
    elif strategy == "random":
        start_idx = random.randint(0, len(sequence) - self.max_seq_length)
        return sequence[start_idx:start_idx + self.max_seq_length]
    else:
        raise ValueError(f"Unknown truncate strategy: {strategy}")

pad_sequence(sequence)

填充序列。

def pad_sequence(self, sequence: List[Any]) -> List[Any]:
    """
    填充序列到最大长度

    Args:
        sequence: 输入序列

    Returns:
        填充后的序列
    """
    if len(sequence) >= self.max_seq_length:
        return sequence[:self.max_seq_length]

    pad_length = self.max_seq_length - len(sequence)
    return sequence + [self.padding_token] * pad_length

数据处理配置

DataProcessingConfig

数据处理相关配置。

@dataclass
class DataProcessingConfig:
    min_user_interactions: int = 5
    min_item_interactions: int = 5
    remove_duplicates: bool = True
    normalize_ratings: bool = False
    rating_scale: Tuple[float, float] = (1.0, 5.0)
    train_ratio: float = 0.7
    val_ratio: float = 0.15
    test_ratio: float = 0.15
    random_seed: int = 42

    def __post_init__(self):
        """验证配置参数"""
        if abs(self.train_ratio + self.val_ratio + self.test_ratio - 1.0) > 1e-6:
            raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0")
        if any(ratio <= 0 for ratio in [self.train_ratio, self.val_ratio, self.test_ratio]):
            raise ValueError("All ratios must be positive")
        if self.min_user_interactions <= 0 or self.min_item_interactions <= 0:
            raise ValueError("Minimum interactions must be positive")

参数: - min_user_interactions: 最少用户交互数 - min_item_interactions: 最少物品交互数 - remove_duplicates: 是否移除重复交互 - normalize_ratings: 是否标准化评分 - rating_scale: 评分范围 - train_ratio: 训练集比例 - val_ratio: 验证集比例 - test_ratio: 测试集比例 - random_seed: 随机种子

方法:

get_split_indices(total_size)

获取数据分割索引。

def get_split_indices(self, total_size: int) -> Tuple[List[int], List[int], List[int]]:
    """
    根据配置比例获取数据分割索引

    Args:
        total_size: 总数据量

    Returns:
        (train_indices, val_indices, test_indices): 分割索引列表
    """
    indices = list(range(total_size))
    random.Random(self.random_seed).shuffle(indices)

    train_size = int(total_size * self.train_ratio)
    val_size = int(total_size * self.val_ratio)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    return train_indices, val_indices, test_indices

normalize_rating(rating)

标准化评分。

def normalize_rating(self, rating: float) -> float:
    """
    标准化评分到 [0, 1] 范围

    Args:
        rating: 原始评分

    Returns:
        标准化后的评分
    """
    if not self.normalize_ratings:
        return rating

    min_rating, max_rating = self.rating_scale
    return (rating - min_rating) / (max_rating - min_rating)

特定数据集配置

P5AmazonConfig

P5 Amazon 数据集专用配置。

@dataclass
class P5AmazonConfig(DatasetConfig):
    category: str = "beauty"
    min_rating: float = 4.0
    include_price: bool = True
    include_brand: bool = True
    download_url: str = "https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2/"

    def __post_init__(self):
        super().__post_init__()

        # 设置特定的文本模板
        if self.include_price and self.include_brand:
            template = "Title: {title}; Brand: {brand}; Category: {category}; Price: {price}"
        elif self.include_brand:
            template = "Title: {title}; Brand: {brand}; Category: {category}"
        else:
            template = "Title: {title}; Category: {category}"

        self.text_config.template = template

    def get_category_url(self) -> str:
        """获取特定类别的下载 URL"""
        return f"{self.download_url}{self.category}.json.gz"

额外参数: - category: 产品类别 - min_rating: 最低评分阈值 - include_price: 是否包含价格信息 - include_brand: 是否包含品牌信息 - download_url: 下载基础 URL

配置验证和工具

validate_config(config)

验证配置完整性。

def validate_config(config: DatasetConfig) -> List[str]:
    """
    验证配置的有效性

    Args:
        config: 数据集配置

    Returns:
        错误信息列表,空列表表示配置有效
    """
    errors = []

    # 检查根目录
    if not config.root_dir:
        errors.append("root_dir cannot be empty")

    # 检查文本配置
    if config.text_config:
        if not config.text_config.encoder_model:
            errors.append("encoder_model cannot be empty")
        if config.text_config.batch_size <= 0:
            errors.append("batch_size must be positive")

    # 检查序列配置
    if config.sequence_config:
        if config.sequence_config.max_seq_length <= config.sequence_config.min_seq_length:
            errors.append("max_seq_length must be greater than min_seq_length")

    # 检查处理配置
    if config.processing_config:
        ratios_sum = (
            config.processing_config.train_ratio + 
            config.processing_config.val_ratio + 
            config.processing_config.test_ratio
        )
        if abs(ratios_sum - 1.0) > 1e-6:
            errors.append("Split ratios must sum to 1.0")

    return errors

load_config_from_file(config_path)

从文件加载配置。

def load_config_from_file(config_path: str) -> DatasetConfig:
    """
    从 YAML 或 JSON 文件加载配置

    Args:
        config_path: 配置文件路径

    Returns:
        数据集配置对象
    """
    config_path = Path(config_path)

    if config_path.suffix.lower() in ['.yaml', '.yml']:
        import yaml
        with open(config_path, 'r') as f:
            config_dict = yaml.safe_load(f)
    elif config_path.suffix.lower() == '.json':
        with open(config_path, 'r') as f:
            config_dict = json.load(f)
    else:
        raise ValueError(f"Unsupported config file format: {config_path.suffix}")

    # 根据配置类型创建相应对象
    config_type = config_dict.pop('config_type', 'DatasetConfig')

    if config_type == 'P5AmazonConfig':
        return P5AmazonConfig(**config_dict)
    else:
        return DatasetConfig(**config_dict)

save_config_to_file(config, config_path)

保存配置到文件。

def save_config_to_file(config: DatasetConfig, config_path: str) -> None:
    """
    保存配置到 YAML 或 JSON 文件

    Args:
        config: 数据集配置对象
        config_path: 配置文件路径
    """
    config_path = Path(config_path)
    config_dict = asdict(config)

    # 添加配置类型信息
    config_dict['config_type'] = config.__class__.__name__

    if config_path.suffix.lower() in ['.yaml', '.yml']:
        import yaml
        with open(config_path, 'w') as f:
            yaml.dump(config_dict, f, default_flow_style=False)
    elif config_path.suffix.lower() == '.json':
        with open(config_path, 'w') as f:
            json.dump(config_dict, f, indent=2)
    else:
        raise ValueError(f"Unsupported config file format: {config_path.suffix}")

使用示例

基本配置创建

from genrec.data.configs import (
    DatasetConfig, TextEncodingConfig, SequenceConfig, DataProcessingConfig
)

# 创建基本配置
config = DatasetConfig(
    root_dir="dataset/amazon",
    split="beauty",
    text_config=TextEncodingConfig(
        encoder_model="sentence-transformers/all-MiniLM-L6-v2",
        batch_size=64
    ),
    sequence_config=SequenceConfig(
        max_seq_length=100,
        min_seq_length=5
    ),
    processing_config=DataProcessingConfig(
        min_user_interactions=10,
        train_ratio=0.8,
        val_ratio=0.1,
        test_ratio=0.1
    )
)

P5 Amazon 配置

from genrec.data.configs import P5AmazonConfig

# 创建 P5 Amazon 配置
config = P5AmazonConfig(
    root_dir="dataset/amazon",
    category="beauty",
    min_rating=4.0,
    include_price=True,
    include_brand=True
)

配置文件操作

# 保存配置到文件
save_config_to_file(config, "config/dataset_config.yaml")

# 从文件加载配置
loaded_config = load_config_from_file("config/dataset_config.yaml")

# 验证配置
errors = validate_config(loaded_config)
if errors:
    print("Configuration errors:")
    for error in errors:
        print(f"  - {error}")
else:
    print("Configuration is valid")