MoneyPrinterTurbo/app/services/subtitle.py
2025-07-08 13:29:54 +08:00

371 lines
12 KiB
Python

import json
import os.path
import re
from timeit import default_timer as timer
from faster_whisper import WhisperModel
from loguru import logger
from app.config import config
from app.utils import utils
model_size = config.whisper.get("model_size", "large-v3")
device = config.whisper.get("device", "cpu")
compute_type = config.whisper.get("compute_type", "int8")
model = None
def create(audio_file, subtitle_file: str = ""):
global model
if not model:
model_path = f"{utils.root_dir()}/models/whisper-{model_size}"
model_bin_file = f"{model_path}/model.bin"
if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file):
model_path = model_size
logger.info(
f"loading model: {model_path}, device: {device}, compute_type: {compute_type}"
)
try:
model = WhisperModel(
model_size_or_path=model_path, device=device, compute_type=compute_type
)
except Exception as e:
logger.error(
f"failed to load model: {e} \n\n"
f"********************************************\n"
f"this may be caused by network issue. \n"
f"please download the model manually and put it in the 'models' folder. \n"
f"see [README.md FAQ](https://github.com/harry0703/MoneyPrinterTurbo) for more details.\n"
f"********************************************\n\n"
)
return None
logger.info(f"start, output file: {subtitle_file}")
if not subtitle_file:
subtitle_file = f"{audio_file}.srt"
segments, info = model.transcribe(
audio_file,
beam_size=5,
word_timestamps=True,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500),
)
logger.info(
f"detected language: '{info.language}', probability: {info.language_probability:.2f}"
)
start = timer()
subtitles = []
def recognized(seg_text, seg_start, seg_end):
seg_text = seg_text.strip()
if not seg_text:
return
msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text)
logger.debug(msg)
subtitles.append(
{"msg": seg_text, "start_time": seg_start, "end_time": seg_end}
)
for segment in segments:
words_idx = 0
words_len = len(segment.words)
seg_start = 0
seg_end = 0
seg_text = ""
if segment.words:
is_segmented = False
for word in segment.words:
if not is_segmented:
seg_start = word.start
is_segmented = True
seg_end = word.end
# If it contains punctuation, then break the sentence.
seg_text += word.word
if utils.str_contains_punctuation(word.word):
# remove last char
seg_text = seg_text[:-1]
if not seg_text:
continue
recognized(seg_text, seg_start, seg_end)
is_segmented = False
seg_text = ""
if words_idx == 0 and segment.start < word.start:
seg_start = word.start
if words_idx == (words_len - 1) and segment.end > word.end:
seg_end = word.end
words_idx += 1
if not seg_text:
continue
recognized(seg_text, seg_start, seg_end)
end = timer()
diff = end - start
logger.info(f"complete, elapsed: {diff:.2f} s")
idx = 1
lines = []
for subtitle in subtitles:
text = subtitle.get("msg")
if text:
lines.append(
utils.text_to_srt(
idx, text, subtitle.get("start_time"), subtitle.get("end_time")
)
)
idx += 1
sub = "\n".join(lines) + "\n"
with open(subtitle_file, "w", encoding="utf-8") as f:
f.write(sub)
logger.info(f"subtitle file created: {subtitle_file}")
def file_to_subtitles(filename):
if not filename or not os.path.isfile(filename):
return []
times_texts = []
current_times = None
current_text = ""
index = 0
with open(filename, "r", encoding="utf-8") as f:
for line in f:
times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line)
if times:
current_times = line
elif line.strip() == "" and current_times:
index += 1
times_texts.append((index, current_times.strip(), current_text.strip()))
current_times, current_text = None, ""
elif current_times:
current_text += line
return times_texts
def levenshtein_distance(s1, s2):
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def similarity(a, b):
distance = levenshtein_distance(a.lower(), b.lower())
max_length = max(len(a), len(b))
return 1 - (distance / max_length)
def correct(subtitle_file, video_script):
subtitle_items = file_to_subtitles(subtitle_file)
script_lines = utils.split_string_by_punctuations(video_script)
corrected = False
new_subtitle_items = []
script_index = 0
subtitle_index = 0
while script_index < len(script_lines) and subtitle_index < len(subtitle_items):
script_line = script_lines[script_index].strip()
subtitle_line = subtitle_items[subtitle_index][2].strip()
if script_line == subtitle_line:
new_subtitle_items.append(subtitle_items[subtitle_index])
script_index += 1
subtitle_index += 1
else:
combined_subtitle = subtitle_line
start_time = subtitle_items[subtitle_index][1].split(" --> ")[0]
end_time = subtitle_items[subtitle_index][1].split(" --> ")[1]
next_subtitle_index = subtitle_index + 1
while next_subtitle_index < len(subtitle_items):
next_subtitle = subtitle_items[next_subtitle_index][2].strip()
if similarity(
script_line, combined_subtitle + " " + next_subtitle
) > similarity(script_line, combined_subtitle):
combined_subtitle += " " + next_subtitle
end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1]
next_subtitle_index += 1
else:
break
if similarity(script_line, combined_subtitle) > 0.8:
logger.warning(
f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True
else:
logger.warning(
f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}"
)
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
f"{start_time} --> {end_time}",
script_line,
)
)
corrected = True
script_index += 1
subtitle_index = next_subtitle_index
# Process the remaining lines of the script.
while script_index < len(script_lines):
logger.warning(f"Extra script line: {script_lines[script_index]}")
if subtitle_index < len(subtitle_items):
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
subtitle_items[subtitle_index][1],
script_lines[script_index],
)
)
subtitle_index += 1
else:
new_subtitle_items.append(
(
len(new_subtitle_items) + 1,
"00:00:00,000 --> 00:00:00,000",
script_lines[script_index],
)
)
script_index += 1
corrected = True
if corrected:
with open(subtitle_file, "w", encoding="utf-8") as fd:
for i, item in enumerate(new_subtitle_items):
fd.write(f"{i + 1}\n{item[1]}\n{item[2]}\n\n")
logger.info("Subtitle corrected")
else:
logger.success("Subtitle is correct")
def combine_srt_files(srt_files: list, output_file: str):
"""
Combines multiple SRT files into a single file, adjusting timestamps sequentially.
"""
logger.info(f"Combining {len(srt_files)} SRT files into {output_file}")
combined_subtitles = []
last_end_time_seconds = 0.0
entry_index = 1
for srt_file in srt_files:
if not os.path.exists(srt_file):
logger.warning(f"SRT file not found, skipping: {srt_file}")
continue
try:
with open(srt_file, 'r', encoding='utf-8') as f:
content = f.read()
entries = re.split(r'\n\s*\n', content.strip())
for entry in entries:
if not entry.strip():
continue
lines = entry.split('\n')
if len(lines) < 3:
continue
# Parse timestamp
timestamp_line = lines[1]
start_time_str, end_time_str = timestamp_line.split(' --> ')
def srt_time_to_seconds(t_str):
h, m, s_ms = t_str.split(':')
s, ms = s_ms.split(',')
return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000.0
start_time = srt_time_to_seconds(start_time_str)
end_time = srt_time_to_seconds(end_time_str)
duration = end_time - start_time
# Adjust time
new_start_time = last_end_time_seconds
new_end_time = new_start_time + duration
def seconds_to_srt_time(seconds):
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
ms = int((seconds * 1000) % 1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
new_start_str = seconds_to_srt_time(new_start_time)
new_end_str = seconds_to_srt_time(new_end_time)
# Append to combined list
text = '\n'.join(lines[2:])
combined_subtitles.append(f"{entry_index}\n{new_start_str} --> {new_end_str}\n{text}")
entry_index += 1
# Update last end time for the next file
last_end_time_seconds = new_end_time
except Exception as e:
logger.error(f"Error processing SRT file {srt_file}: {e}")
# Write combined SRT to output file
with open(output_file, 'w', encoding='utf-8') as f:
f.write('\n\n'.join(combined_subtitles) + '\n\n')
logger.success(f"Successfully combined SRT files into {output_file}")
if __name__ == "__main__":
task_id = "c12fd1e6-4b0a-4d65-a075-c87abe35a072"
task_dir = utils.task_dir(task_id)
subtitle_file = f"{task_dir}/subtitle.srt"
audio_file = f"{task_dir}/audio.mp3"
subtitles = file_to_subtitles(subtitle_file)
print(subtitles)
script_file = f"{task_dir}/script.json"
with open(script_file, "r") as f:
script_content = f.read()
s = json.loads(script_content)
script = s.get("script")
correct(subtitle_file, script)
subtitle_file = f"{task_dir}/subtitle-test.srt"
create(audio_file, subtitle_file)