#pragma once #include #include struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { virtual at::Tensor apply( at::Tensor input, double output_scale, int64_t output_zero_point) = 0; virtual at::Tensor apply_relu( at::Tensor input, double output_scale, int64_t output_zero_point) = 0; // out variant of LinearPackedParamsBase::apply virtual at::Tensor& apply_out( const at::Tensor& /*input*/, double /*output_scale*/, int64_t /*output_zero_point*/, at::Tensor& output) { throw std::runtime_error( "apply_out is not implemented for this packed " "parameter type"); return output; } virtual at::Tensor& apply_relu_out( const at::Tensor& /*input*/, double /*output_scale*/, int64_t /*output_zero_point*/, at::Tensor& output) { throw std::runtime_error( "apply_relu_out is not implemented for this packed " "parameter type"); return output; } virtual at::Tensor apply_dynamic( at::Tensor input, bool reduce_range = false) = 0; virtual at::Tensor apply_dynamic_relu( at::Tensor input, bool reduce_range = false) = 0; virtual at::Tensor& apply_dynamic_out( const at::Tensor& /* input */, at::Tensor& output, bool /* reduce_range */) { throw std::runtime_error( "apply_dynamic_out is not implemented for this packed " "parameter type"); return output; } virtual at::Tensor& apply_dynamic_relu_out( const at::Tensor& /* input */, at::Tensor& output, bool /* reduce_range */) { throw std::runtime_error( "apply_dynamic_relu_out is not implemented for this packed " "parameter type"); return output; } virtual std::tuple> unpack() = 0; virtual c10::optional bias() = 0; virtual void set_bias(c10::optional /*bias*/) { throw std::runtime_error( "set_bias is not implemented for this packed " "parameter type"); } }; template struct ConvPackedParamsBase : public torch::jit::CustomClassHolder { virtual at::Tensor apply( const at::Tensor& input, double output_scale, int64_t output_zero_point) = 0; virtual at::Tensor apply_relu( const at::Tensor& input, double output_scale, int64_t output_zero_point) = 0; virtual at::Tensor apply_dynamic( const at::Tensor& input, bool reduce_range) = 0; virtual std::tuple> unpack() = 0; virtual torch::List stride() const = 0; virtual torch::List padding() const = 0; virtual torch::List output_padding() const = 0; virtual torch::List dilation() const = 0; virtual int64_t groups() const = 0; virtual bool transpose() const = 0; };