#pragma once #include #include #include #include #include #include #include #include /* Convolution prepacked parameters serialization. * * Version 1 * * - Fields: * 1. weight * 2. bias * 3. stride x kSpatialDim * 4. padding x kSpatialDim * 5. dilation x kSpatialDim * 6. groups * * Version 2 * * - Fields: * 0. version (string) * 1. list of non-optional tensors * 0: packed parameters (int16_t) * - kSpatialDim * - stride x kSpatialDim * - padding x kSpatialDim * - dilation x kSpatialDim * - output_padding x kSpatialDim * - groups * - transpose (0 or 1) * 1: weight * 2. list of optional tensors * 0: bias * * Version 3 * * - Fields: * 0. version (int64_t) * 1. list of int64_t configuration values * - kSpatialDim * - stride x kSpatialDim * - padding x kSpatialDim * - dilation x kSpatialDim * - output_padding x kSpatialDim * - groups * - flags (bitmask) * - (1 << 0) transpose (1 = yes) * 2. list of optional tensors * 0: None (helps with type inference) * 1: weight (this must be present) * 2: bias */ using ConvParamsSerializationTypeV2 = std::tuple< // version, for versions 2 and up std::string, // non-optional tensors std::vector, // optional tensors std::vector>>; using ConvParamsSerializationTypeV3 = std::tuple< // version, int for versions 3 and up int64_t, // configuration values std::vector, // optional tensors std::vector>>; // Parses any historical conv packed params format into // the current format. template ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) { // determine the version based on IValue contents int version = -1; if (v.isTuple()) { const auto& elements = v.toTupleRef().elements(); if (elements.size() > 0) { auto firstElement = elements[0]; if (firstElement.isTensor()) { version = 1; } else if (firstElement.isString()) { std::string version_str = firstElement.toStringRef(); // note: not parsing the string to automatically handle bad // inputs if (version_str == "2") { version = 2; } } else if (firstElement.isInt()) { auto raw_version = firstElement.toInt(); if (raw_version == 3) { version = 3; } } } } TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version"); if (version == 1) { // version 1 - convert to version 3 manually const auto& elements = v.toTupleRef().elements(); at::Tensor weight = elements[0].toTensor(); c10::optional bias = elements[1].toOptional(); torch::List stride_x_kSpatialDim = elements[2].toTensorList(); torch::List padding_x_kSpatialDim = elements[3].toTensorList(); torch::List dilation_x_kSpatialDim = elements[4].toTensorList(); at::Tensor groups = elements[5].toTensor(); std::vector non_optional; std::vector> optional; std::vector config_vals; config_vals.push_back(kSpatialDim); for (const auto i : c10::irange(stride_x_kSpatialDim.size())) { auto stride = stride_x_kSpatialDim.get(i); config_vals.push_back(stride[0].item()); } for (const auto i : c10::irange(padding_x_kSpatialDim.size())) { auto padding = padding_x_kSpatialDim.get(i); config_vals.push_back(padding[0].item()); } for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) { auto dilation = dilation_x_kSpatialDim.get(i); config_vals.push_back(dilation[0].item()); } // output_padding does not exist in v1, so we fill in a default value for (const auto i : c10::irange(kSpatialDim)) { (void)i; // Suppress unused variable config_vals.push_back(0); } config_vals.push_back(groups[0].item()); // transpose does not exist in v1, so we fill in a default value config_vals.push_back(0); std::vector> tensors; tensors.emplace_back(); tensors.emplace_back(weight); tensors.emplace_back(bias); int64_t version = 3; return std::tie(version, config_vals, tensors); } else if (version == 2) { // version 2 const auto& elements = v.toTupleRef().elements(); std::vector non_optional = elements[1].toTensorList().vec(); std::vector> optional; if (elements[2].isTensorList()) { for (const auto& elem : elements[2].toTensorList()) { optional.emplace_back(static_cast(elem)); } } else { for (const auto& elem : elements[2].toList()) { optional.emplace_back(static_cast(elem).toOptional()); } } auto config_a = non_optional[0].accessor(); std::vector config_vals; config_vals.reserve(config_a.size(0)); for (const auto i : c10::irange(config_a.size(0))) { config_vals.emplace_back(config_a[i]); } auto weight = non_optional[1]; auto bias = optional[0]; std::vector> tensors; tensors.emplace_back(); tensors.emplace_back(weight); tensors.emplace_back(bias); int64_t version = 3; return std::tie(version, config_vals, tensors); } else if (version == 3) { return v.to(); } else { TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ", version); } } #define QCONV_SERIALIZATION_VERSION 2 #if QCONV_SERIALIZATION_VERSION == 2 using ConvParamsSerializationType = ConvParamsSerializationTypeV2; template ConvParamsSerializationTypeV2 serialize_conv( const c10::intrusive_ptr>& params) { std::string version = "2"; std::vector non_optional; std::vector> optional; // create a packed int8_t tensor for conv params std::vector params_vec; params_vec.push_back(kSpatialDim); auto stride = params->stride().vec(); params_vec.insert(params_vec.end(), stride.begin(), stride.end()); auto padding = params->padding().vec(); params_vec.insert(params_vec.end(), padding.begin(), padding.end()); auto dilation = params->dilation().vec(); params_vec.insert(params_vec.end(), dilation.begin(), dilation.end()); auto output_padding = params->output_padding().vec(); params_vec.insert(params_vec.end(), output_padding.begin(), output_padding.end()); params_vec.push_back(params->groups()); params_vec.push_back(params->transpose()); int64_t vec_size = params_vec.size(); at::Tensor params_tensor = at::from_blob( params_vec.data(), {vec_size}, at::TensorOptions().dtype(at::kShort)) // clone to retain ownership of the data .clone(); at::Tensor weight; c10::optional bias; std::tie(weight, bias) = params->unpack(); non_optional.emplace_back(std::move(params_tensor)); non_optional.emplace_back(std::move(weight)); optional.emplace_back(std::move(bias)); return std::tie(version, non_optional, optional); } #elif QCONV_SERIALIZATION_VERSION == 3 using ConvParamsSerializationType = ConvParamsSerializationTypeV3; template ConvParamsSerializationTypeV3 serialize_conv( const c10::intrusive_ptr>& params) { std::vector config_vals; config_vals.push_back(kSpatialDim); auto stride = params->stride().vec(); config_vals.insert(config_vals.end(), stride.begin(), stride.end()); auto padding = params->padding().vec(); config_vals.insert(config_vals.end(), padding.begin(), padding.end()); auto dilation = params->dilation().vec(); config_vals.insert(config_vals.end(), dilation.begin(), dilation.end()); auto output_padding = params->output_padding().vec(); config_vals.insert(config_vals.end(), output_padding.begin(), output_padding.end()); config_vals.push_back(params->groups()); config_vals.push_back(params->transpose()); at::Tensor weight; c10::optional bias; std::tie(weight, bias) = params->unpack(); std::vector> tensors; tensors.emplace_back(); tensors.emplace_back(weight); tensors.emplace_back(bias); int64_t version = 3; return std::tie(version, config_vals, tensors); } #else #error "Invalid qconv serialization version." #endif template c10::intrusive_ptr> deserialize_conv( ConvParamsSerializationTypeV3 state) { int64_t version; std::vector config_vals; std::vector> tensors; std::tie(version, config_vals, tensors) = state; TORCH_INTERNAL_ASSERT(version == 3, "Unexpected serialized qconv version: ", version); TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size()); c10::optional weight = tensors[1]; c10::optional bias = tensors[2]; TORCH_INTERNAL_ASSERT(weight, "Weight should always be present in serialized qconv."); torch::List stride, padding, output_padding, dilation; // skip kSpatialDim int idx = 1; for (const auto i : c10::irange(kSpatialDim)) { (void)i; // Suppress unused variable stride.emplace_back(config_vals.at(idx)); idx++; } for (const auto i : c10::irange(kSpatialDim)) { (void)i; // Suppress unused variable padding.emplace_back(config_vals.at(idx)); idx++; } for (const auto i : c10::irange(kSpatialDim)) { (void)i; // Suppress unused variable dilation.emplace_back(config_vals.at(idx)); idx++; } for (const auto i : c10::irange(kSpatialDim)) { (void)i; // Suppress unused variable TORCH_INTERNAL_ASSERT(idx < static_cast(config_vals.size()), "Unexpected index = ", idx, " for config_vals of size ", config_vals.size()); output_padding.emplace_back(config_vals.at(idx)); idx++; } int64_t groups = config_vals.at(idx); idx++; int64_t flags = config_vals.at(idx); idx++; TORCH_INTERNAL_ASSERT(idx == static_cast(config_vals.size()), "Unexpected length of config_vals, expected ", idx, " got ", config_vals.size()); bool transpose = flags & (1 << 0); int64_t other_flags = flags & ~(1 << 0); TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, "."); auto& ctx = at::globalContext(); #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::X86) { #if AT_MKLDNN_ENABLED() bool use_onednn = onednn_utils::should_use_onednn_quant( weight.value(), transpose, groups, output_padding); if (use_onednn) { return PackedConvWeightsOnednn::prepack( weight.value(), bias, stride, padding, output_padding, dilation, groups, transpose ); } #endif return PackedConvWeight::prepack( weight.value(), bias, stride, padding, output_padding, dilation, groups, transpose ); } // x86 #endif #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM) { return PackedConvWeight::prepack( weight.value(), bias, stride, padding, output_padding, dilation, groups, transpose ); } #endif // USE_FBGEMM #ifdef USE_PYTORCH_QNNPACK if (ctx.qEngine() == at::QEngine::QNNPACK) { TORCH_CHECK( kSpatialDim == 2, "prepack/__setstate__: QNNPACK only supports Conv2d " "now."); return PackedConvWeightsQnnp::prepack( weight.value(), bias, stride, padding, output_padding, dilation, groups, transpose ); } #endif // USE_PYTORCH_QNNPACK #if AT_MKLDNN_ENABLED() if (ctx.qEngine() == at::QEngine::ONEDNN) { return PackedConvWeightsOnednn::prepack( weight.value(), bias, stride, padding, output_padding, dilation, groups, transpose ); } #endif // AT_MKLDNN_ENABLED() TORCH_CHECK( false, "Didn't find engine for when deserializing ConvPackedParams: ", toString(ctx.qEngine())); }