在 Python 中用 LUT

3D LUT 能够封装任何全局处理的色彩变换算法,以极高的效率和较高的精度应用到图像上。在开发过程中多少会接触到 LUT,如何在 Python 中更好地使用 LUT 呢?

TL;DR:如果没有特殊需求,请使用 OpenColorIO;如果不希望引入新的依赖,则推荐使用 colour-science 或手动的 Numpy 实现。

在 Python 中,有很多的方法来使用 3D LUT,它们各有特点,适用于不同的场景。包括但不限于:

方案插值方式实现方式
colour-science三线性/四面体Python(Numpy)
Pillow三线性C
OpenColorIO三线性/四面体C++
PyTorch三线性C++/GPU

这些方案的 Python 文件可以在这里找到:Github

手动实现

最灵活的一集,在各种 AI 的帮助下,完全可以在几分钟内从零实现 LUT 的应用算法。

最基本的依赖是 Numpy,用于向量化的处理图像和网格化的数据,Numpy 已经足够处理三线性和四面体插值,还可以引入 SciPy 来实现更复杂的插值算法。

手动实现时,最好能够选取下面任一方法作为测试基准,以验证实现的正确性。

Colour-Science

colour-science 是一个功能全面的色彩科学库,覆盖了色彩科学的方方面面,也包括了读取和使用 LUT。它支持多种插值方式,但在性能上不如专为图像处理优化的库,本质上和用 Numpy 手动实现是一样的。

import colour

# 读取 LUT
lut = colour.io.read_LUT("your_lut.cube")
# LUT 需应用在 [0, 1] 范围的图像上
image_float = np.random.rand(800, 800, 3)
# 默认插值方式是三线性
output_trilinear = lut.apply(image_float)
# 使用四面体插值应用 LUT
output_tetrahedral = lut.apply(
    image_float, 
    interpolator=colour.algebra.table_interpolation_tetrahedral,
    )

Pillow

Pillow 是 Python 中最流行的图像处理库。Pillow 的底层是 C 语言实现的,具有不错的性能,但它只支持三线性插值。另外,需要先将图像转换成 8 Bit,不能高精度的处理浮点图像。

from PIL import Image, ImageFilter

# 加载 Color3DLUT(需要手动解析 CUBE 文件)
size, table = read_cube_file("your_lut.cube")
# size 是 LUT 的阶数
# table 是一个长度为 size^3 * 3 的列表,包含了 LUT 的 RGB 输出值
lut_filter = ImageFilter.Color3DLUT(size, table)
# 加载图像并应用 LUT
img = Image.open("your_image.jpg")
output_img = img.filter(lut_filter)

OpenColorIO

OpenColorIO (OCIO) 是一个工业级的色彩管理库,广泛应用于影视后期制作中。OCIO 支持多种插值方式,包括三线性和四面体插值。本身由 C++ 实现,由 Pybind11 绑定到 Python,性能非常出色,API 相对复杂,适合作为基准实现。

uv add opencolorio
import PyOpenColorIO as OCIO

config = OCIO.Config.CreateRaw()
lut_transform = OCIO.FileTransform("your_lut.cube")
# 设置插值方式为四面体
lut_transform.setInterpolation(OCIO.INTERP_TETRAHEDRAL)
# 设置插值方式为三线性
lut_transform.setInterpolation(OCIO.INTERP_LINEAR)
# 获取处理器
processor = config.getProcessor(lut_transform)
cpu_processor = processor.getDefaultCPUProcessor()
# 原地应用变换(作用于浮点图像)
image_float = np.random.rand(800, 800, 3)
output_image = image_float.copy()
cpu_processor.applyRGB(output_image)

PyTorch

PyTorch 是一个流行的深度学习框架,提供了强大的 GPU 加速能力。通过使用 grid_sample 函数,可以实现高效的三线性插值,但需要将图像和 LUT 都转换成 PyTorch 张量,这部分转换比较复杂,请查看具体的代码实现

import torch
import torch.nn.functional as F

out_tensor = F.grid_sample(
    lut_tensor,      # 形状 (1, 3, N, N, N)
    grid_tensor,     # 形状 (1, 1, H, W, 3) 且值在 [-1, 1] 范围
    mode="bilinear", 
    padding_mode="border", 
    align_corners=True
)

性能与精度

为了测试不同方案在实际生产中的处理速度,使用了一张 3840x2160 分辨率的浮点图像,分别在不同的实现方案下测定了应用 LUT 的耗时。

耗时

测试结果如下,时间单位是毫秒,处理器是 M1 Pro。

实现插值方式17 Steps33 Steps65 Steps
colour-science三线性1493.31423.71381.1
colour-science四面体2106.32091.92088.7
Pillow三线性153.9157.0166.3
OpenColorIO三线性87.285.386.9
OpenColorIO四面体49.750.549.5
PyTorch (CPU)三线性262.2262.5264.4
PyTorch (GPU/MPS)三线性31.120.922.3

增加 LUT 的阶数不会影响到插值和查询的次数,OCIO 的实现非常高效,Pytorch 的 GPU 效果比较显著,已经非常接近原生性能。

精度

除了 Pillow 以外,别的方法都是在浮点图像上直接处理的,以 colour 的实现为基准,各种库在同一种插值算法下的差异都在 1e-8 到 1e-9 的数量级,可以视为完全一致。

Pillow 需要将图像先转为 8 Bit,相比 colour,在一个基准测试图像上,有 2% 的像素出现了一个码值的误差,其余完全相同。受制于 8 Bit 的输入和输出精度,不推荐用于 LUT 工作流程。