aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuo, Yejun <yejun.guo@intel.com>2019-04-25 10:14:08 +0800
committerPedro Arthur <bygrandao@gmail.com>2019-05-08 12:33:00 -0300
commit05f86f05bb5060492dd3ff22c23628e4e4334a1e (patch)
treef69992eccf80d239053bfb9c4c020ae65f504213
parent05aec8bb13cc1d698f76c6972a23521a3fba5596 (diff)
downloadffmpeg-05f86f05bb5060492dd3ff22c23628e4e4334a1e.tar.gz
libavfilter/dnn: remove limit for the name of DNN model input/output
remove the requirment that the name of DNN model input/output should be "x"/"y", Signed-off-by: Guo, Yejun <yejun.guo@intel.com> Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
-rw-r--r--libavfilter/dnn_backend_native.c2
-rw-r--r--libavfilter/dnn_backend_tf.c10
-rw-r--r--libavfilter/dnn_interface.h2
-rw-r--r--libavfilter/vf_sr.c4
4 files changed, 9 insertions, 9 deletions
diff --git a/libavfilter/dnn_backend_native.c b/libavfilter/dnn_backend_native.c
index 70d857f5f2..fe4311693a 100644
--- a/libavfilter/dnn_backend_native.c
+++ b/libavfilter/dnn_backend_native.c
@@ -25,7 +25,7 @@
#include "dnn_backend_native.h"
-static DNNReturnType set_input_output_native(void *model, DNNData *input, DNNData *output)
+static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
{
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
InputParams *input_params;
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c
index 9e0c127e77..a838907d98 100644
--- a/libavfilter/dnn_backend_tf.c
+++ b/libavfilter/dnn_backend_tf.c
@@ -76,7 +76,7 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf;
}
-static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *output)
+static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
{
TFModel *tf_model = (TFModel *)model;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
@@ -84,8 +84,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
TF_Tensor *output_tensor;
- // Input operation should be named 'x'
- tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
+ // Input operation
+ tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
if (!tf_model->input.oper){
return DNN_ERROR;
}
@@ -100,8 +100,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
}
input->data = (float *)TF_TensorData(tf_model->input_tensor);
- // Output operation should be named 'y'
- tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
+ // Output operation
+ tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, output_name);
if (!tf_model->output.oper){
return DNN_ERROR;
}
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index e3673438b6..0390e39b99 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -40,7 +40,7 @@ typedef struct DNNModel{
void *model;
// Sets model input and output, while allocating additional memory for intermediate calculations.
// Should be called at least once before model execution.
- DNNReturnType (*set_input_output)(void *model, DNNData *input, DNNData *output);
+ DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name);
} DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index 728c440299..0c048e03a5 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -121,7 +121,7 @@ static int config_props(AVFilterLink *inlink)
sr_context->input.height = inlink->h * sr_context->scale_factor;
sr_context->input.channels = 1;
- result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, &sr_context->output);
+ result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
if (result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
return AVERROR(EIO);
@@ -130,7 +130,7 @@ static int config_props(AVFilterLink *inlink)
if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){
sr_context->input.width = inlink->w;
sr_context->input.height = inlink->h;
- result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, &sr_context->output);
+ result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
if (result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
return AVERROR(EIO);