diff options
author | Guo, Yejun <yejun.guo@intel.com> | 2019-04-25 10:14:42 +0800 |
---|---|---|
committer | Pedro Arthur <bygrandao@gmail.com> | 2019-05-08 12:33:00 -0300 |
commit | c636dc9819ebab1a84237cc017a6a3d35ebc9cdc (patch) | |
tree | 39fd943e649cb1185f25ccce6e7be193448ba23c | |
parent | 25c1cd909fa6c8b6b778dc24192dc3ec780324b0 (diff) | |
download | ffmpeg-c636dc9819ebab1a84237cc017a6a3d35ebc9cdc.tar.gz |
libavfilter/dnn: add more data type support for dnn model input
currently, only float is supported as model input, actually, there
are other data types, this patch adds uint8.
Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
-rw-r--r-- | libavfilter/dnn_backend_native.c | 4 | ||||
-rw-r--r-- | libavfilter/dnn_backend_tf.c | 28 | ||||
-rw-r--r-- | libavfilter/dnn_interface.h | 10 | ||||
-rw-r--r-- | libavfilter/vf_sr.c | 4 |
4 files changed, 39 insertions, 7 deletions
diff --git a/libavfilter/dnn_backend_native.c b/libavfilter/dnn_backend_native.c index 8a83c63c73..06fbdf368b 100644 --- a/libavfilter/dnn_backend_native.c +++ b/libavfilter/dnn_backend_native.c @@ -24,8 +24,9 @@ */ #include "dnn_backend_native.h" +#include "libavutil/avassert.h" -static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) +static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output) { ConvolutionalNetwork *network = (ConvolutionalNetwork *)model; InputParams *input_params; @@ -45,6 +46,7 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const if (input->data){ av_freep(&input->data); } + av_assert0(input->dt == DNN_FLOAT); network->layers[0].output = input->data = av_malloc(cur_height * cur_width * cur_channels * sizeof(float)); if (!network->layers[0].output){ return DNN_ERROR; diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c index ca6472d445..ba959ae3a2 100644 --- a/libavfilter/dnn_backend_tf.c +++ b/libavfilter/dnn_backend_tf.c @@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename) return graph_buf; } -static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) +static TF_Tensor *allocate_input_tensor(const DNNInputData *input) { - TFModel *tf_model = (TFModel *)model; + TF_DataType dt; + size_t size; int64_t input_dims[] = {1, input->height, input->width, input->channels}; + switch (input->dt) { + case DNN_FLOAT: + dt = TF_FLOAT; + size = sizeof(float); + break; + case DNN_UINT8: + dt = TF_UINT8; + size = sizeof(char); + break; + default: + av_assert0(!"should not reach here"); + } + + return TF_AllocateTensor(dt, input_dims, 4, + input_dims[1] * input_dims[2] * input_dims[3] * size); +} + +static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output) +{ + TFModel *tf_model = (TFModel *)model; TF_SessionOptions *sess_opts; const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); @@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char if (tf_model->input_tensor){ TF_DeleteTensor(tf_model->input_tensor); } - tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4, - input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float)); + tf_model->input_tensor = allocate_input_tensor(input); if (!tf_model->input_tensor){ return DNN_ERROR; } diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index 73d226ec91..c24df0e961 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -32,6 +32,14 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType; typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType; +typedef enum {DNN_FLOAT, DNN_UINT8} DNNDataType; + +typedef struct DNNInputData{ + void *data; + DNNDataType dt; + int width, height, channels; +} DNNInputData; + typedef struct DNNData{ float *data; int width, height, channels; @@ -42,7 +50,7 @@ typedef struct DNNModel{ void *model; // Sets model input and output. // Should be called at least once before model execution. - DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output); + DNNReturnType (*set_input_output)(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output); } DNNModel; // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends. diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c index 0145511d11..65baf5f901 100644 --- a/libavfilter/vf_sr.c +++ b/libavfilter/vf_sr.c @@ -40,7 +40,8 @@ typedef struct SRContext { DNNBackendType backend_type; DNNModule *dnn_module; DNNModel *model; - DNNData input, output; + DNNInputData input; + DNNData output; int scale_factor; struct SwsContext *sws_contexts[3]; int sws_slice_h, sws_input_linesize, sws_output_linesize; @@ -86,6 +87,7 @@ static av_cold int init(AVFilterContext *context) return AVERROR(EIO); } + sr_context->input.dt = DNN_FLOAT; sr_context->sws_contexts[0] = NULL; sr_context->sws_contexts[1] = NULL; sr_context->sws_contexts[2] = NULL; |