跳转至

oneflow 中的魔术方法——以 __setitem__为例

魔术方法,python中所有以”__”(双下划线)作为名字开头和结尾的方法。具体介绍见 Python 魔术方法指南 — PyCoder's Weelky CN

c++

class TensorSetItemFunctor {
 public:
  TensorSetItemFunctor() {}
  Maybe<void> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index,
                         const std::shared_ptr<one::Tensor>& value) const {
    std::vector<detail::Slice> slice_indices;
    TensorTuple tensor_indices;
    std::vector<int64_t> expand_dims;
    std::vector<int64_t> target_dims;
    JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &expand_dims,
                             &target_dims));
    if (expand_dims.size()) {
      slice_indices = *JUST(RemoveExpandDimSlice(slice_indices, expand_dims));
    }
    int64_t ndims = x->shape()->NumAxes();
    CHECK_EQ_OR_RETURN(slice_indices.size(), ndims)
        << Error::RuntimeError() << "Failed to prepare slice indices.";
    // Not support combined indexing now
    if (!tensor_indices.empty()) {
      CHECK_OR_RETURN(tensor_indices.size() == ndims
                      && std::all_of(tensor_indices.begin(), tensor_indices.end(),
                                     [](const std::shared_ptr<Tensor>& index) { return index; }))
          << Error::RuntimeError()
          << "Combining indexing is not support for tensor setitem currently";
    }

    Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));
    if (target_shape.Count(0) == 0) { return Maybe<void>::Ok(); }

    const auto& value_shape = value->shape();
    bool matched = [&]() {
      for (int i = 0; i < value_shape->NumAxes() - target_shape.NumAxes(); ++i) {
        if (value_shape->At(i) != 1) { return false; }
      }
      return true;
    }();
    CHECK_OR_RETURN(matched) << Error::RuntimeError() << "The tensor size mismatch. Target sizes: "
                             << target_shape.ToString()
                             << ", value sizes: " << value_shape->ToString();
    std::shared_ptr<one::Tensor> value_tensor(value);
    // TODO: replace reshape by unsqueeze with view mechanism.
    // after here, each scalar tensor will be one with one dimension.
    for (auto& tensor : tensor_indices) {
      if (tensor->ndim() == 0) { tensor = JUST(functional::Reshape(tensor, Shape({1}))); }
    }
    if (tensor_indices.size() == ndims) {  // advance indexing
      std::shared_ptr<Tensor> indices = JUST(functional::Stack(tensor_indices, 0));
      if (indices->shape()->elem_cnt() == 0) { return Maybe<void>::Ok(); }
      indices = JUST(functional::Transpose(indices, {1, 0}));
      value_tensor = JUST(functional::Expand(value_tensor, {indices->shape()->At(0)}));
      JUST(functional::TensorScatterNdUpdate(x, indices, value_tensor, /*inplace=*/true));
    } else {                              // slice update
      if (target_shape.NumAxes() != 0 &&  // NOLINT
          /*need_expand=*/value_shape->Count(0) != target_shape.Count(0)) {
        // Remove the beginning redundant 1-dimensions.
        if (value_shape->NumAxes() > target_shape.NumAxes()) {
          int64_t start_axis = value_shape->NumAxes() - target_shape.NumAxes();
          const auto& shape = JUST(value_shape->Slice(start_axis, value_shape->NumAxes()));
          value_tensor = JUST(Reshape(value, *shape));
        }
        value_tensor = JUST(Expand(value_tensor, target_shape));
      }
      std::vector<int64_t> start(ndims), end(ndims), step(ndims);
      DimVector slice_dims(ndims);
      for (int i = 0; i < ndims; ++i) {
        const auto& slice = slice_indices.at(i);
        start[i] = slice.start();
        end[i] = slice.end();
        step[i] = slice.step();
        slice_dims[i] = (end[i] - start[i] + step[i] - 1) / step[i];
      }
      Shape slice_shape(slice_dims);
      if (slice_shape != *(value_tensor->shape())) {
        value_tensor = JUST(Reshape(value_tensor, slice_shape));
      }
      if (x->is_local()) {
        JUST(SliceUpdate(x, value_tensor, start, end, step, /*inplace=*/true));
      } else {
        if (x->requires_grad() && autograd::GradMode::is_enabled()) {
          return Error::RuntimeError() << "Backward is not support for consistent tensor setitem,"
                                          "please use oneflow.no_grad() to disable autograd "
                                          "currently. We will fix this problem soon.";
        }
        JUST(LogicalSliceAssign(x, value_tensor, start, end, step));
      }
    }
    return Maybe<void>::Ok();
  }
};

ONEFLOW_FUNCTION_LIBRARY(m) {
  m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
}

yaml

- name: "tensor_setitem"
  signature: "Void (Tensor x, TensorIndex index, Tensor value) => TensorSetItem"
  bind_python: True

python

def _setitem(self, key, value):
    if self.is_global:
        if isinstance(value, (int, float)):
            value = flow._C.global_constant(
                [1],
                value,
                dtype=self.dtype,
                placement=self.placement,
                sbp=flow.sbp.broadcast,
            )
        else:
            if value.is_global:
                value = value.to_global(sbp=flow.sbp.broadcast)
                # TODO: remove these lines after asymmetric boxing is ready
                local_tensor = value.to_local()
                if local_tensor.nelement() == 0:
                    local_tensor = flow.zeros(*value.shape)
                value = local_tensor.to_global(self.placement, sbp=flow.sbp.broadcast)
            else:
                value = value.to_global(self.placement, sbp=flow.sbp.broadcast)
    else:
        if isinstance(value, (int, float)):
            value = flow._C.constant([1], value, dtype=self.dtype, device=self.device)
        else:
            value = value.to(device=self.device)

    flow._C.tensor_setitem(self, key, value)
    return self

def RegisterMethods():
    Tensor.__setitem__ = _setitem
Back to top