Model Registry y Versionamiento

Gestión Centralizada del Ciclo de Vida de Modelos

Volver

1. Model Registry: Catálogo Centralizado de Modelos

¿Qué es un Model Registry?

Un Model Registry es un repositorio centralizado que actúa como "fuente de verdad" para todos los modelos ML de una organización. Provee:

  • Catálogo único: Todos los modelos registrados en un solo lugar
  • Versionamiento: Historial completo de evolución del modelo
  • Metadata: Métricas, hiperparámetros, artifacts, linaje de datos
  • Lifecycle management: Estados (staging → production → archived)
  • Access control: Quién puede modificar/desplegar modelos

Arquitectura de Model Registry

graph TD
    A[Data Scientists] -->|Train & Log| B[MLflow Tracking]
    B -->|Register Model| C[Model Registry]

    C --> D[Model: churn_predictor]
    D --> E[Version 1
Stage: Archived
AUC: 0.73] D --> F[Version 2
Stage: Staging
AUC: 0.81] D --> G[Version 3
Stage: Production
AUC: 0.85] G -->|Load for Serving| H[Production System] F -->|A/B Testing| I[Canary Deployment] C --> J[Artifact Store
S3/HDFS] J --> K[model.pkl] J --> L[requirements.txt] J --> M[config.yaml] style C fill:#06b6d4,stroke:#22d3ee,color:#fff style G fill:#10b981,stroke:#34d399,color:#fff style H fill:#ef4444,stroke:#f87171,color:#fff

MLflow Model Registry

MLflow es el estándar de facto para model registry en ecosistema open-source. Componentes principales:

1. Tracking Server

Registra experiments, runs, metrics, parameters durante entrenamiento

2. Model Registry

Catálogo centralizado con versionamiento y lifecycle management

3. Artifact Store

Almacenamiento de binarios (modelos, datasets, gráficas) en S3/HDFS

Código: Registrar Modelo con MLflow

import mlflow
import mlflow.spark
from mlflow.tracking import MlflowClient
from pyspark.ml import PipelineModel

# ============================================================
# PASO 1: Configurar MLflow Tracking Server
# ============================================================
mlflow.set_tracking_uri("http://mlflow-server:5000")  # Servidor central
mlflow.set_experiment("/ml-experiments/churn-prediction")

# ============================================================
# PASO 2: Iniciar run y entrenar modelo
# ============================================================
with mlflow.start_run(run_name="churn_gbt_v3") as run:

    # Obtener run_id para referencia
    run_id = run.info.run_id
    print(f"Run ID: {run_id}")

    # Entrenar modelo (supongamos pipeline ya construido)
    pipeline_model = pipeline.fit(df_train)

    # ============================================================
    # PASO 3: Logear hiperparámetros
    # ============================================================
    # Extraer hiperparámetros del modelo
    gbt = pipeline_model.stages[-1]  # Último stage es GBT

    mlflow.log_param("num_trees", gbt.getNumTrees)
    mlflow.log_param("max_depth", gbt.getMaxDepth())
    mlflow.log_param("step_size", gbt.getStepSize())
    mlflow.log_param("subsampling_rate", gbt.getSubsamplingRate())

    # Logear config completa como artifact
    import json
    config = {
        "model_type": "GradientBoostedTrees",
        "num_trees": gbt.getNumTrees,
        "max_depth": gbt.getMaxDepth(),
        "step_size": gbt.getStepSize(),
        "subsampling_rate": gbt.getSubsamplingRate(),
        "max_bins": gbt.getMaxBins()
    }
    with open("model_config.json", "w") as f:
        json.dump(config, f, indent=2)
    mlflow.log_artifact("model_config.json")

    # ============================================================
    # PASO 4: Logear métricas de evaluación
    # ============================================================
    from pyspark.ml.evaluation import BinaryClassificationEvaluator

    predictions_val = pipeline_model.transform(df_val)

    evaluator_auc = BinaryClassificationEvaluator(
        labelCol="churn",
        metricName="areaUnderROC"
    )
    auc = evaluator_auc.evaluate(predictions_val)

    evaluator_pr = BinaryClassificationEvaluator(
        labelCol="churn",
        metricName="areaUnderPR"
    )
    auc_pr = evaluator_pr.evaluate(predictions_val)

    mlflow.log_metric("auc_roc", auc)
    mlflow.log_metric("auc_pr", auc_pr)
    mlflow.log_metric("train_size", df_train.count())
    mlflow.log_metric("val_size", df_val.count())

    print(f"AUC-ROC: {auc:.4f}")
    print(f"AUC-PR: {auc_pr:.4f}")

    # ============================================================
    # PASO 5: Logear el modelo Spark ML
    # ============================================================
    # MLflow soporta nativamente Spark ML
    mlflow.spark.log_model(
        spark_model=pipeline_model,
        artifact_path="churn_pipeline",  # Path dentro del artifact store
        registered_model_name="churn_predictor"  # Nombre en el registry
    )

    print(f"✓ Modelo registrado como 'churn_predictor'")

    # ============================================================
    # PASO 6: Logear artifacts adicionales
    # ============================================================
    # Feature importance (si el modelo lo soporta)
    if hasattr(gbt, "featureImportances"):
        import pandas as pd
        feature_names = ["feature_" + str(i) for i in range(len(gbt.featureImportances))]
        importance_df = pd.DataFrame({
            "feature": feature_names,
            "importance": gbt.featureImportances.toArray()
        }).sort_values("importance", ascending=False)

        importance_df.to_csv("feature_importance.csv", index=False)
        mlflow.log_artifact("feature_importance.csv")

    # Confusion matrix
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    pred_pd = predictions_val.select("churn", "prediction").toPandas()
    cm = confusion_matrix(pred_pd["churn"], pred_pd["prediction"])

    plt.figure(figsize=(8, 6))
    plt.imshow(cm, cmap="Blues")
    plt.colorbar()
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    for i in range(2):
        for j in range(2):
            plt.text(j, i, str(cm[i, j]), ha="center", va="center")
    plt.savefig("confusion_matrix.png")
    mlflow.log_artifact("confusion_matrix.png")
    plt.close()

    # ============================================================
    # PASO 7: Tags para búsqueda y filtrado
    # ============================================================
    mlflow.set_tag("model_type", "GBT")
    mlflow.set_tag("problem_type", "binary_classification")
    mlflow.set_tag("dataset", "churn_v2024_01")
    mlflow.set_tag("production_ready", "true")

print("\n" + "="*60)
print("MODELO REGISTRADO EXITOSAMENTE")
print("="*60)
print(f"Run ID: {run_id}")
print(f"Model URI: runs:/{run_id}/churn_pipeline")
print(f"Registry Name: churn_predictor")
print(f"AUC-ROC: {auc:.4f}")

2. Versionamiento Semántico para Modelos ML

Esquema Major.Minor.Patch para Modelos

Adaptación del versionamiento semántico (SemVer) al contexto ML:

$$ \text{Version} = \text{MAJOR}.\text{MINOR}.\text{PATCH} $$
MAJOR Cambio en arquitectura del modelo o features que rompe compatibilidad (ej: cambiar de GBT a Neural Network, agregar/quitar features)
MINOR Cambio en hiperparámetros o reentrenamiento con datos nuevos (compatibilidad backward, mismo schema de input)
PATCH Corrección de bugs, optimizaciones menores, cambios en postprocesamiento sin afectar predicciones significativamente

Ejemplos de Versionamiento

v1.0.0 → v2.0.0 (MAJOR)

  • • Cambio de algoritmo: Random Forest → Neural Network
  • • Cambio de features: agregó 10 nuevas features
  • • Cambio de output: de probabilidad binaria a multi-clase
  • Impacto: Clientes deben actualizar pipelines de inferencia

v2.0.0 → v2.1.0 (MINOR)

  • • Reentrenamiento con 6 meses adicionales de datos
  • • Ajuste de hiperparámetros: maxDepth 10 → 15
  • • Mejora de AUC: 0.82 → 0.85
  • Impacto: Drop-in replacement, mismo schema

v2.1.0 → v2.1.1 (PATCH)

  • • Corregir bug en normalización de edad (dividía por 10 en vez de 100)
  • • Optimización de código (reduce latencia 20ms → 15ms)
  • Impacto: Transparente, sin cambios en API

Código: Gestión de Versiones en MLflow

from mlflow.tracking import MlflowClient

client = MlflowClient()

# ============================================================
# LISTAR TODAS LAS VERSIONES DE UN MODELO
# ============================================================
model_name = "churn_predictor"

versions = client.search_model_versions(f"name='{model_name}'")

print(f"Modelo: {model_name}")
print(f"Total de versiones: {len(versions)}\n")

for mv in versions:
    print(f"Version {mv.version}:")
    print(f"  - Stage: {mv.current_stage}")
    print(f"  - Run ID: {mv.run_id}")
    print(f"  - Created: {mv.creation_timestamp}")
    print(f"  - Description: {mv.description}")
    print()

# Output:
# Modelo: churn_predictor
# Total de versiones: 5
#
# Version 5:
#   - Stage: Production
#   - Run ID: abc123...
#   - Created: 1704067200000
#   - Description: v2.1.0 - Retrained with Q4 2024 data, AUC 0.85
#
# Version 4:
#   - Stage: Archived
#   - Run ID: def456...
#   - Created: 1701388800000
#   - Description: v2.0.0 - New architecture with feature expansion

# ============================================================
# TRANSICIONAR MODELO A PRODUCTION
# ============================================================
# Promover version 5 a producción
client.transition_model_version_stage(
    name=model_name,
    version=5,
    stage="Production",
    archive_existing_versions=True  # Archivar versión anterior en production
)

print(f"✓ Version 5 promovida a Production")
print(f"✓ Versiones anteriores en Production archivadas")

# ============================================================
# AGREGAR DESCRIPCIÓN Y TAGS
# ============================================================
client.update_model_version(
    name=model_name,
    version=5,
    description="v2.1.0 - Major performance improvement. "
                "Retrained with 6 additional months of data (Q4 2024). "
                "AUC improved from 0.82 to 0.85. "
                "Backward compatible with v2.0.x."
)

# Agregar tags para metadatos adicionales
client.set_model_version_tag(
    name=model_name,
    version=5,
    key="semantic_version",
    value="2.1.0"
)

client.set_model_version_tag(
    name=model_name,
    version=5,
    key="training_data_version",
    value="churn_2024_q4"
)

client.set_model_version_tag(
    name=model_name,
    version=5,
    key="auc_roc",
    value="0.8523"
)

# ============================================================
# CARGAR MODELO ESPECÍFICO POR VERSION
# ============================================================
# Por versión explícita
model_v5 = mlflow.spark.load_model(f"models:/{model_name}/5")

# Por stage (siempre carga el modelo actual en Production)
model_prod = mlflow.spark.load_model(f"models:/{model_name}/Production")

# Por run_id (máxima especificidad)
model_by_run = mlflow.spark.load_model(f"runs:/abc123.../churn_pipeline")

# ============================================================
# COMPARAR VERSIONES
# ============================================================
def compare_model_versions(model_name, version1, version2):
    """Compara dos versiones de un modelo"""
    mv1 = client.get_model_version(model_name, version1)
    mv2 = client.get_model_version(model_name, version2)

    # Obtener métricas de cada run
    run1 = client.get_run(mv1.run_id)
    run2 = client.get_run(mv2.run_id)

    metrics1 = run1.data.metrics
    metrics2 = run2.data.metrics

    print(f"Comparación: v{version1} vs v{version2}")
    print("="*50)
    for metric in metrics1.keys():
        if metric in metrics2:
            diff = metrics2[metric] - metrics1[metric]
            pct_change = (diff / metrics1[metric]) * 100
            print(f"{metric}:")
            print(f"  v{version1}: {metrics1[metric]:.4f}")
            print(f"  v{version2}: {metrics2[metric]:.4f}")
            print(f"  Change: {diff:+.4f} ({pct_change:+.1f}%)")
            print()

compare_model_versions("churn_predictor", 4, 5)

# Output:
# Comparación: v4 vs v5
# ==================================================
# auc_roc:
#   v4: 0.8234
#   v5: 0.8523
#   Change: +0.0289 (+3.5%)
#
# auc_pr:
#   v4: 0.7845
#   v5: 0.8012
#   Change: +0.0167 (+2.1%)
Volver al Hub Siguiente: Deployment & Monitoring