Video Style Transfer: Coherence vs Quality

Shrek styled with La Muse

Have you ever wondered how those mesmerizing AI-generated art videos maintain their consistency frame by frame? Let me share our journey implementing various video style transfer algorithms during our project at NTNU.

The Mathematics Behind Video Style Transfer

Video style transfer extends beyond simply applying style transfer to individual frames. The key challenge lies in maintaining temporal coherence while preserving artistic style.

Content and Style Representation

The core idea involves separating content and style representations using convolutional neural networks. For a given frame $I_t$, we extract content features $F_c$ and style features $F_s$. The optimization objective combines content loss $L_c$ and style loss $L_s$:

$$L_{total} = \alpha L_c + \beta L_s + \gamma L_{temporal}$$

where $L_{total}$ is the total loss, $\alpha$, $\beta$, and $\gamma$ are weighting factors.

Temporal Coherence

The temporal loss $L_{temporal}$ ensures consistency between consecutive frames. Given frames $I_{t-1}$ and $I_t$, we compute:

$$L_{temporal}=\sum_{i,j}{||S(I_t)-S(I_{t-1})||}$$

where $S(I)$ represents the stylized output.

Here’s our implementation of of the temporal loss function:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

def temporal_loss(
prev_frame: Tensor,
curr_frame: Tensor,
flow: Tensor,
reduction: str = 'mean'
) -> Tensor:
"""
Compute temporal consistency loss between consecutive frames.

Args:
prev_frame: Previous stylized frame tensor of shape (B, C, H, W)
curr_frame: Current stylized frame tensor of shape (B, C, H, W)
flow: Optical flow tensor of shape (B, 2, H, W)
reduction: Reduction method for loss computation ('mean' or 'sum')

Returns:
Scalar tensor containing temporal consistency loss

Raises:
ValueError: If reduction method is not 'mean' or 'sum'
"""
warped_prev = warp_frame(prev_frame, flow)
diff = torch.abs(curr_frame - warped_prev)

if reduction == 'mean':
return torch.mean(diff)
elif reduction == 'sum':
return torch.sum(diff)
else:
raise ValueError(f"Unsupported reduction method: {reduction}")

def warp_frame(
frame: Tensor,
flow: Tensor
) -> Tensor:
"""
Warp frame according to optical flow using grid sampling.

Args:
frame: Input frame tensor of shape (B, C, H, W)
flow: Optical flow tensor of shape (B, 2, H, W)

Returns:
Warped frame tensor of same shape as input frame
"""
B, C, H, W = frame.shape

# Create sampling grid
grid_y, grid_x = torch.meshgrid(
torch.arange(H, device=frame.device),
torch.arange(W, device=frame.device),
indexing='ij'
)

# Apply flow to grid
flow_grid = torch.stack([
2 * (grid_x + flow[:, 0]) / (W - 1) - 1,
2 * (grid_y + flow[:, 1]) / (H - 1) - 1
], dim=-1)

# Perform grid sampling
return F.grid_sample(
frame,
flow_grid.permute(0, 2, 3, 1),
mode='bilinear',
padding_mode='border',
align_corners=True
)

Our Implementation Approaches

In our project, we explored multiple algorithms:

Gatys Method

The naive approach applies style transfer frame by frame. While simple, it produces noticeable flickering.

Ruder’s Temporal Constraint

We implemented a temporal constraint using DeepFlow for motion estimation. This significantly improved temporal coherence by penalizing large deviations between adjacent frames.

Instance Normalization for Speed

To accelerate the optimization process, we incorporated instance normalization:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class InstanceNorm(nn.Module):
"""
Instance Normalization layer for style transfer.

This layer normalizes feature maps independently across spatial dimensions
and applies learnable affine transformation parameters.

Attributes:
scale: Learnable scaling parameter
shift: Learnable shifting parameter
eps: Small constant for numerical stability
"""
def __init__(self, dim: int, eps: float = 1e-8) -> None:
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
self.shift = nn.Parameter(torch.zeros(dim))
self.eps = eps

def forward(self, x: Tensor) -> Tensor:
"""
Apply instance normalization to input tensor.

Args:
x: Input tensor of shape (batch_size, channels, height, width)

Returns:
Normalized and transformed output tensor of same shape as input
"""
mean = x.mean(dim=(2, 3), keepdim=True)
std = x.std(dim=(2, 3), keepdim=True) + self.eps
return self.scale[None, :, None, None] * (x - mean) / std + self.shift[None, :, None, None]

Feature Distribution Matching

For style representation, we adopted the Wasserstein metric to measure feature distribution differences. The Wasserstein distance between distributions $P$ and $Q$ is defined as:

$$W(P,Q) = \inf_{\gamma\in\Pi(P,Q)}\mathbb{E}_{(x,y)\sim\gamma}[|x-y|]$$

where $\Pi(P,Q)$ represents all possible joint distributions.

Results and Performance

Our experiments showed that combining Johnson’s feed-forward network with instance normalization provided the best balance between speed and quality. The temporal constraint from Ruder’s method effectively eliminated flickering artifacts while preserving artistic style.

The implementation is available in our repository, with different approaches organized in separate modules for experimentation.

Through this project, we’ve demonstrated that effective video style transfer requires careful consideration of both artistic quality and temporal consistency. The combination of proper feature representation, temporal constraints, and efficient optimization techniques makes real-time video style transfer possible.