Model and Dataset Versioning Strategy

When working with machine learning models, especially large language models, proper versioning is crucial for reproducibility and tracking changes over time. Here's how to implement an effective versioning strategy:

Setting Up a Model Registry

Create a structured model registry to track all your models and their versions:

#!/usr/bin/env python3

"""

Model registry system for AI projects.

This script provides utilities for versioning and tracking models

and datasets across your AI projects.

"""

import os

import json

import shutil

import hashlib

import datetime

import argparse

from typing import Dict, List, Optional, Any

class ModelRegistry:

"""

Model registry for versioning and tracking ML models.

This class provides functionality to register, version, and

track machine learning models and their associated metadata.

"""

def __init__(self, registry_path: str = "model_registry"):

"""

Initialize the model registry.

Args:

registry_path (str): Path to the registry directory

"""

self.registry_path = registry_path

self.index_file = os.path.join(registry_path, "registry_index.json")

# Create registry directory if it doesn't exist

os.makedirs(registry_path, exist_ok=True)

# Initialize or load the registry index

if os.path.exists(self.index_file):

with open(self.index_file, 'r') as f:

self.registry_index = json.load(f)

else:

self.registry_index = {

"models": {},

"datasets": {},

"last_updated": datetime.datetime.now().isoformat()

}

self._save_index()

def _save_index(self):

"""Save the registry index to disk."""

self.registry_index["last_updated"] = datetime.datetime.now().isoformat()

with open(self.index_file, 'w') as f:

json.dump(self.registry_index, f, indent=2)

def register_model(self, model_name: str, model_path: str,

version: str = None, metadata: Dict = None) -> str:

"""

Register a model in the registry.

Args:

model_name (str): Name of the model

model_path (str): Path to the model directory or file

version (str, optional): Version string (auto-generated if None)

metadata (dict, optional): Additional metadata

Returns:

str: The model version

"""

if not os.path.exists(model_path):

raise FileNotFoundError(f"Model path does not exist: {model_path}")

# Generate version if not provided

if version is None:

version = self._generate_version()

# Initialize model entry if it doesn't exist

if model_name not in self.registry_index["models"]:

self.registry_index["models"][model_name] = {

"versions": {},

"latest_version": None,

"created": datetime.datetime.now().isoformat()

}

# Create model directory in registry

model_dir = os.path.join(self.registry_path, "models", model_name, version)

os.makedirs(model_dir, exist_ok=True)

# Copy model files

if os.path.isdir(model_path):

# Copy directory contents

for item in os.listdir(model_path):

s = os.path.join(model_path, item)

d = os.path.join(model_dir, item)

if os.path.isdir(s):

shutil.copytree(s, d, dirs_exist_ok=True)

else:

shutil.copy2(s, d)

else:

# Copy single file

shutil.copy2(model_path, model_dir)

# Prepare metadata

if metadata is None:

metadata = {}

# Add standard metadata

metadata.update({

"registered_at": datetime.datetime.now().isoformat(),

"original_path": model_path,

"registry_path": model_dir

})

# Calculate model size and file hash

if os.path.isdir(model_path):

total_size = sum(os.path.getsize(os.path.join(dirpath, filename))

for dirpath, _, filenames in os.walk(model_path)

for filename in filenames)

metadata["model_size_bytes"] = total_size

else:

metadata["model_size_bytes"] = os.path.getsize(model_path)

metadata["file_hash"] = self._calculate_file_hash(model_path)

# Update registry index

self.registry_index["models"][model_name]["versions"][version] = metadata

self.registry_index["models"][model_name]["latest_version"] = version

self._save_index()

print(f"Model {model_name} version {version} registered successfully.")

return version

def register_dataset(self, dataset_name: str, dataset_path: str,

version: str = None, metadata: Dict = None) -> str:

"""

Register a dataset in the registry.

Args:

dataset_name (str): Name of the dataset

dataset_path (str): Path to the dataset directory or file

version (str, optional): Version string (auto-generated if None)

metadata (dict, optional): Additional metadata

Returns:

str: The dataset version

"""

if not os.path.exists(dataset_path):

raise FileNotFoundError(f"Dataset path does not exist: {dataset_path}")

# Generate version if not provided

if version is None:

version = self._generate_version()

# Initialize dataset entry if it doesn't exist

if dataset_name not in self.registry_index["datasets"]:

self.registry_index["datasets"][dataset_name] = {

"versions": {},

"latest_version": None,

"created": datetime.datetime.now().isoformat()

}

# Create dataset directory in registry

dataset_dir = os.path.join(self.registry_path, "datasets", dataset_name, version)

os.makedirs(dataset_dir, exist_ok=True)

# Copy dataset files (or reference)

is_large_dataset = False

if os.path.isdir(dataset_path):

# Check dataset size

total_size = sum(os.path.getsize(os.path.join(dirpath, filename))

for dirpath, _, filenames in os.walk(dataset_path)

for filename in filenames)

# If dataset is larger than 1GB, just store reference

if total_size > 1_000_000_000: # 1GB

is_large_dataset = True

with open(os.path.join(dataset_dir, "dataset_reference.txt"), 'w') as f:

f.write(f"Original dataset path: {os.path.abspath(dataset_path)}\n")

f.write(f"Dataset size: {total_size / (1024**2):.2f} MB\n")

else:

# Copy directory contents

for item in os.listdir(dataset_path):

s = os.path.join(dataset_path, item)

d = os.path.join(dataset_dir, item)

if os.path.isdir(s):

shutil.copytree(s, d, dirs_exist_ok=True)

else:

shutil.copy2(s, d)

else:

# Copy single file

shutil.copy2(dataset_path, dataset_dir)

# Prepare metadata

if metadata is None:

metadata = {}

# Add standard metadata

metadata.update({

"registered_at": datetime.datetime.now().isoformat(),

"original_path": dataset_path,

"registry_path": dataset_dir,

"is_reference_only": is_large_dataset

})

# Calculate dataset statistics

if not is_large_dataset:

if os.path.isdir(dataset_path):

metadata["dataset_size_bytes"] = total_size

else:

metadata["dataset_size_bytes"] = os.path.getsize(dataset_path)

metadata["file_hash"] = self._calculate_file_hash(dataset_path)

# Update registry index

self.registry_index["datasets"][dataset_name]["versions"][version] = metadata

self.registry_index["datasets"][dataset_name]["latest_version"] = version

self._save_index()

print(f"Dataset {dataset_name} version {version} registered successfully.")

return version

def get_model(self, model_name: str, version: str = "latest") -> Optional[Dict]:

"""

Get information about a registered model.

Args:

model_name (str): Name of the model

version (str): Model version or "latest"

Returns:

dict or None: Model information

"""

if model_name not in self.registry_index["models"]:

print(f"Model {model_name} not found in registry.")

return None

if version == "latest":

version = self.registry_index["models"][model_name]["latest_version"]

if version not in self.registry_index["models"][model_name]["versions"]:

print(f"Version {version} of model {model_name} not found.")

return None

model_info = self.registry_index["models"][model_name]["versions"][version].copy()

model_info["name"] = model_name

model_info["version"] = version

return model_info

def get_dataset(self, dataset_name: str, version: str = "latest") -> Optional[Dict]:

"""

Get information about a registered dataset.

Args:

dataset_name (str): Name of the dataset

version (str): Dataset version or "latest"

Returns:

dict or None: Dataset information

"""

if dataset_name not in self.registry_index["datasets"]:

print(f"Dataset {dataset_name} not found in registry.")

return None

if version == "latest":

version = self.registry_index["datasets"][dataset_name]["latest_version"]

if version not in self.registry_index["datasets"][dataset_name]["versions"]:

print(f"Version {version} of dataset {dataset_name} not found.")

return None

dataset_info = self.registry_index["datasets"][dataset_name]["versions"][version].copy()

dataset_info["name"] = dataset_name

dataset_info["version"] = version

return dataset_info

def list_models(self) -> List[Dict]:

"""

List all registered models.

Returns:

list: List of model information dictionaries

"""

models = []

for model_name, model_data in self.registry_index["models"].items():

latest_version = model_data["latest_version"]

if latest_version:

model_info = self.get_model(model_name, latest_version)

models.append(model_info)

return models

def list_datasets(self) -> List[Dict]:

"""

List all registered datasets.

Returns:

list: List of dataset information dictionaries

"""

datasets = []

for dataset_name, dataset_data in self.registry_index["datasets"].items():

latest_version = dataset_data["latest_version"]

if latest_version:

dataset_info = self.get_dataset(dataset_name, latest_version)

datasets.append(dataset_info)

return datasets

def _generate_version(self) -> str:

"""

Generate a version string based on date and time.

Returns:

str: Version string (format: v{year}.{month}.{day}.{hour}{minute})

"""

now = datetime.datetime.now()

return f"v{now.year}.{now.month:02d}.{now.day:02d}.{now.hour:02d}{now.minute:02d}"

def _calculate_file_hash(self, file_path: str) -> str:

"""

Calculate SHA-256 hash of a file.

Args:

file_path (str): Path to the file

Returns:

str: SHA-256 hash hexadecimal string

"""

sha256_hash = hashlib.sha256()

with open(file_path, "rb") as f:

# Read the file in chunks to handle large files

for chunk in iter(lambda: f.read(4096), b""):

sha256_hash.update(chunk)

return sha256_hash.hexdigest()

def main():

"""Main function for the model registry utility."""

parser = argparse.ArgumentParser(description="Model and dataset versioning utility")

subparsers = parser.add_subparsers(dest="command", help="Command to execute")

# Register model command

register_model_parser = subparsers.add_parser("register-model", help="Register a model")

register_model_parser.add_argument("--name", required=True, help="Model name")

register_model_parser.add_argument("--path", required=True, help="Path to model directory or file")

register_model_parser.add_argument("--version", help="Version string (optional)")

register_model_parser.add_argument("--metadata", help="JSON metadata string (optional)")

# Register dataset command

register_dataset_parser = subparsers.add_parser("register-dataset", help="Register a dataset")

register_dataset_parser.add_argument("--name", required=True, help="Dataset name")

register_dataset_parser.add_argument("--path", required=True, help="Path to dataset directory or file")

register_dataset_parser.add_argument("--version", help="Version string (optional)")

register_dataset_parser.add_argument("--metadata", help="JSON metadata string (optional)")

# List models command

subparsers.add_parser("list-models", help="List all registered models")

# List datasets command

subparsers.add_parser("list-datasets", help="List all registered datasets")

# Get model command

get_model_parser = subparsers.add_parser("get-model", help="Get model information")

get_model_parser.add_argument("--name", required=True, help="Model name")

get_model_parser.add_argument("--version", default="latest", help="Model version (default: latest)")

# Get dataset command

get_dataset_parser = subparsers.add_parser("get-dataset", help="Get dataset information")

get_dataset_parser.add_argument("--name", required=True, help="Dataset name")

get_dataset_parser.add_argument("--version", default="latest", help="Dataset version (default: latest)")

# Parse arguments

args = parser.parse_args()

# Create registry

registry = ModelRegistry()

# Execute command

if args.command == "register-model":

metadata = json.loads(args.metadata) if args.metadata else None

registry.register_model(args.name, args.path, args.version, metadata)

elif args.command == "register-dataset":

metadata = json.loads(args.metadata) if args.metadata else None

registry.register_dataset(args.name, args.path, args.version, metadata)

elif args.command == "list-models":

models = registry.list_models()

print(json.dumps(models, indent=2))

elif args.command == "list-datasets":

datasets = registry.list_datasets()

print(json.dumps(datasets, indent=2))

elif args.command == "get-model":

model_info = registry.get_model(args.name, args.version)

if model_info:

print(json.dumps(model_info, indent=2))

elif args.command == "get-dataset":

dataset_info = registry.get_dataset(args.name, args.version)

if dataset_info:

print(json.dumps(dataset_info, indent=2))

else:

parser.print_help()

if __name__ == "__main__":

main()

Creating a Git-based Version Control System for Models

You can also use Git LFS (Large File Storage) to version your models directly in Git:

# Setup Git LFS repository

cat > setup_git_lfs.sh << 'EOF'

#!/bin/bash

# Script to set up Git LFS for model versioning

# Check if Git LFS is installed

if ! command -v git-lfs &> /dev/null; then

echo "Git LFS is not installed. Installing..."

brew install git-lfs

fi

# Initialize Git LFS

git lfs install

# Track model files

git lfs track "*.bin"

git lfs track "*.gguf"

git lfs track "*.pt"

git lfs track "*.pth"

git lfs track "*.onnx"

git lfs track "*.mlpackage"

git lfs track "models/**/*"

# Add .gitattributes

git add .gitattributes

git commit -m "Initialize Git LFS for model versioning"

# Create directory structure

mkdir -p models/checkpoints datasets/processed datasets/raw

# Create a README for the models directory

cat > models/README.md << 'README'

# Model Directory

This directory contains versioned ML models.

## Structure

- `checkpoints/`: Training checkpoints

- `production/`: Production-ready models

- `quantized/`: Quantized versions of models

## Versioning

Models are versioned using semantic versioning:

- Major version: Architectural changes

- Minor version: Training improvements

- Patch version: Bug fixes or small adjustments

Example: `chatbot-v1.2.3` represents:

- Version 1 architecture

- Training improvement iteration 2

- Bug fix or adjustment 3

## Usage

Each model directory contains:

- Model weights

- `config.json` with hyperparameters

- `metadata.json` with version info

- `metrics.json` with performance metrics

README

git add models/README.md

git commit -m "Add model directory structure and documentation"

echo "Git LFS setup complete for model versioning!"

EOF

chmod +x setup_git_lfs.sh

Semantic Versioning for Models

Create a utility to manage semantic versioning for your models:

#!/usr/bin/env python3

"""

Semantic versioning utility for ML models.

This script helps manage semantic versioning for machine learning models,

following the MAJOR.MINOR.PATCH convention with specific meaning for ML.

"""

import os

import re

import json

import argparse

from datetime import datetime

class MLVersioning:

"""

Semantic versioning for machine learning models.

MAJOR: Architecture changes

MINOR: Training improvements

PATCH: Bug fixes and minor adjustments

"""

@staticmethod

def parse_version(version_str):

"""Parse version string into components."""

match = re.match(r'v?(\d+)\.(\d+)\.(\d+)', version_str)

if not match:

raise ValueError(f"Invalid version format: {version_str}")

major, minor, patch = map(int, match.groups())

return major, minor, patch

@staticmethod

def format_version(major, minor, patch):

"""Format version components into a string."""

return f"v{major}.{minor}.{patch}"

@staticmethod

def increment_version(version_str, level="patch"):

"""

Increment version at specified level.

Args:

version_str (str): Version string

level (str): Level to increment (major, minor, patch)

Returns:

str: New version string

"""

major, minor, patch = MLVersioning.parse_version(version_str)

if level == "major":

return MLVersioning.format_version(major + 1, 0, 0)

elif level == "minor":

return MLVersioning.format_version(major, minor + 1, 0)

elif level == "patch":

return MLVersioning.format_version(major, minor, patch + 1)

else:

raise ValueError(f"Invalid increment level: {level}")

@staticmethod

def update_model_version(model_dir, level="patch", metadata=None):

"""

Update model version in metadata file.

Args:

model_dir (str): Model directory

level (str): Level to increment

metadata (dict, optional): Additional metadata

Returns:

str: New version string

"""

metadata_file = os.path.join(model_dir, "metadata.json")

# Load existing metadata or create new

if os.path.exists(metadata_file):

with open(metadata_file, 'r') as f:

model_metadata = json.load(f)

current_version = model_metadata.get("version", "v0.0.0")

new_version = MLVersioning.increment_version(current_version, level)

else:

# Start with v0.1.0 for new models

if level == "major":

new_version = "v1.0.0"

elif level == "minor":

new_version = "v0.1.0"

else:

new_version = "v0.0.1"

model_metadata = {}

# Update metadata

model_metadata["version"] = new_version

model_metadata["last_updated"] = datetime.now().isoformat()

# Add additional metadata

if metadata:

model_metadata.update(metadata)

# Save metadata

with open(metadata_file, 'w') as f:

json.dump(model_metadata, f, indent=2)

return new_version

@staticmethod

def tag_model_version(model_dir, version=None):

"""

Create a Git tag for a model version.

Args:

model_dir (str): Model directory

version (str, optional): Version to tag

Returns:

bool: True if successful

"""

try:

# Get model name from directory

model_name = os.path.basename(model_dir)

# Get version from metadata if not provided

if not version:

metadata_file = os.path.join(model_dir, "metadata.json")

if not os.path.exists(metadata_file):

print(f"Metadata file not found: {metadata_file}")

return False

with open(metadata_file, 'r') as f:

metadata = json.load(f)

version = metadata.get("version")

if not version:

print("Version not found in metadata")

return False

# Create a tag name

tag_name = f"model-{model_name}-{version}"

# Create git tag

commit_msg = f"Model {model_name} version {version}"

os.system(f'git tag -a "{tag_name}" -m "{commit_msg}"')

print(f"Created git tag: {tag_name}")

return True

except Exception as e:

print(f"Error creating git tag: {e}")

return False

def main():

"""Main function for the versioning utility."""

parser = argparse.ArgumentParser(description="Semantic versioning for ML models")

subparsers = parser.add_subparsers(dest="command", help="Command to execute")

# Increment version command

increment_parser = subparsers.add_parser("increment", help="Increment model version")

increment_parser.add_argument("--model-dir", required=True, help="Model directory")

increment_parser.add_argument("--level", choices=["major", "minor", "patch"],

default="patch", help="Version level to increment")

increment_parser.add_argument("--metadata", help="Additional metadata as JSON string")

# Tag version command

tag_parser = subparsers.add_parser("tag", help="Create git tag for model version")

tag_parser.add_argument("--model-dir", required=True, help="Model directory")

tag_parser.add_argument("--version", help="Version string (read from metadata if not provided)")

# Parse arguments

args = parser.parse_args()

# Execute command

if args.command == "increment":

metadata = json.loads(args.metadata) if args.metadata else None

new_version = MLVersioning.update_model_version(args.model_dir, args.level, metadata)

print(f"Updated model version to {new_version}")

elif args.command == "tag":

MLVersioning.tag_model_version(args.model_dir, args.version)

else:

parser.print_help()

if __name__ == "__main__":

main()

Tracking Experiments and Results

Create a utility to track your model training experiments and results:

#!/usr/bin/env python3

"""

Experiment tracking utility for ML models.

This script provides a simple system for tracking ML experiments

without requiring external dependencies or cloud services.

"""

import os

import json

import argparse

import uuid

import platform

import psutil

from datetime import datetime

class ExperimentTracker:

"""Simple experiment tracker for machine learning models."""

def __init__(self, experiments_dir="experiments"):

"""

Initialize the experiment tracker.

Args:

experiments_dir (str): Directory to store experiment data

"""

self.experiments_dir = experiments_dir

os.makedirs(experiments_dir, exist_ok=True)

# Create runs subdirectory

self.runs_dir = os.path.join(experiments_dir, "runs")

os.makedirs(self.runs_dir, exist_ok=True)

# Load or create experiments index

self.index_file = os.path.join(experiments_dir, "experiments_index.json")

if os.path.exists(self.index_file):

with open(self.index_file, 'r') as f:

self.index = json.load(f)

else:

self.index = {

"experiments": {},

"last_updated": datetime.now().isoformat()

}

self._save_index()

def _save_index(self):

"""Save the experiments index to disk."""

self.index["last_updated"] = datetime.now().isoformat()

with open(self.index_file, 'w') as f:

json.dump(self.index, f, indent=2)

def create_experiment(self, name, description=None, metadata=None):

"""

Create a new experiment.

Args:

name (str): Experiment name

description (str, optional): Experiment description

metadata (dict, optional): Additional metadata

Returns:

str: Experiment ID

"""

experiment_id = str(uuid.uuid4())

# Create experiment entry

experiment = {

"id": experiment_id,

"name": name,

"description": description or "",

"created_at": datetime.now().isoformat(),

"runs": [],

"best_run": None,

"metadata": metadata or {}

}

# Save to index

self.index["experiments"][experiment_id] = experiment

self._save_index()

# Create experiment directory

experiment_dir = os.path.join(self.experiments_dir, experiment_id)

os.makedirs(experiment_dir, exist_ok=True)

# Save experiment info

with open(os.path.join(experiment_dir, "experiment.json"), 'w') as f:

json.dump(experiment, f, indent=2)

print(f"Created experiment: {name} (ID: {experiment_id})")

return experiment_id

def start_run(self, experiment_id, run_name=None, metadata=None):

"""

Start a new run for an experiment.

Args:

experiment_id (str): Experiment ID

run_name (str, optional): Run name

metadata (dict, optional): Additional metadata

Returns:

str: Run ID

"""

if experiment_id not in self.index["experiments"]:

raise ValueError(f"Experiment not found: {experiment_id}")

# Generate run ID and name

run_id = str(uuid.uuid4())

if run_name is None:

experiment_name = self.index["experiments"][experiment_id]["name"]

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

run_name = f"{experiment_name}_{timestamp}"

# System info

system_info = {

"platform": platform.platform(),

"python_version": platform.python_version(),

"processor": platform.processor(),

"memory_gb": psutil.virtual_memory().total / (1024**3)

}

# Create run entry

run = {

"id": run_id,

"experiment_id": experiment_id,

"name": run_name,

"status": "running",

"started_at": datetime.now().isoformat(),

"completed_at": None,

"duration_seconds": None,

"params": {},

"metrics": {},

"artifacts": [],

"system_info": system_info,

"metadata": metadata or {}

}

# Create run directory

run_dir = os.path.join(self.runs_dir, run_id)

os.makedirs(run_dir, exist_ok=True)

os.makedirs(os.path.join(run_dir, "artifacts"), exist_ok=True)

# Save run info

with open(os.path.join(run_dir, "run.json"), 'w') as f:

json.dump(run, f, indent=2)

# Update experiment

self.index["experiments"][experiment_id]["runs"].append(run_id)

self._save_index()

print(f"Started run: {run_name} (ID: {run_id})")

return run_id

def log_params(self, run_id, params):

"""

Log parameters for a run.

Args:

run_id (str): Run ID

params (dict): Parameters to log

"""

run_file = os.path.join(self.runs_dir, run_id, "run.json")

if not os.path.exists(run_file):

raise ValueError(f"Run not found: {run_id}")

# Load run data

with open(run_file, 'r') as f:

run = json.load(f)

# Update parameters

run["params"].update(params)

# Save run data

with open(run_file, 'w') as f:

json.dump(run, f, indent=2)

print(f"Logged {len(params)} parameters for run: {run_id}")

def log_metrics(self, run_id, metrics, step=None):

"""

Log metrics for a run.

Args:

run_id (str): Run ID

metrics (dict): Metrics to log

step (int, optional): Step number

"""

run_file = os.path.join(self.runs_dir, run_id, "run.json")

if not os.path.exists(run_file):

raise ValueError(f"Run not found: {run_id}")

# Load run data

with open(run_file, 'r') as f:

run = json.load(f)

# Create metrics file if needed

metrics_file = os.path.join(self.runs_dir, run_id, "metrics.json")

if os.path.exists(metrics_file):

with open(metrics_file, 'r') as f:

metrics_data = json.load(f)

else:

metrics_data = []

# Add step information

metrics_entry = {

"timestamp": datetime.now().isoformat(),

"metrics": metrics

}

if step is not None:

metrics_entry["step"] = step

metrics_data.append(metrics_entry)

# Save metrics data

with open(metrics_file, 'w') as f:

json.dump(metrics_data, f, indent=2)

# Update latest metrics in run

for metric_name, value in metrics.items():

run["metrics"][metric_name] = value

# Save run data

with open(run_file, 'w') as f:

json.dump(run, f, indent=2)

print(f"Logged {len(metrics)} metrics for run: {run_id}")

def log_artifact(self, run_id, artifact_path, name=None):

"""

Log an artifact for a run.

Args:

run_id (str): Run ID

artifact_path (str): Path to artifact file

name (str, optional): Artifact name

"""

run_file = os.path.join(self.runs_dir, run_id, "run.json")

if not os.path.exists(run_file):

raise ValueError(f"Run not found: {run_id}")

if not os.path.exists(artifact_path):

raise ValueError(f"Artifact path does not exist: {artifact_path}")

# Load run data

with open(run_file, 'r') as f:

run = json.load(f)

# Copy artifact

artifacts_dir = os.path.join(self.runs_dir, run_id, "artifacts")

artifact_name = name or os.path.basename(artifact_path)

artifact_dest = os.path.join(artifacts_dir, artifact_name)

if os.path.isdir(artifact_path):

shutil.copytree(artifact_path, artifact_dest, dirs_exist_ok=True)

else:

shutil.copy2(artifact_path, artifact_dest)

# Add to artifacts list

artifact_info = {

"name": artifact_name,

"path": artifact_dest,

"type": "directory" if os.path.isdir(artifact_path) else "file",

"added_at": datetime.now().isoformat()

}

run["artifacts"].append(artifact_info)

# Save run data

with open(run_file, 'w') as f:

json.dump(run, f, indent=2)

print(f"Logged artifact: {artifact_name} for run: {run_id}")

def complete_run(self, run_id, status="completed"):

"""

Mark a run as completed.

Args:

run_id (str): Run ID

status (str): Completion status (completed, failed, etc.)

"""

run_file = os.path.join(self.runs_dir, run_id, "run.json")

if not os.path.exists(run_file):

raise ValueError(f"Run not found: {run_id}")

# Load run data

with open(run_file, 'r') as f:

run = json.load(f)

# Calculate duration

started_at = datetime.fromisoformat(run["started_at"])

completed_at = datetime.now()

duration_seconds = (completed_at - started_at).total_seconds()

# Update run data

run["status"] = status

run["completed_at"] = completed_at.isoformat()

run["duration_seconds"] = duration_seconds

# Save run data

with open(run_file, 'w') as f:

json.dump(run, f, indent=2)

# Check if this is the best run for the experiment

experiment_id = run["experiment_id"]

if experiment_id in self.index["experiments"]:

experiment = self.index["experiments"][experiment_id]

# Update best run if this one has better metrics

# (Customize this logic based on your primary metric)

if "accuracy" in run["metrics"]:

current_best = None

if experiment["best_run"]:

best_run_file = os.path.join(self.runs_dir, experiment["best_run"], "run.json")

if os.path.exists(best_run_file):

with open(best_run_file, 'r') as f:

best_run = json.load(f)

if "accuracy" in best_run["metrics"]:

current_best = best_run["metrics"]["accuracy"]

if current_best is None or run["metrics"]["accuracy"] > current_best:

experiment["best_run"] = run_id

self._save_index()

print(f"Completed run: {run_id} with status: {status} (duration: {duration_seconds:.2f}s)")

def main():

"""Main function for the experiment tracking utility."""

parser = argparse.ArgumentParser(description="Track ML experiments and results")

subparsers = parser.add_subparsers(dest="command", help="Command to execute")

# Create experiment command

create_exp_parser = subparsers.add_parser("create-experiment", help="Create a new experiment")

create_exp_parser.add_argument("--name", required=True, help="Experiment name")

create_exp_parser.add_argument("--description", help="Experiment description")

create_exp_parser.add_argument("--metadata", help="Additional metadata as JSON string")

# Start run command

start_run_parser = subparsers.add_parser("start-run", help="Start a new run")

start_run_parser.add_argument("--experiment", required=True, help="Experiment ID")

start_run_parser.add_argument("--name", help="Run name")

start_run_parser.add_argument("--metadata", help="Additional metadata as JSON string")

# Log params command

log_params_parser = subparsers.add_parser("log-params", help="Log parameters")

log_params_parser.add_argument("--run", required=True, help="Run ID")

log_params_parser.add_argument("--params", required=True, help="Parameters as JSON string")

# Log metrics command

log_metrics_parser = subparsers.add_parser("log-metrics", help="Log metrics")

log_metrics_parser.add_argument("--run", required=True, help="Run ID")

log_metrics_parser.add_argument("--metrics", required=True, help="Metrics as JSON string")

log_metrics_parser.add_argument("--step", type=int, help="Step number")

# Log artifact command

log_artifact_parser = subparsers.add_parser("log-artifact", help="Log an artifact")

log_artifact_parser.add_argument("--run", required=True, help="Run ID")

log_artifact_parser.add_argument("--path", required=True, help="Path to artifact file")

log_artifact_parser.add_argument("--name", help="Artifact name")

# Complete run command

complete_run_parser = subparsers.add_parser("complete-run", help="Complete a run")

complete_run_parser.add_argument("--run", required=True, help="Run ID")

complete_run_parser.add_argument("--status", default="completed",

choices=["completed", "failed", "aborted"],

help="Completion status")

# Parse arguments

args = parser.parse_args()

# Initialize tracker

tracker = ExperimentTracker()

# Execute command

if args.command == "create-experiment":

metadata = json.loads(args.metadata) if args.metadata else None

tracker.create_experiment(args.name, args.description, metadata)

elif args.command == "start-run":

metadata = json.loads(args.metadata) if args.metadata else None

tracker.start_run(args.experiment, args.name, metadata)

elif args.command == "log-params":

params = json.loads(args.params)

tracker.log_params(args.run, params)

elif args.command == "log-metrics":

metrics = json.loads(args.metrics)

tracker.log_metrics(args.run, metrics, args.step)

elif args.command == "log-artifact":

tracker.log_artifact(args.run, args.path, args.name)

elif args.command == "complete-run":

tracker.complete_run(args.run, args.status)

else:

parser.print_help()

if __name__ == "__main__":

main()

By implementing these versioning strategies, you'll ensure that your models and datasets are properly tracked and versioned, making your ML projects more reproducible and maintainable over time. This approach is particularly valuable when working with large models on high-end Mac Studio and Mac Pro systems where you might be managing multiple versions of several substantial models.

The key benefits of this versioning strategy include:

  1. Reproducibility: You can recreate any model version exactly as it was
  2. Traceability: Every model change is documented and tracked
  3. Collaboration: Multiple team members can work on models with clear version history
  4. Rollback Capability: You can easily revert to previous model versions if needed
  5. Systematic Experimentation: Each training run is tracked with its parameters and results

These tools form a solid foundation for professional ML development on your Apple Silicon Mac.```python #!/usr/bin/env python3 """ Simple web server for AI with Mac demonstration.

This script creates a Flask web application that showcases both MLX (for text summarization) and PyTorch (for image captioning) in a single interface, demonstrating how to combine different frameworks based on their strengths. """ import os import uuid from flask import Flask, render_template, request, redirect, url_for from summarize import Summarizer from caption import ImageCaptioner

app = Flask(name, template_folder="../templates", static_folder="../static")

Initialize models

summarizer = Summarizer() captioner = ImageCaptioner()

Create upload directories

os.makedirs("../uploads", exist_ok=True)

@app.route("/") def index(): """ Render main page.

Returns:

rendered template: The index.html template

"""

return render_template("index.html")

@app.route("/summarize", methods=["POST"]) def summarize_text(): """ Summarize text input using MLX.

Processes text submitted via form and generates a summary using

the MLX-based summarizer.

Returns:

rendered template: The result.html template with summary results

"""

text = request.form.get("text", "")

if not text:

return render_template("index.html", error="Please enter some text to summarize.")

summary = summarizer.summarize(text)

return render_template("result.html",

result_type="summary",

original=text,

result=summary,

framework="MLX")

@app.route("/caption", methods=["POST"]) def caption_image(): """ Caption uploaded image using PyTorch with Metal.

Processes an uploaded image and generates a descriptive caption

using the PyTorch-based image captioner.

Returns:

rendered template: The result.html template with captioning results

"""

if "image" not in request.files:

return render_template("index.html", error="No image uploaded.")

file = request.files["image"]

if file.filename == "":

return render_template("index.html", error="No image selected.")

# Save uploaded image

filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1]

filepath = os.path.join("../uploads", filename)

file.save(filepath)

# Generate caption

caption = captioner.generate_caption(filepath)

return render_template("result.html",

result_type="caption",

image_path="../uploads/" + filename,

result=caption,

framework="PyTorch with Metal")

if name == "main": app.run(debug=True)

![Decision Tree](https://via.placeholder.com/1200x400)

*This is the final part of our "AI with Mac" series. Check out the previous installments: [Part 1: Introduction](link-to-part1), [Part 2: Python and Git Setup](link-to-part2), [Part 3: Running LLMs](link-to-part3), and [Part 4: Comparing MLX and PyTorch](link-to-part4).*

## Introduction

Throughout this series, we've explored how Apple Silicon has revolutionized machine learning on Mac. We've set up our development environment, run large language models, and compared Apple's MLX framework with the more established PyTorch. Now, in our final installment, we'll bring everything together to help you make an informed decision about which approach to use for your specific ML needs.

![AI on Mac Decision Guide](https://via.placeholder.com/800x500?text=AI+on+Mac+Decision+Guide)

*Figure 1: Comprehensive decision tree mapping different AI workloads, Mac hardware configurations, and use cases to the optimal framework and approach, providing a visual reference for the entire article.*

Rather than giving a one-size-fits-all answer, we'll explore different use cases and scenarios to help you choose the right tool for the job. We'll also build a practical application that combines our knowledge from previous posts, giving you a template for your own machine learning projects on Apple Silicon.

## Decision Framework: Choosing Your ML Path

Let's start by establishing a framework for deciding which approach to use based on your specific requirements:

![Hardware-Based Decision Tree](https://via.placeholder.com/800x600?text=Hardware-Based+Decision+Tree)

*Figure 2: Flowchart showing decision paths based on Mac hardware type, ranging from MacBook Air to Mac Pro, with recommended frameworks, model sizes, and quantization strategies for each hardware tier.*

### Key Decision Factors

1. **Application Type**: What kind of ML task are you tackling?

2. **Hardware Capabilities**: What Mac hardware do you have? (MacBook, iMac, Mac Studio, Mac Pro)

3. **Memory Availability**: How much unified memory is available? (8GB to 512GB)

4. **Production Requirements**: Local-only or eventual deployment to other platforms?

5. **Development Timeline**: Quick prototype or long-term project?

6. **Team Experience**: Familiarity with specific frameworks?

7. **Performance Requirements**: Speed, memory, power constraints?

8. **Scale of Models**: Are you working with small (2-7B), medium (7-30B), or large (70B+) models?

### Decision Tree for Different Mac Configurations

What Mac hardware are you using? ├── MacBook Air/Pro (8-16GB RAM): │ ├── Need maximum efficiency and small models │ │ ├── Yes: MLX with 4-bit quantization │ │ └── No: Either framework (consider battery life) │ ├── MacBook Pro/iMac (32-64GB RAM): │ ├── Running medium-sized models (13-30B params) │ │ ├── Inference focus: MLX │ │ └── Training focus: PyTorch │ └── Mac Studio/Mac Pro (128-512GB RAM): ├── Running largest models (70B+) │ ├── Single model inference: MLX │ ├── Multiple concurrent models: MLX │ └── Fine-tuning or training: PyTorch │ ├── Research or production environment? │ ├── Research: Consider both frameworks │ └── Production: Depends on deployment target

The high-end Mac Studio and Mac Pro configurations with M2/M3/M4 Ultra chips (which effectively combine two M-series chips with up to 512GB of unified memory) can handle workloads previously requiring specialized server hardware or expensive GPUs, making them excellent choices for serious AI development.

Let's explore each branch of this decision tree in more detail.

## Use Case 1: Large Language Models (LLMs)

If you're primarily working with LLMs, your decision largely depends on your deployment requirements and performance needs.

### When to Choose MLX for LLMs

✅ **Recommended when**:

- You need maximum performance on Apple Silicon

- Local inference is your primary goal

- Privacy and offline operation are critical

- You're working with quantized models (4-bit, 8-bit)

- Battery efficiency matters (for MacBooks)

- You're using high-end Mac Studio/Pro with large unified memory (128-512GB)

⚠️ **Considerations**:

- Limited to Apple Silicon devices

- Smaller ecosystem of pre-built tools

- Fewer reference implementations

**Mac Studio/Pro Advantage**: The M2/M3/M4 Ultra chips in Mac Studio and Mac Pro combine two M-series chips with a massive unified memory pool (up to 512GB in top configurations), enabling you to run 70B+ parameter models at full precision or multiple large models simultaneously - something previously only possible with specialized server hardware.

### When to Choose PyTorch for LLMs

✅ **Recommended when**:

- Cross-platform compatibility is essential

- You need to use specific models not yet supported in MLX

- You're integrating with existing PyTorch pipelines

- Team familiarity with PyTorch is high

- You need advanced functionality beyond basic inference

⚠️ **Considerations**:

- Generally slower LLM inference compared to MLX

- Higher memory usage

- More complex setup for optimal performance

### LLM Application Example: Enhanced Document Q&A

Let's enhance our document Q&A system from Part 3 with more advanced features. This implementation using MLX demonstrates a practical application:

```python

#!/usr/bin/env python3

"""

Enhanced document Q&A system with vector search using MLX.

"""

import os

import re

import sys

import argparse

import numpy as np

import pypdf

from mlx_lm import generate, load

from mlx.core import array

class EnhancedDocumentQA:

"""Enhanced document Q&A system with vector search."""

def __init__(self, model_path):

"""Initialize the system."""

print(f"Loading model from {model_path}, please wait...")

self.model, self.tokenizer = load(model_path)

print("Model loaded successfully!")

self.document_chunks = []

self.chunk_embeddings = []

def load_document(self, pdf_path):

"""Load and process a document."""

print(f"Reading document: {pdf_path}")

document_text = self._extract_text_from_pdf(pdf_path)

if not document_text:

print("Failed to extract text from the document.")

return False

# Split into semantically meaningful chunks

self.document_chunks = self._split_into_semantic_chunks(document_text)

print(f"Document split into {len(self.document_chunks)} chunks")

# Create embeddings for each chunk

self._create_embeddings()

return True

def answer_question(self, question):

"""Answer a question about the loaded document."""

if not self.document_chunks:

print("No document loaded. Please load a document first.")

return None

# Create embedding for the question

question_embedding = self._embed_text(question)

# Find most relevant chunks

relevant_chunks = self._find_relevant_chunks(question_embedding, top_k=3)

# Combine relevant chunks into context

context = "\n\n".join([self.document_chunks[idx] for idx in relevant_chunks])

# Create prompt for the model

prompt = f"""Answer the question based ONLY on the following context:

Context:

{context}

Question: {question}

Answer:"""

# Generate answer

gen_config = {

"max_tokens": 500,

"temperature": 0.2,

"top_p": 0.9

}

tokens = self.tokenizer.encode(prompt)

generated_tokens = generate(self.model, self.tokenizer, tokens, gen_config)

answer = self.tokenizer.decode(generated_tokens[len(tokens):])

return answer.strip()

def _extract_text_from_pdf(self, pdf_path):

"""Extract text from a PDF file."""

try:

text = ""

with open(pdf_path, "rb") as file:

reader = pypdf.PdfReader(file)

for page in reader.pages:

text += page.extract_text() + "\n"

return text

except Exception as e:

print(f"Error extracting text from PDF: {e}")

return None

def _split_into_semantic_chunks(self, text, max_chunk_size=1500):

"""Split text into semantically meaningful chunks."""

# Split by section headers or paragraphs

sections = re.split(r'(?=\n\s*[A-Z][^a-z]*\n)|(?=\n\n)', text)

chunks = []

current_chunk = ""

for section in sections:

# Clean the section

section = section.strip()

if not section:

continue

# If adding this section exceeds max size, start a new chunk

if len(current_chunk) + len(section) > max_chunk_size:

if current_chunk:

chunks.append(current_chunk)

current_chunk = section

else:

if current_chunk:

current_chunk += "\n\n" + section

else:

current_chunk = section

# Add the last chunk

if current_chunk:

chunks.append(current_chunk)

return chunks

def _create_embeddings(self):

"""Create embeddings for all chunks."""

print("Creating embeddings for document chunks...")

self.chunk_embeddings = []

for i, chunk in enumerate(self.document_chunks):

embedding = self._embed_text(chunk)

self.chunk_embeddings.append(embedding)

if (i + 1) % 10 == 0:

print(f"Processed {i + 1}/{len(self.document_chunks)} chunks")

def _embed_text(self, text):

"""Create an embedding for a piece of text using the model."""

# For simplicity, we'll use a basic method:

# 1. Tokenize the text

# 2. Get the token IDs

# 3. Create a normalized frequency vector

tokens = self.tokenizer.encode(text)

# Create a frequency vector of token IDs

vocab_size = self.tokenizer.vocab_size

embedding = np.zeros(vocab_size)

unique_tokens, counts = np.unique(tokens, return_counts=True)

for token, count in zip(unique_tokens, counts):

if token < vocab_size:

embedding[token] = count

# Normalize the embedding

norm = np.linalg.norm(embedding)

if norm > 0:

embedding = embedding / norm

return embedding

def _find_relevant_chunks(self, query_embedding, top_k=3):

"""Find the most relevant chunks for a query embedding."""

# Calculate cosine similarity

similarities = []

for chunk_embedding in self.chunk_embeddings:

similarity = np.dot(query_embedding, chunk_embedding)

similarities.append(similarity)

# Get indices of top-k chunks

return np.argsort(similarities)[-top_k:][::-1]

def main():

"""Main function."""

parser = argparse.ArgumentParser(description="Enhanced document Q&A system with vector search")

parser.add_argument("--model", type=str, default="models/gemma-2b-it-4bit",

help="Path to the model directory")

parser.add_argument("--pdf", type=str, required=True,

help="Path to the PDF document")

args = parser.parse_args()

if not os.path.exists(args.pdf):

print(f"Error: PDF file not found at {args.pdf}")

sys.exit(1)

qa_system = EnhancedDocumentQA(args.model)

if not qa_system.load_document(args.pdf):

print("Failed to load document.")

sys.exit(1)

print("\nEnhanced Document Q&A System (type 'exit' to quit)")

while True:

question = input("\nYour question: ")

if question.lower() in ["exit", "quit"]:

break

print("\nSearching document and generating answer...")

answer = qa_system.answer_question(question)

print(f"\nAnswer: {answer}")

if __name__ == "__main__":

main()

This enhanced document Q&A system demonstrates:

Use Case 2: Computer Vision

For computer vision tasks, your decision depends on the specific application, model availability, and performance requirements.

When to Choose MLX for Computer Vision

Recommended when:

⚠️ Considerations:

When to Choose PyTorch for Computer Vision

Recommended when:

⚠️ Considerations:

Computer Vision Example: Image Classification App

Let's create a practical image classification application using PyTorch with Metal acceleration. This example demonstrates how to load an efficient pre-trained model and use it for real-time classification:

#!/usr/bin/env python3

"""

Real-time image classification using PyTorch with Metal acceleration.

"""

import os

import time

import argparse

import torch

import torchvision.transforms as transforms

import torchvision.models as models

from PIL import Image

class ImageClassifier:

"""Real-time image classifier using PyTorch with Metal acceleration."""

def __init__(self, model_name="mobilenet_v3_small"):

"""Initialize the classifier."""

# Check if Metal is available

if torch.backends.mps.is_available():

self.device = torch.device("mps")

print("Using Metal Performance Shaders (MPS)")

else:

self.device = torch.device("cpu")

print("Metal not available, using CPU")

# Load pre-trained model

print(f"Loading {model_name} model...")

if model_name == "mobilenet_v3_small":

self.model = models.mobilenet_v3_small(weights="DEFAULT")

elif model_name == "resnet18":

self.model = models.resnet18(weights="DEFAULT")

elif model_name == "efficientnet_b0":

self.model = models.efficientnet_b0(weights="DEFAULT")

else:

raise ValueError(f"Unsupported model: {model_name}")

self.model = self.model.to(self.device)

self.model.eval()

# Load ImageNet class labels

with open("data/imagenet_classes.txt", "r") as f:

self.class_names = [line.strip() for line in f.readlines()]

# Set up image transformation

self.transform = transforms.Compose([

transforms.Resize(256),

transforms.CenterCrop(224),

transforms.ToTensor(),

transforms.Normalize(

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)

])

def classify_image(self, image_path):

"""Classify an image."""

# Load and transform image

try:

image = Image.open(image_path).convert("RGB")

except Exception as e:

print(f"Error loading image: {e}")

return None

input_tensor = self.transform(image).unsqueeze(0).to(self.device)

# Perform inference

start_time = time.time()

with torch.no_grad():

output = self.model(input_tensor)

inference_time = time.time() - start_time

# Get top-5 predictions

probabilities = torch.nn.functional.softmax(output[0], dim=0)

top5_prob, top5_indices = torch.topk(probabilities, 5)

# Convert to human-readable results

results = []

for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)):

class_name = self.class_names[idx]

results.append((class_name, prob.item()))

return results, inference_time

def ensure_imagenet_labels():

"""Ensure ImageNet class labels file exists."""

labels_file = "data/imagenet_classes.txt"

os.makedirs("data", exist_ok=True)

if not os.path.exists(labels_file):

print("Downloading ImageNet class labels...")

import urllib.request

url = "https://raw.githubusercontent.com/pytorch/vision/main/torchvision/models/imagenet_classes.txt"

urllib.request.urlretrieve(url, labels_file)

def main():

"""Main function."""

parser = argparse.ArgumentParser(description="Real-time image classification using PyTorch with Metal")

parser.add_argument("--model", type=str, default="mobilenet_v3_small",

choices=["mobilenet_v3_small", "resnet18", "efficientnet_b0"],

help="Model architecture to use")

parser.add_argument("--image", type=str, required=True,

help="Path to image for classification")

args = parser.parse_args()

ensure_imagenet_labels()

classifier = ImageClassifier(args.model)

print(f"\nClassifying image: {args.image}")

results, inference_time = classifier.classify_image(args.image)

print(f"\nResults (inference time: {inference_time*1000:.2f} ms):")

if results:

for i, (class_name, probability) in enumerate(results):

print(f"{i+1}. {class_name}: {probability*100:.2f}%")

else:

print("Failed to classify image.")

if __name__ == "__main__":

main()

This example demonstrates:

Use Case 3: Custom ML Models

For research projects and custom model development, your choice depends on your specific requirements and Mac configuration.

When to Choose MLX for Custom Models

Recommended when:

⚠️ Considerations:

When to Choose PyTorch for Custom Models

Recommended when:

⚠️ Considerations:

Mac Studio/Pro Training Advantage

One of the most compelling use cases for high-end Mac Studio and Mac Pro systems is training custom models:

While these systems still can't match dedicated training clusters, they provide a surprisingly powerful local alternative for medium-sized model training and fine-tuning.

Custom Model Example: Time Series Forecasting

Let's create a simple time series forecasting model for stock price prediction using MLX. This demonstrates how to build a custom recurrent neural network:

#!/usr/bin/env python3

"""

Time series forecasting with MLX.

This script demonstrates how to build a custom recurrent neural network

for time series forecasting using Apple's MLX framework, optimized for

Apple Silicon hardware.

"""

import os

import argparse

import csv

import numpy as np

import matplotlib.pyplot as plt

import mlx.core as mx

import mlx.nn as nn

import mlx.optimizers as optim

class SimpleRNN(nn.Module):

"""

Simple recurrent neural network for time series forecasting.

This RNN uses a single recurrent layer followed by a linear output layer

to predict the next value in a time series sequence.

"""

def __init__(self, input_size, hidden_size, output_size):

"""

Initialize the RNN model.

Args:

input_size (int): Size of input features (typically 1 for univariate time series)

hidden_size (int): Size of the hidden state

output_size (int): Size of the output (typically 1 for single-value prediction)

"""

super().__init__()

self.hidden_size = hidden_size

self.rnn_cell = nn.RNNCell(input_size, hidden_size)

self.linear = nn.Linear(hidden_size, output_size)

def __call__(self, x, hidden=None):

"""

Forward pass through the network.

Args:

x (mx.array): Input tensor of shape (batch_size, sequence_length, input_size)

hidden (mx.array, optional): Initial hidden state

Returns:

tuple: (output, hidden_state)

"""

# x shape: (batch_size, sequence_length, input_size)

batch_size, seq_len, _ = x.shape

if hidden is None:

hidden = mx.zeros((batch_size, self.hidden_size))

outputs = []

for t in range(seq_len):

hidden = self.rnn_cell(x[:, t, :], hidden)

outputs.append(hidden)

# Use the last hidden state for prediction

output = self.linear(outputs[-1])

return output, hidden

def load_stock_data(csv_file, target_column="Close"):

"""

Load stock price data from CSV.

Args:

csv_file (str): Path to the CSV file

target_column (str): Column name for the target price data

Returns:

tuple: (dates, prices) as lists

"""

dates = []

prices = []

with open(csv_file, 'r') as f:

reader = csv.DictReader(f)

for row in reader:

dates.append(row["Date"])

prices.append(float(row[target_column]))

return dates, prices

def prepare_timeseries_data(data, sequence_length=10):

"""

Prepare time series data for training.

Args:

data (list): List of price values

sequence_length (int): Length of input sequences

Returns:

tuple: (X, y, X_mean, X_std, y_mean, y_std) - normalized data and normalization parameters

"""

X, y = [], []

for i in range(len(data) - sequence_length):

X.append(data[i:i + sequence_length])

y.append(data[i + sequence_length])

# Convert to MLX arrays

X = mx.array(np.array(X, dtype=np.float32).reshape(-1, sequence_length, 1))

y = mx.array(np.array(y, dtype=np.float32).reshape(-1, 1))

# Normalize data

X_mean = mx.mean(X)

X_std = mx.std(X)

X = (X - X_mean) / X_std

y_mean = mx.mean(y)

y_std = mx.std(y)

y = (y - y_mean) / y_std

return X, y, X_mean, X_std, y_mean, y_std

def train_model(model, X, y, epochs=100, batch_size=32, learning_rate=0.01):

"""

Train the model.

Args:

model (SimpleRNN): The model to train

X (mx.array): Input data

y (mx.array): Target data

epochs (int): Number of training epochs

batch_size (int): Batch size for training

learning_rate (float): Learning rate for optimizer

Returns:

list: Training losses per epoch

"""

optimizer = optim.Adam(learning_rate=learning_rate)

num_samples = X.shape[0]

num_batches = num_samples // batch_size

# Define loss function

def loss_fn(model, X_batch, y_batch):

pred, _ = model(X_batch)

return mx.mean(mx.square(pred - y_batch))

# Define training step

def train_step(model, X_batch, y_batch):

loss, grads = mlx.value_and_grad(model, loss_fn)(model, X_batch, y_batch)

optimizer.update(model, grads)

return loss

# Training loop

losses = []

for epoch in range(epochs):

# Shuffle data

indices = np.random.permutation(num_samples)

epoch_loss = 0

for i in range(num_batches):

batch_indices = indices[i * batch_size:(i + 1) * batch_size]

X_batch = X[batch_indices]

y_batch = y[batch_indices]

loss = train_step(model, X_batch, y_batch)

epoch_loss += loss.item()

epoch_loss /= num_batches

losses.append(epoch_loss)

if (epoch + 1) % 10 == 0:

print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.6f}")

return losses

def main():

"""

Main function to run the time series forecasting application.

"""

parser = argparse.ArgumentParser(description="Time series forecasting with MLX")

parser.add_argument("--csv", type=str, required=True,

help="Path to CSV file with stock price data")

parser.add_argument("--column", type=str, default="Close",

help="Column name for price data")

parser.add_argument("--epochs", type=int, default=100,

help="Number of training epochs")

parser.add_argument("--sequence", type=int, default=10,

help="Sequence length for time series")

parser.add_argument("--predict", type=int, default=30,

help="Number of days to predict")

args = parser.parse_args()

# Load and prepare data

dates, prices = load_stock_data(args.csv, args.column)

X, y, X_mean, X_std, y_mean, y_std = prepare_timeseries_data(prices, args.sequence)

# Create and train model

model = SimpleRNN(input_size=1, hidden_size=32, output_size=1)

losses = train_model(model, X, y, epochs=args.epochs)

# Make predictions

test_data = prices[-args.sequence:]

predictions = []

current_sequence = mx.array(np.array(test_data, dtype=np.float32).reshape(1, args.sequence, 1))

current_sequence = (current_sequence - X_mean) / X_std

# Generate future predictions

hidden = None

for _ in range(args.predict):

# Make prediction

pred, hidden = model(current_sequence, hidden)

# Denormalize

pred_value = pred.item() * y_std.item() + y_mean.item()

predictions.append(pred_value)

# Update sequence (remove first element, add prediction)

new_seq = mx.concatenate([current_sequence[:, 1:, :],

mx.array([[[pred.item()]]]) * y_std / X_std + (y_mean - X_mean) / X_std],

axis=1)

current_sequence = new_seq

# Plot results

plt.figure(figsize=(12, 6))

# Plot historical data

plt.plot(range(len(prices)), prices, label="Historical Data")

# Plot predictions

forecast_range = range(len(prices) - 1, len(prices) + args.predict - 1)

plt.plot(forecast_range, predictions, label="Forecast", color="red")

# Add vertical line at prediction start

plt.axvline(x=len(prices) - 1, color="gray", linestyle="--")

plt.title(f"Stock Price Forecast ({args.column})")

plt.xlabel("Time")

plt.ylabel("Price")

plt.legend()

plt.tight_layout()

# Save plot

output_file = f"stock_forecast_{os.path.basename(args.csv).split('.')[0]}.png"

plt.savefig(output_file)

print(f"Forecast saved to {output_file}")

# Print predictions

print("\nPredictions for the next", args.predict, "days:")

for i, pred in enumerate(predictions):

print(f"Day {i+1}: {pred:.2f}")

if __name__ == "__main__":

main()

This example demonstrates:

Building a Comprehensive Project

Now, let's combine our knowledge to build a more comprehensive project that demonstrates how to choose between MLX and PyTorch based on the specific task. This project will include:

Comprehensive AI Project Architecture Figure 3: System architecture diagram showing how text summarization (MLX) and image captioning (PyTorch) components work together in our sample application, including data flow, API endpoints, and user interface components.

  1. Text summarization (using MLX)
  2. Image captioning (using PyTorch)
  3. A simple web interface to interact with both models

Project Structure

Create the following files and directories:

ai-with-mac/

├── models/

├── data/

├── static/

│ └── css/

│ └── style.css

├── templates/

│ ├── index.html

│ └── result.html

├── scripts/

│ ├── summarize.py

│ ├── caption.py

│ └── server.py

└── requirements.txt

1. Text Summarization with MLX

Create the file scripts/summarize.py:

#!/usr/bin/env python3

"""

Text summarization using MLX.

This script provides a text summarization class that uses Apple's MLX framework

to run language models efficiently on Apple Silicon for generating concise summaries.

"""

import os

import mlx.core as mx

from mlx_lm import generate, load

class Summarizer:

"""

Text summarizer using MLX.

This class handles the loading of an MLX language model and uses it

to generate summaries of input text.

"""

def __init__(self, model_path="models/gemma-2b-it-4bit"):

"""

Initialize the summarizer with an MLX model.

Args:

model_path (str): Path to the directory containing the model files

"""

# Ensure model directory exists

os.makedirs(os.path.dirname(model_path), exist_ok=True)

# Check if model exists, download if needed

if not os.path.exists(model_path):

print(f"Model not found at {model_path}. Please download it first using:")

print(f"python -m mlx_lm.convert --hf-path google/gemma-2b-it -q --out-path {model_path}")

return

print(f"Loading summarization model from {model_path}...")

self.model, self.tokenizer = load(model_path)

print("Summarization model loaded successfully!")

def summarize(self, text, max_length=200, temperature=0.3):

"""

Summarize the given text.

Args:

text (str): Text to summarize

max_length (int): Maximum length of the summary in tokens

temperature (float): Temperature parameter for generation (0.0-1.0)

Returns:

str: Generated summary

"""

prompt = f"""Please summarize the following text concisely:

Text: {text}

Summary:"""

# Generate summary

gen_config = {

"max_tokens": max_length,

"temperature": temperature,

"top_p": 0.9

}

tokens = self.tokenizer.encode(prompt)

generated_tokens = generate(self.model, self.tokenizer, tokens, gen_config)

summary = self.tokenizer.decode(generated_tokens[len(tokens):])

return summary.strip()

# Example usage

if __name__ == "__main__":

summarizer = Summarizer()

# Example text

text = """

Apple Silicon has fundamentally changed what's possible for machine learning on

consumer-grade hardware. With the M1, M2, M3, and now M4 chips, Mac users can run

sophisticated AI models locally with impressive performance. This unified memory

architecture eliminates the traditional bottleneck of data transfers between CPU and

GPU memory, allowing both processors to access the same physical memory seamlessly.

Whether you're a data scientist, AI researcher, app developer, or just an enthusiast

exploring machine learning, understanding your options on Apple Silicon is crucial for

maximizing performance and efficiency.

"""

summary = summarizer.summarize(text)

print("\nSummary:")

print(summary)

2. Image Captioning with PyTorch

Create the file scripts/caption.py:

#!/usr/bin/env python3

"""

Image captioning using PyTorch with Metal acceleration.

This script provides an image captioning class that uses PyTorch with

Apple's Metal API to efficiently run image captioning models on Apple Silicon.

"""

import os

import torch

import torchvision.transforms as transforms

from PIL import Image

class ImageCaptioner:

"""

Image captioner using PyTorch with Metal acceleration.

This class handles the loading of a pre-trained image captioning model

and uses Apple's Metal API for GPU acceleration when available.

"""

def __init__(self):

"""

Initialize the captioner with a pre-trained model.

"""

# Check if Metal is available

if torch.backends.mps.is_available():

self.device = torch.device("mps")

print("Using Metal Performance Shaders (MPS)")

else:

self.device = torch.device("cpu")

print("Metal not available, using CPU")

# Load pre-trained model

print("Loading image captioning model...")

self.model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True).to(self.device)

self.model.eval()

# Set up image transformation

self.transform = transforms.Compose([

transforms.Resize(256),

transforms.CenterCrop(224),

transforms.ToTensor(),

transforms.Normalize(

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)

])

def generate_caption(self, image_path):

"""

Generate a caption for the image.

Args:

image_path (str): Path to the image file

Returns:

str: Generated caption or error message

"""

# Load and transform image

try:

image = Image.open(image_path).convert("RGB")

except Exception as e:

print(f"Error loading image: {e}")

return "Error loading image"

input_tensor = self.transform(image).unsqueeze(0).to(self.device)

# Generate caption

with torch.no_grad():

output = self.model(input_tensor)

caption = self.model.caption_generator.decode(output[0])

return caption

# Example usage

if __name__ == "__main__":

captioner = ImageCaptioner()

# Example image

image_path = "data/sample.jpg"

if os.path.exists(image_path):

caption = captioner.generate_caption(image_path)

print("\nCaption:")

print(caption)

else:

print(f"Image not found at {image_path}")

print(f"Error loading image: {e}")

return "Error loading image"

input_tensor = self.transform(image).unsqueeze(0).to(self.device)

# Generate caption

with torch.no_grad():

output = self.model(input_tensor)

caption = self.model.caption_generator.decode(output[0])

return caption

Example usage

if name == "main": captioner = ImageCaptioner()

# Example image

image_path = "data/sample.jpg"

if os.path.exists(image_path):

caption = captioner.generate_caption(image_path)

print("\nCaption:")

print(caption)

else:

print(f"Image not found at {image_path}")

### 3. Simple Web Server

Create the file `scripts/server.py`:

```python

#!/usr/bin/env python3

"""

Simple web server for AI with Mac demonstration.

"""

import os

import uuid

from flask import Flask, render_template, request, redirect, url_for

from summarize import Summarizer

from caption import ImageCaptioner

app = Flask(__name__, template_folder="../templates", static_folder="../static")

# Initialize models

summarizer = Summarizer()

captioner = ImageCaptioner()

# Create upload directories

os.makedirs("../uploads", exist_ok=True)

@app.route("/")

def index():

"""Render main page."""

return render_template("index.html")

@app.route("/summarize", methods=["POST"])

def summarize_text():

"""Summarize text input."""

text = request.form.get("text", "")

if not text:

return render_template("index.html", error="Please enter some text to summarize.")

summary = summarizer.summarize(text)

return render_template("result.html",

result_type="summary",

original=text,

result=summary,

framework="MLX")

@app.route("/caption", methods=["POST"])

def caption_image():

"""Caption uploaded image."""

if "image" not in request.files:

return render_template("index.html", error="No image uploaded.")

file = request.files["image"]

if file.filename == "":

return render_template("index.html", error="No image selected.")

# Save uploaded image

filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1]

filepath = os.path.join("../uploads", filename)

file.save(filepath)

# Generate caption

caption = captioner.generate_caption(filepath)

return render_template("result.html",

result_type="caption",

image_path="../uploads/" + filename,

result=caption,

framework="PyTorch with Metal")

if __name__ == "__main__":

app.run(debug=True)

4. HTML Templates

Create the file templates/index.html:

<!DOCTYPE html>

<html lang="en">

<head>

<meta charset="UTF-8">

<meta name="viewport" content="width=device-width, initial-scale=1.0">

<title>AI with Mac Demonstration</title>

<link rel="stylesheet" href="../static/css/style.css">

</head>

<body>

<header>

<h1>AI with Mac Demonstration</h1>

<p>Example project combining MLX and PyTorch</p>

</header>

<main>

{% if error %}

<div class="error">{{ error }}</div>

{% endif %}

<div class="cards">

<div class="card">

<h2>Text Summarization</h2>

<p>Using MLX for efficient inference</p>

<form action="/summarize" method="post">

<textarea name="text" rows="10" placeholder="Enter text to summarize..."></textarea>

<button type="submit">Summarize</button>

</form>

</div>

<div class="card">

<h2>Image Captioning</h2>

<p>Using PyTorch with Metal acceleration</p>

<form action="/caption" method="post" enctype="multipart/form-data">

<div class="file-upload">

<input type="file" name="image" accept="image/*">

</div>

<button type="submit">Generate Caption</button>

</form>

</div>

</div>

</main>

<footer>

<p>Part of the "AI with Mac" series</p>

</footer>

</body>

</html>

Create the file templates/result.html:

<!DOCTYPE html>

<html lang="en">

<head>

<meta charset="UTF-8">

<meta name="viewport" content="width=device-width, initial-scale=1.0">

<title>AI with Mac Demonstration - Result</title>

<link rel="stylesheet" href="../static/css/style.css">

</head>

<body>

<header>

<h1>AI with Mac Demonstration</h1>

<p>Example project combining MLX and PyTorch</p>

</header>

<main>

<div class="result-container">

<h2>{{ result_type|capitalize }} Result (using {{ framework }})</h2>

{% if result_type == "summary" %}

<div class="result-pair">

<div class="original">

<h3>Original Text</h3>

<div class="content">{{ original }}</div>

</div>

<div class="result">

<h3>Summary</h3>

<div class="content">{{ result }}</div>

</div>

</div>

{% elif result_type == "caption" %}

<div class="result-image">

<img src="{{ image_path }}" alt="Uploaded image">

<div class="caption">{{ result }}</div>

</div>

{% endif %}

</div>

<div class="back-link">

<a href="/">⬅ Back to demo</a>

</div>

</main>

<footer>

<p>Part of the "AI with Mac" series</p>

</footer>

</body>

</html>

5. CSS Styling

Create the file static/css/style.css:

/* Basic styling for the AI with Mac demo */

body {

font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;

line-height: 1.6;

color: #333;

max-width: 1200px;

margin: 0 auto;

padding: 20px;

background-color: #f5f5f7;

}

header {

text-align: center;

margin-bottom: 40px;

}

h1 {

color: #000;

}

.cards {

display: flex;

gap: 20px;

flex-wrap: wrap;

}

.card {

flex: 1;

min-width: 300px;

background-color: white;

border-radius: 12px;

box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);

padding: 20px;

margin-bottom: 20px;

}

textarea {

width: 100%;

padding: 10px;

border: 1px solid #ddd;

border-radius: 6px;

font-family: inherit;

margin-bottom: 15px;

}

.file-upload {

border: 2px dashed #ddd;

padding: 20px;

text-align: center;

margin-bottom: 15px;

border-radius: 6px;

}

button {

background-color: #0071e3;

color: white;

border: none;

padding: 10px 20px;

border-radius: 6px;

cursor: pointer;

font-weight: 500;

}

button:hover {

background-color: #0062c2;

}

.error {

background-color: #ffebee;

color: #c62828;

padding: 10px;

border-radius: 6px;

margin-bottom: 20px;

}

.result-container {

background-color: white;

border-radius: 12px;

box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);

padding: 20px;

margin-bottom: 30px;

}

.result-pair {

display: flex;

gap: 20px;

flex-wrap: wrap;

}

.original, .result {

flex: 1;

min-width: 300px;

}

.content {

background-color: #f9f9f9;

padding: 15px;

border-radius: 6px;

white-space: pre-wrap;

}

.result-image {

text-align: center;

}

.result-image img {

max-width: 100%;

max-height: 500px;

border-radius: 8px;

margin-bottom: 15px;

}

.caption {

font-size: 1.2em;

font-weight: 500;

margin-top: 10px;

}

.back-link {

margin-top: 20px;

}

.back-link a {

color: #0071e3;

text-decoration: none;

}

footer {

text-align: center;

margin-top: 40px;

padding-top: 20px;

border-top: 1px solid #ddd;

color: #888;

}

6. Requirements

Create the file requirements.txt:

flask

mlx

mlx-lm

torch

torchvision

pillow

Making Informed Decisions

After exploring different approaches and building practical applications, here are some guidelines to help you make informed decisions about which framework to use for your ML projects on Apple Silicon:

Decision Checklist

Ask yourself these questions when starting a new ML project:

  1. What is my primary task?
    • LLM inference → MLX
    • Computer vision with pre-trained models → PyTorch
    • Custom research models → Either (depending on other factors)
  1. What are my hardware constraints?
    • Limited RAM → MLX (better memory efficiency)
    • Older Apple Silicon → Depends on specific task
    • Newest Apple Silicon → Either (both perform well)
  1. What is my deployment target?
    • Apple Silicon only → MLX
    • Cross-platform → PyTorch
    • Mix of platforms → Consider a hybrid approach
  1. What is my timeline?
    • Quick prototype → Use the framework you're most familiar with
    • Long-term project → Worth investing time in learning MLX if appropriate
  1. What is my team's expertise?
    • Strong PyTorch background → Start with PyTorch, explore MLX gradually
    • New to both → MLX has a simpler API for certain tasks

A Balanced Approach

For many projects, a balanced approach works best:

  1. Evaluate task-specific performance: Run benchmarks for your specific task
  2. Consider implementation effort: Weigh development time vs. runtime performance
  3. Think about future maintenance: Consider documentation and community support
  4. Start small: Begin with a proof of concept in both frameworks if feasible
  5. Be flexible: Be willing to switch frameworks if needs change

Conclusion

Throughout this series, we've explored the exciting possibilities that Apple Silicon brings to machine learning on Mac. We've set up development environments, run large language models locally, compared frameworks, and built practical applications.

The key takeaways from our journey:

  1. Apple Silicon has democratized AI: From MacBooks to the powerful Mac Studio and Mac Pro, sophisticated models can now run on Apple hardware without cloud dependencies
  2. High-end Mac configurations enable professional workloads: Mac Studio and Mac Pro with up to 512GB RAM can run workloads previously requiring specialized servers
  3. MLX and PyTorch offer different advantages: Each framework has strengths for different use cases
  4. Scale your approach to your hardware: Choose models and quantization based on available memory
  5. The right approach depends on your specific needs: Consider your task, hardware, and requirements
  6. Practical applications are now possible: From language models to computer vision, Apple Silicon supports diverse AI workloads
  7. This field is rapidly evolving: Both frameworks continue to improve and add capabilities

As you embark on your own AI projects on Apple Silicon, remember that there's no one-size-fits-all solution. The best approach is to understand your specific requirements, experiment with different options, and choose the tools that best fit your needs.

For those with high-end Mac Studio or Mac Pro systems, you have the additional advantage of being able to run the largest open source models (70B+ parameters) or train custom models - capabilities previously limited to specialized infrastructure.

We hope this series has provided a solid foundation for your machine learning journey on Apple Silicon. The code examples and practical applications should serve as useful starting points for your own projects, whether you're using a MacBook Air or a maxed-out Mac Pro.

Happy coding, and enjoy exploring the possibilities of AI on your Mac!

This concludes our "AI with Mac" series. All code examples from this series are available in our GitHub repository.