00001
00002
00003 #ifndef NB_H
00004 #define NB_H
00005
00006 #include "../hash/nps_vocab.h"
00007 #include "../imat/imat_csr.h"
00008 #include "../error/error.h"
00009
00010
00011
00012
00013 typedef struct {
00014 long nCat;
00015 long nFeat;
00016 double *restrict cat;
00017 double *restrict cat_feat;
00018 } nb_counts_t;
00019
00020 extern nb_counts_t *nb_counts_new(long,long);
00021 extern void nb_counts_del(nb_counts_t *);
00022
00023
00024 extern void nb_add_counts(nb_counts_t *, nps_imat_csr_t *);
00025
00026
00027 enum nb_smooth_t { Laplace=0 };
00028
00029
00030 typedef struct {
00031 long nCat;
00032 long nFeat;
00033 bio_t *cat;
00034 bio_t *cat_feat;
00035 } nb_model_t;
00036
00037 extern long nb_model_new(long,long);
00038 extern void nb_model_del(nb_model_t *);
00039 extern nb_model_t *nb_ceate_model(nb_counts_t *,int);
00040 extern long nb_classify(nb_model_t *, char, s8_t *, double *, long, double *);
00041
00042 extern int nb_model_save(nb_model_t *, NPS_VOC_T * , NPS_VOC_T * , NPS_VOC_T * ,const char*, nps_error_t *err);
00043 extern nb_model_t* nb_model_load(nb_model_t **, NPS_VOC_T **, NPS_VOC_T ** ,
00044 const char *,
00045 char ignore_marginals,
00046 nps_error_t *err);
00047
00048 #endif