@@ -181,6 +181,7 @@ def transcribe(self,
181181 n_processors : Optional [int ] = None ,
182182 new_segment_callback : Optional [Callable [[Segment ], None ]] = None ,
183183 abort_callback : Optional [Callable [[], bool ]] = None ,
184+ extract_probability : bool = False ,
184185 ** params ) -> List [Segment ]:
185186 """
186187 Transcribes the media provided as input and returns list of `Segment` objects.
@@ -205,9 +206,6 @@ def transcribe(self,
205206 raise FileNotFoundError (media )
206207 audio = self ._load_audio (media )
207208
208- # Handle extract_probability parameter
209- self .extract_probability = params .pop ('extract_probability' , False )
210-
211209 # update params if any
212210 self ._set_params (params )
213211
@@ -224,7 +222,7 @@ def transcribe(self,
224222 # run inference
225223 start_time = time ()
226224 logger .info ("Transcribing ..." )
227- res = self ._transcribe (audio , n_processors = n_processors )
225+ res = self ._transcribe (audio , n_processors = n_processors , extract_probability = extract_probability )
228226 end_time = time ()
229227 logger .info (f"Inference time: { end_time - start_time :.3f} s" )
230228 return res
@@ -402,12 +400,14 @@ def _set_params(self, kwargs: dict) -> None:
402400 for param , value in normalized .items ():
403401 setattr (self ._params , param , value )
404402
405- def _transcribe (self , audio : np .ndarray , n_processors : Optional [int ] = None ):
403+ def _transcribe (self , audio : np .ndarray , n_processors : Optional [int ] = None , extract_probability : bool = False ):
406404 """
407405 Private method to call the whisper.cpp/whisper_full function
408406
409407 :param audio: numpy array of audio data
410408 :param n_processors: if not None, it will run whisper.cpp/whisper_full_parallel with n_processors
409+ :param extract_probability: If True, calculates the geometric mean of token probabilities for each segment,
410+ providing a confidence score interpretable as a probability in [0, 1].
411411 :return:
412412 """
413413
@@ -416,7 +416,7 @@ def _transcribe(self, audio: np.ndarray, n_processors: Optional[int] = None):
416416 else :
417417 pw .whisper_full (self ._ctx , self ._params , audio , audio .size )
418418 n = pw .whisper_full_n_segments (self ._ctx )
419- res = Model ._get_segments (self ._ctx , 0 , n , self . extract_probability )
419+ res = Model ._get_segments (self ._ctx , 0 , n , extract_probability )
420420 return res
421421
422422
@@ -528,4 +528,4 @@ def __del__(self):
528528 :return: None
529529 """
530530 if self ._ctx is not None :
531- pw .whisper_free (self ._ctx )
531+ pw .whisper_free (self ._ctx )
0 commit comments