Configuration Management API Reference¶
Detailed documentation for configuration management classes used to manage data processing and model training parameters.
Base Configuration Classes¶
DatasetConfig¶
Base dataset configuration class.
@dataclass
class DatasetConfig:
root_dir: str
split: str = "default"
force_reload: bool = False
text_config: Optional[TextEncodingConfig] = None
sequence_config: Optional[SequenceConfig] = None
processing_config: Optional[DataProcessingConfig] = None
def __post_init__(self):
"""Post-initialization processing"""
if self.text_config is None:
self.text_config = TextEncodingConfig()
if self.sequence_config is None:
self.sequence_config = SequenceConfig()
if self.processing_config is None:
self.processing_config = DataProcessingConfig()
Parameters:
- root_dir: Dataset root directory
- split: Data split identifier
- force_reload: Whether to force reload
- text_config: Text encoding configuration
- sequence_config: Sequence processing configuration
- processing_config: Data processing configuration
Text Encoding Configuration¶
TextEncodingConfig¶
Text encoding related configuration.
@dataclass
class TextEncodingConfig:
encoder_model: str = "sentence-transformers/all-MiniLM-L6-v2"
template: str = "Title: {title}; Brand: {brand}; Category: {category}; Price: {price}"
batch_size: int = 32
max_length: int = 512
device: str = "cuda" if torch.cuda.is_available() else "cpu"
cache_dir: str = "cache/text_embeddings"
normalize_embeddings: bool = True
def __post_init__(self):
"""Validate configuration parameters"""
if self.batch_size <= 0:
raise ValueError("batch_size must be positive")
if self.max_length <= 0:
raise ValueError("max_length must be positive")
Parameters:
- encoder_model: Text encoder model name
- template: Text template format
- batch_size: Batch processing size
- max_length: Maximum text length
- device: Computing device
- cache_dir: Cache directory
- normalize_embeddings: Whether to normalize embeddings
Methods:
get_cache_key(split, model_name)¶
Generate cache key.
def get_cache_key(self, split: str, model_name: str = None) -> str:
"""
Generate cache key
Args:
split: Data split
model_name: Model name
Returns:
Cache key string
"""
if model_name is None:
model_name = self.encoder_model
return f"{model_name}_{split}_{hash(self.template)}"
format_text(item_data)¶
Format item text.
def format_text(self, item_data: Dict[str, Any]) -> str:
"""
Format item text using template
Args:
item_data: Item data dictionary
Returns:
Formatted text
"""
try:
return self.template.format(**item_data)
except KeyError as e:
# Handle missing fields
available_fields = set(item_data.keys())
template_fields = set(re.findall(r'\{(\w+)\}', self.template))
missing_fields = template_fields - available_fields
# Fill missing fields with default values
filled_data = item_data.copy()
for field in missing_fields:
filled_data[field] = "Unknown"
return self.template.format(**filled_data)
Sequence Processing Configuration¶
SequenceConfig¶
Sequence processing related configuration.
@dataclass
class SequenceConfig:
max_seq_length: int = 50
min_seq_length: int = 3
padding_token: int = 0
truncate_strategy: str = "recent" # "recent", "random", "oldest"
sequence_stride: int = 1
target_offset: int = 1
include_timestamps: bool = False
time_encoding_dim: int = 32
def __post_init__(self):
"""Validate configuration parameters"""
if self.max_seq_length <= self.min_seq_length:
raise ValueError("max_seq_length must be greater than min_seq_length")
if self.truncate_strategy not in ["recent", "random", "oldest"]:
raise ValueError("Invalid truncate_strategy")
if self.target_offset <= 0:
raise ValueError("target_offset must be positive")
Parameters:
- max_seq_length: Maximum sequence length
- min_seq_length: Minimum sequence length
- padding_token: Padding token
- truncate_strategy: Truncation strategy
- sequence_stride: Sequence stride
- target_offset: Target offset
- include_timestamps: Whether to include timestamps
- time_encoding_dim: Time encoding dimension
Methods:
truncate_sequence(sequence, strategy)¶
Truncate sequence.
def truncate_sequence(
self,
sequence: List[Any],
strategy: str = None
) -> List[Any]:
"""
Truncate sequence according to strategy
Args:
sequence: Input sequence
strategy: Truncation strategy, uses config strategy if None
Returns:
Truncated sequence
"""
if len(sequence) <= self.max_seq_length:
return sequence
strategy = strategy or self.truncate_strategy
if strategy == "recent":
return sequence[-self.max_seq_length:]
elif strategy == "oldest":
return sequence[:self.max_seq_length]
elif strategy == "random":
start_idx = random.randint(0, len(sequence) - self.max_seq_length)
return sequence[start_idx:start_idx + self.max_seq_length]
else:
raise ValueError(f"Unknown truncate strategy: {strategy}")
pad_sequence(sequence)¶
Pad sequence.
def pad_sequence(self, sequence: List[Any]) -> List[Any]:
"""
Pad sequence to maximum length
Args:
sequence: Input sequence
Returns:
Padded sequence
"""
if len(sequence) >= self.max_seq_length:
return sequence[:self.max_seq_length]
pad_length = self.max_seq_length - len(sequence)
return sequence + [self.padding_token] * pad_length
Data Processing Configuration¶
DataProcessingConfig¶
Data processing related configuration.
@dataclass
class DataProcessingConfig:
min_user_interactions: int = 5
min_item_interactions: int = 5
remove_duplicates: bool = True
normalize_ratings: bool = False
rating_scale: Tuple[float, float] = (1.0, 5.0)
train_ratio: float = 0.7
val_ratio: float = 0.15
test_ratio: float = 0.15
random_seed: int = 42
def __post_init__(self):
"""Validate configuration parameters"""
if abs(self.train_ratio + self.val_ratio + self.test_ratio - 1.0) > 1e-6:
raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0")
if any(ratio <= 0 for ratio in [self.train_ratio, self.val_ratio, self.test_ratio]):
raise ValueError("All ratios must be positive")
if self.min_user_interactions <= 0 or self.min_item_interactions <= 0:
raise ValueError("Minimum interactions must be positive")
Parameters:
- min_user_interactions: Minimum user interactions
- min_item_interactions: Minimum item interactions
- remove_duplicates: Whether to remove duplicate interactions
- normalize_ratings: Whether to normalize ratings
- rating_scale: Rating range
- train_ratio: Training set ratio
- val_ratio: Validation set ratio
- test_ratio: Test set ratio
- random_seed: Random seed
Methods:
get_split_indices(total_size)¶
Get data split indices.
def get_split_indices(self, total_size: int) -> Tuple[List[int], List[int], List[int]]:
"""
Get data split indices according to configured ratios
Args:
total_size: Total data size
Returns:
(train_indices, val_indices, test_indices): Split index lists
"""
indices = list(range(total_size))
random.Random(self.random_seed).shuffle(indices)
train_size = int(total_size * self.train_ratio)
val_size = int(total_size * self.val_ratio)
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]
return train_indices, val_indices, test_indices
normalize_rating(rating)¶
Normalize rating.
def normalize_rating(self, rating: float) -> float:
"""
Normalize rating to [0, 1] range
Args:
rating: Original rating
Returns:
Normalized rating
"""
if not self.normalize_ratings:
return rating
min_rating, max_rating = self.rating_scale
return (rating - min_rating) / (max_rating - min_rating)
Specific Dataset Configurations¶
P5AmazonConfig¶
P5 Amazon dataset specific configuration.
@dataclass
class P5AmazonConfig(DatasetConfig):
category: str = "beauty"
min_rating: float = 4.0
include_price: bool = True
include_brand: bool = True
download_url: str = "https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2/"
def __post_init__(self):
super().__post_init__()
# Set specific text template
if self.include_price and self.include_brand:
template = "Title: {title}; Brand: {brand}; Category: {category}; Price: {price}"
elif self.include_brand:
template = "Title: {title}; Brand: {brand}; Category: {category}"
else:
template = "Title: {title}; Category: {category}"
self.text_config.template = template
def get_category_url(self) -> str:
"""Get download URL for specific category"""
return f"{self.download_url}{self.category}.json.gz"
Additional Parameters:
- category: Product category
- min_rating: Minimum rating threshold
- include_price: Whether to include price information
- include_brand: Whether to include brand information
- download_url: Download base URL
Configuration Validation and Tools¶
validate_config(config)¶
Validate configuration integrity.
def validate_config(config: DatasetConfig) -> List[str]:
"""
Validate configuration validity
Args:
config: Dataset configuration
Returns:
List of error messages, empty list means valid configuration
"""
errors = []
# Check root directory
if not config.root_dir:
errors.append("root_dir cannot be empty")
# Check text configuration
if config.text_config:
if not config.text_config.encoder_model:
errors.append("encoder_model cannot be empty")
if config.text_config.batch_size <= 0:
errors.append("batch_size must be positive")
# Check sequence configuration
if config.sequence_config:
if config.sequence_config.max_seq_length <= config.sequence_config.min_seq_length:
errors.append("max_seq_length must be greater than min_seq_length")
# Check processing configuration
if config.processing_config:
ratios_sum = (
config.processing_config.train_ratio +
config.processing_config.val_ratio +
config.processing_config.test_ratio
)
if abs(ratios_sum - 1.0) > 1e-6:
errors.append("Split ratios must sum to 1.0")
return errors
load_config_from_file(config_path)¶
Load configuration from file.
def load_config_from_file(config_path: str) -> DatasetConfig:
"""
Load configuration from YAML or JSON file
Args:
config_path: Configuration file path
Returns:
Dataset configuration object
"""
config_path = Path(config_path)
if config_path.suffix.lower() in ['.yaml', '.yml']:
import yaml
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
elif config_path.suffix.lower() == '.json':
with open(config_path, 'r') as f:
config_dict = json.load(f)
else:
raise ValueError(f"Unsupported config file format: {config_path.suffix}")
# Create appropriate object based on configuration type
config_type = config_dict.pop('config_type', 'DatasetConfig')
if config_type == 'P5AmazonConfig':
return P5AmazonConfig(**config_dict)
else:
return DatasetConfig(**config_dict)
save_config_to_file(config, config_path)¶
Save configuration to file.
def save_config_to_file(config: DatasetConfig, config_path: str) -> None:
"""
Save configuration to YAML or JSON file
Args:
config: Dataset configuration object
config_path: Configuration file path
"""
config_path = Path(config_path)
config_dict = asdict(config)
# Add configuration type information
config_dict['config_type'] = config.__class__.__name__
if config_path.suffix.lower() in ['.yaml', '.yml']:
import yaml
with open(config_path, 'w') as f:
yaml.dump(config_dict, f, default_flow_style=False)
elif config_path.suffix.lower() == '.json':
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=2)
else:
raise ValueError(f"Unsupported config file format: {config_path.suffix}")
Usage Examples¶
Basic Configuration Creation¶
from genrec.data.configs import (
DatasetConfig, TextEncodingConfig, SequenceConfig, DataProcessingConfig
)
# Create basic configuration
config = DatasetConfig(
root_dir="dataset/amazon",
split="beauty",
text_config=TextEncodingConfig(
encoder_model="sentence-transformers/all-MiniLM-L6-v2",
batch_size=64
),
sequence_config=SequenceConfig(
max_seq_length=100,
min_seq_length=5
),
processing_config=DataProcessingConfig(
min_user_interactions=10,
train_ratio=0.8,
val_ratio=0.1,
test_ratio=0.1
)
)
P5 Amazon Configuration¶
from genrec.data.configs import P5AmazonConfig
# Create P5 Amazon configuration
config = P5AmazonConfig(
root_dir="dataset/amazon",
category="beauty",
min_rating=4.0,
include_price=True,
include_brand=True
)
Configuration File Operations¶
# Save configuration to file
save_config_to_file(config, "config/dataset_config.yaml")
# Load configuration from file
loaded_config = load_config_from_file("config/dataset_config.yaml")
# Validate configuration
errors = validate_config(loaded_config)
if errors:
print("Configuration errors:")
for error in errors:
print(f" - {error}")
else:
print("Configuration is valid")