Data Warehouse Integration for AI

Executive Summary

Data warehouses and lakehouse platforms — Snowflake, BigQuery, Databricks — are the primary stores for historical clinical and operational data that AI systems need for analytics AI use cases: population health risk scoring, quality measure reporting, cost analysis, and retrospective clinical research. Integrating AI with these platforms requires understanding both the data access patterns (batch extraction, in-warehouse inference, vector search extensions) and the governance requirements that apply when AI accesses large-scale patient datasets. This chapter covers the integration patterns for connecting AI systems to enterprise data warehouse infrastructure.

Learning Objectives

  • Design data extraction pipelines from Snowflake, BigQuery, and Databricks for AI processing
  • Apply in-warehouse AI inference using platform-native ML functions to avoid data movement
  • Implement row-level security and data masking for PHI in warehoused clinical datasets
  • Choose between data movement to AI and AI movement to data based on scale and compliance constraints

Business Problem

A Reference Healthcare Organization's data warehouse contains 10 years of de-identified clinical encounter data for hundreds of thousands of patients — the dataset that powers population health analysis, quality measure calculation, and clinical research. An AI system that can query this dataset to identify patients at risk of readmission, predict ED surge demand, or surface care gaps across the patient population provides value that real-time EHR integration cannot — because it operates across the entire patient population, not just the current encounter.

The integration challenge is accessing this data at AI-relevant scale: extracting 100,000 patient records for a population health model cannot use the same FHIR API patterns as real-time clinical decision support.

Architecture

Snowflake Integration

python
import snowflake.connector
from contextlib import contextmanager
from typing import Iterator
import pandas as pd

# Educational example — not for clinical use

@contextmanager
def snowflake_connection(
    account: str,
    user: str,
    private_key_path: str,      # Key-pair auth (no passwords in config)
    warehouse: str,
    database: str,
    schema: str,
    role: str,
) -> Iterator[snowflake.connector.SnowflakeConnection]:
    """
    Create a Snowflake connection with key-pair authentication.
    
    Key-pair auth is preferred over password auth for service accounts
    because private keys can be rotated without embedding credentials
    in configuration.
    
    Educational example — not for clinical use.
    """
    from cryptography.hazmat.primitives import serialization
    from cryptography.hazmat.backends import default_backend
    
    with open(private_key_path, "rb") as key_file:
        private_key = serialization.load_pem_private_key(
            key_file.read(),
            password=None,
            backend=default_backend()
        )
    
    private_key_bytes = private_key.private_bytes(
        encoding=serialization.Encoding.DER,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )
    
    conn = snowflake.connector.connect(
        account=account,
        user=user,
        private_key=private_key_bytes,
        warehouse=warehouse,
        database=database,
        schema=schema,
        role=role,
    )
    
    try:
        yield conn
    finally:
        conn.close()


def extract_population_for_risk_scoring(
    conn: snowflake.connector.SnowflakeConnection,
    target_date: str,
    risk_model_id: str,
) -> pd.DataFrame:
    """
    Extract patient population for readmission risk scoring.
    
    Uses parameterized queries to prevent SQL injection.
    Applies ROW ACCESS POLICY on the underlying table to enforce
    data access restrictions without application-layer filtering.
    
    Educational example — not for clinical use.
    """
    query = """
        SELECT
            p.patient_id,
            p.age_at_encounter,
            p.primary_diagnosis_code,
            p.comorbidity_count,
            p.prior_admissions_12_months,
            p.length_of_stay_days,
            p.discharge_disposition_code,
            -- De-identification flag: verify with privacy officer before use
            p.is_deidentified
        FROM patient_encounter_facts p
        WHERE
            p.discharge_date = %(target_date)s
            AND p.encounter_class = 'inpatient'
            AND p.is_deidentified = TRUE   -- Only de-identified records for AI batch scoring
        ORDER BY p.patient_id
    """
    
    cursor = conn.cursor()
    cursor.execute(query, {"target_date": target_date})
    
    columns = [desc[0].lower() for desc in cursor.description]
    rows = cursor.fetchall()
    
    return pd.DataFrame(rows, columns=columns)


def write_risk_scores_to_snowflake(
    conn: snowflake.connector.SnowflakeConnection,
    risk_scores: list[dict],
    model_version: str,
) -> int:
    """
    Write AI-generated risk scores back to Snowflake for downstream use.
    Includes model_version for reproducibility and audit.
    
    Educational example — not for clinical use.
    """
    cursor = conn.cursor()
    
    # Use staging table + MERGE to ensure idempotency
    cursor.execute("CREATE TEMP TABLE IF NOT EXISTS risk_score_staging (LIKE ai_risk_scores)")
    
    cursor.executemany(
        """
        INSERT INTO risk_score_staging 
            (patient_id, encounter_id, risk_score, risk_tier, model_version, scored_at)
        VALUES (%(patient_id)s, %(encounter_id)s, %(risk_score)s, %(risk_tier)s, %(model_version)s, CURRENT_TIMESTAMP())
        """,
        [{**score, "model_version": model_version} for score in risk_scores]
    )
    
    cursor.execute("""
        MERGE INTO ai_risk_scores AS target
        USING risk_score_staging AS source
            ON target.patient_id = source.patient_id AND target.encounter_id = source.encounter_id
        WHEN MATCHED THEN UPDATE SET
            risk_score = source.risk_score,
            risk_tier = source.risk_tier,
            model_version = source.model_version,
            scored_at = source.scored_at
        WHEN NOT MATCHED THEN INSERT
            (patient_id, encounter_id, risk_score, risk_tier, model_version, scored_at)
            VALUES (source.patient_id, source.encounter_id, source.risk_score, 
                    source.risk_tier, source.model_version, source.scored_at)
    """)
    
    conn.commit()
    return len(risk_scores)

BigQuery Integration

python
from google.cloud import bigquery
from typing import Iterator

# Educational example — not for clinical use

def stream_patient_cohort_for_quality_measures(
    project_id: str,
    dataset_id: str,
    measurement_year: int,
) -> Iterator[dict]:
    """
    Stream patient cohort from BigQuery for HEDIS quality measure AI extraction.
    
    Uses streaming reads (client.query iterator) to avoid loading the
    entire cohort into memory — appropriate for large population queries.
    
    Educational example — not for clinical use.
    """
    client = bigquery.Client(project=project_id)
    
    # Parameterized query via BigQuery query parameters
    query = """
        SELECT
            patient_key,
            age_in_years,
            gender_code,
            primary_care_provider_npi,
            continuous_enrollment_flag,
            diabetes_flag,
            hypertension_flag,
            -- Exclude fields that could re-identify
            -- even in this "de-identified" dataset
        FROM `{project}.{dataset}.member_eligibility`
        WHERE
            measurement_year = @measurement_year
            AND continuous_enrollment_flag = TRUE
        ORDER BY patient_key
    """.format(project=project_id, dataset=dataset_id)
    
    job_config = bigquery.QueryJobConfig(
        query_parameters=[
            bigquery.ScalarQueryParameter("measurement_year", "INT64", measurement_year),
        ]
    )
    
    query_job = client.query(query, job_config=job_config)
    
    for row in query_job:
        yield dict(row)


def use_bigquery_ml_for_classification(
    project_id: str,
    dataset_id: str,
    patient_features_table: str,
    model_name: str,
) -> str:
    """
    Use BigQuery ML for in-warehouse inference.
    
    Avoids moving patient data out of BigQuery for scoring —
    the model runs inside BigQuery, results land in a BigQuery table.
    This is the preferred pattern when data cannot leave the warehouse
    for compliance reasons.
    
    Educational example — not for clinical use.
    """
    client = bigquery.Client(project=project_id)
    
    query = f"""
        CREATE OR REPLACE TABLE `{project_id}.{dataset_id}.ai_predictions` AS
        SELECT
            patient_key,
            predicted_readmission_probability,
            predicted_label,
            CURRENT_TIMESTAMP() AS prediction_timestamp,
            '{model_name}' AS model_name
        FROM ML.PREDICT(
            MODEL `{project_id}.{dataset_id}.{model_name}`,
            TABLE `{project_id}.{dataset_id}.{patient_features_table}`
        )
    """
    
    job = client.query(query)
    job.result()    # Wait for completion
    
    return f"{project_id}.{dataset_id}.ai_predictions"

Enterprise Considerations

In-warehouse inference vs. data extraction: Moving 1M patient records out of the data warehouse for AI scoring is slow, expensive, and creates additional PHI exposure. Where possible, use in-warehouse AI capabilities (Snowflake Cortex, BigQuery ML, Databricks MLflow serving) to run inference inside the warehouse and land results in a warehouse table. The PHI never leaves the warehouse's access control boundary.

Row-level security for de-identified datasets: Even "de-identified" data in a clinical warehouse requires access controls. Snowflake Row Access Policies and BigQuery row-level security allow the AI platform to access only the specific cohort it is authorized to score. Never grant the AI service account unrestricted access to the entire clinical warehouse.

Model versioning and reproducibility: AI risk scores written to the data warehouse must include the model version that produced them. When a model is retrained, historical scores from the old model version must remain queryable for comparison. Use a model_version column on all AI output tables.

Common Mistakes

1. Loading the entire dataset into a Pandas DataFrame. A population health dataset of 500,000 patients does not fit in a typical application server's memory as a DataFrame. Use server-side SQL aggregation, BigQuery streaming row iterators, or Snowflake result set iteration.

2. Using SELECT * on clinical warehouse tables. Clinical warehouse tables often contain dozens of columns, many of which the AI model does not need. Always SELECT only the columns required by the model; this reduces data transfer, query cost, and PHI exposure surface.

3. Not including modelversion in AI output tables. When the risk model is retrained (monthly for clinical models), scores from the new model are written alongside scores from the old model. Without a modelversion column, there is no way to distinguish them or audit which score drove a clinical decision.

Key Takeaways

  • In-warehouse inference (Snowflake Cortex, BigQuery ML) is preferred over data extraction when data cannot leave the warehouse
  • Row-level security must be applied to clinical data warehouse tables accessed by AI services
  • All AI output tables must include model_version for reproducibility and audit
  • Stream large datasets row-by-row rather than loading into DataFrames
  • PHI in warehouse datasets requires the same HIPAA controls as PHI in transactional systems

Further Reading