Git: https://github.com/jack813/mlx-chroma
KEY OBSERVATIONS
- Model Parameter Change: Chroma’s model uses 643 parameters, reduced from Flux’s 780.
- Negative Prompt Impact: Chroma introduces negative prompts, which result in roughly double the computation compared to Flux.
RESULT COMPARISON
Parameter Settings
- CFG_SKIP_STEPS = 0
- CFG = 4
- WIDTH = HEIGHT = 1024
- STEPS = 28
- Same prompt, same seed
- Macbook Pro M4 Max 128GB
Model Combination | Memory Usage | Inference Time |
---|---|---|
MLX_BF16 (Chroma_BF16 + T5_BF16) | 28.74 GB | 7:26 |
ComfyUI (Chroma_BF16 + T5_FP8) | 30 GB | 10:48 |
Comparison of Results with Different CFG and skip-cfg-steps Settings
ABOUT THE MODEL
CHROMA: OPEN-SOURCE, UNCENSORED, AND BUILT FOR THE COMMUNITYSTEPS TO MIGRATION CHROMA TO MLX
Approach
Our porting process relied on the Flux inference implementation available in the mlx-examples repository for MLX. We referred to the source code from both ComfyUI and the Flow library developed by the original author to guide the migration.I utilized a dump-anchor data alignment technique to systematically compare intermediate computation results at each step, ensuring the ported version maintains functional parity with the original.
Step 1. Model Configuration and Import
Since Chroma is based on the Flux model but without the Clip computation, referring to the Flux implementation in mlx-examples made the process much easier. The original project already includes the import and configuration for Flux, T5, Clip, and VAE models. Our task was mainly to replace Flux’s configuration and parameters with those of the Chroma model, and remove all Clip-related code.Chroma Model Configuration
ChromaParams(
in_channels= 64,
out_channels = 64,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
patch_size = 2,
qkv_bias=True,
in_dim = 64,
out_dim = 3072,
hidden_dim = 5120,
n_layers = 5
)
self.distilled_guidance_layer = Approximator(
in_dim=self.in_dim,
hidden_dim=self.hidden_dim,
out_dim=self.out_dim,
n_layers=self.n_layers,
)
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=5):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, True)
self.layers = [MLPEmbedder(hidden_dim, hidden_dim) for _ in range(n_layers)]
self.norms = [nn.RMSNorm(hidden_dim, eps=1e-6) for _ in range(n_layers)]
self.out_proj = nn.Linear(hidden_dim, out_dim)
def __call__(self, x: mx.array) -> mx.array:
dtype = x.dtype
x = self.in_proj(x)
for layer, norm in zip(self.layers, self.norms):
x = x + layer(norm(x))
x = self.out_proj(x)
x = x.astype(dtype)
return x
As a result of these changes, the model’s parameter count decreased from 780 to 643. Therefore, the model is smaller than Flux, which should lead to lower memory usage and, theoretically, faster inference computation.Step 2. Forward Propagation
In Flux, the modules time_in, vector_in, and guidance_in are used to generate modulation vectors, which control and guide the model within the MM Block (Double Block) and Single Block. Therefore, new conditional controls need to be created.Construction of Modulation Vectors
batch_size = img.shape[0]
mod_index_length = 344
distill_timestep = timestep_embedding(mx.stop_gradient(timesteps), 16).astype(self.dtype)
distil_guidance = timestep_embedding(mx.stop_gradient(guidance), 16).astype(self.dtype)
modulation_index = timestep_embedding(mx.arange(mod_index_length, dtype=mx.float32), 32).astype(self.dtype)
modulation_index = mx.broadcast_to(modulation_index, (batch_size, mod_index_length, 32)).astype(self.dtype)
timestep_guidance = mx.concatenate([distill_timestep, distil_guidance], axis=1).astype(self.dtype)
timestep_guidance = mx.broadcast_to(timestep_guidance, (batch_size, mod_index_length, 32)).astype(self.dtype) # (B, 344, 32)
input_vec = mx.concatenate([timestep_guidance, modulation_index], axis=-1).astype(self.dtype)
mod_vectors = self.distilled_guidance_layer(input_vec).astype(self.dtype)
def get_modulations(self, tensor: mx.array, block_type: str, *, idx: int = 0):
# This function slices up the modulations tensor which has the following layout:
# single : num_single_blocks * 3 elements
# double_img : num_double_blocks * 6 elements
# double_txt : num_double_blocks * 6 elements
# final : 2 elements
# print("tensor :", tensor)
if block_type == "final":
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
single_block_count = self.params.depth_single_blocks
double_block_count = self.params.depth
offset = 3 * idx
if block_type == "single":
return ChromaModulationOut.from_offset(tensor, offset)
# Double block modulations are 6 elements so we double 3 * idx.
offset *= 2
if block_type in {"double_img", "double_txt"}:
# Advance past the single block modulations.
offset += 3 * single_block_count
if block_type == "double_txt":
# Advance past the double block img modulations.
offset += 6 * double_block_count
# print("offset", offset)
return (
ChromaModulationOut.from_offset(tensor, offset),
ChromaModulationOut.from_offset(tensor, offset + 3),
)
raise ValueError("Bad block_type")
mx.arange(mod_index_length, dtype=mx.float32)
Step 3. Precision Comparison
Due to differences between MLX and PyTorch in both algorithmic implementations and code semantics, small discrepancies can accumulate and potentially cause inference to fail. To detect such issues, we compared intermediate results with those from ComfyUI to identify bugs in the computation pipeline.
Specifically, we kept the input data consistent by using dumped values from ComfyUI, including img, txt, img_ids, txt_ids, and timesteps. At each computation step, we dumped the outputs from both the MLX implementation and ComfyUI. By calculating the mean absolute error (MAE) at each step, we were able to evaluate the precision of the MLX implementation.
Step 4. Generation of Timesteps
As mentioned earlier, the generation of modulation vectors depends on two parameters: timesteps and guidance. Since guidance is set to 0, the entire modulation process relies solely on the input timesteps, which are produced by the sampler.
We followed the implementation provided by the Chroma author in the original
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def timesteps(self, num_steps, image_sequence_length, start: float = 1, stop: float = 0):
t = mx.linspace(start, stop, num_steps + 1
mu = self.get_lin_function(y1=self._base_shift, y2=self._max_shift)(image_sequence_length)
t = self.time_shift(mu, 1.0, t)
return t.tolist()
Step 5. CFG
Eventually, I discovered that setting the CFG parameter to 1 in ComfyUI also resulted in severely color-shifted images—very similar to the outputs from the MLX implementation. This led me to examine the CFG-related logic in the author’s Flow project.
It turned out that Chroma introduces a negative conditioning path
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
txt_mask=txt_mask,
timesteps=t_vec,
guidance=guidance_vec,
)
# disable cfg for x steps before using cfg
if step_count < first_n_steps_without_cfg or first_n_steps_without_cfg == -1:
img = img.to(pred) + (t_prev - t_curr) * pred
else:
pred_neg = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
txt_mask=neg_txt_mask,
timesteps=t_vec,
guidance=guidance_vec,
)
pred_cfg = pred_neg + (pred - pred_neg) * cfg
img = img + (t_prev - t_curr) * pred_cfg
not include any negative conditioning, which makes it significantly faster.
Porting Chroma to MLX was a rewarding yet challenging process. While the model architecture is simpler than Flux in terms of parameter count, the introduction of negative conditioning brought unexpected complexity—both in computation logic and performance implications.
Through detailed debugging, step-by-step comparison with ComfyUI, and careful inspection of the Flow codebase, I uncovered several subtle but critical issues:
MLX-specific behaviors, such as the default int dtype in arange, can silently introduce significant bugs.
Numerical drift accumulates across steps, which emphasizes the importance of precision (e.g., using float32 instead of bf16) in critical computations.
Chroma’s negative conditioning greatly increases inference time, offsetting its lighter parameter size compared to Flux.