MLX-Chroma + Lora

 These days, LoRA is pretty much the most cost-effective way to fine-tune image generation models.

Git: https://github.com/jack813/mlx-chroma

Demo Lora: jack813liu/style_scaramouche_chroma_lora

Tools:ostris/ai-toolkit

Dataset:svjack/Genshin_Impact_Scaramouche_Images_Captioned_Ghibli_Studio_style

The ai-toolkit automatically downloads the latest Chroma model from Hugging Face (note: NOT the detail-calibrated version) for training.

Since I’m using a MacBook, unified memory helps mitigate out-of-memory issues, but the GPU performance still doesn’t quite match up to NVIDIA’s dedicated GPUs. So, I opted to train the model on RunPod instead.

You can find the setup and training instructions for RunPod here:

👉 Training in RunPod - ai-toolkit GitHub

I followed the recommended configuration and trained on an NVIDIA A40 GPU.

  • Training steps: 2550

  • Total time: ~5 hours

  • Average iteration time: ~6 seconds/it

  • Image generation time: 1 minute 55 seconds

From my personal experience, an RTX 4090 can also handle the training just fine. Thanks to quantization in ai-toolkit, the model fits comfortably within 24GB of VRAM. Plus, the 4090 offers significantly faster performance than the A40.

$$
MaxTrainSteps = (\frac{NumberOfSamples * Repeats}{BatchSizes})* Epochs
$$

The dataset contains 102 samples, with repeats left at the default value of 1 and a batch size of 1.

This means one epoch equals 102 steps, so training for 40 epochs would require 4080 steps in total.

I stopped the training at 2550 steps, as I noticed the model had already started to capture the dataset’s style clearly in the validation images—

in fact, the stylistic features were already quite apparent after just 510 steps.

Porting MLX-Chroma to Support LoRA

When porting MLX-Chroma to support LoRA, there are two key steps you need to handle:

  1. Load the LoRA model parameters

  2. Merge the LoRA weights into the base model for inference

The Flux project in MLX already provides support for applying LoRA. Only minimal modifications are needed to make it work for our use case.

In the flux code under mlx-examples, LoRA parameters are loaded via the load_adapter function.

This step requires:

  • Determining the rank of the LoRA model

  • Converting the parameters into the format expected by the base model

In models generated with ai-toolkit, the LoRA rank is not explicitly specified in the config file.

You can either set the rank manually or load the model metadata to infer it automatically.

a = weights["diffusion_model.double_blocks.0.img_mlp.2.lora_A.weight"]
b = weights["diffusion_model.double_blocks.0.img_mlp.2.lora_B.weight"]
rank = a.shape[0] if a.shape[0] == b.shape[1] else a.shape[1]

Renaming Parameters to Match the Model

When loading the LoRA weights, we need to align the parameter names with those expected by the model. Here’s an example of how to do that:

new_weights = {}
for k, v in weights.items():
  # Remove the "diffusion_model." prefix if it exists
  if k.startswith("diffusion_model."):
      new_k = k[len("diffusion_model."):]
  else:
      new_k = k
  # Normalize naming conventions
  new_k = new_k.replace(".txt_mlp", ".txt_mlp.layers")
  new_k = new_k.replace(".img_mlp", ".img_mlp.layers")
  new_k = new_k.replace(".lora_A.weight", ".lora_a")
  new_k = new_k.replace(".lora_B.weight", ".lora_b")
  new_weights[new_k] = v

# Load weights into the model
chroma.flow.load_weights(list(new_weights.items()), strict=False)

The strict=False flag tells the loader to ignore missing parameters.

You can set it to True during debugging to check for any mismatches.

However, for LoRA models, it’s often necessary to leave it as False, since not all LoRA parameters are expected to have a 1:1 match with the full base model.

Adjusting Parameter Shapes in lora.py

In lora.py, the key modification involves adjusting tensor shapes inside the call method.

You’ll need to transpose the LoRA weights to match the model’s expected input format:

z = (self.dropout(x) @ self.lora_a.T) @ self.lora_b.T

Note on Compatibility with Different Training Tools

Depending on which tool was used to train the LoRA model, some customization may be required.

Different tools may export weights in slightly different formats or naming schemes.

A good long-term solution is to implement a more flexible loader that can handle multiple naming conventions or infer them automatically.

Using LoRA in MLX-Chroma

In txt2image.py from MLX-Chroma, a new --adapter argument has been added to allow specifying the path to a LoRA model for inference.

Here’s an example command:

python txt2image.py \
 "In the style of Scaramouche, this is a digital anime-style drawing by artist @koguru. The character, with short, dark blue hair and large, expressive purple eyes, is holding a small, angry cat. The cat has a red face and is visibly upset, with steam coming from its head. The character wears a traditional white and purple kimono. The background is plain white, making the characters stand out. The overall mood is humorous and light-hearted, with the cat's angry expression contrasting the character's calm demeanor." \
 --image-size 512x512 \
 --cfg 4 \
 --adapter /Users/.../lora/style_scaramouche_chroma_lora_v1_000002040.safetensors

To make it more user-friendly, I’ve also provided a Gradio-based UI. You can launch it with:

python app.py

Demo

Seed 17240

Prompt

photorealistic beautiful girl as Diablo 2 sorceress sexy cosplay, attractive face with confident expression, 21 years old, long dark brown hair, slender and large breasts showing cleavage, fitted emerald green fantasy outfit with lots of cutouts, shorter robe showing more leg, form-fitting corset emphasizing silhouette, bare midriff, ornate golden armor pieces, decorative belt with gemstones, holding wooden staff with glowing purple orb, confident alluring pose, cinematic lighting, high resolution photography style, detailed costume textures, form-fitting outfit, fitted clothing, body-conscious design, natural lighthing.

Negtive Prompt

extra fingers, 6 fingers, multi hand, extra hand, painting, colored pencil, cel shading, oekaki, toony, multiple styles, watermark, transparent background, border, body horror, disembodied head, brain, bloodshot eyes, surreal, comic, skull head, fan character, anatomically inaccurate, deformed, disfigured, ugly, anatomical nonsense, duplicate, bad composition, smooth skin, shiny skin, glistening, render, fake, clay, HDR, out of focus, blurry, unclear, diffuse, foggy, censored, glossy body, white spots, glitch, heavy makeup, disfigured belly button, modern clothing, casual wear, bright lighting, cartoonish, low quality, distorted proportions, modern accessories, contemporary hairstyles, cartoon, anime, drawn, painted, illustrated, stylized, game art, digital art, 2d, flat, artificial skin, plastic skin, doll-like

Style Scaramouche Lora


MLX-Chroma + Lora