00001
00017
00018
00019
00020
00021
00022
00023
00024 #include <sent/stddefs.h>
00025 #include <sent/htk_hmm.h>
00026 #include <sent/htk_param.h>
00027 #include <sent/hmm.h>
00028 #include <sent/gprune.h>
00029 #include "globalvars.h"
00030
00031
00032 #define GS_MAX_PROB
00033 #define LAST_BEST
00034 #undef BEAM
00035 #define BEAM_OFFSET 10.0
00036
00037 #ifdef BEAM
00038 #define LAST_BEST
00039 #endif
00040 #ifdef LAST_BEST
00041 #define GS_MAX_PROB
00042 #endif
00043
00044
00045 static int my_gsset_num;
00046 static int *last_max_id;
00047 #ifdef BEAM
00048 static VECT *dimthres;
00049 static int dimthres_num;
00050 #endif
00051
00052
00059 void
00060 gms_gprune_init(HTK_HMM_INFO *hmminfo, int gsset_num)
00061 {
00062 my_gsset_num = gsset_num;
00063 last_max_id = (int *)mymalloc(sizeof(int) * gsset_num);
00064 #ifdef BEAM
00065 dimthres_num = hmminfo->opt.vec_size;
00066 dimthres = (LOGPROB *)mymalloc(sizeof(LOGPROB) * dimthres_num);
00067 #endif
00068 }
00069
00074 void
00075 gms_gprune_prepare()
00076 {
00077 int i;
00078 for(i=0;i<my_gsset_num;i++) {
00079 last_max_id[i] = -1;
00080 }
00081 }
00082
00087 void
00088 gms_gprune_free()
00089 {
00090 free(last_max_id);
00091 #ifdef BEAM
00092 free(dimthres);
00093 #endif
00094 }
00095
00096
00097
00098
00107 static LOGPROB
00108 calc_contprob_with_safe_pruning(HTK_HMM_Dens *binfo, LOGPROB thres)
00109 {
00110 LOGPROB tmp, x;
00111 VECT *mean;
00112 VECT *var;
00113 LOGPROB fthres = thres * (-2.0);
00114 VECT *vec = OP_vec;
00115 short veclen = OP_veclen;
00116
00117 if (binfo == NULL) return(LOG_ZERO);
00118 mean = binfo->mean;
00119 var = binfo->var->vec;
00120
00121 tmp = binfo->gconst;
00122 for (; veclen > 0; veclen--) {
00123 x = *(vec++) - *(mean++);
00124 tmp += x * x * *(var++);
00125 if ( tmp > fthres) {
00126 return LOG_ZERO;
00127 }
00128 }
00129 return(tmp / -2.0);
00130 }
00131
00132 #ifdef BEAM
00133
00142 static LOGPROB
00143 calc_contprob_with_beam_pruning_pre(HTK_HMM_Dens *binfo)
00144 {
00145 LOGPROB tmp, x;
00146 VECT *mean;
00147 VECT *var;
00148 VECT *th = dimthres;
00149 VECT *vec = OP_vec;
00150 short veclen = OP_veclen;
00151
00152 if (binfo == NULL) return(LOG_ZERO);
00153 mean = binfo->mean;
00154 var = binfo->var->vec;
00155
00156 tmp = 0.0;
00157 for (; veclen > 0; veclen--) {
00158 x = *(vec++) - *(mean++);
00159 tmp += x * x * *(var++);
00160 if ( *th < tmp) *th = tmp;
00161 th++;
00162 }
00163 return((tmp + binfo->gconst) / -2.0);
00164 }
00165
00174 static LOGPROB
00175 calc_contprob_with_beam_pruning_post(HTK_HMM_Dens *binfo)
00176 {
00177 LOGPROB tmp, x;
00178 LOGPROB *mean;
00179 LOGPROB *var;
00180 LOGPROB *th = dimthres;
00181 VECT *vec = OP_vec;
00182 short veclen = OP_veclen;
00183
00184 if (binfo == NULL) return(LOG_ZERO);
00185 mean = binfo->mean;
00186 var = binfo->var->vec;
00187
00188 tmp = 0.0;
00189 for (; veclen > 0; veclen--) {
00190 x = *(vec++) - *(mean++);
00191 tmp += x * x * *(var++);
00192 if ( tmp > *(th++)) {
00193 return LOG_ZERO;
00194 }
00195 }
00196 return((tmp + binfo->gconst) / -2.0);
00197 }
00198
00199 #endif
00200
00201 #ifdef LAST_BEST
00202
00213 static LOGPROB
00214 compute_g_max(HTK_HMM_State *stateinfo, int last_maxi, int *maxi_ret)
00215 {
00216 int i, maxi;
00217 LOGPROB prob;
00218 LOGPROB maxprob = LOG_ZERO;
00219
00220 if (last_maxi != -1) {
00221 maxi = last_maxi;
00222 #ifdef BEAM
00223
00224 for(i=0;i<dimthres_num;i++) dimthres[i] = 0.0;
00225
00226 maxprob = calc_contprob_with_beam_pruning_pre(stateinfo->b[maxi]);
00227
00228 for(i=0;i<dimthres_num;i++) dimthres[i] += BEAM_OFFSET;
00229 #else
00230 maxprob = calc_contprob_with_safe_pruning(stateinfo->b[maxi], LOG_ZERO);
00231 #endif
00232 for (i = stateinfo->mix_num - 1; i >= 0; i--) {
00233 if (i == last_maxi) continue;
00234 #ifdef BEAM
00235 prob = calc_contprob_with_beam_pruning_post(stateinfo->b[i]);
00236 #else
00237 prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
00238 #endif
00239 if (prob > maxprob) {
00240 maxprob = prob;
00241 maxi = i;
00242 }
00243 }
00244 *maxi_ret = maxi;
00245 } else {
00246 maxi = stateinfo->mix_num - 1;
00247 maxprob = calc_contprob_with_safe_pruning(stateinfo->b[maxi], LOG_ZERO);
00248 i = maxi - 1;
00249 for (; i >= 0; i--) {
00250 prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
00251 if (prob > maxprob) {
00252 maxprob = prob;
00253 maxi = i;
00254 }
00255 }
00256 *maxi_ret = maxi;
00257 }
00258
00259 return((maxprob + stateinfo->bweight[maxi]) * INV_LOG_TEN);
00260 }
00261
00262 #else
00263
00272 static LOGPROB
00273 compute_g_max(HTK_HMM_State *stateinfo)
00274 {
00275 int i, maxi;
00276 LOGPROB prob;
00277 LOGPROB maxprob = LOG_ZERO;
00278
00279 i = maxi = stateinfo->mix_num - 1;
00280 for (; i >= 0; i--) {
00281 prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
00282 if (prob > maxprob) {
00283 maxprob = prob;
00284 maxi = i;
00285 }
00286 }
00287 return((maxprob + stateinfo->bweight[maxi]) * INV_LOG_TEN);
00288 }
00289 #endif
00290
00291
00292
00293
00294
00305 void
00306 compute_gs_scores(GS_SET *gsset, int gsset_num, LOGPROB *scores_ret)
00307 {
00308 int i;
00309 #ifdef LAST_BEST
00310 int max_id;
00311 #endif
00312
00313 for (i=0;i<gsset_num;i++) {
00314 #ifdef GS_MAX_PROB
00315 #ifdef LAST_BEST
00316
00317 scores_ret[i] = compute_g_max(gsset[i].state, last_max_id[i], &max_id);
00318 last_max_id[i] = max_id;
00319 #else
00320 scores_ret[i] = compute_g_max(gsset[i].state);
00321 #endif
00322 #else
00323
00324 scores_ret[i] = compute_g_base(gsset[i].state);
00325 #endif
00326
00327 }
00328
00329 }