Grand Diomande Research · Full HTML Reader

CC-MotionGen Technical Documentation

CC-MotionGen is a state-of-the-art diffusion-based model for generating temporally coherent motion trajectories conditioned on audio features. The system comprises a 116M parameter UNet1D diffusion backbone, a 2M parameter motion decoder, and a comprehensive post-processing pipeline designed for real-time choreography synthesis.

Embodied Trajectory Systems proposal experiment writeup candidate score 50 .md

Full Public Reader

CC-MotionGen Technical Documentation

> Version: 0.2.0
> Last Updated: December 2025
> Authors: Comp-Core ML Team

Executive Summary

CC-MotionGen is a state-of-the-art diffusion-based model for generating temporally coherent motion trajectories conditioned on audio features. The system comprises a 116M parameter UNet1D diffusion backbone, a 2M parameter motion decoder, and a comprehensive post-processing pipeline designed for real-time choreography synthesis.

Key Capabilities:
- Audio-synchronized motion generation at 30fps
- 25-dimensional motion representation (position, velocity, orientation, phase, style)
- End-to-end differentiable pipeline for temporal coherence
- Scalable inference via DDIM sampling (20 steps vs 1000 DDPM)

---

Table of Contents

1. [System Architecture](#1-system-architecture)
2. [Motion Representation Format](#2-motion-representation-format)
3. [Model Components Deep Dive](#3-model-components-deep-dive)
4. [Audio Feature Extraction](#4-audio-feature-extraction)
5. [Diffusion Process Mathematics](#5-diffusion-process-mathematics)
6. [Training Pipeline](#6-training-pipeline)
7. [End-to-End Fine-tuning](#7-end-to-end-fine-tuning)
8. [Inference & Sampling](#8-inference--sampling)
9. [Post-Processing Pipeline](#9-post-processing-pipeline)
10. [Validation & Sanity Checks](#10-validation--sanity-checks)
11. [Configuration Reference](#11-configuration-reference)
12. [API Reference](#12-api-reference)
13. [Performance & Benchmarks](#13-performance--benchmarks)
14. [Troubleshooting Guide](#14-troubleshooting-guide)
15. [Known Issues & Roadmap](#15-known-issues--roadmap)

---

1. System Architecture

1.1 High-Level Architecture

┌─────────────────────────────────────────────────────────────────────────────┐
│                           CC-MotionGen Pipeline                              │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  ┌──────────────┐    ┌─────────────────┐    ┌──────────────┐                │
│  │    Audio     │    │   Audio Feature │    │  Conditioning │                │
│  │   Waveform   │───>│    Extractor    │───>│    Encoder    │                │
│  │   (44.1kHz)  │    │   (librosa)     │    │   (Linear)    │                │
│  └──────────────┘    └─────────────────┘    └───────┬───────┘                │
│                                                      │                        │
│                                              ┌───────▼───────┐                │
│                                              │  Audio Cond   │                │
│                                              │  [B, 163, T]  │                │
│                                              └───────┬───────┘                │
│                                                      │                        │
│  ┌──────────────┐    ┌─────────────────┐    ┌───────▼───────┐                │
│  │   Gaussian   │    │                 │    │               │                │
│  │    Noise     │───>│   UNet1D        │<───│  Cross-Attn   │                │
│  │  [B, 25, T]  │    │   (116M)        │    │  Conditioning │                │
│  └──────────────┘    └────────┬────────┘    └───────────────┘                │
│                               │                                               │
│                       ┌───────▼───────┐                                       │
│                       │  Raw Latent   │                                       │
│                       │  [B, 25, T]   │                                       │
│                       └───────┬───────┘                                       │
│                               │                                               │
│                       ┌───────▼───────┐                                       │
│                       │MotionDecoder  │                                       │
│                       │    (2M)       │                                       │
│                       └───────┬───────┘                                       │
│                               │                                               │
│                       ┌───────▼───────┐                                       │
│                       │ PostProcessor │                                       │
│                       │  (Normalize)  │                                       │
│                       └───────┬───────┘                                       │
│                               │                                               │
│                       ┌───────▼───────┐                                       │
│                       │ Motion Output │                                       │
│                       │  [B, T, 25]   │                                       │
│                       └───────────────┘                                       │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

1.2 Directory Structure

cc_motiongen/
├── __init__.py
├── config.py                 # Global configuration dataclasses
├── types.py                  # Type definitions (MotionTrajectory, AudioCondition)
│
├── model/
│   ├── __init__.py
│   ├── unet.py              # UNet1D architecture (116M params)
│   │   ├── UNetConfig       # Configuration dataclass
│   │   ├── UNet1D           # Main model class
│   │   ├── ResBlock         # Residual block with time embedding
│   │   ├── AttentionBlock   # Self/cross attention
│   │   └── Downsample/Upsample
│   │
│   ├── diffusion.py         # Gaussian diffusion wrapper
│   │   ├── DiffusionConfig  # Noise schedule configuration
│   │   ├── GaussianDiffusion # Training and sampling logic
│   │   ├── q_sample()       # Forward diffusion (add noise)
│   │   ├── p_sample()       # Reverse diffusion (denoise)
│   │   └── ddim_sample()    # Accelerated sampling
│   │
│   ├── conditioning.py      # Audio conditioning modules
│   │   ├── AudioEncoder     # Project audio features
│   │   └── CrossAttention   # Audio-motion attention
│   │
│   └── decoder.py           # Motion decoder
│       ├── DecoderConfig    # Configuration
│       ├── MotionDecoder    # Latent -> motion mapping
│       └── ResidualBlock    # Decoder building block
│
├── training/
│   ├── __init__.py
│   ├── trainer.py           # Main training loop
│   │   ├── Trainer          # Training orchestrator
│   │   ├── train_epoch()    # Single epoch logic
│   │   └── validate()       # Validation loop
│   │
│   └── losses.py            # Loss functions
│       ├── DiffusionLoss    # MSE on predicted noise
│       ├── DecoderLoss      # Reconstruction losses
│       └── TemporalLoss     # Coherence losses
│
├── inference/
│   ├── __init__.py
│   ├── sampler.py           # Sampling implementations
│   │   ├── DDPMSampler      # Full 1000-step sampling
│   │   ├── DDIMSampler      # Accelerated sampling
│   │   └── sample()         # Unified interface
│   │
│   ├── postprocess.py       # Post-processing
│   │   ├── PostProcessConfig
│   │   ├── MotionPostProcessor      # NumPy version
│   │   └── MotionPostProcessorTorch # GPU version
│   │
│   ├── mpms_sampler.py      # MPMS-enhanced sampling
│   └── selection.py         # Best-of-N selection
│
├── validation/
│   ├── __init__.py
│   └── sanity.py            # Sanity check suite
│       ├── SanityChecker    # Main checker class
│       ├── CheckResult      # Individual check result
│       └── check_*()        # Individual checks
│
├── evaluation/
│   ├── __init__.py
│   ├── metrics.py           # Evaluation metrics
│   └── harness.py           # RAG++ evaluation harness
│
├── data/
│   ├── __init__.py
│   └── dataset.py           # Data loading
│       ├── MotionPhraseDataset  # Main dataset class
│       ├── GCSLoader            # Google Cloud Storage loader
│       └── collate_fn           # Batch collation
│
└── scripts/
    ├── train.py             # Main training script
    ├── finetune_e2e.py      # E2E fine-tuning script
    ├── evaluate_sanity.py   # Sanity evaluation
    ├── run_evaluation.py    # RAG++ evaluation
    └── inference.py         # Inference script

1.3 Data Flow Diagram

┌─────────────────────────────────────────────────────────────────────────────┐
│                              TRAINING FLOW                                   │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  GCS Bucket                                                                  │
│  gs://cc-music-library/motionphrase/                                        │
│       │                                                                      │
│       ├── audio_features/{phrase_id}/features.npz                           │
│       │       │                                                              │
│       │       ├── mel: [128, T]      Mel spectrogram                        │
│       │       ├── chroma: [12, T]    Pitch class                            │
│       │       ├── mfcc: [20, T]      Timbral features                       │
│       │       ├── rms: [T]           Energy                                 │
│       │       ├── spectral_centroid: [T]                                    │
│       │       └── onset_strength: [T]                                       │
│       │                                                                      │
│       └── motion/{phrase_id}/trajectory.npy                                 │
│               │                                                              │
│               └── motion: [T, 25]    Ground truth motion                    │
│                                                                              │
│                           │                                                  │
│                           ▼                                                  │
│               ┌───────────────────────┐                                      │
│               │  MotionPhraseDataset  │                                      │
│               │  - Loads from GCS     │                                      │
│               │  - Caches in memory   │                                      │
│               │  - Handles fallbacks  │                                      │
│               └───────────┬───────────┘                                      │
│                           │                                                  │
│                           ▼                                                  │
│               ┌───────────────────────┐                                      │
│               │     DataLoader        │                                      │
│               │  - batch_size=32      │                                      │
│               │  - shuffle=True       │                                      │
│               │  - num_workers=4      │                                      │
│               └───────────┬───────────┘                                      │
│                           │                                                  │
│                           ▼                                                  │
│               ┌───────────────────────┐                                      │
│               │   Training Step       │                                      │
│               │                       │                                      │
│               │  1. Sample timestep t │                                      │
│               │  2. Add noise: x_t    │                                      │
│               │  3. Predict: ε_θ(x_t) │                                      │
│               │  4. Loss: MSE(ε, ε_θ) │                                      │
│               │  5. Backprop          │                                      │
│               └───────────────────────┘                                      │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

---

2. Motion Representation Format

2.1 25-Dimensional Motion Vector

Each frame of motion is represented as a 25-dimensional vector:

Motion Vector Layout [25 dimensions]
┌─────────────────────────────────────────────────────────────────────────────┐
│ Index │ Name              │ Dims │ Description              │ Valid Range  │
├───────┼───────────────────┼──────┼──────────────────────────┼──────────────┤
│ 0-2   │ position          │  3   │ 3D world position        │ [-100, 100]  │
│ 3-5   │ velocity          │  3   │ Linear velocity (m/s)    │ [-50, 50]    │
│ 6-8   │ acceleration      │  3   │ Linear acceleration      │ [-100, 100]  │
│ 9-12  │ quaternion        │  4   │ Orientation (w,x,y,z)    │ Unit norm    │
│ 13-15 │ angular_velocity  │  3   │ Rotational velocity      │ [-10, 10]    │
│ 16    │ phase             │  1   │ Musical phase            │ [0, 1]       │
│ 17-24 │ style             │  8   │ Style embedding          │ Unbounded    │
└───────┴───────────────────┴──────┴──────────────────────────┴──────────────┘

2.2 Coordinate System

                    +Y (up)
                     │
                     │
                     │
                     │
                     └──────────── +X (right)
                    /
                   /
                  /
                 +Z (forward/towards camera)

Quaternion Convention: Hamilton (w, x, y, z)
- Identity rotation: (1, 0, 0, 0)
- Handedness: Right-handed

2.3 Temporal Relationships

Motion must satisfy physical constraints between consecutive frames:

python
# Frame rate
FPS = 30.0
DT = 1.0 / FPS  # 0.0333 seconds

# Kinematic relationships
# Position → Velocity (first derivative)
velocity[t] = (position[t+1] - position[t]) / DT
           = (position[t+1] - position[t]) * FPS

# Velocity → Acceleration (second derivative)
acceleration[t] = (velocity[t+1] - velocity[t]) / DT
                = (velocity[t+1] - velocity[t]) * FPS

# Acceleration → Jerk (third derivative, smoothness indicator)
jerk[t] = (acceleration[t+1] - acceleration[t]) / DT
        = (acceleration[t+1] - acceleration[t]) * FPS

# Coherence requirement: explicit values should match derived values
# Velocity coherence error:
vel_error = |velocity_explicit - velocity_derived|

# Acceleration coherence error:
accel_error = |acceleration_explicit - acceleration_derived|

2.4 Phase Semantics

The phase dimension encodes musical timing:

Phase Value    Musical Position
───────────────────────────────
0.00           Beat 1 (downbeat)
0.25           Beat 2
0.50           Beat 3
0.75           Beat 4
1.00           Next bar (wraps to 0)

Phase should be:
- Monotonically increasing within a bar
- Wrapping from 1.0 back to 0.0 at bar boundaries
- Synchronized with audio beat times

2.5 Style Embedding

The 8-dimensional style vector captures:

Style Dimensions (learned representation):
├── dims 0-1: Movement intensity (high/low energy)
├── dims 2-3: Spatial extent (wide/narrow movements)
├── dims 4-5: Rhythmic feel (staccato/legato)
└── dims 6-7: Body attitude (tense/relaxed)

Style should be:
- Temporally consistent (similar across consecutive frames)
- Gradually changing to reflect music dynamics
- Normalized for stable training

---

3. Model Components Deep Dive

3.1 UNet1D Architecture (116M Parameters)

3.1.1 Configuration

python
@dataclass
class UNetConfig:
    """UNet1D configuration."""

    # Input/Output dimensions
    in_channels: int = 25              # Motion vector dimension
    out_channels: int = 25             # Same as input (predict noise)

    # Architecture
    model_channels: int = 256          # Base channel count
    num_res_blocks: int = 2            # ResBlocks per level
    attention_resolutions: Tuple[int, ...] = (4, 8, 16)  # Where to apply attention
    channel_mult: Tuple[int, ...] = (1, 2, 4, 8)  # Channel multipliers per level
    num_heads: int = 8                 # Attention heads
    num_head_channels: int = 64        # Channels per head (alternative to num_heads)
    use_scale_shift_norm: bool = True  # FiLM conditioning

    # Conditioning
    audio_cond_dim: int = 163          # Audio feature dimension
    time_embed_dim: int = 256          # Timestep embedding dimension
    context_dim: int = 256             # Cross-attention context dimension

    # Regularization
    dropout: float = 0.1
    use_checkpoint: bool = False       # Gradient checkpointing

3.1.2 Architecture Diagram

Input: x_t [B, 25, T], t [B], audio_cond [B, 163, T]
                │
                ▼
┌───────────────────────────────────────────────────────────────────┐
│                         TIME EMBEDDING                             │
│  t ──> Sinusoidal(256) ──> Linear(256) ──> SiLU ──> Linear(256)  │
│                                                 │                  │
│                                          time_emb [B, 256]        │
└───────────────────────────────────────────────────────────────────┘
                │
                ▼
┌───────────────────────────────────────────────────────────────────┐
│                      AUDIO CONDITIONING                            │
│  audio_cond ──> Linear(163, 256) ──> audio_context [B, 256, T]   │
└───────────────────────────────────────────────────────────────────┘
                │
                ▼
┌───────────────────────────────────────────────────────────────────┐
│                          ENCODER                                   │
│                                                                    │
│  Level 0: 256 channels, T frames                                  │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ Conv1d(25, 256) ──> ResBlock ──> ResBlock ──> Downsample   │  │
│  │                      + time_emb   + time_emb    (T → T/2)  │  │
│  └─────────────────────────────────────────────────────────────┘  │
│                              │ skip_0                              │
│                              ▼                                     │
│  Level 1: 512 channels, T/2 frames                                │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ ResBlock ──> ResBlock ──> Attention ──> Downsample         │  │
│  │ + time_emb   + time_emb   + audio_ctx    (T/2 → T/4)       │  │
│  └─────────────────────────────────────────────────────────────┘  │
│                              │ skip_1                              │
│                              ▼                                     │
│  Level 2: 1024 channels, T/4 frames                               │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ ResBlock ──> ResBlock ──> Attention ──> Downsample         │  │
│  │ + time_emb   + time_emb   + audio_ctx    (T/4 → T/8)       │  │
│  └─────────────────────────────────────────────────────────────┘  │
│                              │ skip_2                              │
│                              ▼                                     │
│  Level 3: 2048 channels, T/8 frames                               │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ ResBlock ──> ResBlock ──> Attention                        │  │
│  │ + time_emb   + time_emb   + audio_ctx                      │  │
│  └─────────────────────────────────────────────────────────────┘  │
│                              │ skip_3                              │
└──────────────────────────────┼────────────────────────────────────┘
                               │
                               ▼
┌───────────────────────────────────────────────────────────────────┐
│                         BOTTLENECK                                 │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ ResBlock ──> Attention ──> ResBlock                        │  │
│  │ + time_emb   + audio_ctx   + time_emb                      │  │
│  └─────────────────────────────────────────────────────────────┘  │
└───────────────────────────────────────────────────────────────────┘
                               │
                               ▼
┌───────────────────────────────────────────────────────────────────┐
│                          DECODER                                   │
│                                                                    │
│  Level 3: 2048 → 1024 channels                                    │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ Concat(x, skip_3) ──> ResBlock ──> ResBlock ──> Upsample   │  │
│  │                       + time_emb   + time_emb    (T/8→T/4) │  │
│  └─────────────────────────────────────────────────────────────┘  │
│                              │                                     │
│                              ▼                                     │
│  Level 2: 1024 → 512 channels                                     │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ Concat(x, skip_2) ──> ResBlock ──> Attention ──> Upsample  │  │
│  │                       + time_emb   + audio_ctx   (T/4→T/2) │  │
│  └─────────────────────────────────────────────────────────────┘  │
│                              │                                     │
│                              ▼                                     │
│  Level 1: 512 → 256 channels                                      │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ Concat(x, skip_1) ──> ResBlock ──> Attention ──> Upsample  │  │
│  │                       + time_emb   + audio_ctx   (T/2→T)   │  │
│  └─────────────────────────────────────────────────────────────┘  │
│                              │                                     │
│                              ▼                                     │
│  Level 0: 256 → 25 channels                                       │
│  ┌─────────────────────────────────────────────────────────────┐  │
│  │ Concat(x, skip_0) ──> ResBlock ──> ResBlock ──> Conv1d(25) │  │
│  │                       + time_emb   + time_emb              │  │
│  └─────────────────────────────────────────────────────────────┘  │
│                                                                    │
└───────────────────────────────────────────────────────────────────┘
                               │
                               ▼
                    Output: ε_θ [B, 25, T]

3.1.3 ResBlock Implementation

python
class ResBlock(nn.Module):
    """
    Residual block with time embedding conditioning.

    Uses FiLM (Feature-wise Linear Modulation) for timestep conditioning:
    out = (scale * norm(x) + shift) where scale, shift = MLP(time_emb)
    """

    def __init__(
        self,
        channels: int,
        time_emb_dim: int,
        dropout: float = 0.1,
        out_channels: Optional[int] = None,
        use_scale_shift_norm: bool = True,
    ):
        super().__init__()
        out_channels = out_channels or channels

        # First convolution
        self.norm1 = nn.GroupNorm(32, channels)
        self.conv1 = nn.Conv1d(channels, out_channels, kernel_size=3, padding=1)

        # Time embedding projection
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels * 2 if use_scale_shift_norm else out_channels),
        )

        # Second convolution
        self.norm2 = nn.GroupNorm(32, out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)

        # Skip connection
        if channels != out_channels:
            self.skip_conv = nn.Conv1d(channels, out_channels, kernel_size=1)
        else:
            self.skip_conv = nn.Identity()

        self.use_scale_shift_norm = use_scale_shift_norm

    def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, C, T] input features
            time_emb: [B, time_emb_dim] timestep embedding

        Returns:
            [B, out_channels, T] output features
        """
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        # Add time embedding
        time_out = self.time_mlp(time_emb)[:, :, None]  # [B, C*2, 1]

        if self.use_scale_shift_norm:
            scale, shift = time_out.chunk(2, dim=1)
            h = self.norm2(h) * (1 + scale) + shift
        else:
            h = self.norm2(h + time_out)

        h = F.silu(h)
        h = self.dropout(h)
        h = self.conv2(h)

        return h + self.skip_conv(x)

3.1.4 Attention Block Implementation

python
class AttentionBlock(nn.Module):
    """
    Multi-head self-attention with optional cross-attention to audio context.
    """

    def __init__(
        self,
        channels: int,
        num_heads: int = 8,
        context_dim: Optional[int] = None,
    ):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.head_dim = channels // num_heads

        # Self-attention
        self.norm = nn.GroupNorm(32, channels)
        self.qkv = nn.Conv1d(channels, channels * 3, kernel_size=1)
        self.proj = nn.Conv1d(channels, channels, kernel_size=1)

        # Cross-attention (if context provided)
        if context_dim is not None:
            self.cross_norm = nn.GroupNorm(32, channels)
            self.q_cross = nn.Conv1d(channels, channels, kernel_size=1)
            self.kv_cross = nn.Conv1d(context_dim, channels * 2, kernel_size=1)
            self.proj_cross = nn.Conv1d(channels, channels, kernel_size=1)
        else:
            self.cross_norm = None

    def forward(
        self,
        x: torch.Tensor,
        context: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            x: [B, C, T] input features
            context: [B, context_dim, T] audio context (optional)

        Returns:
            [B, C, T] attended features
        """
        B, C, T = x.shape

        # Self-attention
        h = self.norm(x)
        qkv = self.qkv(h).reshape(B, 3, self.num_heads, self.head_dim, T)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        attn = torch.einsum('bhdt,bhds->bhts', q, k) * scale
        attn = F.softmax(attn, dim=-1)
        out = torch.einsum('bhts,bhds->bhdt', attn, v)

        out = out.reshape(B, C, T)
        x = x + self.proj(out)

        # Cross-attention to audio context
        if context is not None and self.cross_norm is not None:
            h = self.cross_norm(x)
            q = self.q_cross(h).reshape(B, self.num_heads, self.head_dim, T)
            kv = self.kv_cross(context).reshape(B, 2, self.num_heads, self.head_dim, T)
            k, v = kv[:, 0], kv[:, 1]

            attn = torch.einsum('bhdt,bhds->bhts', q, k) * scale
            attn = F.softmax(attn, dim=-1)
            out = torch.einsum('bhts,bhds->bhdt', attn, v)

            out = out.reshape(B, C, T)
            x = x + self.proj_cross(out)

        return x

3.2 GaussianDiffusion

3.2.1 Configuration

python
@dataclass
class DiffusionConfig:
    """Diffusion process configuration."""

    # Schedule
    timesteps: int = 1000              # Number of diffusion steps
    beta_schedule: str = "cosine"      # "linear", "cosine", "sqrt"
    beta_start: float = 0.0001         # Starting beta (linear only)
    beta_end: float = 0.02             # Ending beta (linear only)

    # Training
    prediction_type: str = "epsilon"   # "epsilon" (noise) or "x0" (sample)
    loss_type: str = "mse"             # "mse", "l1", "huber"
    clip_denoised: bool = True         # Clip x_0 predictions
    clip_range: float = 1.0            # Clipping range

    # Sampling
    rescale_timesteps: bool = False    # Rescale timesteps for training
    use_ddim: bool = True              # Use DDIM by default
    ddim_eta: float = 0.0              # DDIM stochasticity (0 = deterministic)

3.2.2 Noise Schedules

python
def get_beta_schedule(schedule: str, timesteps: int) -> np.ndarray:
    """
    Get beta schedule for diffusion process.

    Linear: β_t increases linearly from β_start to β_end
    Cosine: β_t follows cosine curve (better for images/motion)
    Sqrt: β_t = sqrt of linear (aggressive early denoising)
    """
    if schedule == "linear":
        return np.linspace(beta_start, beta_end, timesteps)

    elif schedule == "cosine":
        # From "Improved Denoising Diffusion Probabilistic Models"
        s = 0.008  # Small offset to prevent singularity
        steps = timesteps + 1
        t = np.linspace(0, timesteps, steps) / timesteps
        alphas_cumprod = np.cos((t + s) / (1 + s) * np.pi / 2) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return np.clip(betas, 0.0001, 0.9999)

    elif schedule == "sqrt":
        return np.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2

3.2.3 Forward Process (Adding Noise)

python
def q_sample(
    self,
    x_0: torch.Tensor,
    t: torch.Tensor,
    noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Forward diffusion: add noise to x_0 at timestep t.

    q(x_t | x_0) = N(x_t; sqrt(α̅_t) * x_0, (1 - α̅_t) * I)

    Args:
        x_0: [B, D, T] clean motion
        t: [B] timesteps
        noise: [B, D, T] optional pre-sampled noise

    Returns:
        x_t: [B, D, T] noisy motion at timestep t
    """
    if noise is None:
        noise = torch.randn_like(x_0)

    # Get schedule values for timestep t
    sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None]
    sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None]

    # x_t = sqrt(α̅_t) * x_0 + sqrt(1 - α̅_t) * ε
    x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise

    return x_t

3.2.4 Reverse Process (Denoising)

python
def p_sample(
    self,
    x_t: torch.Tensor,
    t: torch.Tensor,
    audio_cond: torch.Tensor,
    clip_denoised: bool = True,
) -> torch.Tensor:
    """
    Single reverse diffusion step: denoise x_t to x_{t-1}.

    p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_t² * I)

    Args:
        x_t: [B, D, T] noisy motion at timestep t
        t: [B] timesteps
        audio_cond: [B, 163, T] audio conditioning
        clip_denoised: whether to clip predicted x_0

    Returns:
        x_{t-1}: [B, D, T] denoised motion
    """
    B = x_t.shape[0]

    # Predict noise
    noise_pred = self.model(x_t, t, audio_cond)

    # Compute x_0 prediction from noise prediction
    # x_0 = (x_t - sqrt(1 - α̅_t) * ε_θ) / sqrt(α̅_t)
    sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None]
    sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None]

    x_0_pred = (x_t - sqrt_one_minus_alphas_cumprod_t * noise_pred) / sqrt_alphas_cumprod_t

    if clip_denoised:
        x_0_pred = x_0_pred.clamp(-self.config.clip_range, self.config.clip_range)

    # Compute posterior mean μ_θ(x_t, x_0)
    # μ_θ = (sqrt(α̅_{t-1}) * β_t * x_0 + sqrt(α_t) * (1 - α̅_{t-1}) * x_t) / (1 - α̅_t)
    posterior_mean = (
        self.posterior_mean_coef1[t][:, None, None] * x_0_pred +
        self.posterior_mean_coef2[t][:, None, None] * x_t
    )

    # Add noise (except at t=0)
    if t[0] > 0:
        noise = torch.randn_like(x_t)
        posterior_variance = self.posterior_variance[t][:, None, None]
        x_prev = posterior_mean + torch.sqrt(posterior_variance) * noise
    else:
        x_prev = posterior_mean

    return x_prev

3.3 MotionDecoder (2M Parameters)

3.3.1 Configuration

python
@dataclass
class DecoderConfig:
    """Motion decoder configuration."""

    # Dimensions
    input_dim: int = 25                # Raw diffusion output
    hidden_dim: int = 256              # Hidden layer size
    output_dim: int = 25               # Motion output

    # Architecture
    num_layers: int = 3                # Number of residual blocks
    dropout: float = 0.1               # Dropout rate
    use_layer_norm: bool = True        # Use LayerNorm vs BatchNorm
    activation: str = "gelu"           # "gelu", "relu", "silu"

    # Output activation per dimension
    output_activations: Dict[str, str] = field(default_factory=lambda: {
        "position": "none",            # Unbounded
        "velocity": "tanh_scaled",     # [-50, 50]
        "acceleration": "tanh_scaled", # [-100, 100]
        "quaternion": "normalize",     # Unit norm
        "angular_velocity": "tanh_scaled",  # [-10, 10]
        "phase": "sigmoid",            # [0, 1]
        "style": "none",               # Unbounded
    })

3.3.2 Architecture

python
class MotionDecoder(nn.Module):
    """
    Decoder that maps raw diffusion latent to semantically meaningful motion.

    The diffusion model learns a latent representation that doesn't directly
    correspond to motion semantics. This decoder:
    1. Applies learned transformations to each component
    2. Enforces output constraints (unit quaternions, phase range, etc.)
    3. Can be trained end-to-end with temporal coherence losses
    """

    def __init__(self, config: DecoderConfig):
        super().__init__()
        self.config = config

        # Input projection
        self.input_proj = nn.Linear(config.input_dim, config.hidden_dim)

        # Residual blocks
        self.blocks = nn.ModuleList([
            ResidualBlock(config.hidden_dim, config.dropout)
            for _ in range(config.num_layers)
        ])

        # Output heads (separate for each motion component)
        self.position_head = nn.Linear(config.hidden_dim, 3)
        self.velocity_head = nn.Linear(config.hidden_dim, 3)
        self.acceleration_head = nn.Linear(config.hidden_dim, 3)
        self.quaternion_head = nn.Linear(config.hidden_dim, 4)
        self.angular_velocity_head = nn.Linear(config.hidden_dim, 3)
        self.phase_head = nn.Linear(config.hidden_dim, 1)
        self.style_head = nn.Linear(config.hidden_dim, 8)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, T, 25] raw diffusion output

        Returns:
            [B, T, 25] decoded motion
        """
        B, T, D = x.shape

        # Input projection
        h = self.input_proj(x)  # [B, T, hidden_dim]

        # Residual blocks
        for block in self.blocks:
            h = block(h)

        # Decode each component
        position = self.position_head(h)                    # [B, T, 3]
        velocity = torch.tanh(self.velocity_head(h)) * 20   # [B, T, 3], scaled to [-20, 20]
        acceleration = self.acceleration_head(h)            # [B, T, 3]
        quaternion = self.quaternion_head(h)                # [B, T, 4]
        angular_velocity = self.angular_velocity_head(h)    # [B, T, 3]
        phase = torch.sigmoid(self.phase_head(h))           # [B, T, 1]
        style = self.style_head(h)                          # [B, T, 8]

        # Normalize quaternion to unit norm
        quaternion = F.normalize(quaternion, dim=-1)

        # Concatenate
        output = torch.cat([
            position,          # 0:3
            velocity,          # 3:6
            acceleration,      # 6:9
            quaternion,        # 9:13
            angular_velocity,  # 13:16
            phase,             # 16:17
            style,             # 17:25
        ], dim=-1)

        return output


class ResidualBlock(nn.Module):
    """Simple residual block for decoder."""

    def __init__(self, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.layers(x)

---

4. Audio Feature Extraction

4.1 Feature Specifications

python
AUDIO_FEATURE_CONFIG = {
    # Sampling
    "sample_rate": 44100,
    "hop_length": 1470,        # 44100 / 30 = 1470 (30 fps)
    "n_fft": 2048,

    # Mel spectrogram
    "n_mels": 128,
    "fmin": 20.0,
    "fmax": 8000.0,

    # MFCC
    "n_mfcc": 20,

    # Chroma
    "n_chroma": 12,

    # Output dimensions
    "total_dim": 163,          # 128 + 12 + 20 + 1 + 1 + 1
}

4.2 Feature Extraction Pipeline

python
def extract_audio_features(
    audio_path: str,
    target_fps: float = 30.0,
) -> Dict[str, np.ndarray]:
    """
    Extract audio features aligned to target frame rate.

    Args:
        audio_path: Path to audio file
        target_fps: Target frame rate for features

    Returns:
        Dictionary with all audio features
    """
    import librosa

    # Load audio
    y, sr = librosa.load(audio_path, sr=44100, mono=True)

    # Compute hop length for target fps
    hop_length = int(sr / target_fps)  # 1470 for 30fps

    # Mel spectrogram [128, T]
    mel = librosa.feature.melspectrogram(
        y=y, sr=sr, n_mels=128, hop_length=hop_length,
        fmin=20.0, fmax=8000.0
    )
    mel_db = librosa.power_to_db(mel, ref=np.max)
    mel_norm = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-8)

    # Chroma [12, T]
    chroma = librosa.feature.chroma_cqt(
        y=y, sr=sr, hop_length=hop_length, n_chroma=12
    )

    # MFCC [20, T]
    mfcc = librosa.feature.mfcc(
        y=y, sr=sr, n_mfcc=20, hop_length=hop_length
    )
    mfcc_norm = (mfcc - mfcc.mean(axis=1, keepdims=True)) / (mfcc.std(axis=1, keepdims=True) + 1e-8)

    # RMS energy [T]
    rms = librosa.feature.rms(y=y, hop_length=hop_length)[0]
    rms_norm = rms / (rms.max() + 1e-8)

    # Spectral centroid [T]
    centroid = librosa.feature.spectral_centroid(y=y, sr=sr, hop_length=hop_length)[0]
    centroid_norm = (centroid - centroid.min()) / (centroid.max() - centroid.min() + 1e-8)

    # Onset strength [T]
    onset = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
    onset_norm = onset / (onset.max() + 1e-8)

    # Beat detection
    tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, hop_length=hop_length)
    beat_times = librosa.frames_to_time(beat_frames, sr=sr, hop_length=hop_length)

    return {
        "mel": mel_norm.astype(np.float32),           # [128, T]
        "chroma": chroma.astype(np.float32),          # [12, T]
        "mfcc": mfcc_norm.astype(np.float32),         # [20, T]
        "rms": rms_norm.astype(np.float32),           # [T]
        "spectral_centroid": centroid_norm.astype(np.float32),  # [T]
        "onset_strength": onset_norm.astype(np.float32),  # [T]
        "tempo": float(tempo),
        "beat_times": beat_times.astype(np.float32),
    }

4.3 Feature Stacking

python
def stack_audio_features(features: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Stack audio features into single conditioning tensor.

    Args:
        features: Dictionary of audio features

    Returns:
        [163, T] stacked features
    """
    T = features["mel"].shape[1]

    # Expand scalar features to match time dimension
    rms = features["rms"][np.newaxis, :]                    # [1, T]
    centroid = features["spectral_centroid"][np.newaxis, :] # [1, T]
    onset = features["onset_strength"][np.newaxis, :]       # [1, T]

    # Stack all features
    stacked = np.concatenate([
        features["mel"],      # [128, T]
        features["chroma"],   # [12, T]
        features["mfcc"],     # [20, T]
        rms,                  # [1, T]
        centroid,             # [1, T]
        onset,                # [1, T]
    ], axis=0)  # [163, T]

    return stacked.astype(np.float32)

---

5. Diffusion Process Mathematics

5.1 Forward Process

The forward diffusion process gradually adds Gaussian noise to data:

q(x_t | x_0) = N(x_t; √(ᾱ_t) x_0, (1 - ᾱ_t) I)

Where:
- α_t = 1 - β_t (signal retention)
- ᾱ_t = ∏_{s=1}^{t} α_s (cumulative signal retention)
- β_t = noise schedule at timestep t

Closed-form sampling:
x_t = √(ᾱ_t) x_0 + √(1 - ᾱ_t) ε,  ε ~ N(0, I)

5.2 Reverse Process

The reverse process learns to denoise:

p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), Σ_θ(x_t, t))

Posterior mean (when predicting ε):
μ_θ(x_t, t) = 1/√(α_t) (x_t - β_t/√(1 - ᾱ_t) ε_θ(x_t, t))

Posterior variance (fixed):
Σ_t = β_t * (1 - ᾱ_{t-1}) / (1 - ᾱ_t)  or  Σ_t = β_t

5.3 Training Objective

L = E_{t, x_0, ε} [ ||ε - ε_θ(x_t, t)||² ]

Where:
- t ~ Uniform(1, T)
- x_0 ~ q(x_0) (data distribution)
- ε ~ N(0, I)
- x_t = √(ᾱ_t) x_0 + √(1 - ᾱ_t) ε

5.4 DDIM Sampling

Deterministic sampling with fewer steps:

x_{t-1} = √(ᾱ_{t-1}) f_θ(x_t, t) + √(1 - ᾱ_{t-1} - σ_t²) ε_θ(x_t, t) + σ_t ε_t

Where:
- f_θ(x_t, t) = (x_t - √(1 - ᾱ_t) ε_θ(x_t, t)) / √(ᾱ_t)  (predicted x_0)
- σ_t = η √((1 - ᾱ_{t-1})/(1 - ᾱ_t)) √(1 - ᾱ_t/ᾱ_{t-1})
- η = 0: deterministic, η = 1: DDPM equivalent

---

6. Training Pipeline

6.1 Dataset Implementation

python
class MotionPhraseDataset(torch.utils.data.Dataset):
    """
    Dataset for loading motion-audio pairs from GCS.

    Features:
    - Lazy loading from Google Cloud Storage
    - In-memory caching for repeated access
    - Automatic fallback for corrupted samples
    - Frame alignment and padding
    """

    def __init__(
        self,
        gcs_bucket: str = "cc-music-library",
        gcs_prefix: str = "motionphrase",
        num_phrases: Optional[int] = None,
        target_frames: int = 120,
        cache_in_memory: bool = True,
    ):
        self.gcs_bucket = gcs_bucket
        self.gcs_prefix = gcs_prefix
        self.target_frames = target_frames
        self.cache_in_memory = cache_in_memory

        # List available phrases
        self.phrase_ids = self._list_phrases()
        if num_phrases is not None:
            self.phrase_ids = self.phrase_ids[:num_phrases]

        # Cache
        self._cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}

    def __len__(self) -> int:
        return len(self.phrase_ids)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
            motion: [25, T] motion trajectory
            audio_cond: [163, T] audio conditioning
        """
        if idx in self._cache:
            return self._cache[idx]

        try:
            phrase_id = self.phrase_ids[idx]

            # Load audio features
            audio_features = self._load_audio_features(phrase_id)
            audio_cond = self._stack_features(audio_features)

            # Load motion
            motion = self._load_motion(phrase_id)

            # Align to target frames
            motion, audio_cond = self._align_frames(motion, audio_cond)

            # Convert to tensors
            motion = torch.from_numpy(motion.T).float()      # [25, T]
            audio_cond = torch.from_numpy(audio_cond).float() # [163, T]

            if self.cache_in_memory:
                self._cache[idx] = (motion, audio_cond)

            return motion, audio_cond

        except Exception as e:
            logger.warning(f"Failed to load {idx}, using fallback: {e}")
            return self._get_fallback(idx)

    def _align_frames(
        self,
        motion: np.ndarray,
        audio_cond: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Align motion and audio to same length."""
        T_motion = motion.shape[0]
        T_audio = audio_cond.shape[1]
        T = min(T_motion, T_audio, self.target_frames)

        motion = motion[:T]
        audio_cond = audio_cond[:, :T]

        # Pad if needed
        if T < self.target_frames:
            pad_motion = np.zeros((self.target_frames - T, 25), dtype=np.float32)
            motion = np.concatenate([motion, pad_motion], axis=0)

            pad_audio = np.zeros((163, self.target_frames - T), dtype=np.float32)
            audio_cond = np.concatenate([audio_cond, pad_audio], axis=1)

        return motion, audio_cond

6.2 Training Loop

python
class Trainer:
    """Main training orchestrator."""

    def __init__(
        self,
        model: GaussianDiffusion,
        optimizer: torch.optim.Optimizer,
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        config: Optional[TrainingConfig] = None,
    ):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.config = config or TrainingConfig()

        # EMA model
        self.ema = ExponentialMovingAverage(
            model.parameters(),
            decay=self.config.ema_decay,
        )

        # Metrics
        self.train_losses = []
        self.val_losses = []

    def train_epoch(
        self,
        dataloader: torch.utils.data.DataLoader,
        epoch: int,
    ) -> Dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        epoch_loss = 0.0
        num_batches = 0

        pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
        for motion, audio_cond in pbar:
            motion = motion.to(self.config.device)
            audio_cond = audio_cond.to(self.config.device)

            # Sample timesteps
            B = motion.shape[0]
            t = torch.randint(
                0, self.model.config.timesteps, (B,),
                device=self.config.device
            )

            # Sample noise
            noise = torch.randn_like(motion)

            # Forward diffusion
            x_t = self.model.q_sample(motion, t, noise)

            # Predict noise
            noise_pred = self.model.model(x_t, t, audio_cond)

            # Compute loss
            loss = F.mse_loss(noise_pred, noise)

            # Backprop
            self.optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            if self.config.gradient_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.config.gradient_clip,
                )

            self.optimizer.step()

            # Update EMA
            self.ema.update()

            epoch_loss += loss.item()
            num_batches += 1
            pbar.set_postfix(loss=loss.item())

        if self.scheduler is not None:
            self.scheduler.step()

        return {"loss": epoch_loss / num_batches}

    def validate(
        self,
        dataloader: torch.utils.data.DataLoader,
    ) -> Dict[str, float]:
        """Validate model."""
        self.model.eval()
        val_loss = 0.0
        num_batches = 0

        with torch.no_grad():
            for motion, audio_cond in dataloader:
                motion = motion.to(self.config.device)
                audio_cond = audio_cond.to(self.config.device)

                B = motion.shape[0]
                t = torch.randint(0, self.model.config.timesteps, (B,), device=self.config.device)
                noise = torch.randn_like(motion)

                x_t = self.model.q_sample(motion, t, noise)
                noise_pred = self.model.model(x_t, t, audio_cond)

                loss = F.mse_loss(noise_pred, noise)
                val_loss += loss.item()
                num_batches += 1

        return {"val_loss": val_loss / num_batches}

6.3 Training Configuration

python
@dataclass
class TrainingConfig:
    """Training configuration."""

    # Optimization
    learning_rate: float = 1e-4
    weight_decay: float = 1e-4
    betas: Tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-8
    gradient_clip: float = 1.0

    # Schedule
    num_epochs: int = 100
    warmup_steps: int = 1000
    lr_schedule: str = "cosine"  # "constant", "cosine", "linear"

    # Batch
    batch_size: int = 32
    num_workers: int = 4

    # EMA
    ema_decay: float = 0.9999
    ema_update_every: int = 10

    # Checkpointing
    save_every: int = 10
    validate_every: int = 5
    checkpoint_dir: Path = Path("checkpoints")

    # Hardware
    device: str = "cuda"
    mixed_precision: bool = True

---

7. End-to-End Fine-tuning

7.1 Motivation and Theory

The base diffusion model learns a latent representation optimized for noise prediction, not motion semantics. E2E fine-tuning addresses this by:

1. Joint Training: Backpropagating through both diffusion and decoder
2. Temporal Losses: Explicitly penalizing physical inconsistencies
3. Differential LR: Lower LR for pre-trained diffusion, higher for decoder

┌─────────────────────────────────────────────────────────────────────────────┐
│                         E2E FINE-TUNING ARCHITECTURE                        │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  Ground Truth Motion ────────────────────────────────────────────────────┐  │
│       [B, 25, T]                                                         │  │
│           │                                                              │  │
│           ▼                                                              │  │
│  ┌─────────────────────┐                                                 │  │
│  │   Add Noise (t)     │                                                 │  │
│  │   x_t = √ᾱ_t·x_0    │                                                 │  │
│  │       + √(1-ᾱ_t)·ε  │                                                 │  │
│  └──────────┬──────────┘                                                 │  │
│             │                                                            │  │
│             ▼                                                            │  │
│  ┌──────────────────────────────────────────┐                           │  │
│  │           UNet1D (116M)                   │                           │  │
│  │   ε_θ(x_t, t, audio_cond) → noise_pred   │◄─── Audio Conditioning    │  │
│  └──────────────────┬───────────────────────┘                           │  │
│                     │                                                    │  │
│                     ▼                                                    │  │
│  ┌─────────────────────────────────────────────────────────────────────┐│  │
│  │                    LOSS 1: Diffusion Loss                           ││  │
│  │                    L_diff = MSE(noise_pred, ε)                      ││  │
│  └─────────────────────────────────────────────────────────────────────┘│  │
│                     │                                                    │  │
│                     ▼                                                    │  │
│  ┌──────────────────────────────────────────┐                           │  │
│  │        DDIM Sample (5 steps)             │                           │  │
│  │   WITH GRADIENTS (differentiable)        │◄─── Audio Conditioning    │  │
│  │   x_0_pred = DDIM(noise, audio_cond)     │                           │  │
│  └──────────────────┬───────────────────────┘                           │  │
│                     │                                                    │  │
│                     ▼                                                    │  │
│  ┌──────────────────────────────────────────┐                           │  │
│  │        MotionDecoder (2M)                │                           │  │
│  │   decoded = Decoder(x_0_pred)            │                           │  │
│  └──────────────────┬───────────────────────┘                           │  │
│                     │                                                    │  │
│        ┌────────────┴────────────────────────────────────┐              │  │
│        │                                                  │              │  │
│        ▼                                                  ▼              ▼  │
│  ┌───────────────────────────┐    ┌──────────────────────────────────────┐│
│  │ LOSS 2: Decoder Recon     │    │ LOSS 3: Temporal Coherence           ││
│  │ L_pos = MSE(pos, gt_pos)  │    │ L_vel = MSE(vel, d(pos)/dt)          ││
│  │ L_quat = MSE(quat, gt_q)  │    │ L_acc = MSE(acc, d(vel)/dt)          ││
│  └───────────────────────────┘    │ L_jerk = mean(|d(acc)/dt|²)          ││
│                                   └──────────────────────────────────────┘│
│                                                                            │
│  ┌────────────────────────────────────────────────────────────────────────┐│
│  │                        TOTAL LOSS                                      ││
│  │  L = w_diff·L_diff + w_dec·(L_pos + L_quat) + w_temp·(L_vel + L_acc)  ││
│  └────────────────────────────────────────────────────────────────────────┘│
│                                                                              │
│                                   │                                          │
│                                   ▼                                          │
│  ┌────────────────────────────────────────────────────────────────────────┐ │
│  │                         BACKPROPAGATION                                │ │
│  │   Gradients flow through: Decoder ← DDIM ← UNet ← Input               │ │
│  │   Differential LR: UNet (1e-6), Decoder (5e-5)                        │ │
│  └────────────────────────────────────────────────────────────────────────┘ │
│                                                                              │
└──────────────────────────────────────────────────────────────────────────────┘

7.2 Temporal Coherence Loss Implementation

python
@dataclass
class TemporalLossConfig:
    """Configuration for temporal coherence losses."""

    # Loss weights (aggressive for coherence)
    velocity_weight: float = 10.0
    acceleration_weight: float = 5.0
    jerk_weight: float = 1.0
    quat_delta_weight: float = 2.0
    phase_monotonic_weight: float = 1.0

    # Normalization thresholds
    target_jerk: float = 50000.0
    target_vel_coherence: float = 5.0
    target_accel_coherence: float = 100.0
    target_quat_delta: float = 0.95


class TemporalCoherenceLoss(nn.Module):
    """
    Computes temporal coherence losses for motion.

    Ensures that:
    1. Explicit velocity matches position derivative
    2. Explicit acceleration matches velocity derivative
    3. Jerk (acceleration change) is bounded
    4. Quaternion changes are smooth
    5. Phase is monotonically increasing
    """

    def __init__(self, config: TemporalLossConfig):
        super().__init__()
        self.config = config
        self.fps = 30.0

    def forward(
        self,
        decoded: torch.Tensor,  # [B, T, 25]
    ) -> Dict[str, torch.Tensor]:
        """
        Compute all temporal coherence losses.

        Returns dictionary with individual loss values.
        """
        losses = {}

        # Extract components
        pos = decoded[:, :, 0:3]
        vel = decoded[:, :, 3:6]
        accel = decoded[:, :, 6:9]
        quat = decoded[:, :, 9:13]
        phase = decoded[:, :, 16]

        # 1. JERK LOSS
        # Jerk = d(acceleration)/dt, should be bounded for smooth motion
        jerk = (accel[:, 1:] - accel[:, :-1]) * self.fps
        jerk_magnitude = jerk.pow(2).sum(dim=-1).sqrt()  # L2 norm per frame

        # Normalize by target and compute MSE
        jerk_normalized = jerk_magnitude / self.config.target_jerk
        jerk_loss = jerk_normalized.pow(2).mean()
        losses['jerk'] = jerk_loss * self.config.jerk_weight

        # 2. VELOCITY COHERENCE
        # Explicit velocity should match position derivative
        derived_velocity = (pos[:, 1:] - pos[:, :-1]) * self.fps
        vel_error = (vel[:, :-1] - derived_velocity).pow(2).mean()
        vel_normalized = vel_error / (self.config.target_vel_coherence ** 2)
        losses['velocity_coherence'] = vel_normalized * self.config.velocity_weight

        # 3. ACCELERATION COHERENCE
        # Explicit acceleration should match velocity derivative
        derived_accel = (vel[:, 1:] - vel[:, :-1]) * self.fps
        accel_error = (accel[:, :-1] - derived_accel).pow(2).mean()
        accel_normalized = accel_error / (self.config.target_accel_coherence ** 2)
        losses['accel_coherence'] = accel_normalized * self.config.acceleration_weight

        # 4. QUATERNION DELTA
        # Quaternion should change smoothly (high dot product between consecutive)
        quat_dot = (quat[:, :-1] * quat[:, 1:]).sum(dim=-1).abs()
        quat_delta_loss = F.relu(self.config.target_quat_delta - quat_dot).mean()
        losses['quat_delta'] = quat_delta_loss * self.config.quat_delta_weight

        # 5. PHASE MONOTONICITY
        # Phase should increase (or wrap around at bar boundaries)
        phase_delta = phase[:, 1:] - phase[:, :-1]
        # Allow wrapping: if delta < -0.5, it's a wrap from 1→0
        phase_violations = F.relu(-phase_delta - 0.5).sum(dim=-1)
        phase_loss = phase_violations.mean()
        losses['phase_monotonic'] = phase_loss * self.config.phase_monotonic_weight

        return losses

7.3 DDIM with Gradients

python
def _ddim_sample_with_grad(
    self,
    shape: Tuple[int, ...],
    audio_cond: torch.Tensor,
    num_steps: int = 5,
) -> torch.Tensor:
    """
    DDIM sampling with gradient flow for E2E training.

    Unlike standard DDIM which uses torch.no_grad(), this version
    maintains gradients for backpropagation through the sampling process.

    Args:
        shape: (B, D, T) output shape
        audio_cond: [B, 163, T] audio conditioning
        num_steps: Number of DDIM steps (fewer for training efficiency)

    Returns:
        x_0: [B, D, T] generated samples (with gradients)
    """
    device = audio_cond.device
    B = shape[0]

    # Create timestep schedule (evenly spaced)
    timesteps = torch.linspace(
        self.diffusion.config.timesteps - 1, 0, num_steps,
        device=device
    ).long()

    # Start from pure noise
    x = torch.randn(shape, device=device, requires_grad=True)

    # DDIM sampling loop
    for i in range(len(timesteps)):
        t = timesteps[i].expand(B)

        # Predict noise (with gradients)
        noise_pred = self.diffusion.model(x, t, audio_cond)

        # Get schedule values
        alpha_cumprod = self.diffusion.alphas_cumprod[t][:, None, None]
        alpha_cumprod_prev = (
            self.diffusion.alphas_cumprod[timesteps[i+1]]
            if i < len(timesteps) - 1
            else torch.ones_like(alpha_cumprod)
        )

        # Predict x_0
        x_0_pred = (x - torch.sqrt(1 - alpha_cumprod) * noise_pred) / torch.sqrt(alpha_cumprod)

        # DDIM update (deterministic, eta=0)
        x = (
            torch.sqrt(alpha_cumprod_prev) * x_0_pred +
            torch.sqrt(1 - alpha_cumprod_prev) * noise_pred
        )

    return x

7.4 E2E Training Script Usage

bash
# Basic E2E fine-tuning
python -m cc_motiongen.scripts.finetune_e2e \
    --epochs 50 \
    --num-phrases 200 \
    --batch-size 8

# Full configuration
python -m cc_motiongen.scripts.finetune_e2e \
    --epochs 50 \
    --num-phrases 1000 \
    --batch-size 8 \
    --lr-diffusion 1e-6 \
    --lr-decoder 5e-5 \
    --diffusion-weight 0.1 \
    --decoder-weight 1.0 \
    --temporal-weight 10.0 \
    --patience 10 \
    --checkpoint outputs/cc_motiongen/checkpoints/best.pt \
    --decoder-checkpoint outputs/cc_motiongen/decoder/decoder_real_best.pt

7.5 Training Results Summary

┌─────────────────────────────────────────────────────────────────────────────┐
│                      E2E FINE-TUNING RESULTS (50 EPOCHS)                     │
├──────────────────────┬──────────────┬──────────────┬────────────────────────┤
│ Metric               │ Epoch 1      │ Epoch 50     │ Improvement            │
├──────────────────────┼──────────────┼──────────────┼────────────────────────┤
│ Total Loss           │ 4548.15      │ 57.75        │ -98.7%                 │
│ Velocity Coherence   │ 22.49        │ 0.0000       │ -100.0%                │
│ Accel Coherence      │ 187.04       │ 5.70         │ -96.9%                 │
│ Jerk                 │ 0.99         │ 0.023        │ -97.7%                 │
│ Quaternion Delta     │ 0.21         │ 0.002        │ -99.0%                 │
│ Decoder Position     │ 1.70         │ 0.41         │ -76.0%                 │
│ Decoder Quaternion   │ 0.15         │ 0.046        │ -69.3%                 │
├──────────────────────┼──────────────┼──────────────┼────────────────────────┤
│ Quaternion Pass Rate │ 10%          │ 100%         │ +90 points             │
└──────────────────────┴──────────────┴──────────────┴────────────────────────┘

---

8. Inference & Sampling

8.1 DDIM Sampler

python
class DDIMSampler:
    """
    Denoising Diffusion Implicit Models sampler.

    Enables faster sampling by using deterministic updates
    with fewer timesteps (e.g., 20 vs 1000).
    """

    def __init__(
        self,
        diffusion: GaussianDiffusion,
        num_steps: int = 20,
        eta: float = 0.0,
    ):
        self.diffusion = diffusion
        self.num_steps = num_steps
        self.eta = eta

    @torch.no_grad()
    def sample(
        self,
        shape: Tuple[int, ...],
        audio_cond: torch.Tensor,
        progress: bool = True,
    ) -> torch.Tensor:
        """
        Generate samples using DDIM.

        Args:
            shape: (B, D, T) output shape
            audio_cond: [B, 163, T] audio conditioning
            progress: Show progress bar

        Returns:
            samples: [B, D, T] generated motion
        """
        device = audio_cond.device
        B = shape[0]

        # Create timestep schedule
        timesteps = self._get_timesteps()

        # Start from noise
        x = torch.randn(shape, device=device)

        # Sampling loop
        iterator = tqdm(timesteps, desc="DDIM Sampling") if progress else timesteps
        for i, t in enumerate(iterator):
            t_batch = torch.full((B,), t, device=device, dtype=torch.long)

            # Get next timestep
            t_next = timesteps[i + 1] if i < len(timesteps) - 1 else 0

            # DDIM step
            x = self._ddim_step(x, t_batch, t_next, audio_cond)

        return x

    def _get_timesteps(self) -> List[int]:
        """Generate evenly spaced timesteps."""
        total = self.diffusion.config.timesteps
        step_size = total // self.num_steps
        timesteps = list(range(total - 1, -1, -step_size))[:self.num_steps]
        return timesteps

    def _ddim_step(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        t_next: int,
        audio_cond: torch.Tensor,
    ) -> torch.Tensor:
        """Single DDIM denoising step."""
        # Predict noise
        noise_pred = self.diffusion.model(x, t, audio_cond)

        # Get alpha values
        alpha = self.diffusion.alphas_cumprod[t[0]]
        alpha_next = self.diffusion.alphas_cumprod[t_next] if t_next >= 0 else torch.tensor(1.0)

        # Predict x_0
        x_0 = (x - torch.sqrt(1 - alpha) * noise_pred) / torch.sqrt(alpha)

        # Clip if configured
        if self.diffusion.config.clip_denoised:
            x_0 = x_0.clamp(-1, 1)

        # Compute sigma for stochasticity
        sigma = self.eta * torch.sqrt((1 - alpha_next) / (1 - alpha)) * torch.sqrt(1 - alpha / alpha_next)

        # Direction pointing to x_t
        dir_xt = torch.sqrt(1 - alpha_next - sigma ** 2) * noise_pred

        # Random noise (if eta > 0)
        noise = torch.randn_like(x) if sigma > 0 else 0

        # DDIM update
        x_next = torch.sqrt(alpha_next) * x_0 + dir_xt + sigma * noise

        return x_next

8.2 Full Inference Pipeline

python
class MotionGenerator:
    """
    Complete inference pipeline for motion generation.
    """

    def __init__(
        self,
        diffusion_checkpoint: str,
        decoder_checkpoint: str,
        device: str = "cuda",
    ):
        self.device = device

        # Load models
        self.diffusion = self._load_diffusion(diffusion_checkpoint)
        self.decoder = self._load_decoder(decoder_checkpoint)
        self.postprocessor = MotionPostProcessorTorch()

        # Sampler
        self.sampler = DDIMSampler(self.diffusion, num_steps=20)

    def generate(
        self,
        audio_features: Union[np.ndarray, torch.Tensor],
        num_samples: int = 1,
        guidance_scale: float = 1.0,
    ) -> np.ndarray:
        """
        Generate motion from audio features.

        Args:
            audio_features: [163, T] audio conditioning
            num_samples: Number of samples to generate
            guidance_scale: Classifier-free guidance scale (1.0 = no guidance)

        Returns:
            motion: [num_samples, T, 25] generated motion
        """
        # Prepare conditioning
        if isinstance(audio_features, np.ndarray):
            audio_features = torch.from_numpy(audio_features).float()

        audio_cond = audio_features.unsqueeze(0).expand(num_samples, -1, -1)
        audio_cond = audio_cond.to(self.device)

        T = audio_cond.shape[2]

        # Sample from diffusion
        with torch.no_grad():
            # Generate raw samples
            raw_samples = self.sampler.sample(
                shape=(num_samples, 25, T),
                audio_cond=audio_cond,
                progress=True,
            )

            # Decode
            samples_td = raw_samples.permute(0, 2, 1)  # [B, T, D]
            decoded = self.decoder(samples_td)

            # Post-process
            decoded_dt = decoded.permute(0, 2, 1)  # [B, D, T]
            processed = self.postprocessor.process(decoded_dt)

            # Convert to numpy
            motion = processed.permute(0, 2, 1).cpu().numpy()

        return motion

    def generate_from_audio(
        self,
        audio_path: str,
        **kwargs,
    ) -> np.ndarray:
        """Generate motion from audio file."""
        # Extract features
        features = extract_audio_features(audio_path)
        audio_cond = stack_audio_features(features)

        return self.generate(audio_cond, **kwargs)

8.3 Classifier-Free Guidance

python
def sample_with_cfg(
    self,
    shape: Tuple[int, ...],
    audio_cond: torch.Tensor,
    guidance_scale: float = 3.0,
    num_steps: int = 20,
) -> torch.Tensor:
    """
    Sample with classifier-free guidance.

    Uses unconditional + conditional prediction:
    ε_guided = ε_uncond + scale * (ε_cond - ε_uncond)

    Args:
        shape: Output shape
        audio_cond: Audio conditioning
        guidance_scale: CFG scale (1.0 = no guidance)
        num_steps: DDIM steps

    Returns:
        Guided samples
    """
    device = audio_cond.device
    B = shape[0]

    # Create unconditional conditioning (zeros or learned null embedding)
    uncond = torch.zeros_like(audio_cond)

    timesteps = self._get_timesteps(num_steps)
    x = torch.randn(shape, device=device)

    for t in timesteps:
        t_batch = torch.full((B,), t, device=device, dtype=torch.long)

        # Conditional prediction
        noise_cond = self.diffusion.model(x, t_batch, audio_cond)

        # Unconditional prediction
        noise_uncond = self.diffusion.model(x, t_batch, uncond)

        # Guided prediction
        noise_guided = noise_uncond + guidance_scale * (noise_cond - noise_uncond)

        # DDIM step with guided noise
        x = self._ddim_step_with_noise(x, t_batch, noise_guided)

    return x

---

9. Post-Processing Pipeline

9.1 Configuration

python
@dataclass
class PostProcessConfig:
    """Post-processing configuration."""

    # Quaternion normalization
    normalize_quaternions: bool = True
    quat_eps: float = 1e-8

    # Quaternion smoothing
    smooth_quaternions: bool = False
    quat_smooth_window: int = 5

    # Phase handling
    clamp_phase: bool = True
    phase_min: float = 0.0
    phase_max: float = 1.0
    enforce_phase_monotonic: bool = False

    # Value clamping
    clamp_values: bool = True
    max_position: float = 100.0
    max_velocity: float = 50.0
    max_acceleration: float = 100.0
    max_angular_velocity: float = 10.0

    # Coherence enforcement (post-hoc)
    enforce_velocity_coherence: bool = False
    enforce_accel_coherence: bool = False

    # Temporal smoothing
    apply_smoothing: bool = False
    smoothing_kernel_size: int = 3

9.2 Processing Steps

python
class MotionPostProcessor:
    """
    Post-processes diffusion model outputs.

    Processing order:
    1. Value clamping (bound extreme values)
    2. Quaternion normalization (unit norm)
    3. Quaternion smoothing (optional)
    4. Phase clamping [0, 1]
    5. Coherence enforcement (optional)
    6. Temporal smoothing (optional)
    """

    def process(self, motion: np.ndarray) -> np.ndarray:
        """Process motion trajectory."""
        motion = motion.copy()

        # 1. Value clamping
        if self.config.clamp_values:
            motion[:, 0:3] = np.clip(motion[:, 0:3], -100, 100)    # position
            motion[:, 3:6] = np.clip(motion[:, 3:6], -50, 50)      # velocity
            motion[:, 6:9] = np.clip(motion[:, 6:9], -100, 100)    # acceleration
            motion[:, 13:16] = np.clip(motion[:, 13:16], -10, 10)  # angular velocity

        # 2. Quaternion normalization
        if self.config.normalize_quaternions:
            quats = motion[:, 9:13]
            norms = np.linalg.norm(quats, axis=1, keepdims=True)
            norms = np.maximum(norms, self.config.quat_eps)
            motion[:, 9:13] = quats / norms

            # Ensure w >= 0 (hemisphere consistency)
            w_negative = motion[:, 9] < 0
            motion[w_negative, 9:13] *= -1

        # 3. Phase clamping
        if self.config.clamp_phase:
            motion[:, 16] = np.clip(motion[:, 16], 0, 1)

        # 4. Optional coherence enforcement
        if self.config.enforce_velocity_coherence:
            pos = motion[:, 0:3]
            vel_derived = np.zeros_like(pos)
            vel_derived[1:] = (pos[1:] - pos[:-1]) * 30.0
            vel_derived[0] = vel_derived[1]
            motion[:, 3:6] = 0.5 * motion[:, 3:6] + 0.5 * vel_derived

        return motion

---

10. Validation & Sanity Checks

10.1 Sanity Check Suite

python
@dataclass
class SanityThresholds:
    """Thresholds for sanity checks."""

    # Quaternion
    quaternion_norm_min: float = 0.95
    quaternion_norm_max: float = 1.05

    # Phase
    phase_min: float = 0.0
    phase_max: float = 1.0

    # Temporal coherence
    max_jerk: float = 50000.0
    max_velocity_error: float = 5.0
    max_acceleration_error: float = 100.0


@dataclass
class CheckResult:
    """Result of a single sanity check."""
    name: str
    passed: bool
    value: float
    threshold: float
    message: str


class SanityChecker:
    """
    Comprehensive sanity checking for motion trajectories.
    """

    def __init__(self, thresholds: Optional[SanityThresholds] = None):
        self.thresholds = thresholds or SanityThresholds()
        self.fps = 30.0

    def check_all(
        self,
        trajectory: MotionTrajectory,
    ) -> Tuple[bool, Dict[str, CheckResult]]:
        """
        Run all sanity checks on a trajectory.

        Returns:
            passed: Whether all checks passed
            results: Dictionary of individual check results
        """
        motion = trajectory.data  # [T, 25]
        results = {}

        # 1. Quaternion normalization
        results['quaternion'] = self._check_quaternion(motion)

        # 2. Phase range
        results['phase'] = self._check_phase(motion)

        # 3. Jerk bound
        results['jerk'] = self._check_jerk(motion)

        # 4. Velocity coherence
        results['velocity_coherence'] = self._check_velocity_coherence(motion)

        # 5. Acceleration coherence
        results['accel_coherence'] = self._check_accel_coherence(motion)

        passed = all(r.passed for r in results.values())
        return passed, results

    def _check_quaternion(self, motion: np.ndarray) -> CheckResult:
        """Check quaternion normalization."""
        quats = motion[:, 9:13]
        norms = np.linalg.norm(quats, axis=1)
        mean_norm = np.mean(norms)

        passed = (
            self.thresholds.quaternion_norm_min <= mean_norm <=
            self.thresholds.quaternion_norm_max
        )

        return CheckResult(
            name="quaternion",
            passed=passed,
            value=mean_norm,
            threshold=1.0,
            message=f"Mean quaternion norm: {mean_norm:.4f}",
        )

    def _check_phase(self, motion: np.ndarray) -> CheckResult:
        """Check phase range [0, 1]."""
        phase = motion[:, 16]
        min_phase = phase.min()
        max_phase = phase.max()

        passed = (
            min_phase >= self.thresholds.phase_min and
            max_phase <= self.thresholds.phase_max
        )

        return CheckResult(
            name="phase",
            passed=passed,
            value=(min_phase, max_phase),
            threshold=(0.0, 1.0),
            message=f"Phase range: [{min_phase:.4f}, {max_phase:.4f}]",
        )

    def _check_jerk(self, motion: np.ndarray) -> CheckResult:
        """Check jerk (acceleration smoothness)."""
        accel = motion[:, 6:9]
        jerk = np.diff(accel, axis=0) * self.fps
        jerk_magnitude = np.linalg.norm(jerk, axis=1)
        mean_jerk = np.mean(jerk_magnitude)

        passed = mean_jerk < self.thresholds.max_jerk

        return CheckResult(
            name="jerk",
            passed=passed,
            value=mean_jerk,
            threshold=self.thresholds.max_jerk,
            message=f"Mean jerk: {mean_jerk:.2f} (threshold: {self.thresholds.max_jerk})",
        )

    def _check_velocity_coherence(self, motion: np.ndarray) -> CheckResult:
        """Check velocity matches position derivative."""
        pos = motion[:, 0:3]
        vel = motion[:, 3:6]

        derived_vel = np.diff(pos, axis=0) * self.fps
        vel_error = np.abs(vel[:-1] - derived_vel).mean()

        passed = vel_error < self.thresholds.max_velocity_error

        return CheckResult(
            name="velocity_coherence",
            passed=passed,
            value=vel_error,
            threshold=self.thresholds.max_velocity_error,
            message=f"Velocity coherence error: {vel_error:.4f}",
        )

    def _check_accel_coherence(self, motion: np.ndarray) -> CheckResult:
        """Check acceleration matches velocity derivative."""
        vel = motion[:, 3:6]
        accel = motion[:, 6:9]

        derived_accel = np.diff(vel, axis=0) * self.fps
        accel_error = np.abs(accel[:-1] - derived_accel).mean()

        passed = accel_error < self.thresholds.max_acceleration_error

        return CheckResult(
            name="accel_coherence",
            passed=passed,
            value=accel_error,
            threshold=self.thresholds.max_acceleration_error,
            message=f"Acceleration coherence error: {accel_error:.4f}",
        )

10.2 Running Evaluation

bash
# Quick evaluation (20 samples, 10 DDIM steps)
python -m cc_motiongen.scripts.evaluate_sanity --quick

# Full evaluation
python -m cc_motiongen.scripts.evaluate_sanity \
    --num-samples 100 \
    --num-frames 120 \
    --num-steps 20

# Evaluate specific checkpoint
python -m cc_motiongen.scripts.evaluate_sanity \
    --diffusion-checkpoint outputs/cc_motiongen/checkpoints/e2e_best.pt \
    --decoder-checkpoint outputs/cc_motiongen/decoder/e2e_decoder_best.pt \
    --output-dir results/evaluation/

---

11. Configuration Reference

11.1 Main Configuration

python
# cc_motiongen/config.py

@dataclass
class MotionGenConfig:
    """Master configuration for CC-MotionGen."""

    # Model configs
    unet: UNetConfig = field(default_factory=UNetConfig)
    diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
    decoder: DecoderConfig = field(default_factory=DecoderConfig)

    # Training
    training: TrainingConfig = field(default_factory=TrainingConfig)

    # Data
    data_dir: Path = Path("gs://cc-music-library/motionphrase")
    output_dir: Path = Path("outputs/cc_motiongen")
    num_phrases: Optional[int] = None

    # Hardware
    device: str = "auto"
    seed: int = 42

    def resolved_device(self) -> str:
        """Auto-detect best available device."""
        if self.device != "auto":
            return self.device

        if torch.cuda.is_available():
            return "cuda"
        elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            return "mps"
        return "cpu"


def get_config() -> MotionGenConfig:
    """Get default configuration."""
    return MotionGenConfig()

11.2 Environment Variables

bash
# Required for GCS access
export GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account.json

# Optional configuration
export MOTIONGEN_DEVICE=cuda
export MOTIONGEN_OUTPUT_DIR=outputs/cc_motiongen
export MOTIONGEN_DATA_DIR=gs://cc-music-library/motionphrase

# Checkpoint paths
export MOTIONGEN_DIFFUSION_CHECKPOINT=outputs/cc_motiongen/checkpoints/e2e_best.pt
export MOTIONGEN_DECODER_CHECKPOINT=outputs/cc_motiongen/decoder/e2e_decoder_best.pt

---

12. API Reference

12.1 Core Classes

python
# Model classes
from cc_motiongen.model.unet import UNet1D, UNetConfig
from cc_motiongen.model.diffusion import GaussianDiffusion, DiffusionConfig
from cc_motiongen.model.decoder import MotionDecoder, DecoderConfig

# Training
from cc_motiongen.training.trainer import Trainer
from cc_motiongen.training.losses import TemporalCoherenceLoss

# Inference
from cc_motiongen.inference.sampler import DDIMSampler
from cc_motiongen.inference.postprocess import MotionPostProcessor, PostProcessConfig

# Validation
from cc_motiongen.validation.sanity import SanityChecker, SanityThresholds

# Data
from cc_motiongen.data.dataset import MotionPhraseDataset

# Types
from cc_motiongen.types import MotionTrajectory, AudioCondition

12.2 Quick Start

python
import torch
from cc_motiongen.model.diffusion import GaussianDiffusion
from cc_motiongen.model.decoder import MotionDecoder
from cc_motiongen.inference.sampler import DDIMSampler
from cc_motiongen.inference.postprocess import MotionPostProcessorTorch

# Load models
diffusion = GaussianDiffusion.from_checkpoint("checkpoints/e2e_best.pt")
decoder = MotionDecoder.from_checkpoint("decoder/e2e_decoder_best.pt")

# Create sampler and post-processor
sampler = DDIMSampler(diffusion, num_steps=20)
postprocessor = MotionPostProcessorTorch()

# Generate motion
audio_cond = torch.randn(1, 163, 120)  # [B, C, T]
with torch.no_grad():
    raw = sampler.sample((1, 25, 120), audio_cond)
    decoded = decoder(raw.permute(0, 2, 1))
    processed = postprocessor.process(decoded.permute(0, 2, 1))

motion = processed.permute(0, 2, 1).numpy()  # [1, T, 25]

---

13. Performance & Benchmarks

13.1 Training Performance

┌─────────────────────────────────────────────────────────────────────────────┐
│                         TRAINING BENCHMARKS                                  │
├──────────────────────────┬──────────────────────────────────────────────────┤
│ Hardware                 │ NVIDIA A100 40GB                                 │
│ Batch Size               │ 32                                               │
│ Sequence Length          │ 120 frames (4 seconds)                           │
│ Mixed Precision          │ FP16                                             │
├──────────────────────────┼──────────────────────────────────────────────────┤
│ Training Speed           │ ~2.5 iterations/second                           │
│ Memory Usage             │ ~24 GB                                           │
│ Time per Epoch (5242 samples) │ ~35 minutes                                 │
│ Full Training (100 epochs)    │ ~58 hours                                   │
└──────────────────────────┴──────────────────────────────────────────────────┘

13.2 E2E Fine-tuning Performance

┌─────────────────────────────────────────────────────────────────────────────┐
│                      E2E FINE-TUNING BENCHMARKS                              │
├──────────────────────────┬──────────────────────────────────────────────────┤
│ Hardware                 │ Apple M2 Max (MPS)                               │
│ Batch Size               │ 8                                                │
│ DDIM Steps (training)    │ 5                                                │
│ Phrases per Epoch        │ 200                                              │
├──────────────────────────┼──────────────────────────────────────────────────┤
│ Time per Batch           │ ~3-4 seconds                                     │
│ Time per Epoch           │ ~3-4 minutes                                     │
│ Full Training (50 epochs)│ ~2.5 hours                                       │
│ Memory Usage             │ ~8 GB                                            │
└──────────────────────────┴──────────────────────────────────────────────────┘

13.3 Inference Performance

┌─────────────────────────────────────────────────────────────────────────────┐
│                         INFERENCE BENCHMARKS                                 │
├─────────────────────┬────────────────┬────────────────┬─────────────────────┤
│ DDIM Steps          │ Time (A100)    │ Time (M2 Max)  │ Time (CPU)          │
├─────────────────────┼────────────────┼────────────────┼─────────────────────┤
│ 5 steps             │ 0.3s           │ 0.8s           │ 5.2s                │
│ 10 steps            │ 0.5s           │ 1.4s           │ 9.8s                │
│ 20 steps            │ 0.9s           │ 2.6s           │ 18.5s               │
│ 50 steps            │ 2.1s           │ 6.2s           │ 45.3s               │
├─────────────────────┼────────────────┴────────────────┴─────────────────────┤
│ Batch Size          │ 1 sample, 120 frames (4 seconds of motion)           │
│ Decoder             │ +0.02s (negligible overhead)                          │
│ Post-processing     │ +0.01s (negligible overhead)                          │
└─────────────────────┴───────────────────────────────────────────────────────┘

---

14. Troubleshooting Guide

14.1 Common Issues

Issue: CUDA Out of Memory

python
# Symptoms
RuntimeError: CUDA out of memory

# Solutions
1. Reduce batch size
   --batch-size 16  # or 8

2. Enable gradient checkpointing
   config.unet.use_checkpoint = True

3. Use mixed precision
   with torch.cuda.amp.autocast():
       ...

4. Reduce sequence length
   --target-frames 60  # instead of 120

Issue: NaN Loss During Training

python
# Symptoms
Loss becomes NaN after several epochs

# Solutions
1. Reduce learning rate
   --learning-rate 5e-5

2. Enable gradient clipping
   --gradient-clip 1.0

3. Check for data issues
   - Verify audio features are normalized
   - Check for inf/nan in dataset

4. Use stable loss scaling
   scaler = torch.cuda.amp.GradScaler()

Issue: Poor Generation Quality

python
# Symptoms
Generated motion looks random or jittery

# Diagnosis
1. Check training loss converged
2. Verify checkpoint loaded correctly
3. Test with more DDIM steps

# Solutions
1. Train longer (more epochs)
2. Use more DDIM steps for inference
   sampler = DDIMSampler(num_steps=50)

3. Apply post-processing
   config = PostProcessConfig(
       enforce_velocity_coherence=True,
       apply_smoothing=True,
   )

Issue: Velocity Coherence Failure

python
# Symptoms
Sanity check fails on velocity_coherence
Velocity range shows exactly [-20, 20] (clamped)

# Root Cause
Model outputs positions that change faster than velocity can track

# Solutions
1. Train with position smoothness loss
   pos_delta = pos[:, 1:] - pos[:, :-1]
   smoothness_loss = F.relu(pos_delta.abs() - 0.5).mean()

2. Increase velocity head output range
   velocity = torch.tanh(self.velocity_head(h)) * 50  # instead of 20

3. Use coherence enforcement in post-processing
   config.enforce_velocity_coherence = True

14.2 Debugging Tools

python
# Visualize training progress
import matplotlib.pyplot as plt

def plot_losses(loss_history: Dict[str, List[float]]):
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    metrics = ['total', 'diffusion', 'decoder_pos',
               'velocity_coherence', 'accel_coherence', 'jerk']

    for ax, metric in zip(axes.flat, metrics):
        if metric in loss_history:
            ax.plot(loss_history[metric])
            ax.set_title(metric)
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.set_yscale('log')

    plt.tight_layout()
    plt.savefig('training_progress.png')


# Visualize motion output
def visualize_motion(motion: np.ndarray, output_path: str):
    """Create visualization of motion trajectory."""
    T, D = motion.shape

    fig, axes = plt.subplots(3, 2, figsize=(12, 10))

    # Position
    axes[0, 0].plot(motion[:, 0:3])
    axes[0, 0].set_title('Position (x, y, z)')
    axes[0, 0].legend(['x', 'y', 'z'])

    # Velocity
    axes[0, 1].plot(motion[:, 3:6])
    axes[0, 1].set_title('Velocity')

    # Quaternion
    axes[1, 0].plot(motion[:, 9:13])
    axes[1, 0].set_title('Quaternion (w, x, y, z)')

    # Phase
    axes[1, 1].plot(motion[:, 16])
    axes[1, 1].set_title('Phase')

    # Quaternion norm
    quat_norm = np.linalg.norm(motion[:, 9:13], axis=1)
    axes[2, 0].plot(quat_norm)
    axes[2, 0].axhline(y=1.0, color='r', linestyle='--')
    axes[2, 0].set_title('Quaternion Norm')

    # Jerk
    accel = motion[:, 6:9]
    jerk = np.diff(accel, axis=0) * 30
    jerk_mag = np.linalg.norm(jerk, axis=1)
    axes[2, 1].plot(jerk_mag)
    axes[2, 1].axhline(y=50000, color='r', linestyle='--')
    axes[2, 1].set_title('Jerk Magnitude')

    plt.tight_layout()
    plt.savefig(output_path)

---

15. Known Issues & Roadmap

15.1 Current Limitations

IssueStatusDescriptionWorkaround
Generalization GapKnownModel trained on real audio doesn't generalize to random noiseEvaluate with real audio features
Velocity ClampingKnownDecoder outputs bounded velocities despite large position changesAdd position smoothness loss
Memory UsageKnownE2E training is memory-intensive due to DDIM gradientsUse fewer DDIM steps (5)
GCS LatencyKnownTraining slows due to GCS data loadingEnable caching or download locally

15.2 Roadmap

┌─────────────────────────────────────────────────────────────────────────────┐
│                              DEVELOPMENT ROADMAP                             │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  v0.3.0 (Planned)                                                           │
│  ├── [ ] Add conditioning augmentation (noise injection during training)    │
│  ├── [ ] Implement position smoothness loss                                 │
│  ├── [ ] Add real audio feature evaluation pipeline                         │
│  └── [ ] Optimize GCS data loading with local caching                       │
│                                                                              │
│  v0.4.0 (Future)                                                            │
│  ├── [ ] Implement classifier-free guidance training                        │
│  ├── [ ] Add style control (conditioning on style vectors)                  │
│  ├── [ ] Multi-GPU distributed training support                             │
│  └── [ ] ONNX export for production deployment                              │
│                                                                              │
│  v1.0.0 (Release)                                                           │
│  ├── [ ] Achieve 80%+ sanity check pass rate                               │
│  ├── [ ] Real-time inference optimization                                   │
│  ├── [ ] Comprehensive documentation and tutorials                          │
│  └── [ ] Integration with Comp-Core pipeline                                │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

---

Appendix A: Mathematical Notation

SymbolDescription
$x_0$Clean data sample
$x_t$Noisy sample at timestep $t$
$\epsilon$Gaussian noise $\sim \mathcal{N}(0, I)$
$\alpha_t$Signal retention at step $t$
$\bar{\alpha}_t$Cumulative signal retention
$\beta_t$Noise schedule at step $t$
$\epsilon_\theta$Learned noise predictor
$\mu_\theta$Learned posterior mean

---

Appendix B: File Checksums

Checkpoint Files (after E2E training):
├── e2e_best.pt          SHA256: [computed at runtime]
├── e2e_decoder_best.pt  SHA256: [computed at runtime]
└── training_log.json    SHA256: [computed at runtime]

Model Sizes:
├── UNet1D:        ~460 MB (116M params × 4 bytes)
├── Decoder:       ~8 MB (2M params × 4 bytes)
└── Total:         ~468 MB

---

Document generated by CC-MotionGen Development Team
For questions or issues, contact: [email]

Promotion Decision

Attach run IDs, datasets, metrics, and reproduction commands.

Source Anchor

Comp-Core/core/ml/cc-ml/cc_motiongen/TECHNICAL_DOCUMENTATION.md

Detected Structure

Method · Evaluation · Math · Figures · Code Anchors · Architecture