...
Code Block | ||||
---|---|---|---|---|
| ||||
MXNET_REGISTER_API("_npi.zeros") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { // Part1: populate NodeAttrs // Part1: populate NodeAttrs using namespace runtime; const nnvm::Op* op = Op::Get("_npi_zeros"); nnvm::NodeAttrs attrs; op::InitOpParam param; if (args[0].type_code() == kDLInt) { param.shape = TShape(1, args[0].operator int64_t()); } else { param.shape = TShape(args[0].operator ObjectRef()); } if (args[1].type_code() == kNull) { param.dtype = mshadow::kFloat32; } else { param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); } attrs.parsed = std::move(param); attrs.op = op; SetAttrDict<op::InitOpParam>(&attrs); if (args[2].type_code() != kNull) { attrs.dict["ctx"] = args[2].operator std::string(); } // Part2: invoke int num_outputs = 0; auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); *ret = ndoutputs[0]; }); |
...