网上看到普遍的答案是这个
class TVLoss(nn.Module):
def __init__(self,TVLoss_weight=1):
super(TVLoss,self).__init__()
self.TVLoss_weight = TVLoss_weight
def forward(self,x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self._tensor_size(x[:,:,1:,:])
count_w = self._tensor_size(x[:,:,:,1:])
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
def _tensor_size(self,t):
return t.size()[1]*t.size()[2]*t.size()[3]
这里给出的说的是β=2,且不支持变更. 所以按照这里给出的公式 https://blog.csdn.net/yexiaogu1104/article/details/88395475 β/2, 当β=2 那就是 1 也就是不进行任何操作. 所以最后 return 这里为什么会返回一个 self.TVLoss_weight2, 为啥要2 呢..