Skip to content

Modules API Reference

Detailed documentation for core building blocks including encoders, loss functions, metrics, and utilities.

Encoder Modules

TransformerEncoder

Transformer-based encoder implementation.

class TransformerEncoder(nn.Module):
    """Transformer encoder"""

    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        num_heads: int,
        num_layers: int,
        attn_dim: int,
        dropout: float = 0.1,
        max_seq_length: int = 1024
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoding = PositionalEncoding(embedding_dim, max_seq_length)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=attn_dim,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.dropout = nn.Dropout(dropout)

Methods:

forward(input_ids, attention_mask)

Forward pass computation.

def forward(
    self, 
    input_ids: torch.Tensor, 
    attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """
    Forward pass

    Args:
        input_ids: Input sequence (batch_size, seq_len)
        attention_mask: Attention mask (batch_size, seq_len)

    Returns:
        Encoded sequence (batch_size, seq_len, embedding_dim)
    """
    # Embedding and positional encoding
    embeddings = self.embedding(input_ids)
    embeddings = self.pos_encoding(embeddings)
    embeddings = self.dropout(embeddings)

    # Create padding mask
    if attention_mask is not None:
        # Convert to format expected by Transformer
        src_key_padding_mask = (attention_mask == 0)
    else:
        src_key_padding_mask = None

    # Transformer encoding
    encoded = self.transformer(
        embeddings,
        src_key_padding_mask=src_key_padding_mask
    )

    return encoded

PositionalEncoding

Positional encoding module.

class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding"""

    def __init__(self, embedding_dim: int, max_seq_length: int = 5000):
        super().__init__()

        pe = torch.zeros(max_seq_length, embedding_dim)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * 
                           (-math.log(10000.0) / embedding_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Add positional encoding

        Args:
            x: Input embeddings (batch_size, seq_len, embedding_dim)

        Returns:
            Embeddings with positional encoding added
        """
        return x + self.pe[:, :x.size(1)]

MultiHeadAttention

Multi-head attention mechanism.

class MultiHeadAttention(nn.Module):
    """Multi-head attention"""

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        dropout: float = 0.1
    ):
        super().__init__()
        assert embedding_dim % num_heads == 0

        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads

        self.w_q = nn.Linear(embedding_dim, embedding_dim)
        self.w_k = nn.Linear(embedding_dim, embedding_dim)
        self.w_v = nn.Linear(embedding_dim, embedding_dim)
        self.w_o = nn.Linear(embedding_dim, embedding_dim)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Multi-head attention computation

        Args:
            query: Query vectors (batch_size, seq_len, embedding_dim)
            key: Key vectors (batch_size, seq_len, embedding_dim)
            value: Value vectors (batch_size, seq_len, embedding_dim)
            mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            (attention_output, attention_weights): Attention output and weights
        """
        batch_size, seq_len, _ = query.size()

        # Linear transformations
        Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # Apply mask
        if mask is not None:
            mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
            scores.masked_fill_(mask == 0, -1e9)

        # Attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Attention output
        attention_output = torch.matmul(attention_weights, V)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.embedding_dim
        )

        output = self.w_o(attention_output)

        return output, attention_weights

Loss Functions

VQVAELoss

VQVAE loss function.

class VQVAELoss(nn.Module):
    """VQVAE loss function"""

    def __init__(
        self,
        commitment_cost: float = 0.25,
        beta: float = 1.0
    ):
        super().__init__()
        self.commitment_cost = commitment_cost
        self.beta = beta

    def forward(
        self,
        x: torch.Tensor,
        x_recon: torch.Tensor,
        commitment_loss: torch.Tensor,
        embedding_loss: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Compute VQVAE loss

        Args:
            x: Original input
            x_recon: Reconstructed output
            commitment_loss: Commitment loss
            embedding_loss: Embedding loss

        Returns:
            Loss dictionary
        """
        # Reconstruction loss
        recon_loss = F.mse_loss(x_recon, x, reduction='mean')

        # Total loss
        total_loss = (
            recon_loss + 
            self.commitment_cost * commitment_loss + 
            self.beta * embedding_loss
        )

        return {
            'total_loss': total_loss,
            'reconstruction_loss': recon_loss,
            'commitment_loss': commitment_loss,
            'embedding_loss': embedding_loss
        }

SequenceLoss

Sequence modeling loss function.

class SequenceLoss(nn.Module):
    """Sequence modeling loss function"""

    def __init__(
        self,
        vocab_size: int,
        ignore_index: int = -100,
        label_smoothing: float = 0.0
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing

        if label_smoothing > 0:
            self.criterion = nn.CrossEntropyLoss(
                ignore_index=ignore_index,
                label_smoothing=label_smoothing
            )
        else:
            self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)

    def forward(
        self,
        logits: torch.Tensor,
        labels: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Compute sequence modeling loss

        Args:
            logits: Model output (batch_size, seq_len, vocab_size)
            labels: Target labels (batch_size, seq_len)
            attention_mask: Attention mask (batch_size, seq_len)

        Returns:
            Loss dictionary
        """
        # Flatten tensors
        flat_logits = logits.view(-1, self.vocab_size)
        flat_labels = labels.view(-1)

        # Compute loss
        loss = self.criterion(flat_logits, flat_labels)

        # Compute accuracy
        with torch.no_grad():
            predictions = torch.argmax(flat_logits, dim=-1)
            mask = (flat_labels != self.ignore_index)
            correct = (predictions == flat_labels) & mask
            accuracy = correct.sum().float() / mask.sum().float()

        return {
            'loss': loss,
            'accuracy': accuracy
        }

ContrastiveLoss

Contrastive learning loss function.

class ContrastiveLoss(nn.Module):
    """Contrastive learning loss function"""

    def __init__(
        self,
        temperature: float = 0.1,
        margin: float = 0.2
    ):
        super().__init__()
        self.temperature = temperature
        self.margin = margin

    def forward(
        self,
        anchor: torch.Tensor,
        positive: torch.Tensor,
        negative: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute contrastive loss

        Args:
            anchor: Anchor embeddings (batch_size, embedding_dim)
            positive: Positive embeddings (batch_size, embedding_dim)
            negative: Negative embeddings (batch_size, num_negatives, embedding_dim)

        Returns:
            Contrastive loss
        """
        # Normalize embeddings
        anchor = F.normalize(anchor, dim=-1)
        positive = F.normalize(positive, dim=-1)
        negative = F.normalize(negative, dim=-1)

        # Compute similarities
        pos_sim = torch.sum(anchor * positive, dim=-1) / self.temperature
        neg_sim = torch.bmm(negative, anchor.unsqueeze(-1)).squeeze(-1) / self.temperature

        # Compute contrastive loss
        logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

        loss = F.cross_entropy(logits, labels)

        return loss

Evaluation Metrics

RecommendationMetrics

Recommendation system evaluation metrics.

class RecommendationMetrics:
    """Recommendation system evaluation metrics"""

    @staticmethod
    def recall_at_k(predictions: List[List[int]], targets: List[List[int]], k: int) -> float:
        """
        Compute Recall@K

        Args:
            predictions: Prediction lists
            targets: Target lists
            k: Top-K

        Returns:
            Recall@K value
        """
        recall_scores = []

        for pred, target in zip(predictions, targets):
            if len(target) == 0:
                continue

            top_k_pred = set(pred[:k])
            target_set = set(target)

            recall = len(top_k_pred & target_set) / len(target_set)
            recall_scores.append(recall)

        return np.mean(recall_scores) if recall_scores else 0.0

    @staticmethod
    def precision_at_k(predictions: List[List[int]], targets: List[List[int]], k: int) -> float:
        """Compute Precision@K"""
        precision_scores = []

        for pred, target in zip(predictions, targets):
            if k == 0:
                continue

            top_k_pred = set(pred[:k])
            target_set = set(target)

            precision = len(top_k_pred & target_set) / k
            precision_scores.append(precision)

        return np.mean(precision_scores) if precision_scores else 0.0

    @staticmethod
    def ndcg_at_k(predictions: List[List[int]], targets: List[List[int]], k: int) -> float:
        """Compute NDCG@K"""
        ndcg_scores = []

        for pred, target in zip(predictions, targets):
            if len(target) == 0:
                continue

            # Compute DCG
            dcg = 0
            for i, item in enumerate(pred[:k]):
                if item in target:
                    dcg += 1 / np.log2(i + 2)

            # Compute IDCG
            idcg = sum(1 / np.log2(i + 2) for i in range(min(len(target), k)))

            # Compute NDCG
            ndcg = dcg / idcg if idcg > 0 else 0
            ndcg_scores.append(ndcg)

        return np.mean(ndcg_scores) if ndcg_scores else 0.0

    @staticmethod
    def hit_rate_at_k(predictions: List[List[int]], targets: List[List[int]], k: int) -> float:
        """Compute Hit Rate@K"""
        hits = 0
        total = 0

        for pred, target in zip(predictions, targets):
            if len(target) == 0:
                continue

            top_k_pred = set(pred[:k])
            target_set = set(target)

            if len(top_k_pred & target_set) > 0:
                hits += 1
            total += 1

        return hits / total if total > 0 else 0.0

    @staticmethod
    def coverage(predictions: List[List[int]], total_items: int) -> float:
        """Compute item coverage"""
        recommended_items = set()
        for pred in predictions:
            recommended_items.update(pred)

        return len(recommended_items) / total_items

    @staticmethod
    def diversity(predictions: List[List[int]]) -> float:
        """Compute recommendation diversity (average Jaccard distance)"""
        if len(predictions) < 2:
            return 0.0

        distances = []
        for i in range(len(predictions)):
            for j in range(i + 1, len(predictions)):
                set_i = set(predictions[i])
                set_j = set(predictions[j])

                if len(set_i | set_j) > 0:
                    jaccard_sim = len(set_i & set_j) / len(set_i | set_j)
                    jaccard_dist = 1 - jaccard_sim
                    distances.append(jaccard_dist)

        return np.mean(distances) if distances else 0.0

Utility Modules

AttentionVisualization

Attention visualization tools.

class AttentionVisualization:
    """Attention weight visualization"""

    @staticmethod
    def plot_attention_heatmap(
        attention_weights: torch.Tensor,
        input_tokens: List[str],
        output_tokens: List[str],
        save_path: Optional[str] = None
    ) -> None:
        """
        Plot attention heatmap

        Args:
            attention_weights: Attention weights (seq_len_out, seq_len_in)
            input_tokens: Input token list
            output_tokens: Output token list
            save_path: Save path
        """
        import matplotlib.pyplot as plt
        import seaborn as sns

        plt.figure(figsize=(10, 8))

        # Create heatmap
        sns.heatmap(
            attention_weights.cpu().numpy(),
            xticklabels=input_tokens,
            yticklabels=output_tokens,
            cmap='Blues',
            annot=True,
            fmt='.2f'
        )

        plt.title('Attention Weights')
        plt.xlabel('Input Tokens')
        plt.ylabel('Output Tokens')
        plt.xticks(rotation=45)
        plt.yticks(rotation=0)

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        plt.show()

ModelUtils

Model utility functions.

class ModelUtils:
    """Model utility functions"""

    @staticmethod
    def count_parameters(model: nn.Module) -> int:
        """Count model parameters"""
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    @staticmethod
    def get_model_size(model: nn.Module) -> str:
        """Get model size in MB"""
        param_size = 0
        buffer_size = 0

        for param in model.parameters():
            param_size += param.nelement() * param.element_size()

        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()

        size_mb = (param_size + buffer_size) / 1024 / 1024
        return f"{size_mb:.2f} MB"

    @staticmethod
    def freeze_layers(model: nn.Module, layer_names: List[str]) -> None:
        """Freeze specified layers"""
        for name, param in model.named_parameters():
            for layer_name in layer_names:
                if layer_name in name:
                    param.requires_grad = False
                    break

    @staticmethod
    def unfreeze_layers(model: nn.Module, layer_names: List[str]) -> None:
        """Unfreeze specified layers"""
        for name, param in model.named_parameters():
            for layer_name in layer_names:
                if layer_name in name:
                    param.requires_grad = True
                    break

    @staticmethod
    def initialize_weights(model: nn.Module, init_type: str = 'xavier') -> None:
        """Initialize model weights"""
        for name, param in model.named_parameters():
            if 'weight' in name:
                if init_type == 'xavier':
                    nn.init.xavier_uniform_(param)
                elif init_type == 'kaiming':
                    nn.init.kaiming_uniform_(param)
                elif init_type == 'normal':
                    nn.init.normal_(param, mean=0, std=0.02)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

Usage Examples

Using Encoders

from genrec.modules import TransformerEncoder

# Create encoder
encoder = TransformerEncoder(
    vocab_size=1000,
    embedding_dim=512,
    num_heads=8,
    num_layers=6,
    attn_dim=2048
)

# Encode sequence
input_ids = torch.randint(0, 1000, (32, 50))  # (batch_size, seq_len)
attention_mask = torch.ones_like(input_ids)

encoded = encoder(input_ids, attention_mask)
print(f"Encoded shape: {encoded.shape}")  # (32, 50, 512)

Using Loss Functions

from genrec.modules import VQVAELoss, SequenceLoss

# VQVAE loss
vqvae_loss = VQVAELoss(commitment_cost=0.25)
losses = vqvae_loss(x, x_recon, commitment_loss, embedding_loss)

# Sequence loss
seq_loss = SequenceLoss(vocab_size=1000, label_smoothing=0.1)
losses = seq_loss(logits, labels, attention_mask)

Computing Evaluation Metrics

from genrec.modules import RecommendationMetrics

# Example data
predictions = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
targets = [[1, 3, 5], [7, 9]]

# Compute metrics
recall_5 = RecommendationMetrics.recall_at_k(predictions, targets, 5)
ndcg_5 = RecommendationMetrics.ndcg_at_k(predictions, targets, 5)
hit_rate = RecommendationMetrics.hit_rate_at_k(predictions, targets, 5)

print(f"Recall@5: {recall_5:.4f}")
print(f"NDCG@5: {ndcg_5:.4f}")
print(f"Hit Rate@5: {hit_rate:.4f}")

Model Tools

from genrec.modules import ModelUtils

# Model information
param_count = ModelUtils.count_parameters(model)
model_size = ModelUtils.get_model_size(model)

print(f"Parameters: {param_count:,}")
print(f"Model size: {model_size}")

# Freeze/unfreeze layers
ModelUtils.freeze_layers(model, ['embedding', 'pos_encoding'])
ModelUtils.unfreeze_layers(model, ['transformer'])

# Weight initialization
ModelUtils.initialize_weights(model, init_type='xavier')