Or: how I spent twelve hours building a workaround for a flag that already existed


I run a four-node DGX Spark cluster for local DFIR inference. The biggest model on it is Intel's Qwen3.5-397B-A17B in INT4 AutoRound — about 200 GB of weights, fits comfortably across four 128 GiB unified-memory boxes when sharded TP=4. Until last week that cluster ran at 26 tok/s with Conch as the GPTQ kernel. Marlin should have been faster — the published reference number from sonusflow's earlier vLLM patch was 37 tok/s on the same hardware — but every attempt I made to run Marlin at TP=4 produced models that loaded cleanly, accepted prompts, and silently emitted garbage.

This is what was actually wrong, the dead end I spent a day on, and the two-line fix that ended up at 41.5 tok/s.

The setup

Working baseline before this work: mods/install-conch active, no Marlin patch, 26 tok/s with coherent output. Good enough to use, slow enough to be irritating.

The problem

The Qwen3-Next architecture has a small but architecturally critical projection inside its gated delta network — in_proj_ba. It produces two num_v_heads-wide tensors (b and a) that gate the attention computation downstream. For Qwen3.5-397B these are 64 wide each, packed as a MergedColumnParallelLinear with output_sizes=[64, 64].

At TP=4, that splits to 32 outputs per rank.

vLLM's gptq_marlin path requires output_size_per_partition % 64 == 0 (GPTQ_MARLIN_MIN_THREAD_N = 64). 32 fails the check. So the moment you try to run Marlin at TP=4 on this model, this one layer can't be served by the kernel. From vllm/model_executor/layers/quantization/utils/marlin_utils.py:170:

if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
    raise ...

sonusflow's older patch worked around this by hand-rolling a replacement that bypassed the constraint. That patch was written against pre-March-5 vLLM and no longer applies cleanly — vLLM has since refactored in_proj_ba creation into a create_ba_proj factory method and routed inference through a gdn_in_proj torch custom op. The original landing points are gone.

The wrong turn

My first instinct was to write a wrapper class that made in_proj_ba look like a MergedColumnParallelLinear to the outside world but used two separate ReplicatedLinear modules underneath, each with output=64 (which passes min_thread_n) and full TP replication.

Six iterations in, the cluster loaded. Marlin was selected — Using MarlinLinearKernel for GPTQMarlinLinearMethod showed up in the logs. Throughput jumped to ~40 tok/s. And the model emitted strings of tokens that looked coherent but were complete nonsense:

I added file logging to every layer's b/a forward pass on every rank:

b_full norm=0.0000  EXACTLY
a_full norm=0.0000  EXACTLY
output  [0.0, 0.0, 0.0, 0.0, 0.0]

Zero across all four ranks, every layer. Marlin was happily multiplying hidden states by garbage weights and producing exact zeros. The model's gated delta net was structurally present but numerically dead; what came out was the residual path leaking through the rest of the architecture, enough to produce token-shaped output but no actual reasoning.

Diagnostic that broke the kernel hypothesis: switching from Marlin to Conch with the same wrapper produced the same zeros. The bug wasn't kernel-specific — it was that AutoRound INT4 weights weren't being loaded correctly into a ReplicatedLinear constructed outside the standard linear factories. Shapes looked correct, qweight was in Marlin's repacked (256, 128) format, scales had right dims (32, 64). The actual values were wrong in some way I couldn't surface without GPU-side numerical dumps.

After about twelve hours of this I gave up, restored the Conch baseline, tagged the image, and went to sleep.

The fix

Fresh eyes the next morning. I went back to read the linear module definitions in vllm/model_executor/layers/linear.py properly instead of patching around them. And there, in MergedColumnParallelLinear.__init__, was a parameter I had completely missed:

def __init__(
    self,
    input_size: int,
    output_sizes: list[int],
    bias: bool = True,
    ...
    *,
    return_bias: bool = True,
    disable_tp: bool = False,  # ← here
):

Docstring:

disable_tp: If true, all weights matrix won't be sharded, this layer will be treated as a "Replicated" MergedLinear.

vLLM had built exactly the thing I had spent twelve hours trying to construct by hand. The wrapper was unnecessary. The custom weight loader was unnecessary. The bespoke shard/concat logic was unnecessary. One keyword argument.

With disable_tp=True, MergedColumnParallelLinear:

The only remaining concern was that downstream code expected the sharded shape — each rank receives [batch, 32] and unpacks via chunk(2) into [batch, 16] for b and a. With disable_tp=True it now receives the full [batch, 128]. Trivial: slice it back down to the local TP rank's chunk before returning, so the rest of the pipeline sees the shape it always saw.

The whole patch is ~30 lines added to Qwen3_5GatedDeltaNet in vllm/model_executor/models/qwen3_5.py:

def create_ba_proj(self, hidden_size, num_v_heads, quant_config, prefix):
    return MergedColumnParallelLinear(
        input_size=hidden_size,
        output_sizes=[num_v_heads] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=prefix,
        disable_tp=True,  # replicate, output=128 per rank, Marlin happy
    )

def _slice_ba_to_local(self, ba_full):
    """Slice replicated [batch, 128] back to local TP chunk [batch, 32]."""
    tp_rank = get_tensor_model_parallel_rank()
    tp_size = get_tensor_model_parallel_world_size()
    output_sizes = self.in_proj_ba.output_sizes
    offset = output_sizes[0]
    chunk_b = output_sizes[0] // tp_size
    chunk_a = output_sizes[1] // tp_size
    b_local = ba_full[..., tp_rank * chunk_b : (tp_rank + 1) * chunk_b]
    a_local = ba_full[..., offset + tp_rank * chunk_a : offset + (tp_rank + 1) * chunk_a]
    return torch.cat([b_local, a_local], dim=-1)

def _forward_in_proj(self, hidden_states):
    """Override parent to slice replicated ba output."""
    projected_qkvz, projected_ba_full = maybe_execute_in_parallel(
        lambda: self.in_proj_qkvz(hidden_states)[0],
        lambda: self.in_proj_ba(hidden_states)[0],
        self.events[0], self.events[1], self.aux_stream,
    )
    return projected_qkvz, self._slice_ba_to_local(projected_ba_full)

Plus the same slice helper applied in the LoRA forward path for completeness.

Memory cost of replicating in_proj_ba across four ranks: ~264 KB per layer, ~20 MB total per rank across the model. Negligible.

Results

Config Throughput Notes
TP=2 Marlin ~27 tok/s only fits because of UMA
TP=4 Conch 26 tok/s working, the baseline
TP=4 Marlin + my ReplicatedLinear wrapper 0 tok/s of actual content model emits coherent-looking gibberish
TP=4 Marlin + sonusflow's old patch 37 tok/s reference target, no longer applies
TP=4 Marlin + disable_tp=True 41.5 tok/s +60% over Conch, beats reference

Coherence verified: capital of France returns "Paris", 2+2 returns 4 (after 760 chars of thinking-mode reasoning that charmingly included a George Orwell 1984 detour about forced falsehood), longer-form technical explanations are clean.

Reproducing this

If you're on the same hardware (or close to it) and want to skip the twelve hours, here's everything you need.

Recipe

recipes/4x-spark-cluster/qwen3.5-397b-int4-autoround-local_4.yaml:

recipe_version: "1"
name: Qwen3.5-397B-INT4-Autoround
description: Qwen3.5-397B-INT4-Autoround on 4x DGX Spark, TP=4 Marlin via disable_tp=True
model: Intel/Qwen3.5-397B-A17B-int4-AutoRound
cluster_only: true
container: vllm-node-tf5
mods:
  - mods/fix-qwen3.5-chat-template
  - mods/fix-qwen3.5-autoround
  - mods/drop-caches
  - mods/fix-qwen35-tp4-marlin
defaults:
  port: 8000
  host: 0.0.0.0
  tensor_parallel: 4
  gpu_memory_utilization_gb: 85
  max_model_len: 262144
  max_num_batched_tokens: 32768
env:
  PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
  VLLM_MARLIN_USE_ATOMIC_ADD: 1
  NCCL_IB_GID_INDEX: "3"
  NCCL_DEBUG: "INFO"
  NCCL_IB_TIMEOUT: "30"
  NCCL_IB_RETRY_CNT: "7"
command: |
  vllm serve Intel/Qwen3.5-397B-A17B-int4-AutoRound \
    --max-model-len {max_model_len} \
    --max-num-seqs 8 \
    --kv-cache-dtype fp8 \
    --gpu-memory-utilization-gb {gpu_memory_utilization_gb} \
    --port {port} \
    --host {host} \
    --enable-prefix-caching \
    --enable-auto-tool-choice \
    --tool-call-parser qwen3_coder \
    --reasoning-parser qwen3 \
    --chat-template unsloth.jinja \
    --load-format instanttensor \
    --max-num-batched-tokens {max_num_batched_tokens} \
    --trust-remote-code \
    -tp {tensor_parallel} \
    --quantization gptq \
    --distributed-executor-backend ray

A few of the values matter and are worth flagging:

The mod

mods/fix-qwen35-tp4-marlin/run.sh:

#!/bin/bash
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
echo "[fix-qwen35-tp4-marlin] Applying disable_tp=True patch to qwen3_5.py..."
patch -p0 -d /usr/local/lib/python3.12/dist-packages/vllm < "${SCRIPT_DIR}/qwen3_5.patch" \
  || echo "[fix-qwen35-tp4-marlin] Patch not applicable, skipping..."
echo "[fix-qwen35-tp4-marlin] Done."

mods/fix-qwen35-tp4-marlin/qwen3_5.patch:

--- model_executor/models/qwen3_5.py
+++ model_executor/models/qwen3_5.py
@@ -191,14 +191,45 @@
         # Qwen3.5 has separate in_proj_b and in_proj_a weights in the
         # checkpoint, which are loaded into the fused in_proj_ba parameter
         # via stacked_params_mapping with shard_id 0 and 1 respectively.
+        # disable_tp=True: replicate across TP ranks (output=128 satisfies
+        # Marlin min_thread_n=64; sharded TP=4 would give 32 which fails).
         return MergedColumnParallelLinear(
             input_size=hidden_size,
             output_sizes=[num_v_heads] * 2,
             bias=False,
             quant_config=quant_config,
             prefix=prefix,
+            disable_tp=True,
         )

+    def _slice_ba_to_local(self, ba_full):
+        """Slice replicated ba output [batch, 128] to local TP chunk [batch, 32]."""
+        from vllm.distributed import (
+            get_tensor_model_parallel_rank,
+            get_tensor_model_parallel_world_size,
+        )
+        tp_rank = get_tensor_model_parallel_rank()
+        tp_size = get_tensor_model_parallel_world_size()
+        output_sizes = self.in_proj_ba.output_sizes
+        offset = output_sizes[0]
+        chunk_b = output_sizes[0] // tp_size
+        chunk_a = output_sizes[1] // tp_size
+        b_local = ba_full[..., tp_rank * chunk_b:(tp_rank + 1) * chunk_b]
+        a_local = ba_full[..., offset + tp_rank * chunk_a:offset + (tp_rank + 1) * chunk_a]
+        return torch.cat([b_local, a_local], dim=-1)
+
+    def _forward_in_proj(self, hidden_states):
+        """Override parent to slice replicated ba output."""
+        from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
+        projected_states_qkvz, projected_states_ba_full = maybe_execute_in_parallel(
+            lambda: self.in_proj_qkvz(hidden_states)[0],
+            lambda: self.in_proj_ba(hidden_states)[0],
+            self.events[0],
+            self.events[1],
+            self.aux_stream,
+        )
+        return projected_states_qkvz, self._slice_ba_to_local(projected_states_ba_full)
+
     def forward(
         self,
         hidden_states: torch.Tensor,
@@ -200,7 +231,8 @@
         if hasattr(self, "in_proj_qkv"):
             # LoRA path: separate in_proj_qkv and in_proj_z
             mixed_qkv, _ = self.in_proj_qkv(hidden_states)
-            ba, _ = self.in_proj_ba(hidden_states)
+            ba_full, _ = self.in_proj_ba(hidden_states)
+            ba = self._slice_ba_to_local(ba_full)
             z, _ = self.in_proj_z(hidden_states)
         else:
             mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(

Bring it up

# Verify the patch will apply cleanly
docker run --rm -v ~/spark-vllm-docker/mods/fix-qwen35-tp4-marlin:/mod vllm-node-tf5 \
  bash -c "patch -p0 --dry-run -d /usr/local/lib/python3.12/dist-packages/vllm < /mod/qwen3_5.patch"

# Tear down any previous instances
for ip in 192.168.177.11 192.168.177.12 192.168.177.13 192.168.177.14; do
  ssh $USER@$ip 'docker rm -f vllm_node 2>/dev/null' &
done; wait

# Launch
cd ~/spark-vllm-docker
./run-recipe.sh recipes/4x-spark-cluster/qwen3.5-397b-int4-autoround-local_4.yaml

Watch for Using MarlinLinearKernel for GPTQMarlinLinearMethod in the launch logs to confirm the kernel selection, then sanity-test:

curl -s http://<head-node-ip>:8000/v1/completions \
  -H "Content-Type: application/json" \
  -d '{"model":"Intel/Qwen3.5-397B-A17B-int4-AutoRound","prompt":"The capital of France is","max_tokens":10}' \
  | jq -r '.choices[0].text'

If you get " Paris" you're done. If you get an empty echo of the prompt, the patch didn't apply or the recipe is still pulling Conch.

Takeaways

Two things worth writing down.

First, the meta-lesson. Read the framework before patching the framework. I assumed the tool I needed didn't exist and started building it. Twelve hours and one tagged docker image later, the tool was a keyword argument with a one-line docstring. The pattern recurs in agentic coding too — when something feels like it should be supported but isn't surfacing, search the source before you write the workaround. grep -rn "disable_tp\|replicated" vllm/model_executor/layers/ would have found this in two minutes.

Second, the specific lesson for vLLM users. MergedColumnParallelLinear(..., disable_tp=True) is the correct pattern any time a layer's per-rank output dim falls below Marlin's tile minimum at high TP. The check that triggers this — output_size_per_partition % 64 == 0 — bites Qwen3-Next family models specifically because of their small per-head gating projections, but it's a general phenomenon. If you ever see Marlin select cleanly at high TP and then produce zeros, this is the first place to look.

The patch is published as a spark-vllm-docker mod at mods/fix-qwen35-tp4-marlin/. Recipe diff is one line — drop mods/install-conch, add mods/fix-qwen35-tp4-marlin. Tested on vLLM 0.18.1rc1 build dated 2026-03-24, on a 4-node ASUS Ascent GX10 cluster, against Intel/Qwen3.5-397B-A17B-int4-AutoRound.

If you're running the same stack and seeing the same zero-output behavior, this should land you on Marlin in about ten minutes.