Model and Dataset Versioning Strategy
When working with machine learning models, especially large language models, proper versioning is crucial for reproducibility and tracking changes over time. Here's how to implement an effective versioning strategy:
Setting Up a Model Registry
Create a structured model registry to track all your models and their versions:
#!/usr/bin/env python3
"""
Model registry system for AI projects.
This script provides utilities for versioning and tracking models
and datasets across your AI projects.
"""
import os
import json
import shutil
import hashlib
import datetime
import argparse
from typing import Dict, List, Optional, Any
class ModelRegistry:
"""
Model registry for versioning and tracking ML models.
This class provides functionality to register, version, and
track machine learning models and their associated metadata.
"""
def __init__(self, registry_path: str = "model_registry"):
"""
Initialize the model registry.
Args:
registry_path (str): Path to the registry directory
"""
self.registry_path = registry_path
self.index_file = os.path.join(registry_path, "registry_index.json")
# Create registry directory if it doesn't exist
os.makedirs(registry_path, exist_ok=True)
# Initialize or load the registry index
if os.path.exists(self.index_file):
with open(self.index_file, 'r') as f:
self.registry_index = json.load(f)
else:
self.registry_index = {
"models": {},
"datasets": {},
"last_updated": datetime.datetime.now().isoformat()
}
self._save_index()
def _save_index(self):
"""Save the registry index to disk."""
self.registry_index["last_updated"] = datetime.datetime.now().isoformat()
with open(self.index_file, 'w') as f:
json.dump(self.registry_index, f, indent=2)
def register_model(self, model_name: str, model_path: str,
version: str = None, metadata: Dict = None) -> str:
"""
Register a model in the registry.
Args:
model_name (str): Name of the model
model_path (str): Path to the model directory or file
version (str, optional): Version string (auto-generated if None)
metadata (dict, optional): Additional metadata
Returns:
str: The model version
"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model path does not exist: {model_path}")
# Generate version if not provided
if version is None:
version = self._generate_version()
# Initialize model entry if it doesn't exist
if model_name not in self.registry_index["models"]:
self.registry_index["models"][model_name] = {
"versions": {},
"latest_version": None,
"created": datetime.datetime.now().isoformat()
}
# Create model directory in registry
model_dir = os.path.join(self.registry_path, "models", model_name, version)
os.makedirs(model_dir, exist_ok=True)
# Copy model files
if os.path.isdir(model_path):
# Copy directory contents
for item in os.listdir(model_path):
s = os.path.join(model_path, item)
d = os.path.join(model_dir, item)
if os.path.isdir(s):
shutil.copytree(s, d, dirs_exist_ok=True)
else:
shutil.copy2(s, d)
else:
# Copy single file
shutil.copy2(model_path, model_dir)
# Prepare metadata
if metadata is None:
metadata = {}
# Add standard metadata
metadata.update({
"registered_at": datetime.datetime.now().isoformat(),
"original_path": model_path,
"registry_path": model_dir
})
# Calculate model size and file hash
if os.path.isdir(model_path):
total_size = sum(os.path.getsize(os.path.join(dirpath, filename))
for dirpath, _, filenames in os.walk(model_path)
for filename in filenames)
metadata["model_size_bytes"] = total_size
else:
metadata["model_size_bytes"] = os.path.getsize(model_path)
metadata["file_hash"] = self._calculate_file_hash(model_path)
# Update registry index
self.registry_index["models"][model_name]["versions"][version] = metadata
self.registry_index["models"][model_name]["latest_version"] = version
self._save_index()
print(f"Model {model_name} version {version} registered successfully.")
return version
def register_dataset(self, dataset_name: str, dataset_path: str,
version: str = None, metadata: Dict = None) -> str:
"""
Register a dataset in the registry.
Args:
dataset_name (str): Name of the dataset
dataset_path (str): Path to the dataset directory or file
version (str, optional): Version string (auto-generated if None)
metadata (dict, optional): Additional metadata
Returns:
str: The dataset version
"""
if not os.path.exists(dataset_path):
raise FileNotFoundError(f"Dataset path does not exist: {dataset_path}")
# Generate version if not provided
if version is None:
version = self._generate_version()
# Initialize dataset entry if it doesn't exist
if dataset_name not in self.registry_index["datasets"]:
self.registry_index["datasets"][dataset_name] = {
"versions": {},
"latest_version": None,
"created": datetime.datetime.now().isoformat()
}
# Create dataset directory in registry
dataset_dir = os.path.join(self.registry_path, "datasets", dataset_name, version)
os.makedirs(dataset_dir, exist_ok=True)
# Copy dataset files (or reference)
is_large_dataset = False
if os.path.isdir(dataset_path):
# Check dataset size
total_size = sum(os.path.getsize(os.path.join(dirpath, filename))
for dirpath, _, filenames in os.walk(dataset_path)
for filename in filenames)
# If dataset is larger than 1GB, just store reference
if total_size > 1_000_000_000: # 1GB
is_large_dataset = True
with open(os.path.join(dataset_dir, "dataset_reference.txt"), 'w') as f:
f.write(f"Original dataset path: {os.path.abspath(dataset_path)}\n")
f.write(f"Dataset size: {total_size / (1024**2):.2f} MB\n")
else:
# Copy directory contents
for item in os.listdir(dataset_path):
s = os.path.join(dataset_path, item)
d = os.path.join(dataset_dir, item)
if os.path.isdir(s):
shutil.copytree(s, d, dirs_exist_ok=True)
else:
shutil.copy2(s, d)
else:
# Copy single file
shutil.copy2(dataset_path, dataset_dir)
# Prepare metadata
if metadata is None:
metadata = {}
# Add standard metadata
metadata.update({
"registered_at": datetime.datetime.now().isoformat(),
"original_path": dataset_path,
"registry_path": dataset_dir,
"is_reference_only": is_large_dataset
})
# Calculate dataset statistics
if not is_large_dataset:
if os.path.isdir(dataset_path):
metadata["dataset_size_bytes"] = total_size
else:
metadata["dataset_size_bytes"] = os.path.getsize(dataset_path)
metadata["file_hash"] = self._calculate_file_hash(dataset_path)
# Update registry index
self.registry_index["datasets"][dataset_name]["versions"][version] = metadata
self.registry_index["datasets"][dataset_name]["latest_version"] = version
self._save_index()
print(f"Dataset {dataset_name} version {version} registered successfully.")
return version
def get_model(self, model_name: str, version: str = "latest") -> Optional[Dict]:
"""
Get information about a registered model.
Args:
model_name (str): Name of the model
version (str): Model version or "latest"
Returns:
dict or None: Model information
"""
if model_name not in self.registry_index["models"]:
print(f"Model {model_name} not found in registry.")
return None
if version == "latest":
version = self.registry_index["models"][model_name]["latest_version"]
if version not in self.registry_index["models"][model_name]["versions"]:
print(f"Version {version} of model {model_name} not found.")
return None
model_info = self.registry_index["models"][model_name]["versions"][version].copy()
model_info["name"] = model_name
model_info["version"] = version
return model_info
def get_dataset(self, dataset_name: str, version: str = "latest") -> Optional[Dict]:
"""
Get information about a registered dataset.
Args:
dataset_name (str): Name of the dataset
version (str): Dataset version or "latest"
Returns:
dict or None: Dataset information
"""
if dataset_name not in self.registry_index["datasets"]:
print(f"Dataset {dataset_name} not found in registry.")
return None
if version == "latest":
version = self.registry_index["datasets"][dataset_name]["latest_version"]
if version not in self.registry_index["datasets"][dataset_name]["versions"]:
print(f"Version {version} of dataset {dataset_name} not found.")
return None
dataset_info = self.registry_index["datasets"][dataset_name]["versions"][version].copy()
dataset_info["name"] = dataset_name
dataset_info["version"] = version
return dataset_info
def list_models(self) -> List[Dict]:
"""
List all registered models.
Returns:
list: List of model information dictionaries
"""
models = []
for model_name, model_data in self.registry_index["models"].items():
latest_version = model_data["latest_version"]
if latest_version:
model_info = self.get_model(model_name, latest_version)
models.append(model_info)
return models
def list_datasets(self) -> List[Dict]:
"""
List all registered datasets.
Returns:
list: List of dataset information dictionaries
"""
datasets = []
for dataset_name, dataset_data in self.registry_index["datasets"].items():
latest_version = dataset_data["latest_version"]
if latest_version:
dataset_info = self.get_dataset(dataset_name, latest_version)
datasets.append(dataset_info)
return datasets
def _generate_version(self) -> str:
"""
Generate a version string based on date and time.
Returns:
str: Version string (format: v{year}.{month}.{day}.{hour}{minute})
"""
now = datetime.datetime.now()
return f"v{now.year}.{now.month:02d}.{now.day:02d}.{now.hour:02d}{now.minute:02d}"
def _calculate_file_hash(self, file_path: str) -> str:
"""
Calculate SHA-256 hash of a file.
Args:
file_path (str): Path to the file
Returns:
str: SHA-256 hash hexadecimal string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
# Read the file in chunks to handle large files
for chunk in iter(lambda: f.read(4096), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
def main():
"""Main function for the model registry utility."""
parser = argparse.ArgumentParser(description="Model and dataset versioning utility")
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
# Register model command
register_model_parser = subparsers.add_parser("register-model", help="Register a model")
register_model_parser.add_argument("--name", required=True, help="Model name")
register_model_parser.add_argument("--path", required=True, help="Path to model directory or file")
register_model_parser.add_argument("--version", help="Version string (optional)")
register_model_parser.add_argument("--metadata", help="JSON metadata string (optional)")
# Register dataset command
register_dataset_parser = subparsers.add_parser("register-dataset", help="Register a dataset")
register_dataset_parser.add_argument("--name", required=True, help="Dataset name")
register_dataset_parser.add_argument("--path", required=True, help="Path to dataset directory or file")
register_dataset_parser.add_argument("--version", help="Version string (optional)")
register_dataset_parser.add_argument("--metadata", help="JSON metadata string (optional)")
# List models command
subparsers.add_parser("list-models", help="List all registered models")
# List datasets command
subparsers.add_parser("list-datasets", help="List all registered datasets")
# Get model command
get_model_parser = subparsers.add_parser("get-model", help="Get model information")
get_model_parser.add_argument("--name", required=True, help="Model name")
get_model_parser.add_argument("--version", default="latest", help="Model version (default: latest)")
# Get dataset command
get_dataset_parser = subparsers.add_parser("get-dataset", help="Get dataset information")
get_dataset_parser.add_argument("--name", required=True, help="Dataset name")
get_dataset_parser.add_argument("--version", default="latest", help="Dataset version (default: latest)")
# Parse arguments
args = parser.parse_args()
# Create registry
registry = ModelRegistry()
# Execute command
if args.command == "register-model":
metadata = json.loads(args.metadata) if args.metadata else None
registry.register_model(args.name, args.path, args.version, metadata)
elif args.command == "register-dataset":
metadata = json.loads(args.metadata) if args.metadata else None
registry.register_dataset(args.name, args.path, args.version, metadata)
elif args.command == "list-models":
models = registry.list_models()
print(json.dumps(models, indent=2))
elif args.command == "list-datasets":
datasets = registry.list_datasets()
print(json.dumps(datasets, indent=2))
elif args.command == "get-model":
model_info = registry.get_model(args.name, args.version)
if model_info:
print(json.dumps(model_info, indent=2))
elif args.command == "get-dataset":
dataset_info = registry.get_dataset(args.name, args.version)
if dataset_info:
print(json.dumps(dataset_info, indent=2))
else:
parser.print_help()
if __name__ == "__main__":
main()
Creating a Git-based Version Control System for Models
You can also use Git LFS (Large File Storage) to version your models directly in Git:
# Setup Git LFS repository
cat > setup_git_lfs.sh << 'EOF'
#!/bin/bash
# Script to set up Git LFS for model versioning
# Check if Git LFS is installed
if ! command -v git-lfs &> /dev/null; then
echo "Git LFS is not installed. Installing..."
brew install git-lfs
fi
# Initialize Git LFS
git lfs install
# Track model files
git lfs track "*.bin"
git lfs track "*.gguf"
git lfs track "*.pt"
git lfs track "*.pth"
git lfs track "*.onnx"
git lfs track "*.mlpackage"
git lfs track "models/**/*"
# Add .gitattributes
git add .gitattributes
git commit -m "Initialize Git LFS for model versioning"
# Create directory structure
mkdir -p models/checkpoints datasets/processed datasets/raw
# Create a README for the models directory
cat > models/README.md << 'README'
# Model Directory
This directory contains versioned ML models.
## Structure
- `checkpoints/`: Training checkpoints
- `production/`: Production-ready models
- `quantized/`: Quantized versions of models
## Versioning
Models are versioned using semantic versioning:
- Major version: Architectural changes
- Minor version: Training improvements
- Patch version: Bug fixes or small adjustments
Example: `chatbot-v1.2.3` represents:
- Version 1 architecture
- Training improvement iteration 2
- Bug fix or adjustment 3
## Usage
Each model directory contains:
- Model weights
- `config.json` with hyperparameters
- `metadata.json` with version info
- `metrics.json` with performance metrics
README
git add models/README.md
git commit -m "Add model directory structure and documentation"
echo "Git LFS setup complete for model versioning!"
EOF
chmod +x setup_git_lfs.sh
Semantic Versioning for Models
Create a utility to manage semantic versioning for your models:
#!/usr/bin/env python3
"""
Semantic versioning utility for ML models.
This script helps manage semantic versioning for machine learning models,
following the MAJOR.MINOR.PATCH convention with specific meaning for ML.
"""
import os
import re
import json
import argparse
from datetime import datetime
class MLVersioning:
"""
Semantic versioning for machine learning models.
MAJOR: Architecture changes
MINOR: Training improvements
PATCH: Bug fixes and minor adjustments
"""
@staticmethod
def parse_version(version_str):
"""Parse version string into components."""
match = re.match(r'v?(\d+)\.(\d+)\.(\d+)', version_str)
if not match:
raise ValueError(f"Invalid version format: {version_str}")
major, minor, patch = map(int, match.groups())
return major, minor, patch
@staticmethod
def format_version(major, minor, patch):
"""Format version components into a string."""
return f"v{major}.{minor}.{patch}"
@staticmethod
def increment_version(version_str, level="patch"):
"""
Increment version at specified level.
Args:
version_str (str): Version string
level (str): Level to increment (major, minor, patch)
Returns:
str: New version string
"""
major, minor, patch = MLVersioning.parse_version(version_str)
if level == "major":
return MLVersioning.format_version(major + 1, 0, 0)
elif level == "minor":
return MLVersioning.format_version(major, minor + 1, 0)
elif level == "patch":
return MLVersioning.format_version(major, minor, patch + 1)
else:
raise ValueError(f"Invalid increment level: {level}")
@staticmethod
def update_model_version(model_dir, level="patch", metadata=None):
"""
Update model version in metadata file.
Args:
model_dir (str): Model directory
level (str): Level to increment
metadata (dict, optional): Additional metadata
Returns:
str: New version string
"""
metadata_file = os.path.join(model_dir, "metadata.json")
# Load existing metadata or create new
if os.path.exists(metadata_file):
with open(metadata_file, 'r') as f:
model_metadata = json.load(f)
current_version = model_metadata.get("version", "v0.0.0")
new_version = MLVersioning.increment_version(current_version, level)
else:
# Start with v0.1.0 for new models
if level == "major":
new_version = "v1.0.0"
elif level == "minor":
new_version = "v0.1.0"
else:
new_version = "v0.0.1"
model_metadata = {}
# Update metadata
model_metadata["version"] = new_version
model_metadata["last_updated"] = datetime.now().isoformat()
# Add additional metadata
if metadata:
model_metadata.update(metadata)
# Save metadata
with open(metadata_file, 'w') as f:
json.dump(model_metadata, f, indent=2)
return new_version
@staticmethod
def tag_model_version(model_dir, version=None):
"""
Create a Git tag for a model version.
Args:
model_dir (str): Model directory
version (str, optional): Version to tag
Returns:
bool: True if successful
"""
try:
# Get model name from directory
model_name = os.path.basename(model_dir)
# Get version from metadata if not provided
if not version:
metadata_file = os.path.join(model_dir, "metadata.json")
if not os.path.exists(metadata_file):
print(f"Metadata file not found: {metadata_file}")
return False
with open(metadata_file, 'r') as f:
metadata = json.load(f)
version = metadata.get("version")
if not version:
print("Version not found in metadata")
return False
# Create a tag name
tag_name = f"model-{model_name}-{version}"
# Create git tag
commit_msg = f"Model {model_name} version {version}"
os.system(f'git tag -a "{tag_name}" -m "{commit_msg}"')
print(f"Created git tag: {tag_name}")
return True
except Exception as e:
print(f"Error creating git tag: {e}")
return False
def main():
"""Main function for the versioning utility."""
parser = argparse.ArgumentParser(description="Semantic versioning for ML models")
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
# Increment version command
increment_parser = subparsers.add_parser("increment", help="Increment model version")
increment_parser.add_argument("--model-dir", required=True, help="Model directory")
increment_parser.add_argument("--level", choices=["major", "minor", "patch"],
default="patch", help="Version level to increment")
increment_parser.add_argument("--metadata", help="Additional metadata as JSON string")
# Tag version command
tag_parser = subparsers.add_parser("tag", help="Create git tag for model version")
tag_parser.add_argument("--model-dir", required=True, help="Model directory")
tag_parser.add_argument("--version", help="Version string (read from metadata if not provided)")
# Parse arguments
args = parser.parse_args()
# Execute command
if args.command == "increment":
metadata = json.loads(args.metadata) if args.metadata else None
new_version = MLVersioning.update_model_version(args.model_dir, args.level, metadata)
print(f"Updated model version to {new_version}")
elif args.command == "tag":
MLVersioning.tag_model_version(args.model_dir, args.version)
else:
parser.print_help()
if __name__ == "__main__":
main()
Tracking Experiments and Results
Create a utility to track your model training experiments and results:
#!/usr/bin/env python3
"""
Experiment tracking utility for ML models.
This script provides a simple system for tracking ML experiments
without requiring external dependencies or cloud services.
"""
import os
import json
import argparse
import uuid
import platform
import psutil
from datetime import datetime
class ExperimentTracker:
"""Simple experiment tracker for machine learning models."""
def __init__(self, experiments_dir="experiments"):
"""
Initialize the experiment tracker.
Args:
experiments_dir (str): Directory to store experiment data
"""
self.experiments_dir = experiments_dir
os.makedirs(experiments_dir, exist_ok=True)
# Create runs subdirectory
self.runs_dir = os.path.join(experiments_dir, "runs")
os.makedirs(self.runs_dir, exist_ok=True)
# Load or create experiments index
self.index_file = os.path.join(experiments_dir, "experiments_index.json")
if os.path.exists(self.index_file):
with open(self.index_file, 'r') as f:
self.index = json.load(f)
else:
self.index = {
"experiments": {},
"last_updated": datetime.now().isoformat()
}
self._save_index()
def _save_index(self):
"""Save the experiments index to disk."""
self.index["last_updated"] = datetime.now().isoformat()
with open(self.index_file, 'w') as f:
json.dump(self.index, f, indent=2)
def create_experiment(self, name, description=None, metadata=None):
"""
Create a new experiment.
Args:
name (str): Experiment name
description (str, optional): Experiment description
metadata (dict, optional): Additional metadata
Returns:
str: Experiment ID
"""
experiment_id = str(uuid.uuid4())
# Create experiment entry
experiment = {
"id": experiment_id,
"name": name,
"description": description or "",
"created_at": datetime.now().isoformat(),
"runs": [],
"best_run": None,
"metadata": metadata or {}
}
# Save to index
self.index["experiments"][experiment_id] = experiment
self._save_index()
# Create experiment directory
experiment_dir = os.path.join(self.experiments_dir, experiment_id)
os.makedirs(experiment_dir, exist_ok=True)
# Save experiment info
with open(os.path.join(experiment_dir, "experiment.json"), 'w') as f:
json.dump(experiment, f, indent=2)
print(f"Created experiment: {name} (ID: {experiment_id})")
return experiment_id
def start_run(self, experiment_id, run_name=None, metadata=None):
"""
Start a new run for an experiment.
Args:
experiment_id (str): Experiment ID
run_name (str, optional): Run name
metadata (dict, optional): Additional metadata
Returns:
str: Run ID
"""
if experiment_id not in self.index["experiments"]:
raise ValueError(f"Experiment not found: {experiment_id}")
# Generate run ID and name
run_id = str(uuid.uuid4())
if run_name is None:
experiment_name = self.index["experiments"][experiment_id]["name"]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"{experiment_name}_{timestamp}"
# System info
system_info = {
"platform": platform.platform(),
"python_version": platform.python_version(),
"processor": platform.processor(),
"memory_gb": psutil.virtual_memory().total / (1024**3)
}
# Create run entry
run = {
"id": run_id,
"experiment_id": experiment_id,
"name": run_name,
"status": "running",
"started_at": datetime.now().isoformat(),
"completed_at": None,
"duration_seconds": None,
"params": {},
"metrics": {},
"artifacts": [],
"system_info": system_info,
"metadata": metadata or {}
}
# Create run directory
run_dir = os.path.join(self.runs_dir, run_id)
os.makedirs(run_dir, exist_ok=True)
os.makedirs(os.path.join(run_dir, "artifacts"), exist_ok=True)
# Save run info
with open(os.path.join(run_dir, "run.json"), 'w') as f:
json.dump(run, f, indent=2)
# Update experiment
self.index["experiments"][experiment_id]["runs"].append(run_id)
self._save_index()
print(f"Started run: {run_name} (ID: {run_id})")
return run_id
def log_params(self, run_id, params):
"""
Log parameters for a run.
Args:
run_id (str): Run ID
params (dict): Parameters to log
"""
run_file = os.path.join(self.runs_dir, run_id, "run.json")
if not os.path.exists(run_file):
raise ValueError(f"Run not found: {run_id}")
# Load run data
with open(run_file, 'r') as f:
run = json.load(f)
# Update parameters
run["params"].update(params)
# Save run data
with open(run_file, 'w') as f:
json.dump(run, f, indent=2)
print(f"Logged {len(params)} parameters for run: {run_id}")
def log_metrics(self, run_id, metrics, step=None):
"""
Log metrics for a run.
Args:
run_id (str): Run ID
metrics (dict): Metrics to log
step (int, optional): Step number
"""
run_file = os.path.join(self.runs_dir, run_id, "run.json")
if not os.path.exists(run_file):
raise ValueError(f"Run not found: {run_id}")
# Load run data
with open(run_file, 'r') as f:
run = json.load(f)
# Create metrics file if needed
metrics_file = os.path.join(self.runs_dir, run_id, "metrics.json")
if os.path.exists(metrics_file):
with open(metrics_file, 'r') as f:
metrics_data = json.load(f)
else:
metrics_data = []
# Add step information
metrics_entry = {
"timestamp": datetime.now().isoformat(),
"metrics": metrics
}
if step is not None:
metrics_entry["step"] = step
metrics_data.append(metrics_entry)
# Save metrics data
with open(metrics_file, 'w') as f:
json.dump(metrics_data, f, indent=2)
# Update latest metrics in run
for metric_name, value in metrics.items():
run["metrics"][metric_name] = value
# Save run data
with open(run_file, 'w') as f:
json.dump(run, f, indent=2)
print(f"Logged {len(metrics)} metrics for run: {run_id}")
def log_artifact(self, run_id, artifact_path, name=None):
"""
Log an artifact for a run.
Args:
run_id (str): Run ID
artifact_path (str): Path to artifact file
name (str, optional): Artifact name
"""
run_file = os.path.join(self.runs_dir, run_id, "run.json")
if not os.path.exists(run_file):
raise ValueError(f"Run not found: {run_id}")
if not os.path.exists(artifact_path):
raise ValueError(f"Artifact path does not exist: {artifact_path}")
# Load run data
with open(run_file, 'r') as f:
run = json.load(f)
# Copy artifact
artifacts_dir = os.path.join(self.runs_dir, run_id, "artifacts")
artifact_name = name or os.path.basename(artifact_path)
artifact_dest = os.path.join(artifacts_dir, artifact_name)
if os.path.isdir(artifact_path):
shutil.copytree(artifact_path, artifact_dest, dirs_exist_ok=True)
else:
shutil.copy2(artifact_path, artifact_dest)
# Add to artifacts list
artifact_info = {
"name": artifact_name,
"path": artifact_dest,
"type": "directory" if os.path.isdir(artifact_path) else "file",
"added_at": datetime.now().isoformat()
}
run["artifacts"].append(artifact_info)
# Save run data
with open(run_file, 'w') as f:
json.dump(run, f, indent=2)
print(f"Logged artifact: {artifact_name} for run: {run_id}")
def complete_run(self, run_id, status="completed"):
"""
Mark a run as completed.
Args:
run_id (str): Run ID
status (str): Completion status (completed, failed, etc.)
"""
run_file = os.path.join(self.runs_dir, run_id, "run.json")
if not os.path.exists(run_file):
raise ValueError(f"Run not found: {run_id}")
# Load run data
with open(run_file, 'r') as f:
run = json.load(f)
# Calculate duration
started_at = datetime.fromisoformat(run["started_at"])
completed_at = datetime.now()
duration_seconds = (completed_at - started_at).total_seconds()
# Update run data
run["status"] = status
run["completed_at"] = completed_at.isoformat()
run["duration_seconds"] = duration_seconds
# Save run data
with open(run_file, 'w') as f:
json.dump(run, f, indent=2)
# Check if this is the best run for the experiment
experiment_id = run["experiment_id"]
if experiment_id in self.index["experiments"]:
experiment = self.index["experiments"][experiment_id]
# Update best run if this one has better metrics
# (Customize this logic based on your primary metric)
if "accuracy" in run["metrics"]:
current_best = None
if experiment["best_run"]:
best_run_file = os.path.join(self.runs_dir, experiment["best_run"], "run.json")
if os.path.exists(best_run_file):
with open(best_run_file, 'r') as f:
best_run = json.load(f)
if "accuracy" in best_run["metrics"]:
current_best = best_run["metrics"]["accuracy"]
if current_best is None or run["metrics"]["accuracy"] > current_best:
experiment["best_run"] = run_id
self._save_index()
print(f"Completed run: {run_id} with status: {status} (duration: {duration_seconds:.2f}s)")
def main():
"""Main function for the experiment tracking utility."""
parser = argparse.ArgumentParser(description="Track ML experiments and results")
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
# Create experiment command
create_exp_parser = subparsers.add_parser("create-experiment", help="Create a new experiment")
create_exp_parser.add_argument("--name", required=True, help="Experiment name")
create_exp_parser.add_argument("--description", help="Experiment description")
create_exp_parser.add_argument("--metadata", help="Additional metadata as JSON string")
# Start run command
start_run_parser = subparsers.add_parser("start-run", help="Start a new run")
start_run_parser.add_argument("--experiment", required=True, help="Experiment ID")
start_run_parser.add_argument("--name", help="Run name")
start_run_parser.add_argument("--metadata", help="Additional metadata as JSON string")
# Log params command
log_params_parser = subparsers.add_parser("log-params", help="Log parameters")
log_params_parser.add_argument("--run", required=True, help="Run ID")
log_params_parser.add_argument("--params", required=True, help="Parameters as JSON string")
# Log metrics command
log_metrics_parser = subparsers.add_parser("log-metrics", help="Log metrics")
log_metrics_parser.add_argument("--run", required=True, help="Run ID")
log_metrics_parser.add_argument("--metrics", required=True, help="Metrics as JSON string")
log_metrics_parser.add_argument("--step", type=int, help="Step number")
# Log artifact command
log_artifact_parser = subparsers.add_parser("log-artifact", help="Log an artifact")
log_artifact_parser.add_argument("--run", required=True, help="Run ID")
log_artifact_parser.add_argument("--path", required=True, help="Path to artifact file")
log_artifact_parser.add_argument("--name", help="Artifact name")
# Complete run command
complete_run_parser = subparsers.add_parser("complete-run", help="Complete a run")
complete_run_parser.add_argument("--run", required=True, help="Run ID")
complete_run_parser.add_argument("--status", default="completed",
choices=["completed", "failed", "aborted"],
help="Completion status")
# Parse arguments
args = parser.parse_args()
# Initialize tracker
tracker = ExperimentTracker()
# Execute command
if args.command == "create-experiment":
metadata = json.loads(args.metadata) if args.metadata else None
tracker.create_experiment(args.name, args.description, metadata)
elif args.command == "start-run":
metadata = json.loads(args.metadata) if args.metadata else None
tracker.start_run(args.experiment, args.name, metadata)
elif args.command == "log-params":
params = json.loads(args.params)
tracker.log_params(args.run, params)
elif args.command == "log-metrics":
metrics = json.loads(args.metrics)
tracker.log_metrics(args.run, metrics, args.step)
elif args.command == "log-artifact":
tracker.log_artifact(args.run, args.path, args.name)
elif args.command == "complete-run":
tracker.complete_run(args.run, args.status)
else:
parser.print_help()
if __name__ == "__main__":
main()
By implementing these versioning strategies, you'll ensure that your models and datasets are properly tracked and versioned, making your ML projects more reproducible and maintainable over time. This approach is particularly valuable when working with large models on high-end Mac Studio and Mac Pro systems where you might be managing multiple versions of several substantial models.
The key benefits of this versioning strategy include:
- Reproducibility: You can recreate any model version exactly as it was
- Traceability: Every model change is documented and tracked
- Collaboration: Multiple team members can work on models with clear version history
- Rollback Capability: You can easily revert to previous model versions if needed
- Systematic Experimentation: Each training run is tracked with its parameters and results
These tools form a solid foundation for professional ML development on your Apple Silicon Mac.```python #!/usr/bin/env python3 """ Simple web server for AI with Mac demonstration.
This script creates a Flask web application that showcases both MLX (for text summarization) and PyTorch (for image captioning) in a single interface, demonstrating how to combine different frameworks based on their strengths. """ import os import uuid from flask import Flask, render_template, request, redirect, url_for from summarize import Summarizer from caption import ImageCaptioner
app = Flask(name, template_folder="../templates", static_folder="../static")
Initialize models
summarizer = Summarizer() captioner = ImageCaptioner()
Create upload directories
os.makedirs("../uploads", exist_ok=True)
@app.route("/") def index(): """ Render main page.
Returns:
rendered template: The index.html template
"""
return render_template("index.html")
@app.route("/summarize", methods=["POST"]) def summarize_text(): """ Summarize text input using MLX.
Processes text submitted via form and generates a summary using
the MLX-based summarizer.
Returns:
rendered template: The result.html template with summary results
"""
text = request.form.get("text", "")
if not text:
return render_template("index.html", error="Please enter some text to summarize.")
summary = summarizer.summarize(text)
return render_template("result.html",
result_type="summary",
original=text,
result=summary,
framework="MLX")
@app.route("/caption", methods=["POST"]) def caption_image(): """ Caption uploaded image using PyTorch with Metal.
Processes an uploaded image and generates a descriptive caption
using the PyTorch-based image captioner.
Returns:
rendered template: The result.html template with captioning results
"""
if "image" not in request.files:
return render_template("index.html", error="No image uploaded.")
file = request.files["image"]
if file.filename == "":
return render_template("index.html", error="No image selected.")
# Save uploaded image
filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1]
filepath = os.path.join("../uploads", filename)
file.save(filepath)
# Generate caption
caption = captioner.generate_caption(filepath)
return render_template("result.html",
result_type="caption",
image_path="../uploads/" + filename,
result=caption,
framework="PyTorch with Metal")
if name == "main": app.run(debug=True)

*This is the final part of our "AI with Mac" series. Check out the previous installments: [Part 1: Introduction](link-to-part1), [Part 2: Python and Git Setup](link-to-part2), [Part 3: Running LLMs](link-to-part3), and [Part 4: Comparing MLX and PyTorch](link-to-part4).*
## Introduction
Throughout this series, we've explored how Apple Silicon has revolutionized machine learning on Mac. We've set up our development environment, run large language models, and compared Apple's MLX framework with the more established PyTorch. Now, in our final installment, we'll bring everything together to help you make an informed decision about which approach to use for your specific ML needs.

*Figure 1: Comprehensive decision tree mapping different AI workloads, Mac hardware configurations, and use cases to the optimal framework and approach, providing a visual reference for the entire article.*
Rather than giving a one-size-fits-all answer, we'll explore different use cases and scenarios to help you choose the right tool for the job. We'll also build a practical application that combines our knowledge from previous posts, giving you a template for your own machine learning projects on Apple Silicon.
## Decision Framework: Choosing Your ML Path
Let's start by establishing a framework for deciding which approach to use based on your specific requirements:

*Figure 2: Flowchart showing decision paths based on Mac hardware type, ranging from MacBook Air to Mac Pro, with recommended frameworks, model sizes, and quantization strategies for each hardware tier.*
### Key Decision Factors
1. **Application Type**: What kind of ML task are you tackling?
2. **Hardware Capabilities**: What Mac hardware do you have? (MacBook, iMac, Mac Studio, Mac Pro)
3. **Memory Availability**: How much unified memory is available? (8GB to 512GB)
4. **Production Requirements**: Local-only or eventual deployment to other platforms?
5. **Development Timeline**: Quick prototype or long-term project?
6. **Team Experience**: Familiarity with specific frameworks?
7. **Performance Requirements**: Speed, memory, power constraints?
8. **Scale of Models**: Are you working with small (2-7B), medium (7-30B), or large (70B+) models?
### Decision Tree for Different Mac Configurations
What Mac hardware are you using? ├── MacBook Air/Pro (8-16GB RAM): │ ├── Need maximum efficiency and small models │ │ ├── Yes: MLX with 4-bit quantization │ │ └── No: Either framework (consider battery life) │ ├── MacBook Pro/iMac (32-64GB RAM): │ ├── Running medium-sized models (13-30B params) │ │ ├── Inference focus: MLX │ │ └── Training focus: PyTorch │ └── Mac Studio/Mac Pro (128-512GB RAM): ├── Running largest models (70B+) │ ├── Single model inference: MLX │ ├── Multiple concurrent models: MLX │ └── Fine-tuning or training: PyTorch │ ├── Research or production environment? │ ├── Research: Consider both frameworks │ └── Production: Depends on deployment target
The high-end Mac Studio and Mac Pro configurations with M2/M3/M4 Ultra chips (which effectively combine two M-series chips with up to 512GB of unified memory) can handle workloads previously requiring specialized server hardware or expensive GPUs, making them excellent choices for serious AI development.
Let's explore each branch of this decision tree in more detail.
## Use Case 1: Large Language Models (LLMs)
If you're primarily working with LLMs, your decision largely depends on your deployment requirements and performance needs.
### When to Choose MLX for LLMs
✅ **Recommended when**:
- You need maximum performance on Apple Silicon
- Local inference is your primary goal
- Privacy and offline operation are critical
- You're working with quantized models (4-bit, 8-bit)
- Battery efficiency matters (for MacBooks)
- You're using high-end Mac Studio/Pro with large unified memory (128-512GB)
⚠️ **Considerations**:
- Limited to Apple Silicon devices
- Smaller ecosystem of pre-built tools
- Fewer reference implementations
**Mac Studio/Pro Advantage**: The M2/M3/M4 Ultra chips in Mac Studio and Mac Pro combine two M-series chips with a massive unified memory pool (up to 512GB in top configurations), enabling you to run 70B+ parameter models at full precision or multiple large models simultaneously - something previously only possible with specialized server hardware.
### When to Choose PyTorch for LLMs
✅ **Recommended when**:
- Cross-platform compatibility is essential
- You need to use specific models not yet supported in MLX
- You're integrating with existing PyTorch pipelines
- Team familiarity with PyTorch is high
- You need advanced functionality beyond basic inference
⚠️ **Considerations**:
- Generally slower LLM inference compared to MLX
- Higher memory usage
- More complex setup for optimal performance
### LLM Application Example: Enhanced Document Q&A
Let's enhance our document Q&A system from Part 3 with more advanced features. This implementation using MLX demonstrates a practical application:
```python
#!/usr/bin/env python3
"""
Enhanced document Q&A system with vector search using MLX.
"""
import os
import re
import sys
import argparse
import numpy as np
import pypdf
from mlx_lm import generate, load
from mlx.core import array
class EnhancedDocumentQA:
"""Enhanced document Q&A system with vector search."""
def __init__(self, model_path):
"""Initialize the system."""
print(f"Loading model from {model_path}, please wait...")
self.model, self.tokenizer = load(model_path)
print("Model loaded successfully!")
self.document_chunks = []
self.chunk_embeddings = []
def load_document(self, pdf_path):
"""Load and process a document."""
print(f"Reading document: {pdf_path}")
document_text = self._extract_text_from_pdf(pdf_path)
if not document_text:
print("Failed to extract text from the document.")
return False
# Split into semantically meaningful chunks
self.document_chunks = self._split_into_semantic_chunks(document_text)
print(f"Document split into {len(self.document_chunks)} chunks")
# Create embeddings for each chunk
self._create_embeddings()
return True
def answer_question(self, question):
"""Answer a question about the loaded document."""
if not self.document_chunks:
print("No document loaded. Please load a document first.")
return None
# Create embedding for the question
question_embedding = self._embed_text(question)
# Find most relevant chunks
relevant_chunks = self._find_relevant_chunks(question_embedding, top_k=3)
# Combine relevant chunks into context
context = "\n\n".join([self.document_chunks[idx] for idx in relevant_chunks])
# Create prompt for the model
prompt = f"""Answer the question based ONLY on the following context:
Context:
{context}
Question: {question}
Answer:"""
# Generate answer
gen_config = {
"max_tokens": 500,
"temperature": 0.2,
"top_p": 0.9
}
tokens = self.tokenizer.encode(prompt)
generated_tokens = generate(self.model, self.tokenizer, tokens, gen_config)
answer = self.tokenizer.decode(generated_tokens[len(tokens):])
return answer.strip()
def _extract_text_from_pdf(self, pdf_path):
"""Extract text from a PDF file."""
try:
text = ""
with open(pdf_path, "rb") as file:
reader = pypdf.PdfReader(file)
for page in reader.pages:
text += page.extract_text() + "\n"
return text
except Exception as e:
print(f"Error extracting text from PDF: {e}")
return None
def _split_into_semantic_chunks(self, text, max_chunk_size=1500):
"""Split text into semantically meaningful chunks."""
# Split by section headers or paragraphs
sections = re.split(r'(?=\n\s*[A-Z][^a-z]*\n)|(?=\n\n)', text)
chunks = []
current_chunk = ""
for section in sections:
# Clean the section
section = section.strip()
if not section:
continue
# If adding this section exceeds max size, start a new chunk
if len(current_chunk) + len(section) > max_chunk_size:
if current_chunk:
chunks.append(current_chunk)
current_chunk = section
else:
if current_chunk:
current_chunk += "\n\n" + section
else:
current_chunk = section
# Add the last chunk
if current_chunk:
chunks.append(current_chunk)
return chunks
def _create_embeddings(self):
"""Create embeddings for all chunks."""
print("Creating embeddings for document chunks...")
self.chunk_embeddings = []
for i, chunk in enumerate(self.document_chunks):
embedding = self._embed_text(chunk)
self.chunk_embeddings.append(embedding)
if (i + 1) % 10 == 0:
print(f"Processed {i + 1}/{len(self.document_chunks)} chunks")
def _embed_text(self, text):
"""Create an embedding for a piece of text using the model."""
# For simplicity, we'll use a basic method:
# 1. Tokenize the text
# 2. Get the token IDs
# 3. Create a normalized frequency vector
tokens = self.tokenizer.encode(text)
# Create a frequency vector of token IDs
vocab_size = self.tokenizer.vocab_size
embedding = np.zeros(vocab_size)
unique_tokens, counts = np.unique(tokens, return_counts=True)
for token, count in zip(unique_tokens, counts):
if token < vocab_size:
embedding[token] = count
# Normalize the embedding
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
def _find_relevant_chunks(self, query_embedding, top_k=3):
"""Find the most relevant chunks for a query embedding."""
# Calculate cosine similarity
similarities = []
for chunk_embedding in self.chunk_embeddings:
similarity = np.dot(query_embedding, chunk_embedding)
similarities.append(similarity)
# Get indices of top-k chunks
return np.argsort(similarities)[-top_k:][::-1]
def main():
"""Main function."""
parser = argparse.ArgumentParser(description="Enhanced document Q&A system with vector search")
parser.add_argument("--model", type=str, default="models/gemma-2b-it-4bit",
help="Path to the model directory")
parser.add_argument("--pdf", type=str, required=True,
help="Path to the PDF document")
args = parser.parse_args()
if not os.path.exists(args.pdf):
print(f"Error: PDF file not found at {args.pdf}")
sys.exit(1)
qa_system = EnhancedDocumentQA(args.model)
if not qa_system.load_document(args.pdf):
print("Failed to load document.")
sys.exit(1)
print("\nEnhanced Document Q&A System (type 'exit' to quit)")
while True:
question = input("\nYour question: ")
if question.lower() in ["exit", "quit"]:
break
print("\nSearching document and generating answer...")
answer = qa_system.answer_question(question)
print(f"\nAnswer: {answer}")
if __name__ == "__main__":
main()
This enhanced document Q&A system demonstrates:
- Semantic chunking of documents
- Basic vector search using embeddings
- Context-aware question answering
Use Case 2: Computer Vision
For computer vision tasks, your decision depends on the specific application, model availability, and performance requirements.
When to Choose MLX for Computer Vision
✅ Recommended when:
- You're developing custom models from scratch
- You need maximum performance on Apple Silicon
- You're working with smaller datasets
- Memory efficiency is critical
- You have simple model architectures
⚠️ Considerations:
- Fewer pre-trained models available
- Less comprehensive documentation
- May require reimplementing existing architectures
When to Choose PyTorch for Computer Vision
✅ Recommended when:
- You need pre-trained models (ResNet, YOLO, etc.)
- You're working with complex model architectures
- You need advanced data augmentation pipelines
- You want to leverage transfer learning
- You need specialized CV libraries (torchvision, etc.)
⚠️ Considerations:
- Slightly lower performance than MLX
- Higher memory usage
- More complex setup for optimal Metal performance
Computer Vision Example: Image Classification App
Let's create a practical image classification application using PyTorch with Metal acceleration. This example demonstrates how to load an efficient pre-trained model and use it for real-time classification:
#!/usr/bin/env python3
"""
Real-time image classification using PyTorch with Metal acceleration.
"""
import os
import time
import argparse
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
class ImageClassifier:
"""Real-time image classifier using PyTorch with Metal acceleration."""
def __init__(self, model_name="mobilenet_v3_small"):
"""Initialize the classifier."""
# Check if Metal is available
if torch.backends.mps.is_available():
self.device = torch.device("mps")
print("Using Metal Performance Shaders (MPS)")
else:
self.device = torch.device("cpu")
print("Metal not available, using CPU")
# Load pre-trained model
print(f"Loading {model_name} model...")
if model_name == "mobilenet_v3_small":
self.model = models.mobilenet_v3_small(weights="DEFAULT")
elif model_name == "resnet18":
self.model = models.resnet18(weights="DEFAULT")
elif model_name == "efficientnet_b0":
self.model = models.efficientnet_b0(weights="DEFAULT")
else:
raise ValueError(f"Unsupported model: {model_name}")
self.model = self.model.to(self.device)
self.model.eval()
# Load ImageNet class labels
with open("data/imagenet_classes.txt", "r") as f:
self.class_names = [line.strip() for line in f.readlines()]
# Set up image transformation
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def classify_image(self, image_path):
"""Classify an image."""
# Load and transform image
try:
image = Image.open(image_path).convert("RGB")
except Exception as e:
print(f"Error loading image: {e}")
return None
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Perform inference
start_time = time.time()
with torch.no_grad():
output = self.model(input_tensor)
inference_time = time.time() - start_time
# Get top-5 predictions
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Convert to human-readable results
results = []
for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)):
class_name = self.class_names[idx]
results.append((class_name, prob.item()))
return results, inference_time
def ensure_imagenet_labels():
"""Ensure ImageNet class labels file exists."""
labels_file = "data/imagenet_classes.txt"
os.makedirs("data", exist_ok=True)
if not os.path.exists(labels_file):
print("Downloading ImageNet class labels...")
import urllib.request
url = "https://raw.githubusercontent.com/pytorch/vision/main/torchvision/models/imagenet_classes.txt"
urllib.request.urlretrieve(url, labels_file)
def main():
"""Main function."""
parser = argparse.ArgumentParser(description="Real-time image classification using PyTorch with Metal")
parser.add_argument("--model", type=str, default="mobilenet_v3_small",
choices=["mobilenet_v3_small", "resnet18", "efficientnet_b0"],
help="Model architecture to use")
parser.add_argument("--image", type=str, required=True,
help="Path to image for classification")
args = parser.parse_args()
ensure_imagenet_labels()
classifier = ImageClassifier(args.model)
print(f"\nClassifying image: {args.image}")
results, inference_time = classifier.classify_image(args.image)
print(f"\nResults (inference time: {inference_time*1000:.2f} ms):")
if results:
for i, (class_name, probability) in enumerate(results):
print(f"{i+1}. {class_name}: {probability*100:.2f}%")
else:
print("Failed to classify image.")
if __name__ == "__main__":
main()
This example demonstrates:
- Using PyTorch with Metal acceleration for efficient inference
- Loading pre-trained models for immediate use
- Real-time image classification with performance metrics
Use Case 3: Custom ML Models
For research projects and custom model development, your choice depends on your specific requirements and Mac configuration.
When to Choose MLX for Custom Models
✅ Recommended when:
- You're developing exclusively for Apple Silicon
- You need fine-grained control over memory usage
- You prefer a NumPy-like API with JAX-inspired functions
- You want to leverage the unified memory architecture
- You're using high-end Mac Studio/Pro and want to maximize performance
⚠️ Considerations:
- Smaller community for troubleshooting
- Fewer examples and tutorials
- More manual implementation of training infrastructure
When to Choose PyTorch for Custom Models
✅ Recommended when:
- You need access to advanced training features (distributed training, mixed precision, etc.)
- You're implementing research papers or working with academic code
- You want extensive visualization and debugging tools
- You need broad library support and extensions
- You're working across multiple hardware platforms beyond just Mac
⚠️ Considerations:
- Slightly less optimized for Apple Silicon
- More complex configuration for optimal performance
- Higher overhead for simple models
Mac Studio/Pro Training Advantage
One of the most compelling use cases for high-end Mac Studio and Mac Pro systems is training custom models:
- Unified Memory: The 128GB to 512GB unified memory in top configurations allows for larger batch sizes
- Dual-Chip Advantage: M2/M3/M4 Ultra chips effectively combine two M-series chips, providing substantially more compute
- Power Efficiency: Train models locally without the extreme power requirements of traditional GPU setups
- Cost Efficiency: For certain training tasks, Mac Studio/Pro can replace cloud compute spending
While these systems still can't match dedicated training clusters, they provide a surprisingly powerful local alternative for medium-sized model training and fine-tuning.
Custom Model Example: Time Series Forecasting
Let's create a simple time series forecasting model for stock price prediction using MLX. This demonstrates how to build a custom recurrent neural network:
#!/usr/bin/env python3
"""
Time series forecasting with MLX.
This script demonstrates how to build a custom recurrent neural network
for time series forecasting using Apple's MLX framework, optimized for
Apple Silicon hardware.
"""
import os
import argparse
import csv
import numpy as np
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
class SimpleRNN(nn.Module):
"""
Simple recurrent neural network for time series forecasting.
This RNN uses a single recurrent layer followed by a linear output layer
to predict the next value in a time series sequence.
"""
def __init__(self, input_size, hidden_size, output_size):
"""
Initialize the RNN model.
Args:
input_size (int): Size of input features (typically 1 for univariate time series)
hidden_size (int): Size of the hidden state
output_size (int): Size of the output (typically 1 for single-value prediction)
"""
super().__init__()
self.hidden_size = hidden_size
self.rnn_cell = nn.RNNCell(input_size, hidden_size)
self.linear = nn.Linear(hidden_size, output_size)
def __call__(self, x, hidden=None):
"""
Forward pass through the network.
Args:
x (mx.array): Input tensor of shape (batch_size, sequence_length, input_size)
hidden (mx.array, optional): Initial hidden state
Returns:
tuple: (output, hidden_state)
"""
# x shape: (batch_size, sequence_length, input_size)
batch_size, seq_len, _ = x.shape
if hidden is None:
hidden = mx.zeros((batch_size, self.hidden_size))
outputs = []
for t in range(seq_len):
hidden = self.rnn_cell(x[:, t, :], hidden)
outputs.append(hidden)
# Use the last hidden state for prediction
output = self.linear(outputs[-1])
return output, hidden
def load_stock_data(csv_file, target_column="Close"):
"""
Load stock price data from CSV.
Args:
csv_file (str): Path to the CSV file
target_column (str): Column name for the target price data
Returns:
tuple: (dates, prices) as lists
"""
dates = []
prices = []
with open(csv_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
dates.append(row["Date"])
prices.append(float(row[target_column]))
return dates, prices
def prepare_timeseries_data(data, sequence_length=10):
"""
Prepare time series data for training.
Args:
data (list): List of price values
sequence_length (int): Length of input sequences
Returns:
tuple: (X, y, X_mean, X_std, y_mean, y_std) - normalized data and normalization parameters
"""
X, y = [], []
for i in range(len(data) - sequence_length):
X.append(data[i:i + sequence_length])
y.append(data[i + sequence_length])
# Convert to MLX arrays
X = mx.array(np.array(X, dtype=np.float32).reshape(-1, sequence_length, 1))
y = mx.array(np.array(y, dtype=np.float32).reshape(-1, 1))
# Normalize data
X_mean = mx.mean(X)
X_std = mx.std(X)
X = (X - X_mean) / X_std
y_mean = mx.mean(y)
y_std = mx.std(y)
y = (y - y_mean) / y_std
return X, y, X_mean, X_std, y_mean, y_std
def train_model(model, X, y, epochs=100, batch_size=32, learning_rate=0.01):
"""
Train the model.
Args:
model (SimpleRNN): The model to train
X (mx.array): Input data
y (mx.array): Target data
epochs (int): Number of training epochs
batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer
Returns:
list: Training losses per epoch
"""
optimizer = optim.Adam(learning_rate=learning_rate)
num_samples = X.shape[0]
num_batches = num_samples // batch_size
# Define loss function
def loss_fn(model, X_batch, y_batch):
pred, _ = model(X_batch)
return mx.mean(mx.square(pred - y_batch))
# Define training step
def train_step(model, X_batch, y_batch):
loss, grads = mlx.value_and_grad(model, loss_fn)(model, X_batch, y_batch)
optimizer.update(model, grads)
return loss
# Training loop
losses = []
for epoch in range(epochs):
# Shuffle data
indices = np.random.permutation(num_samples)
epoch_loss = 0
for i in range(num_batches):
batch_indices = indices[i * batch_size:(i + 1) * batch_size]
X_batch = X[batch_indices]
y_batch = y[batch_indices]
loss = train_step(model, X_batch, y_batch)
epoch_loss += loss.item()
epoch_loss /= num_batches
losses.append(epoch_loss)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.6f}")
return losses
def main():
"""
Main function to run the time series forecasting application.
"""
parser = argparse.ArgumentParser(description="Time series forecasting with MLX")
parser.add_argument("--csv", type=str, required=True,
help="Path to CSV file with stock price data")
parser.add_argument("--column", type=str, default="Close",
help="Column name for price data")
parser.add_argument("--epochs", type=int, default=100,
help="Number of training epochs")
parser.add_argument("--sequence", type=int, default=10,
help="Sequence length for time series")
parser.add_argument("--predict", type=int, default=30,
help="Number of days to predict")
args = parser.parse_args()
# Load and prepare data
dates, prices = load_stock_data(args.csv, args.column)
X, y, X_mean, X_std, y_mean, y_std = prepare_timeseries_data(prices, args.sequence)
# Create and train model
model = SimpleRNN(input_size=1, hidden_size=32, output_size=1)
losses = train_model(model, X, y, epochs=args.epochs)
# Make predictions
test_data = prices[-args.sequence:]
predictions = []
current_sequence = mx.array(np.array(test_data, dtype=np.float32).reshape(1, args.sequence, 1))
current_sequence = (current_sequence - X_mean) / X_std
# Generate future predictions
hidden = None
for _ in range(args.predict):
# Make prediction
pred, hidden = model(current_sequence, hidden)
# Denormalize
pred_value = pred.item() * y_std.item() + y_mean.item()
predictions.append(pred_value)
# Update sequence (remove first element, add prediction)
new_seq = mx.concatenate([current_sequence[:, 1:, :],
mx.array([[[pred.item()]]]) * y_std / X_std + (y_mean - X_mean) / X_std],
axis=1)
current_sequence = new_seq
# Plot results
plt.figure(figsize=(12, 6))
# Plot historical data
plt.plot(range(len(prices)), prices, label="Historical Data")
# Plot predictions
forecast_range = range(len(prices) - 1, len(prices) + args.predict - 1)
plt.plot(forecast_range, predictions, label="Forecast", color="red")
# Add vertical line at prediction start
plt.axvline(x=len(prices) - 1, color="gray", linestyle="--")
plt.title(f"Stock Price Forecast ({args.column})")
plt.xlabel("Time")
plt.ylabel("Price")
plt.legend()
plt.tight_layout()
# Save plot
output_file = f"stock_forecast_{os.path.basename(args.csv).split('.')[0]}.png"
plt.savefig(output_file)
print(f"Forecast saved to {output_file}")
# Print predictions
print("\nPredictions for the next", args.predict, "days:")
for i, pred in enumerate(predictions):
print(f"Day {i+1}: {pred:.2f}")
if __name__ == "__main__":
main()
This example demonstrates:
- Building a custom RNN model with MLX
- Processing time series data
- Making forecasts with trained models
- Visualizing predictions
Building a Comprehensive Project
Now, let's combine our knowledge to build a more comprehensive project that demonstrates how to choose between MLX and PyTorch based on the specific task. This project will include:
Figure 3: System architecture diagram showing how text summarization (MLX) and image captioning (PyTorch) components work together in our sample application, including data flow, API endpoints, and user interface components.
- Text summarization (using MLX)
- Image captioning (using PyTorch)
- A simple web interface to interact with both models
Project Structure
Create the following files and directories:
ai-with-mac/
├── models/
├── data/
├── static/
│ └── css/
│ └── style.css
├── templates/
│ ├── index.html
│ └── result.html
├── scripts/
│ ├── summarize.py
│ ├── caption.py
│ └── server.py
└── requirements.txt
1. Text Summarization with MLX
Create the file scripts/summarize.py:
#!/usr/bin/env python3
"""
Text summarization using MLX.
This script provides a text summarization class that uses Apple's MLX framework
to run language models efficiently on Apple Silicon for generating concise summaries.
"""
import os
import mlx.core as mx
from mlx_lm import generate, load
class Summarizer:
"""
Text summarizer using MLX.
This class handles the loading of an MLX language model and uses it
to generate summaries of input text.
"""
def __init__(self, model_path="models/gemma-2b-it-4bit"):
"""
Initialize the summarizer with an MLX model.
Args:
model_path (str): Path to the directory containing the model files
"""
# Ensure model directory exists
os.makedirs(os.path.dirname(model_path), exist_ok=True)
# Check if model exists, download if needed
if not os.path.exists(model_path):
print(f"Model not found at {model_path}. Please download it first using:")
print(f"python -m mlx_lm.convert --hf-path google/gemma-2b-it -q --out-path {model_path}")
return
print(f"Loading summarization model from {model_path}...")
self.model, self.tokenizer = load(model_path)
print("Summarization model loaded successfully!")
def summarize(self, text, max_length=200, temperature=0.3):
"""
Summarize the given text.
Args:
text (str): Text to summarize
max_length (int): Maximum length of the summary in tokens
temperature (float): Temperature parameter for generation (0.0-1.0)
Returns:
str: Generated summary
"""
prompt = f"""Please summarize the following text concisely:
Text: {text}
Summary:"""
# Generate summary
gen_config = {
"max_tokens": max_length,
"temperature": temperature,
"top_p": 0.9
}
tokens = self.tokenizer.encode(prompt)
generated_tokens = generate(self.model, self.tokenizer, tokens, gen_config)
summary = self.tokenizer.decode(generated_tokens[len(tokens):])
return summary.strip()
# Example usage
if __name__ == "__main__":
summarizer = Summarizer()
# Example text
text = """
Apple Silicon has fundamentally changed what's possible for machine learning on
consumer-grade hardware. With the M1, M2, M3, and now M4 chips, Mac users can run
sophisticated AI models locally with impressive performance. This unified memory
architecture eliminates the traditional bottleneck of data transfers between CPU and
GPU memory, allowing both processors to access the same physical memory seamlessly.
Whether you're a data scientist, AI researcher, app developer, or just an enthusiast
exploring machine learning, understanding your options on Apple Silicon is crucial for
maximizing performance and efficiency.
"""
summary = summarizer.summarize(text)
print("\nSummary:")
print(summary)
2. Image Captioning with PyTorch
Create the file scripts/caption.py:
#!/usr/bin/env python3
"""
Image captioning using PyTorch with Metal acceleration.
This script provides an image captioning class that uses PyTorch with
Apple's Metal API to efficiently run image captioning models on Apple Silicon.
"""
import os
import torch
import torchvision.transforms as transforms
from PIL import Image
class ImageCaptioner:
"""
Image captioner using PyTorch with Metal acceleration.
This class handles the loading of a pre-trained image captioning model
and uses Apple's Metal API for GPU acceleration when available.
"""
def __init__(self):
"""
Initialize the captioner with a pre-trained model.
"""
# Check if Metal is available
if torch.backends.mps.is_available():
self.device = torch.device("mps")
print("Using Metal Performance Shaders (MPS)")
else:
self.device = torch.device("cpu")
print("Metal not available, using CPU")
# Load pre-trained model
print("Loading image captioning model...")
self.model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True).to(self.device)
self.model.eval()
# Set up image transformation
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def generate_caption(self, image_path):
"""
Generate a caption for the image.
Args:
image_path (str): Path to the image file
Returns:
str: Generated caption or error message
"""
# Load and transform image
try:
image = Image.open(image_path).convert("RGB")
except Exception as e:
print(f"Error loading image: {e}")
return "Error loading image"
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Generate caption
with torch.no_grad():
output = self.model(input_tensor)
caption = self.model.caption_generator.decode(output[0])
return caption
# Example usage
if __name__ == "__main__":
captioner = ImageCaptioner()
# Example image
image_path = "data/sample.jpg"
if os.path.exists(image_path):
caption = captioner.generate_caption(image_path)
print("\nCaption:")
print(caption)
else:
print(f"Image not found at {image_path}")
print(f"Error loading image: {e}")
return "Error loading image"
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Generate caption
with torch.no_grad():
output = self.model(input_tensor)
caption = self.model.caption_generator.decode(output[0])
return caption
Example usage
if name == "main": captioner = ImageCaptioner()
# Example image
image_path = "data/sample.jpg"
if os.path.exists(image_path):
caption = captioner.generate_caption(image_path)
print("\nCaption:")
print(caption)
else:
print(f"Image not found at {image_path}")
### 3. Simple Web Server
Create the file `scripts/server.py`:
```python
#!/usr/bin/env python3
"""
Simple web server for AI with Mac demonstration.
"""
import os
import uuid
from flask import Flask, render_template, request, redirect, url_for
from summarize import Summarizer
from caption import ImageCaptioner
app = Flask(__name__, template_folder="../templates", static_folder="../static")
# Initialize models
summarizer = Summarizer()
captioner = ImageCaptioner()
# Create upload directories
os.makedirs("../uploads", exist_ok=True)
@app.route("/")
def index():
"""Render main page."""
return render_template("index.html")
@app.route("/summarize", methods=["POST"])
def summarize_text():
"""Summarize text input."""
text = request.form.get("text", "")
if not text:
return render_template("index.html", error="Please enter some text to summarize.")
summary = summarizer.summarize(text)
return render_template("result.html",
result_type="summary",
original=text,
result=summary,
framework="MLX")
@app.route("/caption", methods=["POST"])
def caption_image():
"""Caption uploaded image."""
if "image" not in request.files:
return render_template("index.html", error="No image uploaded.")
file = request.files["image"]
if file.filename == "":
return render_template("index.html", error="No image selected.")
# Save uploaded image
filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1]
filepath = os.path.join("../uploads", filename)
file.save(filepath)
# Generate caption
caption = captioner.generate_caption(filepath)
return render_template("result.html",
result_type="caption",
image_path="../uploads/" + filename,
result=caption,
framework="PyTorch with Metal")
if __name__ == "__main__":
app.run(debug=True)
4. HTML Templates
Create the file templates/index.html:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI with Mac Demonstration</title>
<link rel="stylesheet" href="../static/css/style.css">
</head>
<body>
<header>
<h1>AI with Mac Demonstration</h1>
<p>Example project combining MLX and PyTorch</p>
</header>
<main>
{% if error %}
<div class="error">{{ error }}</div>
{% endif %}
<div class="cards">
<div class="card">
<h2>Text Summarization</h2>
<p>Using MLX for efficient inference</p>
<form action="/summarize" method="post">
<textarea name="text" rows="10" placeholder="Enter text to summarize..."></textarea>
<button type="submit">Summarize</button>
</form>
</div>
<div class="card">
<h2>Image Captioning</h2>
<p>Using PyTorch with Metal acceleration</p>
<form action="/caption" method="post" enctype="multipart/form-data">
<div class="file-upload">
<input type="file" name="image" accept="image/*">
</div>
<button type="submit">Generate Caption</button>
</form>
</div>
</div>
</main>
<footer>
<p>Part of the "AI with Mac" series</p>
</footer>
</body>
</html>
Create the file templates/result.html:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI with Mac Demonstration - Result</title>
<link rel="stylesheet" href="../static/css/style.css">
</head>
<body>
<header>
<h1>AI with Mac Demonstration</h1>
<p>Example project combining MLX and PyTorch</p>
</header>
<main>
<div class="result-container">
<h2>{{ result_type|capitalize }} Result (using {{ framework }})</h2>
{% if result_type == "summary" %}
<div class="result-pair">
<div class="original">
<h3>Original Text</h3>
<div class="content">{{ original }}</div>
</div>
<div class="result">
<h3>Summary</h3>
<div class="content">{{ result }}</div>
</div>
</div>
{% elif result_type == "caption" %}
<div class="result-image">
<img src="{{ image_path }}" alt="Uploaded image">
<div class="caption">{{ result }}</div>
</div>
{% endif %}
</div>
<div class="back-link">
<a href="/">⬅ Back to demo</a>
</div>
</main>
<footer>
<p>Part of the "AI with Mac" series</p>
</footer>
</body>
</html>
5. CSS Styling
Create the file static/css/style.css:
/* Basic styling for the AI with Mac demo */
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
line-height: 1.6;
color: #333;
max-width: 1200px;
margin: 0 auto;
padding: 20px;
background-color: #f5f5f7;
}
header {
text-align: center;
margin-bottom: 40px;
}
h1 {
color: #000;
}
.cards {
display: flex;
gap: 20px;
flex-wrap: wrap;
}
.card {
flex: 1;
min-width: 300px;
background-color: white;
border-radius: 12px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
padding: 20px;
margin-bottom: 20px;
}
textarea {
width: 100%;
padding: 10px;
border: 1px solid #ddd;
border-radius: 6px;
font-family: inherit;
margin-bottom: 15px;
}
.file-upload {
border: 2px dashed #ddd;
padding: 20px;
text-align: center;
margin-bottom: 15px;
border-radius: 6px;
}
button {
background-color: #0071e3;
color: white;
border: none;
padding: 10px 20px;
border-radius: 6px;
cursor: pointer;
font-weight: 500;
}
button:hover {
background-color: #0062c2;
}
.error {
background-color: #ffebee;
color: #c62828;
padding: 10px;
border-radius: 6px;
margin-bottom: 20px;
}
.result-container {
background-color: white;
border-radius: 12px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
padding: 20px;
margin-bottom: 30px;
}
.result-pair {
display: flex;
gap: 20px;
flex-wrap: wrap;
}
.original, .result {
flex: 1;
min-width: 300px;
}
.content {
background-color: #f9f9f9;
padding: 15px;
border-radius: 6px;
white-space: pre-wrap;
}
.result-image {
text-align: center;
}
.result-image img {
max-width: 100%;
max-height: 500px;
border-radius: 8px;
margin-bottom: 15px;
}
.caption {
font-size: 1.2em;
font-weight: 500;
margin-top: 10px;
}
.back-link {
margin-top: 20px;
}
.back-link a {
color: #0071e3;
text-decoration: none;
}
footer {
text-align: center;
margin-top: 40px;
padding-top: 20px;
border-top: 1px solid #ddd;
color: #888;
}
6. Requirements
Create the file requirements.txt:
flask
mlx
mlx-lm
torch
torchvision
pillow
Making Informed Decisions
After exploring different approaches and building practical applications, here are some guidelines to help you make informed decisions about which framework to use for your ML projects on Apple Silicon:
Decision Checklist
Ask yourself these questions when starting a new ML project:
- What is my primary task?
-
- LLM inference → MLX
- Computer vision with pre-trained models → PyTorch
- Custom research models → Either (depending on other factors)
- What are my hardware constraints?
-
- Limited RAM → MLX (better memory efficiency)
- Older Apple Silicon → Depends on specific task
- Newest Apple Silicon → Either (both perform well)
- What is my deployment target?
-
- Apple Silicon only → MLX
- Cross-platform → PyTorch
- Mix of platforms → Consider a hybrid approach
- What is my timeline?
-
- Quick prototype → Use the framework you're most familiar with
- Long-term project → Worth investing time in learning MLX if appropriate
- What is my team's expertise?
-
- Strong PyTorch background → Start with PyTorch, explore MLX gradually
- New to both → MLX has a simpler API for certain tasks
A Balanced Approach
For many projects, a balanced approach works best:
- Evaluate task-specific performance: Run benchmarks for your specific task
- Consider implementation effort: Weigh development time vs. runtime performance
- Think about future maintenance: Consider documentation and community support
- Start small: Begin with a proof of concept in both frameworks if feasible
- Be flexible: Be willing to switch frameworks if needs change
Conclusion
Throughout this series, we've explored the exciting possibilities that Apple Silicon brings to machine learning on Mac. We've set up development environments, run large language models locally, compared frameworks, and built practical applications.
The key takeaways from our journey:
- Apple Silicon has democratized AI: From MacBooks to the powerful Mac Studio and Mac Pro, sophisticated models can now run on Apple hardware without cloud dependencies
- High-end Mac configurations enable professional workloads: Mac Studio and Mac Pro with up to 512GB RAM can run workloads previously requiring specialized servers
- MLX and PyTorch offer different advantages: Each framework has strengths for different use cases
- Scale your approach to your hardware: Choose models and quantization based on available memory
- The right approach depends on your specific needs: Consider your task, hardware, and requirements
- Practical applications are now possible: From language models to computer vision, Apple Silicon supports diverse AI workloads
- This field is rapidly evolving: Both frameworks continue to improve and add capabilities
As you embark on your own AI projects on Apple Silicon, remember that there's no one-size-fits-all solution. The best approach is to understand your specific requirements, experiment with different options, and choose the tools that best fit your needs.
For those with high-end Mac Studio or Mac Pro systems, you have the additional advantage of being able to run the largest open source models (70B+ parameters) or train custom models - capabilities previously limited to specialized infrastructure.
We hope this series has provided a solid foundation for your machine learning journey on Apple Silicon. The code examples and practical applications should serve as useful starting points for your own projects, whether you're using a MacBook Air or a maxed-out Mac Pro.
Happy coding, and enjoy exploring the possibilities of AI on your Mac!
This concludes our "AI with Mac" series. All code examples from this series are available in our GitHub repository.