Back to projects
Reproducing DeepTFUS

Reproducing DeepTFUS

An open reproduction of a 3D AI model that predicts how focused ultrasound travels through the skull.

··

ground truth

model prediction

drag · scroll
drag · scroll
Predicted (right) vs ground-truth (left) ultrasound focus inside the head. Skull gray, transducer red, focal cloud blue. Drag to orbit.

1. Reproducing paper results

The paper does two things: releases TFUScapes, a 3D dataset of simulated tFUS pressure fields through MRI-derived skulls, and proposes DeepTFUS, a 3D U-Net that predicts those fields from a CT volume and transducer position. However, at the time of writing, weights and training code have not been released, and §3.2 leaves several implementation specifics unspecified[1] so this reproduction has to fill in some gaps.

Three metrics we care about

The paper reports three numbers on a 597-sample held-out test set. (Lower = better for all three).

relative_l2

Is the field shape right?

Total error between the predicted and ground-truth 3D pressure fields, normalized by ground-truth magnitude.

41.4% ± 8.6 paper mean ± std

focal_position_error_mm

Where does the focus land?

Distance in mm between where the model places the peak pressure and where the ground-truth simulator places it.

2.89 ± 2.14 mm paper mean ± std

max_pressure_error

How strong is the focus?

Relative error in peak pressure intensity.

19.9% ± 15.8 paper mean ± std

My base reproduction vs the paper

50 epochs from scratch using the paper’s composite loss exactly, with two small architecture deviations to fit the “batch size of 4 on a single A100” memory budget the paper specifies[2]. Training ran for about 11 hours on a single H100 80GB (1686 train / 200 val / 597 test, batch size 4, pure-bf16, peak ~69 GiB of GPU memory).

variantrelative_l2focal_position_error_mmmax_pressure_error
DeepTFUS (paper)41.4% ± 8.62.89 ± 2.14 mm19.9% ± 15.8
DeepTFUS-tiny (paper)41.0%2.95 mm19.6%
base(mine)38.4% ± 7.86.49 ± 4.58 mm22.5% ± 11.6
Reproduction vs the paper on the 597-sample test set, mean ± per-sample std (lower better; best per column in bold). DeepTFUS-tiny is the paper’s ablation row, not a published smaller model.

relative_l2 is slightly improved upon (38.4% vs paper’s 41.4%), max_pressure_error is within paper’s spread (22.5% vs 19.9% ± 15.8), but focal_position_error_mm is the one that didn’t reproduce: my model’s focus lands 6.49 mm off on average versus the paper’s 2.89 mm. (Alongside these three, I also evaluated the three TUSNet metrics that apply to DeepTFUS: focal pressure error, focal IoU at FWHM, and inference time. Full breakdown in the appendix .)

Visually:

Qualitative results from test set

ground truth

model prediction

5th percentile placement (1.2 mm off)

ground truth, 5th percentile case
model prediction, 5th percentile case

25th percentile placement (3.4 mm off)

ground truth, 25th percentile case
model prediction, 25th percentile case

75th percentile placement (8.7 mm off)

ground truth, 75th percentile case
model prediction, 75th percentile case

95th percentile placement (14.7 mm off)

ground truth, 95th percentile case
model prediction, 95th percentile case
Four held-out test cases, best focal placement (top) to worst (bottom). These are the 5th, 25th, 75th, and 95th percentile of focal position error on our test set. The number in parentheses is how far (in mm) the model’s predicted focal spot lands from the ground-truth simulator’s.

I suspect the focal-position gap is a capacity issue, i.e., my 3.4M model isn't large enough. The base model plateaued early (weighted-MSE flatlined around epoch 30 and stayed there for the final 20 epochs) without overfitting (the train↔val gap stayed flat). The next training run (if there is one) should probably double base_width from 16 to 32.

Validation metrics over base training

Validation metrics across the 50-epoch base run: focal_position_error_mm, max_pressure_error, relative_l2 vs epoch.
Validation metrics across the 50-epoch base reproduction (n = 200 validation samples). Dashed line is the paper’s published test result on the same metric.

2. Closing the focal-position gap

Because this was a quick learning sprint, it seemed more reasonable to continue training the base checkpoint for another 5 to 10 epochs (especially since there had not been overfitting yet) with a modified loss and see how each variant moved the metrics. Each fine-tune below is a continuation from the base with a change to the loss function.

Can we add an explicit position loss without degrading the other metrics?

The paper’s loss

Training drives a composite loss with two terms:

L  =  Lweighted  +  λLgrad,λ=0.1\mathcal{L} \;=\; \mathcal{L}_{\text{weighted}} \;+\; \lambda\, \mathcal{L}_{\text{grad}}, \qquad \lambda = 0.1

Lweighted\mathcal{L}_{\text{weighted}} is a spatially-weighted MSE that puts most of its mass on the voxels near the focal spot, since that’s the only region that physically matters. In practice the weight runs from ~0.3 in background voxels to ~10 at the focal peak; without it, the model would learn to drive whole-volume MSE to zero by predicting near-zero pressure everywhere (the focal spot is <1% of the voxels).

Lgrad\mathcal{L}_{\text{grad}} is a gradient-consistency term that asks the predicted field’s spatial derivatives to match ground truth’s along each axis, encouraging clean focal-zone boundaries instead of blurry blobs.

Every fine-tune below is a one or two-line change to this recipe.

The five variants

Each starts from the base checkpoint and changes effectively a single line of the loss.

The dominant idea is a soft-argmax L1 focal-position term that supervises the location of the predicted hot spot directly:

Lfocal=soft-argmaxτ(P^norm)argmax(Pnorm)1\mathcal{L}_{\text{focal}} = \bigl\lVert \,\text{soft-argmax}_\tau(\hat P_{\text{norm}}) - \arg\max(P_{\text{norm}}) \,\bigr\rVert_1

soft-argmax is a temperature-controlled differentiable expectation of voxel coordinates over a softmax of the predicted pressure; smaller τ\tau gives sharper peaks, closer to true argmax[4]. The fine-tunes vary λfocal\lambda_{\text{focal}}, τ\tau, and whether the paper’s gradient term stays on as an anchor.

variant Asoft-argmax L1 (mild)

LA=Lweighted+0.1Lgrad+105Lfocalnew, τ=0.05\mathcal{L}_A = \mathcal{L}_{\text{weighted}} + 0.1\,\mathcal{L}_{\text{grad}} + \underbrace{10^{-5}\,\mathcal{L}_{\text{focal}}}_{\text{new, } \tau = 0.05}

Conservative test: does any soft-argmax move focal_mm without breaking rel_l2?

Finding. Modest position improvement (focal_mm 6.49 → 5.60 mm mean, ~14% better than base), and the only variant where max_p actually improves over the base run (22.5% → 20.4%). Not bad!

variant Bsoft-argmax L1 (cranked, anchor dropped)

LB=Lweighted+0anchor offLgrad+5×1055×ALfocal,τ=0.03\mathcal{L}_B = \mathcal{L}_{\text{weighted}} + \underbrace{0}_{\text{anchor off}}\cdot\mathcal{L}_{\text{grad}} + \underbrace{5 \times 10^{-5}}_{5\times \text{A}}\,\mathcal{L}_{\text{focal}}, \quad \tau = 0.03

5× stronger focal pressure, sharper τ\tau, gradient anchor dropped (testing the hypothesis that Lgrad\mathcal{L}_{\text{grad}} was diluting the focal signal).

Finding. Stronger position win (focal_mm 5.06 mm mean), but max_p degrades to 24% and stray off-target hot-spots multiply (probably due to dropping Lgrad\mathcal{L}_{\text{grad}}, which evidently had been pulling the predicted field’s spatial derivatives toward ground truth’s).

variant Csoft-argmax L1 (cranked, anchor restored)

LC=Lweighted+0.1Lgradanchor restored+5×105Lfocal\mathcal{L}_C = \mathcal{L}_{\text{weighted}} + \underbrace{0.1\,\mathcal{L}_{\text{grad}}}_{\text{anchor restored}} + 5 \times 10^{-5}\,\mathcal{L}_{\text{focal}}

B exactly except add the anchor Lgrad\mathcal{L}_{\text{grad}} back.

Finding. Statistically tied with B on the three paper-canonical metrics (focal_mm 5.11 vs 5.06 mm; rel_l2 and max_p essentially identical) but performs slightly better on our additional metrics.

variant Dsoft-argmax L1 (extreme, anchor restored)

LD=Lweighted+0.1Lgrad+1.5×1043×C,15×ALfocal,τ=0.02\mathcal{L}_D = \mathcal{L}_{\text{weighted}} + 0.1\,\mathcal{L}_{\text{grad}} + \underbrace{1.5 \times 10^{-4}}_{3\times \text{C},\,15\times \text{A}}\,\mathcal{L}_{\text{focal}}, \quad \tau = 0.02

The “position over everything” point on the trade-off curve. Sharpest τ\tau, biggest λfocal\lambda_{\text{focal}}, and 3× learning rate to push past C’s plateau.

Finding. Best focal_mm of any variant (4.19 ± 2.93 mm mean, ~35% better than base), but predictably the rest of the metrics degrade: rel_l2 climbs to 42.2% (the first variant to exceed the paper’s budget) and max_p climbs to 28.3%. (Somewhat of a trivial result?)

variant EDice on −6 dB iso-volume

LE=LC+3×103LDicenew\mathcal{L}_E = \mathcal{L}_C + \underbrace{3 \times 10^{-3}\,\mathcal{L}_{\text{Dice}}}_{\text{new}}

Variants A–D all landed the peak in roughly the right place but predicted a focal blob too wide around it: soft-argmax only constrains where the centroid sits, so a tight concentrated peak and a fat diffuse blob with the same centroid satisfy it equally well, and the optimizer prefers the fat one (smoother predictions are easier to fit elsewhere).

LDice\mathcal{L}_{\text{Dice}} directly goes after the shape by measuring how well two 3D blobs overlap: the focal region thresholded at half the peak pressure in the predicted field, and the corresponding region in the ground-truth field. Dice measures how well those two blobs overlap; a soft, differentiable version is used here so it can drive gradients. Effectively, the new loss term gives the model a signal that says exactly which voxels should and shouldn’t be in the focal region, and not just where its center should sit.

Finding. Brings max_p to 12.9% (well below the paper’s 19.9%), and is the only variant that tightens the predicted focal region rather than letting it spread. Trades back ~0.2 mm of focal_mm vs C, probably because the Dice and soft-argmax terms compete for the same gradient mass.

Validation curves over fine-tuning

Before the final test-set numbers, here’s how each variant evolved on the 200-sample validation set during the fine-tune epochs:

Validation metrics over fine-tune training

Validation metrics across the 5 fine-tune variants over epochs 0 to 8: focal_position_error_mm, max_pressure_error, relative_l2.
Validation metrics across the 5 fine-tunes, epochs 0 to 8 (n = 200 validation samples). Lower is better on every panel.

A plateaus quickly on focal_mm because its position term is gentle; D drives focal_mm down hardest but its rel_l2 climbs out of the paper’s budget; E’s Dice term collapses max_p almost immediately and stays there.

Final results

variantrelative_l2focal_position_error_mmmax_pressure_error
DeepTFUS (paper)41.4% ± 8.62.89 ± 2.14 mm19.9% ± 15.8
base(mine)38.4% ± 7.86.49 ± 4.58 mm22.5% ± 11.6
variant A38.9%5.60 mm20.4%
variant B38.8% ± 7.75.06 ± 3.57 mm24.0% ± 10.6
variant C38.8% ± 7.75.11 ± 3.76 mm23.9% ± 10.6
variant D42.2% ± 8.24.19 ± 2.93 mm28.3% ± 10.5
variant E40.1% ± 8.15.32 ± 3.44 mm12.9% ± 9.5
The 5 fine-tunes vs base & paper on the 597-sample test set, mean ± per-sample std (lower better; best per column in bold).

Three takeaways across the variants:

  • D closes the most focal-position gap at 4.19 mm mean (vs paper’s 2.89), but only by overshooting the rel_l2 budget. So D isn’t a clean reproduction; it’s the “position at any cost” corner. C is the strongest in-budget variant at 5.11 mm.
  • E (Dice) crushes max_p, pulling it from baseline’s 22.5% all the way down to 12.9% (a wide margin below paper’s 19.9%). The Dice term explicitly penalizes the predicted half-max region from spreading wider than ground truth’s, which the soft-argmax variants couldn’t do. E gives back ~0.2 mm of focal_mm relative to C in exchange.
  • The position metric still misses the paper. The best variant is 1.45× worse on the mean than what the paper reports. As discussed previously, probably a combination of my model not being large enough and some missed guesses at the architecture.

All six checkpoints (the base + five fine-tunes) are public on HuggingFace, browsable as the DeepTFUS reproduction collection . For the full per-variant numbers across the three paper-canonical metrics and the three TUSNet metrics that apply to DeepTFUS, see the appendix .

3. Analysis

These are quick experiments I had Claude run to help me understand the model better, which I felt were worth sharing.

What does the bottleneck encode?

The trained model squeezes its full 256³ input (a CT volume of the head plus a transducer point cloud) into a single 128-dimensional feature vector at its deepest layer (the “bottleneck”). For the most part, the decoder's downstream tasks (reconstructing the pressure field, placing the focal spot) depend on this bottleneck. So what's in there?

Linear probes give the cheapest answer. I trained tiny linear models to predict four candidate factors (subject identity, transducer azimuth, transducer aperture, predicted focal depth from skull) from the 128-d bottleneck.

What is recoverable from the bottleneck

2x2 panel of linear probe results: subject ID classification accuracy, transducer azimuth scatter, transducer aperture scatter, focal depth scatter.
For each input attribute, can a tiny linear model recover it from the 128-d bottleneck? If yes, the bottleneck encodes that attribute; if not, it has thrown that information away. Five-fold cross-validated logistic / Ridge probes on the bottleneck features (mean-pooled over the 32³ spatial grid). Linear probes are a lower bound on encoding (non-linear readouts could do better), so “encoded” here means “encoded enough that a linear model can recover it.”

So the bottleneck is a fairly complete joint summary of at least pose, geometry, focal depth, and anatomy.

Is the bottleneck smooth in transducer pose?

Two reasons this matters:

  1. Generalization: a smooth pose-to-feature mapping is what lets the model interpolate to unseen poses instead of memorizing per-pose patterns.

  2. Inverse problems: an eventual application (given a desired focal spot, where do I put the bowl?) is naturally solved by gradient descent on the placement, and that only converges if the loss is differentiable in placement all the way down to the bottleneck.

A direct check: rotate the bowl in small (~1.6°) steps around the head with everything else frozen (same patient, same bowl geometry, same CT), re-run the trained model at each step, and watch how a few of the deepest channels respond.

Sliding transducer sweep

Predicted pressure field with rotated transducer (red)

frame 1 of 35, sweep angle ? deg

preparing 0/35 frames

frame 1/35

8 of 128 bottleneck channels (mean-pooled over space) vs rotation angle

35 placements built by rotating one transducer ~54° around one held-out subject’s head in even ~1.6° steps. CT and bowl geometry are frozen across all frames; only the rotation changes.

The channel curves are decently smooth across the sweep, without folds or abrupt jumps which is good! This seems to be weak qualitative evidence that gradient-based pose optimization should at least be feasible. Again, this is more of a sanity check that comes with what I hope to be a nice view into our model :)

Update (5.14.26)

The authors got back to me with the §3.2 details I’d flagged in note 1 , and confirmed I may share them here.

DeepTFUS is ~38M parameters; DeepTFUS-tiny is ~10M. My reproduction at base_width=16 is ~3.4M, about 1/3 the size of DeepTFUS-tiny and ~1/11 the size of the full model. Although I did not train a new model, this confirms my stated suspicion of my reproduction being capacity-bound (i.e., model too small).

Acknowledgments

Deep thanks to the DeepTFUS authors (Vinkle Srivastav, Juliette Puel, Jonathan Vappou, Elijah Van Houten, Paolo Cabras, and Nicolas Padoy) for releasing the TFUScapes dataset  and for getting back to me with the architecture details. Full credit for the dataset and the architecture goes to them; this post is a reproduction attempt.

Appendix

Full per-variant test metrics

The DeepTFUS paper reports three test-set metrics (the ones tabulated in sections 1 and 2). Throughout the project I also tracked the three metrics from the TUSNet evaluation suite (Naftchi-Ardebili et al., 2024 , the earlier deep-learning model for this same transcranial-ultrasound prediction task) that DeepTFUS is also susceptible to: focal_pressure_error (how close the predicted pressure at the ground-truth focal location matches the simulator’s), focal_iou_fwhm (how cleanly the predicted half-max region overlaps with ground truth’s), and inference_latency_s. TUSNet’s remaining metrics either don’t apply to DeepTFUS (phase aberration correction) or are restatements of the three paper-canonical ones (focal positioning, peak pressure).

These matter because clinical deployment of focused ultrasound depends on field properties beyond the peak’s location and amplitude. The fine-tuning section may discuss these effects qualitatively without putting the underlying numbers in the main flow.

metricstatpaperbaseABCDE
relative_l2mean ± std0.414 ± 0.0860.384 ± 0.0780.3890.388 ± 0.0770.388 ± 0.0770.422 ± 0.0820.401 ± 0.081
relative_l2median0.3940.3690.3720.3720.3720.404n/a
focal_position_error_mmmean ± std2.89 ± 2.146.49 ± 4.585.605.06 ± 3.575.11 ± 3.764.19 ± 2.935.32 ± 3.44
focal_position_error_mmmedian2.455.154.644.184.153.614.39
max_pressure_errormean ± std0.199 ± 0.1580.225 ± 0.1160.2040.240 ± 0.1060.239 ± 0.1060.283 ± 0.1050.129 ± 0.095
max_pressure_errormedian0.1660.2170.2000.2390.2390.2870.110
focal_pressure_errormediann/a0.5280.4870.5020.4960.4750.421
focal_iou_fwhmmediann/a0.1430.1480.1360.1360.1210.152
inference_latency_smedian11.4 (RTX 4090)0.2330.2320.2330.2320.2330.232
Per-sample eval.py aggregates on the 597-sample test set. The first three rows are the paper-canonical metrics; the last three (focal_pressure_error, focal_iou_fwhm, inference_latency_s) are the TUSNet metrics that also apply to DeepTFUS. Lower is better on all rows except focal_iou_fwhm (higher is better) and inference_latency_s (informational). Bold cells are the project best per row. In the n/a cells, I had shut down my H100 before realizing I did not save the per-sample metrics for variant A (and one row for variant E).

Notes

[1]

The paper does not specify:

  • Total parameter count of the full DeepTFUS model (and its smaller “DeepTFUS-tiny” variant). I have no grounding for whether base_width should be 16 or 64 or something else entirely.
  • U-Net depth and base channel width. Everything in §3.2 is described qualitatively (“four encoder stages”, “multi-scale features”) without numbers.
  • The number n of Fourier encoding frequencies used in the transducer positional encoding. The paper says “Fourier features” with no count.
  • Cross-attention head count and per-head dimension.
  • Dynamic convolution kernel size and number of generated kernels per layer.

I emailed the authors asking for these and have not heard back at the time of writing. I went with what fit the 80GB of memory: base_width=16 (~3.4M params, at the small end of plausible for a 3D U-Net at this scale), depth 4, 8 Fourier frequencies, 4 cross-attention heads, dynamic-conv kernel size 3.

Update: the authors got back with the full spec; see the Update  section.

[2]

The two simplifications relative to the paper-text architecture, both for fitting on one 80GB GPU (batch size 4):

  • One-way cross-attention instead of bidirectional. The paper specifies “two multi-head attention blocks”: one where the CT volume queries the transducer features, and one where the transducer queries the CT. The first direction is degenerate the way the paper encodes the transducer (as a single pooled token). The softmax over one key is identically 1, so that block collapses to a learned broadcast of a projection of the transducer vector, which is the same function the encoder’s DynamicConv layers already serve. Dropping it saves memory at no architectural cost.
  • No FiLM in the decoder. The paper §3.2 puts FiLM modulation in the decoding path. I leave it off, partly to fit memory and partly because the paper’s own ablation table (Table 1, “No FiLM” row) shows lower max_pressure_error than full DeepTFUS, with the other metrics within one standard deviation.

(There’s a third memory-driven choice: I apply cross-attention only at the bottleneck and the deeper encoder levels, not at the full 256³ input. Although the paper says “each encoder level” full-resolution CA blows past 80 GB no matter what else you do.)

[3]

The exponent constant is α=5\alpha = 5, and the weights are normalized per-sample so Ev[w(v)]=1\mathbb{E}_v[w(v)] = 1:

w(v)=exp ⁣(α(P(v)maxvP(v)))Ev ⁣[exp ⁣(α(P(v)maxvP(v)))]w(v) = \frac{\exp\!\bigl(\alpha\,(P(v) - \max_{v'} P(v'))\bigr)}{\mathbb{E}_v\!\bigl[\exp\!\bigl(\alpha\,(P(v) - \max_{v'} P(v'))\bigr)\bigr]}

This is paper Eq. 5. The max\max subtraction is just numerical stability; it shifts the exponent into (,0](-\infty, 0] before exponentiation. Both pressures live in a per-sample log-normalized space, so the peak is around log20.69\log 2 \approx 0.69.

[4]

Concretely, with P^norm\hat P_{\text{norm}} the un-log-transformed predicted pressure in [0,1][0, 1]:

soft-argmaxτ(P^)=vsoftmax ⁣(P^(v)/τ)v\text{soft-argmax}_\tau(\hat P) = \sum_v \text{softmax}\!\bigl(\hat P(v)/\tau\bigr) \cdot v

It’s a temperature-controlled differentiable expectation of the voxel coordinate over a softmax of pressure. Smaller τ\tau gives a sharper peak in the softmax, closer to true argmax; larger τ\tau gives a softer support, biased toward the volume centroid. The hard argmax is argmaxvP^(v)\arg\max_v \hat P(v) but isn’t differentiable, so we can’t backprop through it.

Back to projects