39#include <torch/torch.h>
42#include "itkTimeProbe.h"
72 Loss(
bool isLossNormalized)
74 if (!isLossNormalized)
104 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) = 0;
107 torch::Tensor & movingOutput,
108 torch::Tensor & jacobian,
109 torch::Tensor & nonZeroJacobianIndices)
113 nonZeroJacobianIndices.flatten(),
116 virtual torch::Tensor
121 m_Derivative.index_add_(0, nonZeroJacobianIndices.flatten(), jacobian.flatten());
134 virtual torch::Tensor
185 std::unique_ptr<Loss>
193 throw std::runtime_error(
"Error: Unknown loss function " + name);
222 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput)
override
225 this->
m_Value += (fixedOutput - movingOutput).abs().mean(1).sum().item<
double>();
232 torch::Tensor diffOutput = fixedOutput - movingOutput;
233 this->
m_Value += diffOutput.abs().mean(1).sum().item<
double>();
234 return -torch::sign(diffOutput) / fixedOutput.size(1);
252 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput)
override
255 this->
m_Value += (fixedOutput - movingOutput).pow(2).mean(1).sum().item<
double>();
262 torch::Tensor diffOutput = fixedOutput - movingOutput;
263 this->
m_Value += diffOutput.pow(2).mean(1).sum().item<
double>();
264 return -2 * diffOutput / fixedOutput.size(1);
285 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput)
override
288 torch::Tensor intersectionSum = (fixedOutput * movingOutput).sum(1);
289 torch::Tensor unionSum = (fixedOutput + movingOutput).sum(1);
292 torch::Tensor isEmpty = (unionSum == 0);
301 torch::Tensor unionSumSafe = unionSum + isEmpty.to(unionSum.scalar_type());
304 torch::Tensor dice = 2.0 * intersectionSum / unionSumSafe;
308 dice.masked_fill_(isEmpty, 1.0);
311 this->
m_Value -= dice.sum().item<
double>();
319 torch::Tensor intersectionSum = (fixedOutput * movingOutput).sum(1);
320 torch::Tensor unionSum = (fixedOutput + movingOutput).sum(1);
323 torch::Tensor isEmpty = (unionSum == 0);
326 torch::Tensor unionSumSafe = unionSum + isEmpty.to(unionSum.scalar_type());
329 torch::Tensor dice = 2.0 * intersectionSum / unionSumSafe;
332 dice.masked_fill_(isEmpty, 1.0);
333 this->
m_Value -= dice.sum().item<
double>();
337 torch::Tensor grad = -2.0 * (fixedOutput * unionSumSafe.unsqueeze(-1) - intersectionSum.unsqueeze(-1)) /
338 (unionSumSafe * unionSumSafe).unsqueeze(-1);
341 grad.masked_fill_(isEmpty.unsqueeze(-1), 0.0);
369 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput)
override
372 torch::Tensor dotProduct = (fixedOutput * movingOutput).sum(1);
373 torch::Tensor normFixed = torch::norm(fixedOutput, 2, 1);
374 torch::Tensor normMoving = torch::norm(movingOutput, 2, 1);
375 torch::Tensor cosine = dotProduct / (normFixed * normMoving);
376 torch::Tensor expL1 = torch::exp(-
m_Lambda * (fixedOutput - movingOutput).abs());
377 this->
m_Value -= (cosine.unsqueeze(-1) * expL1).mean(1).sum().item<
double>();
384 torch::Tensor diffOutput = fixedOutput - movingOutput;
385 torch::Tensor dotProduct = (fixedOutput * movingOutput).sum(1);
386 torch::Tensor normFixed = torch::norm(fixedOutput, 2, 1);
387 torch::Tensor normMoving = torch::norm(movingOutput, 2, 1);
388 torch::Tensor v = (normFixed * normMoving);
390 torch::Tensor cosine = dotProduct / (v);
391 torch::Tensor expL1 = torch::exp(-
m_Lambda * (fixedOutput - movingOutput).abs());
393 torch::Tensor dCosine = -(fixedOutput / v.unsqueeze(-1) -
394 (fixedOutput * movingOutput * movingOutput) / (v * normMoving.pow(2)).unsqueeze(-1));
395 torch::Tensor dexpL1 = -torch::sign(diffOutput) * expL1 / fixedOutput.size(1);
396 this->
m_Value -= (cosine.unsqueeze(-1) * expL1).mean(1).sum().item<
double>();
397 return dCosine * dexpL1 + cosine.unsqueeze(-1) * dexpL1;
416 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput)
override
419 torch::Tensor dotProduct = (fixedOutput * movingOutput).sum(1);
420 torch::Tensor normFixed = torch::norm(fixedOutput, 2, 1);
421 torch::Tensor normMoving = torch::norm(movingOutput, 2, 1);
422 this->
m_Value -= (dotProduct / (normFixed * normMoving)).sum().item<
double>();
429 torch::Tensor dotProduct = (fixedOutput * movingOutput).sum(1);
430 torch::Tensor normFixed = torch::norm(fixedOutput, 2, 1);
431 torch::Tensor normMoving = torch::norm(movingOutput, 2, 1);
432 torch::Tensor v = (normFixed * normMoving);
433 this->
m_Value -= (dotProduct / v).sum().item<
double>();
434 return -(fixedOutput / v.unsqueeze(-1) -
435 (fixedOutput * movingOutput * movingOutput) / (v * normMoving.pow(2)).unsqueeze(-1));
453 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput)
override
456 this->
m_Value -= (fixedOutput * movingOutput).sum(1).sum().item<
double>();
463 this->
m_Value -= (fixedOutput * movingOutput).sum(1).sum().item<
double>();
495 m_Sff = torch::zeros({ output.size(1) }, output.options());
496 m_Smm = torch::zeros({ output.size(1) }, output.options());
497 m_Sfm = torch::zeros({ output.size(1) }, output.options());
498 m_Sf = torch::zeros({ output.size(1) }, output.options());
499 m_Sm = torch::zeros({ output.size(1) }, output.options());
505 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput)
override
508 m_Sff += (fixedOutput * fixedOutput).sum(0);
509 m_Smm += (movingOutput * movingOutput).sum(0);
510 m_Sfm += (fixedOutput * movingOutput).sum(0);
511 m_Sf += fixedOutput.sum(0);
512 m_Sm += movingOutput.sum(0);
517 torch::Tensor & movingOutput,
518 torch::Tensor & jacobian,
519 torch::Tensor & nonZeroJacobianIndices)
override
531 1, nonZeroJacobianIndices.flatten(), (fixedOutput.unsqueeze(-1) * jacobian).permute({ 1, 0, 2 }).flatten(1, 2));
533 1, nonZeroJacobianIndices.flatten(), (movingOutput.unsqueeze(-1) * jacobian).permute({ 1, 0, 2 }).flatten(1, 2));
534 m_Sdm.index_add_(1, nonZeroJacobianIndices.flatten(), (jacobian).permute({ 1, 0, 2 }).flatten(1, 2));
546 const double N = fixedOutput.size(0);
547 torch::Tensor sff = (fixedOutput * fixedOutput).sum(0);
548 torch::Tensor smm = (movingOutput * movingOutput).sum(0);
549 torch::Tensor sfm = (fixedOutput * movingOutput).sum(0);
550 torch::Tensor sf = fixedOutput.sum(0);
551 torch::Tensor sm = movingOutput.sum(0);
559 torch::Tensor u = sfm - (sf * sm / N);
560 torch::Tensor v = torch::sqrt(sff - sf * sf / N) * torch::sqrt(smm - sm * sm / N);
562 torch::Tensor u_p = fixedOutput - sf.unsqueeze(0) / N;
563 return -((u_p - u.unsqueeze(0) * (movingOutput - sm.unsqueeze(0) / N) / (smm - sm * sm / N).unsqueeze(0)) /
576 return -(u / v).mean().item<
double>();
600 const auto * nccOther =
dynamic_cast<const NCC *
>(&other);
603 m_Sff += nccOther->m_Sff;
604 m_Smm += nccOther->m_Smm;
605 m_Sfm += nccOther->m_Sfm;
606 m_Sf += nccOther->m_Sf;
607 m_Sm += nccOther->m_Sm;
610 m_Sfdm += nccOther->m_Sfdm;
611 m_Smdm += nccOther->m_Smdm;
612 m_Sdm += nccOther->m_Sdm;
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Singleton factory to register and create Loss instances by string name.
std::function< std::unique_ptr< Loss >()> CreatorFunc
std::unordered_map< std::string, CreatorFunc > factoryMap
static LossFactory & Instance()
void RegisterLoss(const std::string &name, CreatorFunc creator)
std::unique_ptr< Loss > Create(const std::string &name)
Loss(bool isLossNormalized)
virtual torch::Tensor GetDerivative(double N) const
virtual Loss & operator+=(const Loss &other)
virtual void updateValueAndDerivativeInStaticMode(torch::Tensor &fixedOutput, torch::Tensor &movingOutput, torch::Tensor &jacobian, torch::Tensor &nonZeroJacobianIndices)
virtual void initialize(torch::Tensor &output)
void setNumberOfParameters(int numberOfParameters)
torch::Tensor m_Derivative
virtual torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput)=0
virtual double GetValue(double N) const
virtual void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput)=0
void updateDerivativeInJacobianMode(torch::Tensor &jacobian, torch::Tensor &nonZeroJacobianIndices)
Normalized Cross Correlation loss over feature vectors.
void updateValueAndDerivativeInStaticMode(torch::Tensor &fixedOutput, torch::Tensor &movingOutput, torch::Tensor &jacobian, torch::Tensor &nonZeroJacobianIndices) override
double GetValue(double N) const override
NCC & operator+=(const Loss &other) override
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
void initialize(torch::Tensor &output) override
torch::Tensor GetDerivative(double N) const override
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
RegisterLoss(const std::string &name)
RegisterLoss< L2 > MSE_reg("L2")
RegisterLoss< L1Cosine > L1CosineReg("L1Cosine")
RegisterLoss< DotProduct > DotProductReg("DotProduct")
RegisterLoss< Cosine > CosineReg("Cosine")
RegisterLoss< Dice > Dice_reg("Dice")
RegisterLoss< NCC > NCC_reg("NCC")
RegisterLoss< L1 > L1_reg("L1")