00001
00051
00052
00053
00054
00055
00056
00057
00058 #include <sent/stddefs.h>
00059 #include <sent/speech.h>
00060 #include <sent/htk_hmm.h>
00061 #include <sent/htk_param.h>
00062 #include <sent/hmm.h>
00063 #include <sent/gprune.h>
00064 #include "globalvars.h"
00065
00066
00067
00068 static int statenum;
00069 static LOGPROB **outprob_cache = NULL;
00070 static int allocframenum;
00071 static int allocblock;
00072 static BMALLOC_BASE *croot;
00073 static LOGPROB *last_cache;
00074 #define LOG_UNDEF (LOG_ZERO - 1)
00075
00076
00081 boolean
00082 outprob_cache_init()
00083 {
00084 statenum = OP_hmminfo->totalstatenum;
00085 outprob_cache = NULL;
00086 allocframenum = 0;
00087 allocblock = OUTPROB_CACHE_PERIOD;
00088 OP_time = -1;
00089 croot = NULL;
00090 return TRUE;
00091 }
00092
00099 boolean
00100 outprob_cache_prepare()
00101 {
00102 int s,t;
00103
00104
00105 for (t = 0; t < allocframenum; t++) {
00106 for (s = 0; s < statenum; s++) {
00107 outprob_cache[t][s] = LOG_UNDEF;
00108 }
00109 }
00110
00111 return TRUE;
00112 }
00113
00119 static void
00120 outprob_cache_extend(int reqframe)
00121 {
00122 int newnum;
00123 int size;
00124 int t, s;
00125 LOGPROB *tmpp;
00126
00127
00128 if (reqframe < allocframenum) return;
00129
00130
00131 newnum = reqframe + 1;
00132 if (newnum < allocframenum + allocblock) newnum = allocframenum + allocblock;
00133 size = (newnum - allocframenum) * statenum;
00134
00135
00136 if (outprob_cache == NULL) {
00137 outprob_cache = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * newnum);
00138 } else {
00139 outprob_cache = (LOGPROB **)myrealloc(outprob_cache, sizeof(LOGPROB *) * newnum);
00140 }
00141 tmpp = (LOGPROB *)mybmalloc2(sizeof(LOGPROB) * size, &croot);
00142
00143 for(t = allocframenum; t < newnum; t++) {
00144 outprob_cache[t] = &(tmpp[(t - allocframenum) * statenum]);
00145 for (s = 0; s < statenum; s++) {
00146 outprob_cache[t][s] = LOG_UNDEF;
00147 }
00148 }
00149
00150
00151 allocframenum = newnum;
00152 }
00153
00158 void
00159 outprob_cache_free()
00160 {
00161 if (croot != NULL) mybfree2(&croot);
00162 if (outprob_cache != NULL) free(outprob_cache);
00163 }
00164
00165
00179 LOGPROB
00180 outprob_state(
00181 int t,
00182 HTK_HMM_State *stateinfo,
00183 HTK_Param *param)
00184 {
00185 LOGPROB outp;
00186
00187
00188 OP_state = stateinfo;
00189 OP_state_id = stateinfo->id;
00190 OP_param = param;
00191 if (OP_time != t) {
00192 OP_last_time = OP_time;
00193 OP_time = t;
00194 OP_vec = param->parvec[t];
00195 OP_veclen = param->veclen;
00196
00197 outprob_cache_extend(t);
00198 last_cache = outprob_cache[t];
00199 }
00200
00201
00202 if ((outp = last_cache[OP_state_id]) == LOG_UNDEF) {
00203 outp = last_cache[OP_state_id] = calc_outprob_state();
00204 }
00205 return(outp);
00206 }
00207
00208 static LOGPROB *maxprobs;
00209 static int maxn;
00210
00216 void
00217 outprob_cd_nbest_init(int num)
00218 {
00219 maxprobs = (LOGPROB *)mymalloc(sizeof(LOGPROB) * num);
00220 maxn = num;
00221 }
00222
00227 void
00228 outprob_cd_nbest_free()
00229 {
00230 free(maxprobs);
00231 }
00232
00242 static LOGPROB
00243 outprob_cd_nbest(int t, CD_State_Set *lset, HTK_Param *param)
00244 {
00245 LOGPROB prob;
00246 int i, k, n;
00247
00248 n = 0;
00249 for(i=0;i<lset->num;i++) {
00250 prob = outprob_state(t, lset->s[i], param);
00251
00252 if (prob <= LOG_ZERO) continue;
00253 if (n == 0 || prob <= maxprobs[n-1]) {
00254 if (n == maxn) continue;
00255 maxprobs[n] = prob;
00256 n++;
00257 } else {
00258 for(k=0; k<n; k++) {
00259 if (prob > maxprobs[k]) {
00260 memmove(&(maxprobs[k+1]), &(maxprobs[k]),
00261 sizeof(LOGPROB) * (n - k - ( (n == maxn) ? 1 : 0)));
00262 maxprobs[k] = prob;
00263 break;
00264 }
00265 }
00266 if (n < maxn) n++;
00267 }
00268 }
00269 prob = 0.0;
00270 for(i=0;i<n;i++) {
00271
00272 prob += maxprobs[i];
00273 }
00274 return(prob/(float)n);
00275 }
00276
00286 static LOGPROB
00287 outprob_cd_max(int t, CD_State_Set *lset, HTK_Param *param)
00288 {
00289 LOGPROB maxprob, prob;
00290 int i;
00291 maxprob = LOG_ZERO;
00292 for(i=0;i<lset->num;i++) {
00293 prob = outprob_state(t, lset->s[i], param);
00294 if (maxprob < prob) maxprob = prob;
00295 }
00296 return(maxprob);
00297 }
00298
00308 static LOGPROB
00309 outprob_cd_avg(int t, CD_State_Set *lset, HTK_Param *param)
00310 {
00311 LOGPROB sum, p;
00312 int i,j;
00313 sum = 0.0;
00314 j = 0;
00315 for(i=0;i<lset->num;i++) {
00316 p = outprob_state(t, lset->s[i], param);
00317 if (p > LOG_ZERO) {
00318 sum += p;
00319 j++;
00320 }
00321 }
00322 return(sum/(float)j);
00323 }
00324
00334 LOGPROB
00335 outprob_cd(int t, CD_State_Set *lset, HTK_Param *param)
00336 {
00337 LOGPROB ret;
00338
00339
00340 switch(OP_hmminfo->cdset_method) {
00341 case IWCD_AVG:
00342 ret = outprob_cd_avg(t, lset, param);
00343 break;
00344 case IWCD_MAX:
00345 ret = outprob_cd_max(t, lset, param);
00346 break;
00347 case IWCD_NBEST:
00348 ret = outprob_cd_nbest(t, lset, param);
00349 break;
00350 default:
00351 j_error("unknown cdhmm method!\n");
00352 ret = 0;
00353 break;
00354 }
00355 return(ret);
00356 }
00357
00358
00368 LOGPROB
00369 outprob(int t, HMM_STATE *hmmstate, HTK_Param *param)
00370 {
00371 if (hmmstate->is_pseudo_state) {
00372 return(outprob_cd(t, hmmstate->out.cdset, param));
00373 } else {
00374 return(outprob_state(t, hmmstate->out.state, param));
00375 }
00376 }