Dataset API Reference¶
Detailed API documentation for the genrec dataset module.
Base Dataset Classes¶
BaseRecommenderDataset¶
Abstract base class for all recommender system datasets.
class BaseRecommenderDataset:
"""Base class for recommender system datasets"""
def __init__(self, config: DatasetConfig):
self.config = config
self.root_dir = Path(config.root_dir)
self._items_df = None
self._interactions_df = None
@abstractmethod
def download(self) -> None:
"""Download dataset"""
pass
@abstractmethod
def load_raw_data(self) -> Dict[str, pd.DataFrame]:
"""Load raw data"""
pass
@abstractmethod
def preprocess_data(self, raw_data: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]:
"""Preprocess data"""
pass
Main Methods:
load_dataset()¶
Load the dataset.
def load_dataset(self, force_reload: bool = False) -> None:
"""
Load dataset
Args:
force_reload: Whether to force reload
"""
get_items()¶
Get item data.
get_interactions()¶
Get interaction data.
def get_interactions(self) -> pd.DataFrame:
"""
Get user-item interaction data
Returns:
Interactions DataFrame
"""
Item Dataset Class¶
ItemDataset¶
Dataset class for item encoding and feature learning.
class ItemDataset(Dataset):
"""Item dataset class"""
def __init__(
self,
base_dataset: BaseRecommenderDataset,
split: str = "all",
return_text: bool = False
):
self.base_dataset = base_dataset
self.split = split
self.return_text = return_text
Parameters:
- base_dataset: Base dataset instance
- split: Data split ("train", "val", "test", "all")
- return_text: Whether to return text features
Methods:
getitem(idx)¶
Get data item.
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Get data item at specified index
Args:
idx: Data index
Returns:
Dictionary containing item information
"""
Sequence Dataset Class¶
SequenceDataset¶
Dataset class for sequence generation training.
class SequenceDataset(Dataset):
"""Sequence dataset class"""
def __init__(
self,
base_dataset: BaseRecommenderDataset,
split: str = "train",
semantic_encoder: Optional[nn.Module] = None
):
self.base_dataset = base_dataset
self.split = split
self.semantic_encoder = semantic_encoder
Parameters:
- base_dataset: Base dataset instance
- split: Data split
- semantic_encoder: Semantic encoder (e.g., RQVAE)
Methods:
create_sequences()¶
Create user sequences.
def create_sequences(self) -> List[Dict[str, Any]]:
"""
Create user interaction sequences
Returns:
List of sequences
"""
encode_sequence()¶
Encode sequences.
def encode_sequence(self, item_ids: List[int]) -> torch.Tensor:
"""
Encode item ID sequence to semantic representation
Args:
item_ids: List of item IDs
Returns:
Encoded sequence tensor
"""
Concrete Dataset Implementations¶
P5AmazonDataset¶
P5 Amazon dataset implementation.
@gin.configurable
class P5AmazonDataset(BaseRecommenderDataset):
"""P5 Amazon dataset"""
def __init__(self, config: P5AmazonConfig):
super().__init__(config)
self.category = config.category
self.min_rating = config.min_rating
Key Features: - Supports multiple product categories - Automatic download and preprocessing - Text feature extraction - Rating filtering
P5AmazonItemDataset¶
P5 Amazon item dataset wrapper.
@gin.configurable
class P5AmazonItemDataset(ItemDataset):
"""P5 Amazon item dataset"""
def __init__(
self,
root: str,
split: str = "beauty",
train_test_split: str = "all",
return_text: bool = False,
**kwargs
):
P5AmazonSequenceDataset¶
P5 Amazon sequence dataset wrapper.
@gin.configurable
class P5AmazonSequenceDataset(SequenceDataset):
"""P5 Amazon sequence dataset"""
def __init__(
self,
root: str,
split: str = "beauty",
train_test_split: str = "train",
pretrained_rqvae_path: str = None,
**kwargs
):
Dataset Factory¶
DatasetFactory¶
Dataset factory class for unified dataset management and creation.
class DatasetFactory:
"""Dataset factory"""
_registered_datasets = {}
@classmethod
def register_dataset(
cls,
name: str,
base_class: Type[BaseRecommenderDataset],
item_class: Type[ItemDataset],
sequence_class: Type[SequenceDataset]
) -> None:
"""Register dataset classes"""
Usage Example:
# Register dataset
DatasetFactory.register_dataset(
"p5_amazon",
P5AmazonDataset,
P5AmazonItemDataset,
P5AmazonSequenceDataset
)
# Create dataset
item_dataset = DatasetFactory.create_item_dataset(
"p5_amazon",
root="data/amazon",
split="beauty"
)
Data Processors¶
TextProcessor¶
Text processor for item text feature encoding.
class TextProcessor:
"""Text processor"""
def __init__(self, config: TextEncodingConfig):
self.config = config
self.encoder = SentenceTransformer(config.encoder_model)
Methods:
encode_item_features()¶
Encode item text features.
def encode_item_features(
self,
items_df: pd.DataFrame,
cache_key: str = None,
force_reload: bool = False
) -> torch.Tensor:
"""
Encode item text features
Args:
items_df: Items dataframe
cache_key: Cache key
force_reload: Whether to force recomputation
Returns:
Item text encoding tensor
"""
SequenceProcessor¶
Sequence processor for sequence data preprocessing.
class SequenceProcessor:
"""Sequence processor"""
def __init__(self, config: SequenceConfig):
self.config = config
Methods:
process_user_sequence()¶
Process user sequences.
def process_user_sequence(
self,
sequence: List[int],
target_offset: int = 1
) -> Dict[str, torch.Tensor]:
"""
Process user interaction sequence
Args:
sequence: Raw sequence
target_offset: Target offset
Returns:
Processed sequence data
"""
Usage Examples¶
Basic Usage¶
from genrec.data import P5AmazonDataset, P5AmazonConfig
# Create configuration
config = P5AmazonConfig(
root_dir="data/amazon",
split="beauty"
)
# Create dataset
dataset = P5AmazonDataset(config)
dataset.load_dataset()
# Get data
items = dataset.get_items()
interactions = dataset.get_interactions()
Item Dataset Usage¶
from genrec.data import P5AmazonItemDataset
# Create item dataset
item_dataset = P5AmazonItemDataset(
root="data/amazon",
split="beauty",
return_text=True
)
# Use DataLoader
dataloader = DataLoader(item_dataset, batch_size=32, shuffle=True)
for batch in dataloader:
item_ids = batch['item_id']
text_features = batch['text_features']
# Train item encoder...
Sequence Dataset Usage¶
from genrec.data import P5AmazonSequenceDataset
from genrec.models import RqVae
# Load pretrained RQVAE
rqvae = RqVae.load_from_checkpoint("checkpoints/rqvae.ckpt")
# Create sequence dataset
seq_dataset = P5AmazonSequenceDataset(
root="data/amazon",
split="beauty",
train_test_split="train",
pretrained_rqvae_path="checkpoints/rqvae.ckpt"
)
# Use DataLoader
dataloader = DataLoader(seq_dataset, batch_size=16, shuffle=True)
for batch in dataloader:
input_ids = batch['input_ids']
target_ids = batch['target_ids']
# Train sequence generation model...
Related Links¶
- Configurations - Dataset configuration system
- Processors - Data processing utilities
- Dataset Factory - Factory pattern for dataset creation
- Trainers - Model training utilities