Deployment Guide¶
This guide covers deploying genrec models in production environments.
Production Deployment¶
Model Serving with FastAPI¶
Create a REST API server for model inference:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import torch
from genrec.models.tiger import Tiger
from genrec.models.rqvae import RqVae
app = FastAPI(title="genrec API", version="1.0.0")
class RecommendationRequest(BaseModel):
user_id: int
user_history: List[int]
num_recommendations: int = 10
class RecommendationResponse(BaseModel):
user_id: int
recommendations: List[int]
scores: Optional[List[float]] = None
class ModelService:
def __init__(self, rqvae_path: str, tiger_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models
self.rqvae = RqVae.load_from_checkpoint(rqvae_path)
self.rqvae.to(self.device)
self.rqvae.eval()
self.tiger = Tiger.load_from_checkpoint(tiger_path)
self.tiger.to(self.device)
self.tiger.eval()
def get_recommendations(self, user_history: List[int], k: int) -> List[int]:
"""Generate recommendations for user"""
with torch.no_grad():
# Convert item IDs to semantic IDs
semantic_sequence = self.items_to_semantic_ids(user_history)
# Generate recommendations
input_seq = torch.tensor(semantic_sequence).unsqueeze(0).to(self.device)
generated = self.tiger.generate(input_seq, max_length=k*3) # Generate more to account for duplicates
# Convert back to item IDs and deduplicate
recommendations = self.semantic_ids_to_items(generated.squeeze().tolist())
# Remove items already in user history
recommendations = [item for item in recommendations if item not in user_history]
return recommendations[:k]
# Initialize model service
model_service = ModelService(
rqvae_path="checkpoints/rqvae.ckpt",
tiger_path="checkpoints/tiger.ckpt"
)
@app.post("/recommend", response_model=RecommendationResponse)
async def recommend(request: RecommendationRequest):
"""Generate recommendations for a user"""
try:
recommendations = model_service.get_recommendations(
request.user_history,
request.num_recommendations
)
return RecommendationResponse(
user_id=request.user_id,
recommendations=recommendations
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Docker Deployment¶
Create a Dockerfile:
FROM python:3.9-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Expose port
EXPOSE 8000
# Run the application
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
Build and run the container:
# Build the image
docker build -t generative-recommenders:latest .
# Run the container
docker run -d -p 8000:8000 \
-v /path/to/checkpoints:/app/checkpoints \
generative-recommenders:latest
Kubernetes Deployment¶
Create Kubernetes manifests:
# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: generative-recommenders
spec:
replicas: 3
selector:
matchLabels:
app: generative-recommenders
template:
metadata:
labels:
app: generative-recommenders
spec:
containers:
- name: api
image: generative-recommenders:latest
ports:
- containerPort: 8000
env:
- name: MODEL_PATH
value: "/models"
volumeMounts:
- name: model-storage
mountPath: /models
resources:
requests:
memory: "2Gi"
cpu: "1000m"
limits:
memory: "4Gi"
cpu: "2000m"
volumes:
- name: model-storage
persistentVolumeClaim:
claimName: model-pvc
---
# service.yaml
apiVersion: v1
kind: Service
metadata:
name: generative-recommenders-service
spec:
selector:
app: generative-recommenders
ports:
- protocol: TCP
port: 80
targetPort: 8000
type: LoadBalancer
Deploy to Kubernetes:
Batch Processing¶
Apache Spark Integration¶
Process large datasets with Spark:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, explode
from pyspark.sql.types import ArrayType, IntegerType
import torch
def create_spark_session():
return SparkSession.builder \
.appName("genrec") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.getOrCreate()
def broadcast_model(spark, model_path):
"""Broadcast model to all workers"""
model = Tiger.load_from_checkpoint(model_path)
model.eval()
return spark.sparkContext.broadcast(model)
def batch_recommend_udf(broadcast_model):
"""UDF for batch recommendations"""
@udf(returnType=ArrayType(IntegerType()))
def recommend(user_history):
model = broadcast_model.value
with torch.no_grad():
# Convert to tensor
input_seq = torch.tensor(user_history).unsqueeze(0)
# Generate recommendations
recommendations = model.generate(input_seq, max_length=20)
return recommendations.squeeze().tolist()
return recommend
# Main processing
spark = create_spark_session()
model_broadcast = broadcast_model(spark, "checkpoints/tiger.ckpt")
# Load user data
user_data = spark.read.parquet("s3://data/user_interactions")
# Generate recommendations
recommend_func = batch_recommend_udf(model_broadcast)
recommendations = user_data.withColumn(
"recommendations",
recommend_func(col("interaction_history"))
)
# Save results
recommendations.write.mode("overwrite").parquet("s3://output/recommendations")
Apache Airflow Pipeline¶
Create a recommendation pipeline:
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.operators.bash_operator import BashOperator
from datetime import datetime, timedelta
default_args = {
'owner': 'data-team',
'depends_on_past': False,
'start_date': datetime(2024, 1, 1),
'email_on_failure': True,
'email_on_retry': False,
'retries': 1,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'genrec_pipeline',
default_args=default_args,
description='Daily recommendation generation',
schedule_interval='@daily',
catchup=False
)
def extract_user_data(**context):
"""Extract user interaction data"""
# Implementation here
pass
def generate_recommendations(**context):
"""Generate recommendations using TIGER model"""
# Implementation here
pass
def upload_recommendations(**context):
"""Upload recommendations to recommendation service"""
# Implementation here
pass
# Define tasks
extract_task = PythonOperator(
task_id='extract_user_data',
python_callable=extract_user_data,
dag=dag
)
recommend_task = PythonOperator(
task_id='generate_recommendations',
python_callable=generate_recommendations,
dag=dag
)
upload_task = PythonOperator(
task_id='upload_recommendations',
python_callable=upload_recommendations,
dag=dag
)
# Set dependencies
extract_task >> recommend_task >> upload_task
Monitoring and Observability¶
Prometheus Metrics¶
Add metrics to your FastAPI application:
from prometheus_client import Counter, Histogram, generate_latest
import time
# Metrics
REQUEST_COUNT = Counter('recommendations_requests_total', 'Total recommendation requests')
REQUEST_LATENCY = Histogram('recommendations_request_duration_seconds', 'Request latency')
ERROR_COUNT = Counter('recommendations_errors_total', 'Total errors')
@app.middleware("http")
async def add_metrics(request, call_next):
start_time = time.time()
REQUEST_COUNT.inc()
try:
response = await call_next(request)
return response
except Exception as e:
ERROR_COUNT.inc()
raise
finally:
REQUEST_LATENCY.observe(time.time() - start_time)
@app.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint"""
return Response(generate_latest(), media_type="text/plain")
Logging Configuration¶
Set up structured logging:
import logging
import json
from datetime import datetime
class JSONFormatter(logging.Formatter):
def format(self, record):
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"message": record.getMessage(),
"module": record.module,
"function": record.funcName,
"line": record.lineno
}
if hasattr(record, 'user_id'):
log_entry['user_id'] = record.user_id
if hasattr(record, 'request_id'):
log_entry['request_id'] = record.request_id
return json.dumps(log_entry)
# Configure logging
logging.basicConfig(
level=logging.INFO,
handlers=[logging.StreamHandler()],
format='%(message)s'
)
logger = logging.getLogger(__name__)
for handler in logger.handlers:
handler.setFormatter(JSONFormatter())
Performance Optimization¶
Model Quantization¶
Reduce model size and inference time:
import torch.quantization as quantization
def quantize_model(model, example_inputs):
"""Quantize model for faster inference"""
# Prepare model for quantization
model.qconfig = quantization.get_default_qconfig('fbgemm')
model_prepared = quantization.prepare(model, inplace=False)
# Calibrate with example inputs
model_prepared.eval()
with torch.no_grad():
model_prepared(example_inputs)
# Convert to quantized model
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
# Usage
example_input = torch.randint(0, 1000, (1, 50))
quantized_tiger = quantize_model(tiger_model, example_input)
ONNX Export¶
Export models to ONNX for cross-platform deployment:
def export_to_onnx(model, example_input, output_path):
"""Export PyTorch model to ONNX"""
model.eval()
torch.onnx.export(
model,
example_input,
output_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input_ids'],
output_names=['output'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence'},
'output': {0: 'batch_size', 1: 'sequence'}
}
)
# Export models
export_to_onnx(tiger_model, example_input, "models/tiger.onnx")
TensorRT Optimization¶
Optimize models for NVIDIA GPUs:
import tensorrt as trt
def convert_onnx_to_tensorrt(onnx_path, engine_path, max_batch_size=32):
"""Convert ONNX model to TensorRT engine"""
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# Parse ONNX model
with open(onnx_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
# Build engine
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
profile = builder.create_optimization_profile()
profile.set_shape("input_ids", (1, 1), (max_batch_size, 512), (max_batch_size, 512))
config.add_optimization_profile(profile)
engine = builder.build_engine(network, config)
# Save engine
with open(engine_path, 'wb') as f:
f.write(engine.serialize())
return engine
A/B Testing Framework¶
Experiment Configuration¶
Set up A/B testing for model comparisons:
import hashlib
import random
from typing import Dict, Any
class ABTestFramework:
def __init__(self, experiments: Dict[str, Any]):
self.experiments = experiments
def get_variant(self, user_id: int, experiment_name: str) -> str:
"""Get user's variant for an experiment"""
if experiment_name not in self.experiments:
return "control"
experiment = self.experiments[experiment_name]
# Use consistent hashing for user assignment
hash_input = f"{user_id}_{experiment_name}_{experiment['salt']}"
hash_value = int(hashlib.md5(hash_input.encode()).hexdigest(), 16)
bucket = hash_value % 100
cumulative_traffic = 0
for variant, traffic in experiment['variants'].items():
cumulative_traffic += traffic
if bucket < cumulative_traffic:
return variant
return "control"
def is_user_in_experiment(self, user_id: int, experiment_name: str) -> bool:
"""Check if user is in experiment"""
if experiment_name not in self.experiments:
return False
experiment = self.experiments[experiment_name]
if not experiment.get('active', False):
return False
# Check eligibility criteria
if 'eligibility' in experiment:
# Implement eligibility logic
pass
return True
# Example experiment configuration
experiments_config = {
"model_comparison": {
"active": True,
"salt": "experiment_salt_123",
"variants": {
"control": 50, # 50% traffic
"new_model": 50 # 50% traffic
},
"eligibility": {
"min_interactions": 10
}
}
}
ab_tester = ABTestFramework(experiments_config)
@app.post("/recommend")
async def recommend_with_ab_test(request: RecommendationRequest):
"""Generate recommendations with A/B testing"""
variant = ab_tester.get_variant(request.user_id, "model_comparison")
if variant == "new_model":
# Use new model
recommendations = new_model_service.get_recommendations(
request.user_history, request.num_recommendations
)
else:
# Use control model
recommendations = model_service.get_recommendations(
request.user_history, request.num_recommendations
)
# Log experiment data
logger.info("Recommendation served", extra={
"user_id": request.user_id,
"variant": variant,
"experiment": "model_comparison"
})
return RecommendationResponse(
user_id=request.user_id,
recommendations=recommendations
)
Security Considerations¶
API Authentication¶
Add JWT authentication:
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import jwt
security = HTTPBearer()
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Verify JWT token"""
try:
payload = jwt.decode(
credentials.credentials,
SECRET_KEY,
algorithms=["HS256"]
)
return payload
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials"
)
@app.post("/recommend")
async def recommend(
request: RecommendationRequest,
token_data: dict = Depends(verify_token)
):
"""Protected recommendation endpoint"""
# Verify user access
if token_data.get("user_id") != request.user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
return await model_service.get_recommendations(request)
Input Validation¶
Validate and sanitize inputs:
from pydantic import validator
class RecommendationRequest(BaseModel):
user_id: int
user_history: List[int]
num_recommendations: int = 10
@validator('user_id')
def validate_user_id(cls, v):
if v <= 0:
raise ValueError('User ID must be positive')
return v
@validator('user_history')
def validate_user_history(cls, v):
if len(v) > 1000: # Limit history length
raise ValueError('User history too long')
if any(item <= 0 for item in v):
raise ValueError('Invalid item IDs in history')
return v
@validator('num_recommendations')
def validate_num_recommendations(cls, v):
if not 1 <= v <= 100:
raise ValueError('Number of recommendations must be between 1 and 100')
return v
This deployment guide covers the essential aspects of deploying genrec models in production, from basic API serving to advanced optimization and monitoring techniques.