00001
00043
00044
00045
00046
00047
00048
00049
00050 #include <sent/stddefs.h>
00051 #include <sent/htk_hmm.h>
00052 #include <sent/htk_param.h>
00053 #include <sent/hmm.h>
00054 #include <sent/gprune.h>
00055 #include "globalvars.h"
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106 static LOGPROB *backmax;
00107 static int backmax_num;
00108
00109 static boolean *mixcalced;
00110
00115 static void
00116 init_backmax()
00117 {
00118 int i;
00119 for(i=0;i<backmax_num;i++) backmax[i] = 0;
00120 }
00121
00127
00128
00129
00130
00131
00132
00133
00134
00135 static void
00136 make_backmax()
00137 {
00138 int i;
00139 backmax[backmax_num-1] = 0.0;
00140
00141 for(i=backmax_num-2;i>=0;i--) {
00142 backmax[i] += backmax[i+1];
00143 }
00144
00145
00146
00147 }
00148
00161 static LOGPROB
00162 compute_g_heu_updating(HTK_HMM_Dens *binfo)
00163 {
00164 VECT tmp, x, sum = 0.0;
00165 VECT *mean;
00166 VECT *var;
00167 VECT *bm = backmax;
00168 VECT *vec = OP_vec;
00169 short veclen = OP_veclen;
00170
00171 if (binfo == NULL) return(LOG_ZERO);
00172 mean = binfo->mean;
00173 var = binfo->var->vec;
00174
00175 tmp = 0.0;
00176 for (; veclen > 0; veclen--) {
00177 x = *(vec++) - *(mean++);
00178 tmp = x * x * *(var++);
00179 sum += tmp;
00180 if ( *bm < tmp) *bm = tmp;
00181 bm++;
00182 }
00183 return((sum + binfo->gconst) * -0.5);
00184 }
00185
00199 static LOGPROB
00200 compute_g_heu_pruning(HTK_HMM_Dens *binfo, LOGPROB thres)
00201 {
00202 VECT tmp, x;
00203 VECT *mean;
00204 VECT *var;
00205 VECT *bm = backmax;
00206 VECT *vec = OP_vec;
00207 short veclen = OP_veclen;
00208 LOGPROB fthres;
00209
00210 if (binfo == NULL) return(LOG_ZERO);
00211 mean = binfo->mean;
00212 var = binfo->var->vec;
00213 fthres = thres * (-2.0);
00214
00215 tmp = 0.0;
00216 bm++;
00217 for (; veclen > 0; veclen--) {
00218 x = *(vec++) - *(mean++);
00219 tmp += x * x * *(var++);
00220 if ( tmp + *bm > fthres) {
00221 return LOG_ZERO;
00222 }
00223 bm++;
00224 }
00225 return((tmp + binfo->gconst) * -0.5);
00226 }
00227
00228
00234 boolean
00235 gprune_heu_init()
00236 {
00237 int i;
00238
00239 OP_calced_maxnum = OP_hmminfo->maxmixturenum;
00240 OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * OP_gprune_num);
00241 OP_calced_id = (int *)mymalloc(sizeof(int) * OP_gprune_num);
00242 mixcalced = (boolean *)mymalloc(sizeof(int) * OP_calced_maxnum);
00243 for(i=0;i<OP_calced_maxnum;i++) mixcalced[i] = FALSE;
00244 backmax_num = OP_hmminfo->opt.vec_size + 1;
00245 backmax = (LOGPROB *)mymalloc(sizeof(LOGPROB) * backmax_num);
00246
00247 return TRUE;
00248 }
00249
00254 void
00255 gprune_heu_free()
00256 {
00257 free(OP_calced_score);
00258 free(OP_calced_id);
00259 free(mixcalced);
00260 free(backmax);
00261 }
00262
00287 void
00288 gprune_heu(HTK_HMM_Dens **g, int gnum, int *last_id)
00289 {
00290 int i, j, num = 0;
00291 LOGPROB score, thres;
00292
00293 if (last_id != NULL) {
00294
00295 init_backmax();
00296
00297 for (j=0; j<OP_gprune_num; j++) {
00298 i = last_id[j];
00299 score = compute_g_heu_updating(g[i]);
00300 num = cache_push(i, score, num);
00301 mixcalced[i] = TRUE;
00302 }
00303
00304 make_backmax();
00305
00306 thres = OP_calced_score[num-1];
00307 for (i = 0; i < gnum; i++) {
00308
00309 if (mixcalced[i]) {
00310 mixcalced[i] = FALSE;
00311 continue;
00312 }
00313
00314 score = compute_g_heu_pruning(g[i], thres);
00315 if (score > LOG_ZERO) {
00316 num = cache_push(i, score, num);
00317 thres = OP_calced_score[num-1];
00318 }
00319 }
00320 } else {
00321
00322
00323 thres = LOG_ZERO;
00324 for (i = 0; i < gnum; i++) {
00325 if (num < OP_gprune_num) {
00326 score = compute_g_base(g[i]);
00327 } else {
00328 score = compute_g_safe(g[i], thres);
00329 if (score <= thres) continue;
00330 }
00331 num = cache_push(i, score, num);
00332 thres = OP_calced_score[num-1];
00333 }
00334 }
00335 OP_calced_num = num;
00336 }