MIGRATING CHROMA TO MLX

Git: https://github.com/jack813/mlx-chroma

KEY OBSERVATIONS

  1. Model Parameter Change: Chroma’s model uses 643 parameters, reduced from Flux’s 780.
  2. Negative Prompt Impact: Chroma introduces negative prompts, which result in roughly double the computation compared to Flux.

RESULT COMPARISON

The following table compares the memory usage and generation time between two framework Inference using identical prompts and seeds.

Parameter Settings

  • CFG_SKIP_STEPS = 0
  • CFG = 4
  • WIDTH = HEIGHT = 1024
  • STEPS = 28
  • Same prompt, same seed
  • Macbook Pro M4 Max 128GB
Model Combination Memory Usage Inference Time
MLX_BF16 (Chroma_BF16 + T5_BF16) 28.74 GB 7:26
ComfyUI (Chroma_BF16 + T5_FP8) 30 GB 10:48

Comparison of Results with Different CFG and skip-cfg-steps Settings



ABOUT THE MODEL

CHROMA: OPEN-SOURCE, UNCENSORED, AND BUILT FOR THE COMMUNITY
Chroma is a 8.9B parameter model based on FLUX.1-schnell (technical report coming soon!). It’s fully Apache 2.0 licensed, ensuring that anyone can use, modify, and build on top of it—no corporate gatekeeping. The model is still training right now, and I’d love to hear your thoughts! Your input and feedback are really appreciated. Huggingface

STEPS TO MIGRATION CHROMA TO MLX

Approach

Our porting process relied on the Flux inference implementation available in the mlx-examples repository for MLX. We referred to the source code from both ComfyUI and the Flow library developed by the original author to guide the migration.

I utilized a dump-anchor data alignment technique to systematically compare intermediate computation results at each step, ensuring the ported version maintains functional parity with the original.

Step 1. Model Configuration and Import

Since Chroma is based on the Flux model but without the Clip computation, referring to the Flux implementation in mlx-examples made the process much easier. The original project already includes the import and configuration for Flux, T5, Clip, and VAE models. Our task was mainly to replace Flux’s configuration and parameters with those of the Chroma model, and remove all Clip-related code.

Chroma Model Configuration

ChromaParams(		
	in_channels= 64,
	out_channels = 64,
    	context_in_dim=4096,
        hidden_size=3072,
        mlp_ratio=4.0,
        num_heads=24,
        depth=19,
        depth_single_blocks=38,
        axes_dim=[16, 56, 56],
        theta=10_000,
        patch_size = 2,
        qkv_bias=True,
        in_dim = 64,
        out_dim = 3072,
        hidden_dim = 5120,
        n_layers = 5
)
The main difference in the model parameters is that the three modules time_in, vector_in, and guidance_in present in Flux were removed and replaced with "distilled_guidance_layer"
self.distilled_guidance_layer = Approximator(
                    in_dim=self.in_dim,
                    hidden_dim=self.hidden_dim,
                    out_dim=self.out_dim,
                    n_layers=self.n_layers,
                )
                
class Approximator(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=5):
        super().__init__()
        self.in_proj = nn.Linear(in_dim, hidden_dim, True)
        self.layers = [MLPEmbedder(hidden_dim, hidden_dim) for _ in range(n_layers)]
        self.norms = [nn.RMSNorm(hidden_dim, eps=1e-6) for _ in range(n_layers)]
        self.out_proj = nn.Linear(hidden_dim, out_dim)

    def __call__(self, x: mx.array) -> mx.array:
        dtype = x.dtype
        x = self.in_proj(x)

        for layer, norm in zip(self.layers, self.norms):
            x = x + layer(norm(x))

        x = self.out_proj(x)
        x = x.astype(dtype)
        return x
As a result of these changes, the model’s parameter count decreased from 780 to 643. Therefore, the model is smaller than Flux, which should lead to lower memory usage and, theoretically, faster inference computation.


Step 2. Forward Propagation

In Flux, the modules time_in, vector_in, and guidance_in are used to generate modulation vectors, which control and guide the model within the MM Block (Double Block) and Single Block. Therefore, new conditional controls need to be created.

Construction of Modulation Vectors

batch_size = img.shape[0]
mod_index_length = 344
distill_timestep = timestep_embedding(mx.stop_gradient(timesteps), 16).astype(self.dtype)
distil_guidance = timestep_embedding(mx.stop_gradient(guidance), 16).astype(self.dtype)
modulation_index = timestep_embedding(mx.arange(mod_index_length, dtype=mx.float32), 32).astype(self.dtype)
modulation_index = mx.broadcast_to(modulation_index, (batch_size, mod_index_length, 32)).astype(self.dtype)
timestep_guidance = mx.concatenate([distill_timestep, distil_guidance], axis=1).astype(self.dtype)
timestep_guidance = mx.broadcast_to(timestep_guidance, (batch_size, mod_index_length, 32)).astype(self.dtype)  # (B, 344, 32)

input_vec = mx.concatenate([timestep_guidance, modulation_index], axis=-1).astype(self.dtype)
mod_vectors = self.distilled_guidance_layer(input_vec).astype(self.dtype)

def get_modulations(self, tensor: mx.array, block_type: str, *, idx: int = 0):
        # This function slices up the modulations tensor which has the following layout:
        #   single     : num_single_blocks * 3 elements
        #   double_img : num_double_blocks * 6 elements
        #   double_txt : num_double_blocks * 6 elements
        #   final      : 2 elements
        # print("tensor :", tensor)
        if block_type == "final":
            return (tensor[:, -2:-1, :], tensor[:, -1:, :])
        single_block_count = self.params.depth_single_blocks
        double_block_count = self.params.depth
        offset = 3 * idx
        if block_type == "single":
            return ChromaModulationOut.from_offset(tensor, offset)
        # Double block modulations are 6 elements so we double 3 * idx.
        offset *= 2
        if block_type in {"double_img", "double_txt"}:
            # Advance past the single block modulations.
            offset += 3 * single_block_count
            if block_type == "double_txt":
                # Advance past the double block img modulations.
                offset += 6 * double_block_count
            # print("offset", offset)
            return (
                ChromaModulationOut.from_offset(tensor, offset),
                ChromaModulationOut.from_offset(tensor, offset + 3),
            )
        raise ValueError("Bad block_type") 
The biggest bug was caused by this line:
mx.arange(mod_index_length, dtype=mx.float32)
In MLX, arange defaults to an integer type if dtype is not explicitly specified. Without specifying dtype, the constructed vector for the 344 layers ends up with identical parameters for all layers except the first one, leading to serious calculation errors. It took me a long time to track down this bug. Additionally, specifying dtype=BF16 results in much lower precision and failure to generate images properly. Only float32 works correctly. Furthermore, the Guidance parameter needs to be set to 0, whereas Flux typically uses a default value of 4. This was another bug I discovered through debugging data.

Step 3. Precision Comparison
Due to differences between MLX and PyTorch in both algorithmic implementations and code semantics, small discrepancies can accumulate and potentially cause inference to fail. To detect such issues, we compared intermediate results with those from ComfyUI to identify bugs in the computation pipeline.

Specifically, we kept the input data consistent by using dumped values from ComfyUI, including img, txt, img_ids, txt_ids, and timesteps. At each computation step, we dumped the outputs from both the MLX implementation and ComfyUI. By calculating the mean absolute error (MAE) at each step, we were able to evaluate the precision of the MLX implementation.

The results show that although the average relative error tends to increase step by step, it eventually drops back to a lower level after passing through the final layer.


However, we observed that as the number of steps increases, the error in the final output image—after passing through the Final Layer—also becomes more noticeable. This indicates that minimizing cumulative numerical drift across layers is essential, and serves as an important direction for improving the MLX implementation’s precision and stability.

Step 4. Generation of Timesteps

As mentioned earlier, the generation of modulation vectors depends on two parameters: timesteps and guidance. Since guidance is set to 0, the entire modulation process relies solely on the input timesteps, which are produced by the sampler.

We followed the implementation provided by the Chroma author in the original Flow project to replicate the timesteps generation logic.

def time_shift(mu: float, sigma: float, t: Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)


def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b
    
def timesteps(self, num_steps, image_sequence_length, start: float = 1, stop: float = 0):
		t = mx.linspace(start, stop, num_steps + 1
		mu = self.get_lin_function(y1=self._base_shift, y2=self._max_shift)(image_sequence_length)
		t = self.time_shift(mu, 1.0, t)
		return t.tolist()

Although the key model parameters and computation processes were mostly aligned, the results were still unsatisfactory. The generated images showed visible issues such as blurred contours, color shifts, or artifacts. In some cases, even when decoding data directly dumped from ComfyUI using the VAE, the output images appeared inverted in color or contained significant noise.

Similarly, feeding data generated by the MLX implementation back into ComfyUI failed to produce clear images. This issue blocked my progress for several hours. At one point, I suspected that the problem might lie in the T5 token handling or the VAE implementation. However, after thorough investigation, I concluded that neither component was responsible for the degraded output.

Step 5. CFG

Eventually, I discovered that setting the CFG parameter to 1 in ComfyUI also resulted in severely color-shifted images—very similar to the outputs from the MLX implementation. This led me to examine the CFG-related logic in the author’s Flow project.

It turned out that Chroma introduces a negative conditioning path that needs to be computed separately. Then, the results from both the positive and negative conditions are combined through CFG weighting. This step is essential, and missing or miscomputing it can lead to distorted outputs.

pred = model(
		img=img,
    img_ids=img_ids,
    txt=txt,
    txt_ids=txt_ids,
    txt_mask=txt_mask,
    timesteps=t_vec,
    guidance=guidance_vec,
    )
# disable cfg for x steps before using cfg
if step_count < first_n_steps_without_cfg or first_n_steps_without_cfg == -1:
		img = img.to(pred) + (t_prev - t_curr) * pred
else:
		pred_neg = model(
				img=img,
				img_ids=img_ids,
				txt=neg_txt,
				txt_ids=neg_txt_ids,
				txt_mask=neg_txt_mask,
				timesteps=t_vec,
				guidance=guidance_vec,
		)

    pred_cfg = pred_neg + (pred - pred_neg) * cfg
		img = img + (t_prev - t_curr) * pred_cfg

Eventually, after adding the negative conditioning computation, the model began generating clear images. However, the inference time nearly doubled as a result. In contrast, the original Flux generation process does not include any negative conditioning, which makes it significantly faster.

This explains why, despite having fewer parameters than Flux, the Chroma model does not show a significant improvement in generation speed or memory usage.

CONCLUSION

Porting Chroma to MLX was a rewarding yet challenging process. While the model architecture is simpler than Flux in terms of parameter count, the introduction of negative conditioning brought unexpected complexity—both in computation logic and performance implications.

Through detailed debugging, step-by-step comparison with ComfyUI, and careful inspection of the Flow codebase, I uncovered several subtle but critical issues:

  • MLX-specific behaviors, such as the default int dtype in arange, can silently introduce significant bugs.

  • Numerical drift accumulates across steps, which emphasizes the importance of precision (e.g., using float32 instead of bf16) in critical computations.

  • Chroma’s negative conditioning greatly increases inference time, offsetting its lighter parameter size compared to Flux.

These findings highlight that aligning architecture alone is not enough—behavioral fidelity across frameworks also depends on sampling strategies, condition handling, and numerical precision.


MIGRATING CHROMA TO MLX