Processing math: 100%
Liu Shouda coder

mxnet loss

2017-03-07

本文写作时候2017.3.7,mxnet loss有点晦涩,写作本文梳理一下。mxnet讨论区说近期会重构loss layer. 待重构后再做更新。

本文介绍mxnet loss,反向传播,自定义损失函数,metric. 先从mxnet center loss作为例子,逐渐扩大讲解范围。

mxnet center loss

自定义op

自定义op如下。具体参考这里

  1. class CenterLoss(mx.operator.CustomOp):
  2. def __init__(self, ctx, shapes, dtypes, num_class, alpha, scale=1.0):
  3. if not len(shapes[0]) == 2:
  4. raise ValueError('dim for input_data shoudl be 2 for CenterLoss')
  5. self.alpha = alpha
  6. self.batch_size = shapes[0][0]
  7. self.num_class = num_class
  8. self.scale = scale
  9. def forward(self, is_train, req, in_data, out_data, aux):
  10. labels = in_data[1].asnumpy()
  11. diff = aux[0]
  12. center = aux[1]
  13. # store x_i - c_yi
  14. for i in range(self.batch_size):
  15. diff[i] = in_data[0][i] - center[int(labels[i])]
  16. loss = mx.nd.sum(mx.nd.square(diff)) / self.batch_size / 2
  17. self.assign(out_data[0], req[0], loss)
  18. def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
  19. diff = aux[0]
  20. center = aux[1]
  21. sum_ = aux[2]
  22. # back grad is just scale * ( x_i - c_yi)
  23. grad_scale = float(self.scale)
  24. self.assign(in_grad[0], req[0], diff * grad_scale)
  25. # update the center
  26. labels = in_data[1].asnumpy()
  27. label_occur = dict()
  28. for i, label in enumerate(labels):
  29. label_occur.setdefault(int(label), []).append(i)
  30. for label, sample_index in label_occur.items():
  31. sum_[:] = 0
  32. for i in sample_index:
  33. sum_ = sum_ + diff[i]
  34. delta_c = sum_ / (1 + len(sample_index))
  35. center[label] += self.alpha * delta_c

在forward里面,根据in_data计算out_data;

在backward中,根据输入数据in_data, 输出梯度out_grad来计算后向梯度in_grad

这里forward, backward都使用了assign函数:

  1. def assign(self, dst, req, src):
  2. """Helper function for assigning into dst depending on requirements."""
  3. if req == 'null':
  4. return
  5. elif req == 'write' or req == 'inplace':
  6. dst[:] = src
  7. elif req == 'add':
  8. dst[:] += src

其实是根据req的操作类型了来将src拷贝到dst.

使用aux 0,1,2分别保存计算center loss 所需要的diff, center, sum.

前向时候记录diff, 后向时候累积diff的sum,并更新label的center。这里的sum其实不要使用aux吧?

设计要点见原博客

  • centerloss这个operator有个稍微特殊一点的地方,它内含的各个类别的中心center,既不是in_data也不是out_data,又不是普通的weight(它并不是随着grad去更新的,有自己的更新手段),所以最终把这个作为了auxiliary。但是对多卡来说,这里其实不是很准确,由于mxnet采用的数据划分,所以在每个卡上都有一套完整的参数。可是centerloss的参数center是aux变量,所以会各自有一套不同的center。暂时没有想到更好的方式来处理这个问题,不过由于特征抽取的网络是同步的,所以每个center差别也不大,最好的训练结果其实没有太大的区别,当然我只在MNIST这个小库上对比了一发,如果你有更好的方法请告诉我 :)
  • 最后,由于使用了aux变量,需要添加list_auxiliary_states函数表明需求, 并在infer_shape函数中返回正确的大小。每个参数都加上了bias字样,因为mxnet会根据不同的名字选择不同的初始化方式,bias一般都使用零初始化。另外一个比较坑的地方是,diff_bias的shape是需要和in_data一致,但是不要使用in_shape来推断。假如每个batch的大小的是100,那么diff_bias应该是(100,2)的shape,但是如果我有两个gpu卡,在每个卡在下一次infer_shape的时候会得到(50,2)。由于mxnet是先申请一次cpu的内存,然后把参数复制过去,这里会因为shape不一致而报错。所以这里才用了直接计算每张卡的样本数目,当作参数传递进去。

自定义op还需要一下定式化代码, 用于描述op输入,输出,shape信息。

  1. @mx.operator.register("centerloss")
  2. class CenterLossProp(mx.operator.CustomOpProp):
  3. def __init__(self, num_class, alpha, scale=1.0, batchsize=64):
  4. super(CenterLossProp, self).__init__(need_top_grad=False)
  5. # convert it to numbers
  6. self.num_class = int(num_class)
  7. self.alpha = float(alpha)
  8. self.scale = float(scale)
  9. self.batchsize = int(batchsize)
  10. def list_arguments(self):
  11. return ['data', 'label']
  12. def list_outputs(self):
  13. return ['output']
  14. def list_auxiliary_states(self):
  15. # call them 'bias' for zero initialization
  16. return ['diff_bias', 'center_bias', 'sum_bias']
  17. def infer_shape(self, in_shape):
  18. data_shape = in_shape[0]
  19. label_shape = (in_shape[0][0],)
  20. # store diff , same shape as input batch
  21. diff_shape = [self.batchsize, data_shape[1]]
  22. # store the center of each class , should be ( num_class, d )
  23. center_shape = [self.num_class, diff_shape[1]]
  24. # computation buf
  25. sum_shape = [diff_shape[1],]
  26. output_shape = [1, ]
  27. return [data_shape, label_shape], [output_shape], [diff_shape, center_shape, sum_shape]
  28. def create_operator(self, ctx, shapes, dtypes):
  29. return CenterLoss(ctx, shapes, dtypes, self.num_class, self.alpha, self.scale)

网络构造

  1. def get_symbol(batchsize=64):
  2. data = mx.symbol.Variable('data')
  3. softmax_label = mx.symbol.Variable('softmax_label')
  4. center_label = mx.symbol.Variable('center_label')
  5. # ...略去部分
  6. fc2 = mx.symbol.FullyConnected(data=embedding, num_hidden=10, name='fc2')
  7. ce_loss = mx.symbol.SoftmaxOutput(data=fc2, label=softmax_label, name='softmax')
  8. center_loss_ = mx.symbol.Custom(data=fc2, label=center_label, name='center_loss_', op_type='centerloss',\
  9. num_class=10, alpha=0.5, scale=0.01, batchsize=batchsize)
  10. center_loss = mx.symbol.MakeLoss(name='center_loss', data=center_loss_)
  11. mlp = mx.symbol.Group([ce_loss, center_loss])
  12. return mlp

这里可以看到将softmax的损失函数和center_loss的损失函数Group在一起,优化这个目标函数。

同一个label被同时用于softmax_label和center_label, 需要改造data loader:

  1. @property
  2. def provide_label(self):
  3. provide_label = self.data_iter.provide_label[0]
  4. return [('softmax_label', provide_label[1]), \
  5. ('center_label', provide_label[1])]

metric

为了随时观察损失函数的值,定义metric:

  1. class Accuracy(mx.metric.EvalMetric):
  2. def __init__(self, num=None):
  3. super(Accuracy, self).__init__('accuracy', num)
  4. def update(self, labels, preds):
  5. mx.metric.check_label_shapes(labels, preds)
  6. if self.num is not None:
  7. assert len(labels) == self.num
  8. pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
  9. label = labels[0].asnumpy().astype('int32')
  10. mx.metric.check_label_shapes(label, pred_label)
  11. self.sum_metric += (pred_label.flat == label.flat).sum()
  12. self.num_inst += len(pred_label.flat)
  13. # define some metric of center_loss
  14. class CenterLossMetric(mx.metric.EvalMetric):
  15. def __init__(self):
  16. super(CenterLossMetric, self).__init__('center_loss')
  17. def update(self, labels, preds):
  18. self.sum_metric = + preds[1].asnumpy()[0]
  19. self.num_inst = 1

合并metric:

  1. eval_metrics = mx.metric.CompositeEvalMetric()
  2. eval_metrics.add(Accuracy())
  3. eval_metrics.add(CenterLossMetric())

Group

合并多个损失函数。

  1. def get_loss(gram, content):
  2. gram_loss = []
  3. for i in range(len(gram.list_outputs())):
  4. gvar = mx.sym.Variable("target_gram_%d" % i)
  5. gram_loss.append(mx.sym.sum(mx.sym.square(gvar - gram[i])))
  6. cvar = mx.sym.Variable("target_content")
  7. content_loss = mx.sym.sum(mx.sym.square(cvar - content))
  8. return mx.sym.Group(gram_loss), content_loss

MakeLoss

  1. ce_loss = mx.symbol.SoftmaxOutput(data=fc2, label=softmax_label, name='softmax')
  2. center_loss_ = mx.symbol.Custom(data=fc2, label=center_label, name='center_loss_', op_type='centerloss',\
  3. num_class=10, alpha=0.5, scale=0.01, batchsize=batchsize)
  4. center_loss = mx.symbol.MakeLoss(name='center_loss', data=center_loss_)
  5. mlp = mx.symbol.Group([ce_loss, center_loss])

output layer

  • SoftmaxOutput

    out[i,:]=softmax(data[i,:]) softmax(x)=[...,exp(x[j])exp(x[0])+...+exp(x[k1]),...]

    Softmax with logit loss.In the forward pass, the softmax output is returned. In the backward pass, the logit loss, also called cross-entroy loss, is added.**

  • LinearRegressionOutput

  • LogisticRegressionOutput

  • MAERegressionOutput

  • SVMOutput

使用各种output layer可以不指定参数直接用backward()计算梯度, 而如果是neural style这种不是output layer的要手动将gram_diff, content_diff backward回去

  1. gram_executors[j].forward(is_train=True)
  2. gram_diff[j] = gram_executors[j].outputs[0]-target_grams[j]
  3. gram_executors[j].backward(gram_diff[j])
  4. gram_grad[j][i:i+1] = gram_executors[j].grad_dict['gram_data'] / batch_size
  1. content_grad = alpha*(desc_executor.outputs[len(style_layer)]-target_content) / layer_size / batch_size #注意: 手动计算梯度
  2. desc_executor.backward(gram_grad+[content_grad])
  3. gene_executor.backward(desc_executor.grad_dict['data']+tv_grad_executor.outputs[0])
  4. for i, var in enumerate(gene_executor.grad_dict):
  5. if var != 'data':
  6. optimizer.update(i, gene_executor.arg_dict[var], gene_executor.grad_dict[var], optim_states[i])

backward是计算梯度,update是真正根据梯度更新参数。更新策略在optimizer里面.

下面以mxnet的softmax output和regression output为例, 介绍mxnet的output layer。在这些output layer中其实loss并没有被计算,而是直接把其梯度计算出来就好了。

softmax output

softmax_output-inl.h:

  1. template<typename xpu, typename DType>
  2. class SoftmaxOutputOp : public Operator {
  3. virtual void Forward() {
  4. //...
  5. Softmax(out, data);
  6. }
  7. virtual void Backward() {
  8. //...
  9. SoftmaxGrad(grad, out, label);
  10. }
  11. }

mshadow/tensor_cpu-inl.h:

  1. template<typename DType>
  2. inline void Softmax(Tensor<cpu, 1, DType> dst,
  3. const Tensor<cpu, 1, DType> &energy) {
  4. DType mmax = energy[0];
  5. for (index_t x = 1; x < dst.size(0); ++x) {
  6. if (mmax < energy[x]) mmax = energy[x];
  7. }
  8. DType sum = DType(0.0f);
  9. for (index_t x = 0; x < dst.size(0); ++x) {
  10. dst[x] = std::exp(energy[x] - mmax);
  11. sum += dst[x];
  12. }
  13. for (index_t x = 0; x < dst.size(0); ++x) {
  14. dst[x] /= sum;
  15. }
  16. }
  17. template<typename DType>
  18. inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
  19. const Tensor<cpu, 2, DType> &src,
  20. const Tensor<cpu, 1, DType> &label) {
  21. #pragma omp parallel for
  22. for (openmp_index_t y = 0; y < dst.size(0); ++y) {
  23. const index_t k = static_cast<int>(label[y]);
  24. for (index_t x = 0; x < dst.size(1); ++x) {
  25. if (x == k) {
  26. dst[y][k] = src[y][k] - 1.0f;
  27. } else {
  28. dst[y][x] = src[y][x];
  29. }
  30. }
  31. }
  32. }

regression output

regression_output-inl.h:

  1. template<typename xpu, typename ForwardOp, typename BackwardOp>
  2. class RegressionOutputOp : public Operator {
  3. virtual void Forward() {
  4. //...
  5. Assign(out, req[reg_enum::kOut], F<ForwardOp>(data));
  6. }
  7. virtual void Backward() {
  8. //...
  9. //不需要计算loss,只要有对应的后向计算就可以了
  10. Assign(grad, req[reg_enum::kData], param_.grad_scale/num_output*
  11. F<BackwardOp>(out, reshape(label, grad.shape_)));
  12. }
  13. }

regression_output.cc:

  1. template<>
  2. Operator *CreateRegressionOutputOp<cpu>(reg_enum::RegressionOutputType type,
  3. RegressionOutputParam param) {
  4. switch (type) {
  5. case reg_enum::kLinear:
  6. return new RegressionOutputOp<cpu, mshadow::op::identity, mshadow::op::minus>(param); //不同的regression只是前向,后向op不一样
  7. case reg_enum::kLogistic:
  8. return new RegressionOutputOp<cpu, mshadow_op::sigmoid, mshadow::op::minus>(param);
  9. case reg_enum::kMAE:
  10. return new RegressionOutputOp<cpu, mshadow::op::identity, mshadow_op::minus_sign>(param);
  11. default:
  12. LOG(FATAL) << "unknown activation type " << type;
  13. }
  14. return nullptr;
  15. }

metric输出

如center loss的metric所示,loss的计算与metric的输出是独立的。两者需要分别实现。

对于使用LogisticRegressionOutput等loss layer, 在forward之后可以马上backward, 无需管是否update_metric.

metric输出并不是必须的,也可以手动进行输出

  1. loss[len(style_layer)] += alpha*np.sum(np.square((desc_executor.outputs[len(style_layer)]-target_content).asnumpy()/np.sqrt(layer_size))) / batch_size #注意:手动计算loss,不用metric
  2. if epoch % 20 == 0:
  3. print 'loss', sum(loss), np.array(loss)

L1, L2 norm

smooth_l1:

  1. cls_score = mx.symbol.FullyConnected(name='cls_score', data=drop7, num_hidden=num_classes)
  2. cls_prob = mx.symbol.SoftmaxOutput(name='cls_prob', data=cls_score, label=label, normalization='batch')
  3. # bounding box regression
  4. bbox_pred = mx.symbol.FullyConnected(name='bbox_pred', data=drop7, num_hidden=num_classes * 4)
  5. bbox_loss_ = bbox_weight * mx.symbol.smooth_l1(name='bbox_loss_', scalar=1.0, data=(bbox_pred - bbox_target))
  6. bbox_loss = mx.sym.MakeLoss(name='bbox_loss', data=bbox_loss_, grad_scale=1.0 / config.TRAIN.BATCH_ROIS)
  7. cls_prob = mx.symbol.Reshape(data=cls_prob, shape=(config.TRAIN.BATCH_IMAGES, -1, num_classes), name='cls_prob_reshape')
  8. bbox_loss = mx.symbol.Reshape(data=bbox_loss, shape=(config.TRAIN.BATCH_IMAGES, -1, 4 * num_classes), name='bbox_loss_reshape')
  9. group = mx.symbol.Group([cls_prob, bbox_loss])

从上面可以看出,有多个损失函数可以用Group合并起来。

L2

l2正则化通过weight decay实现:

先通过backward算出纯梯度,再通过一下公式更新正则化的梯度。个人看法...

  1. weight[:] += -lr * (grad + weight_decay * weight)

下一篇 deconvolution