TIGER API Reference¶
Detailed API documentation for the Transformer-based generative retrieval model (TIGER).
Core Classes¶
Tiger¶
Main TIGER model class.
class Tiger(LightningModule):
def __init__(
self,
vocab_size: int,
embedding_dim: int = 512,
num_heads: int = 8,
num_layers: int = 6,
attn_dim: int = 2048,
dropout: float = 0.1,
max_seq_length: int = 1024,
learning_rate: float = 1e-4
)
Parameters:
- vocab_size: Vocabulary size
- embedding_dim: Embedding dimension
- num_heads: Number of attention heads
- num_layers: Number of Transformer layers
- attn_dim: Attention dimension
- dropout: Dropout probability
- max_seq_length: Maximum sequence length
- learning_rate: Learning rate
Methods:
forward(input_ids, attention_mask=None)¶
Forward pass computation.
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
input_ids: Input sequence (batch_size, seq_len)
attention_mask: Attention mask (batch_size, seq_len)
Returns:
logits: Output logits (batch_size, seq_len, vocab_size)
"""
generate(input_ids, max_length=50, temperature=1.0, top_k=None, top_p=None)¶
Generate recommendation sequences.
def generate(
self,
input_ids: torch.Tensor,
max_length: int = 50,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None
) -> torch.Tensor:
"""
Args:
input_ids: Input sequence
max_length: Maximum generation length
temperature: Temperature parameter
top_k: Top-k sampling
top_p: Top-p sampling
Returns:
generated: Generated sequence
"""
generate_with_trie(input_ids, trie, max_length=50)¶
Generate with Trie constraints.
def generate_with_trie(
self,
input_ids: torch.Tensor,
trie: TrieNode,
max_length: int = 50
) -> torch.Tensor:
"""
Args:
input_ids: Input sequence
trie: Trie constraint structure
max_length: Maximum generation length
Returns:
generated: Constrained generated sequence
"""
Component Classes¶
TransformerBlock¶
Transformer block implementation.
class TransformerBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
attn_dim: int,
dropout: float = 0.1
)
MultiHeadAttention¶
Multi-head attention mechanism.
class MultiHeadAttention(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
dropout: float = 0.1
)
PositionalEncoding¶
Positional encoding.
class PositionalEncoding(nn.Module):
def __init__(
self,
embedding_dim: int,
max_seq_length: int = 5000
)
Data Structures¶
TrieNode¶
Trie node for constrained generation.
class TrieNode(defaultdict):
def __init__(self):
super().__init__(TrieNode)
self.is_end = False
def add_sequence(self, sequence: List[int]):
"""Add sequence to Trie"""
node = self
for token in sequence:
node = node[token]
node.is_end = True
def get_valid_tokens(self) -> List[int]:
"""Get valid tokens at current node"""
return list(self.keys())
Build Trie¶
def build_trie(valid_sequences: List[List[int]]) -> TrieNode:
"""Build Trie of valid sequences"""
root = TrieNode()
for sequence in valid_sequences:
root.add_sequence(sequence)
return root
Training Interface¶
Training Step¶
def training_step(self, batch, batch_idx):
"""Training step"""
input_ids = batch['input_ids']
labels = batch['labels']
attention_mask = batch.get('attention_mask', None)
# Forward pass
logits = self(input_ids, attention_mask)
# Compute loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fn(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# Log metrics
self.log('train_loss', loss)
return loss
Validation Step¶
def validation_step(self, batch, batch_idx):
"""Validation step"""
input_ids = batch['input_ids']
labels = batch['labels']
attention_mask = batch.get('attention_mask', None)
# Forward pass
logits = self(input_ids, attention_mask)
# Compute loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fn(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# Log metrics
self.log('val_loss', loss)
return loss
Inference Interface¶
Batch Generation¶
def batch_generate(
model: Tiger,
input_sequences: List[torch.Tensor],
max_length: int = 50,
device: str = 'cuda'
) -> List[torch.Tensor]:
"""Batch generation for recommendations"""
model.eval()
model.to(device)
results = []
with torch.no_grad():
for input_seq in input_sequences:
input_seq = input_seq.to(device)
generated = model.generate(input_seq, max_length=max_length)
results.append(generated.cpu())
return results
Constrained Generation¶
def constrained_generate(
model: Tiger,
input_ids: torch.Tensor,
valid_item_sequences: List[List[int]],
max_length: int = 50
) -> torch.Tensor:
"""Constrained generation for recommendations"""
# Build Trie
trie = build_trie(valid_item_sequences)
# Constrained generation
return model.generate_with_trie(input_ids, trie, max_length)
Evaluation Interface¶
Top-K Recommendation Evaluation¶
def evaluate_recommendation(
model: Tiger,
test_dataloader: DataLoader,
k_values: List[int] = [5, 10, 20],
device: str = 'cuda'
) -> Dict[str, float]:
"""Evaluate recommendation performance"""
model.eval()
model.to(device)
all_predictions = []
all_targets = []
with torch.no_grad():
for batch in test_dataloader:
input_ids = batch['input_ids'].to(device)
targets = batch['targets']
# Generate recommendations
generated = model.generate(input_ids, max_length=50)
all_predictions.extend(generated.cpu().tolist())
all_targets.extend(targets.tolist())
# Compute metrics
metrics = {}
for k in k_values:
recall_k = compute_recall_at_k(all_predictions, all_targets, k)
ndcg_k = compute_ndcg_at_k(all_predictions, all_targets, k)
metrics[f'recall@{k}'] = recall_k
metrics[f'ndcg@{k}'] = ndcg_k
return metrics
Perplexity Evaluation¶
def evaluate_perplexity(
model: Tiger,
test_dataloader: DataLoader,
device: str = 'cuda'
) -> float:
"""Evaluate perplexity"""
model.eval()
model.to(device)
total_loss = 0
total_tokens = 0
with torch.no_grad():
for batch in test_dataloader:
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
attention_mask = batch.get('attention_mask', None)
logits = model(input_ids, attention_mask)
# Compute loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fn = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')
loss = loss_fn(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# Count valid tokens
valid_tokens = (shift_labels != -100).sum()
total_loss += loss.item()
total_tokens += valid_tokens.item()
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity
Utility Functions¶
Sequence Processing¶
def pad_sequences(
sequences: List[torch.Tensor],
pad_token_id: int = 0,
max_length: Optional[int] = None
) -> torch.Tensor:
"""Pad sequences to same length"""
if max_length is None:
max_length = max(len(seq) for seq in sequences)
padded = []
for seq in sequences:
if len(seq) < max_length:
pad_length = max_length - len(seq)
padded_seq = torch.cat([
seq,
torch.full((pad_length,), pad_token_id, dtype=seq.dtype)
])
else:
padded_seq = seq[:max_length]
padded.append(padded_seq)
return torch.stack(padded)
Sampling Strategies¶
def top_k_top_p_sampling(
logits: torch.Tensor,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: float = 1.0
) -> torch.Tensor:
"""Top-k and top-p sampling"""
logits = logits / temperature
# Top-k sampling
if top_k is not None:
top_k = min(top_k, logits.size(-1))
values, indices = torch.topk(logits, top_k)
logits[logits < values[..., [-1]]] = float('-inf')
# Top-p sampling
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Find positions where cumulative probability exceeds top_p
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = float('-inf')
# Sampling
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
return next_token
Usage Examples¶
Basic Training¶
from genrec.models.tiger import Tiger
from genrec.data.p5_amazon import P5AmazonSequenceDataset
import pytorch_lightning as pl
# Create dataset
dataset = P5AmazonSequenceDataset(
root="dataset/amazon",
split="beauty",
train_test_split="train",
pretrained_rqvae_path="checkpoints/rqvae.ckpt"
)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
# Create model
model = Tiger(
vocab_size=1024,
embedding_dim=512,
num_heads=8,
num_layers=6,
learning_rate=1e-4
)
# Train model
trainer = pl.Trainer(max_epochs=50, gpus=1)
trainer.fit(model, dataloader)
Recommendation Generation¶
# Load trained model
model = Tiger.load_from_checkpoint("checkpoints/tiger.ckpt")
model.eval()
# User history sequence
user_sequence = torch.tensor([10, 25, 67, 89]) # Semantic ID sequence
# Generate recommendations
with torch.no_grad():
recommendations = model.generate(
user_sequence.unsqueeze(0),
max_length=20,
temperature=0.8,
top_k=50
)
print(f"Recommendations: {recommendations.squeeze().tolist()}")