Skip to content

Commit 294e1e1

Browse files
authored
Merge pull request #169 from RomanValov/main
Dont expose extract_probability as C API parameter
2 parents c668bd6 + d8e774a commit 294e1e1

2 files changed

Lines changed: 7 additions & 13 deletions

File tree

pywhispercpp/constants.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,6 @@
302302
'options': None,
303303
'default': {"beam_size": -1, "patience": -1.0}
304304
},
305-
'extract_probability': {
306-
'type': bool,
307-
'description': 'calculate the geometric mean of token probabilities for each segment.',
308-
'options': None,
309-
'default': False
310-
},
311305
'vad': {
312306
'type': bool,
313307
'description': 'Enable VAD',

pywhispercpp/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def transcribe(self,
181181
n_processors: Optional[int] = None,
182182
new_segment_callback: Optional[Callable[[Segment], None]] = None,
183183
abort_callback: Optional[Callable[[], bool]] = None,
184+
extract_probability: bool = False,
184185
**params) -> List[Segment]:
185186
"""
186187
Transcribes the media provided as input and returns list of `Segment` objects.
@@ -205,9 +206,6 @@ def transcribe(self,
205206
raise FileNotFoundError(media)
206207
audio = self._load_audio(media)
207208

208-
# Handle extract_probability parameter
209-
self.extract_probability = params.pop('extract_probability', False)
210-
211209
# update params if any
212210
self._set_params(params)
213211

@@ -224,7 +222,7 @@ def transcribe(self,
224222
# run inference
225223
start_time = time()
226224
logger.info("Transcribing ...")
227-
res = self._transcribe(audio, n_processors=n_processors)
225+
res = self._transcribe(audio, n_processors=n_processors, extract_probability=extract_probability)
228226
end_time = time()
229227
logger.info(f"Inference time: {end_time - start_time:.3f} s")
230228
return res
@@ -402,12 +400,14 @@ def _set_params(self, kwargs: dict) -> None:
402400
for param, value in normalized.items():
403401
setattr(self._params, param, value)
404402

405-
def _transcribe(self, audio: np.ndarray, n_processors: Optional[int] = None):
403+
def _transcribe(self, audio: np.ndarray, n_processors: Optional[int] = None, extract_probability: bool = False):
406404
"""
407405
Private method to call the whisper.cpp/whisper_full function
408406
409407
:param audio: numpy array of audio data
410408
:param n_processors: if not None, it will run whisper.cpp/whisper_full_parallel with n_processors
409+
:param extract_probability: If True, calculates the geometric mean of token probabilities for each segment,
410+
providing a confidence score interpretable as a probability in [0, 1].
411411
:return:
412412
"""
413413

@@ -416,7 +416,7 @@ def _transcribe(self, audio: np.ndarray, n_processors: Optional[int] = None):
416416
else:
417417
pw.whisper_full(self._ctx, self._params, audio, audio.size)
418418
n = pw.whisper_full_n_segments(self._ctx)
419-
res = Model._get_segments(self._ctx, 0, n, self.extract_probability)
419+
res = Model._get_segments(self._ctx, 0, n, extract_probability)
420420
return res
421421

422422

@@ -528,4 +528,4 @@ def __del__(self):
528528
:return: None
529529
"""
530530
if self._ctx is not None:
531-
pw.whisper_free(self._ctx)
531+
pw.whisper_free(self._ctx)

0 commit comments

Comments
 (0)