You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

152 lines
3.2 KiB

// Oh boy, why am I about to do this....
#pragma once
#ifndef NETWORK_H
#define NETWORK_H
#include "define_inc.h"
#include "darknet.h"
#include "image.h"
#include "layer.h"
#include "data.h"
#include "tree.h"
typedef enum {
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
} learning_rate_policy;
typedef struct network {
int n;
int batch;
size_t *seen;
int *t;
float epoch;
int subdivisions;
layer *layers;
float *output;
learning_rate_policy policy;
float learning_rate;
float momentum;
float decay;
float gamma;
float scale;
float power;
int time_steps;
int step;
int max_batches;
float *scales;
int *steps;
int num_steps;
int burn_in;
int adam;
float B1;
float B2;
float eps;
int inputs;
int outputs;
int truths;
int notruth;
int h, w, c;
int max_crop;
int min_crop;
float max_ratio;
float min_ratio;
int center;
float angle;
float aspect;
float exposure;
float saturation;
float hue;
int random;
int gpu_index;
tree *hierarchy;
float *input;
float *truth;
float *delta;
float *workspace;
int train;
int index;
float *cost;
float clip;
#ifdef GPU
float *input_gpu;
float *truth_gpu;
float *delta_gpu;
float *output_gpu;
#endif
} network;
#ifdef GPU
void pull_network_output(network *net);
#endif
void compare_networks(network *n1, network *n2, data d);
char *get_layer_string(LAYER_TYPE a);
network *make_network(int n);
void forward_network(network *net);
void backward_network(network *net);
void update_network(network *net);
network *load_network(char *cfg, char *weights, int clear);
load_args get_base_args(network *net);
float network_accuracy_multi(network *net, data d, int n);
int get_predicted_class_network(network *net);
void print_network(network *net);
int resize_network(network *net, int w, int h);
void calc_network_cost(network *net);
float train_network_sgd(network *net, data d, int n);
float *network_accuracies(network *net, data d, int n);
float train_network_datum(network *net);
network *parse_network_cfg(char *filename);
void save_weights(network *net, char *filename);
void load_weights(network *net, char *filename);
void save_weights_upto(network *net, char *filename, int cutoff);
void load_weights_upto(network *net, char *filename, int start, int cutoff);
void free_network(network *net);
void set_batch_network(network *net, int b);
void set_temp_network(network *net, float t);
int resize_network(network *net, int w, int h);
float get_current_rate(network *net);
size_t get_current_batch(network *net);
image get_network_image_layer(network *net, int i);
layer get_network_output_layer(network *net);
void top_predictions(network *net, int n, int *index);
float network_accuracy(network *net, data d);
void visualize_network(network *net);
matrix network_predict_data(network *net, data test);
image get_network_image(network *net);
float *network_predict(network *net, float *input);
int network_width(network *net);
int network_height(network *net);
float *network_predict_image(network *net, image im);
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num);
void reset_network_state(network *net, int b);
float train_network(network *net, data d);
#endif