diff --git a/.gitignore b/.gitignore index c24424b..5799ee3 100644 --- a/.gitignore +++ b/.gitignore @@ -133,4 +133,8 @@ dmypy.json # Pyre type checker .pyre/ .username -.idea \ No newline at end of file +.idea +output/ +test_feats/ +logs/ +metadata/ \ No newline at end of file diff --git a/app/config.py b/app/config.py index 576c077..3198241 100644 --- a/app/config.py +++ b/app/config.py @@ -81,6 +81,22 @@ def GITHUB_INSTALLATION_ID(self): return self._get_env_variable('GITHUB_INSTALLATION_ID', "To use GitHub App integration, you need to define a 'GITHUB_INSTALLATION_ID' in your .env file") + @property + def LLM_PROVIDER(self): + return os.getenv('LLM_PROVIDER', 'openai') + + @property + def OPENAI_API_KEY(self): + return self._get_env_variable('OPENAI_API_KEY') + + @property + def GOOGLE_API_KEY(self): + return self._get_env_variable('GOOGLE_API_KEY') + + @property + def CLAUDE_API_KEY(self): + return self._get_env_variable('CLAUDE_API_KEY') + # Initialize the Settings class and expose an instance settings = Settings() diff --git a/app/exporters.py b/app/exporters.py index e8ef544..ccfd15e 100644 --- a/app/exporters.py +++ b/app/exporters.py @@ -198,7 +198,7 @@ def _create_with_metadata(self, transcript: Transcript, **kwargs) -> str: Args: transcript: The transcript to export - **kwargs: Additional parameters like review_flag + **kwargs: Additional parameters like review_flag and content_key Returns: The complete Markdown content with metadata @@ -215,6 +215,13 @@ def increase_indent(self, flow=False, indentless=False): # Get metadata from the source metadata = transcript.source.to_json() + # Determine which content to use + content_key = kwargs.get("content_key", "corrected_text") + content = transcript.outputs.get(content_key, transcript.outputs.get("raw")) + + if content is None: + raise Exception(f"No transcript content found for key '{content_key}' or 'raw'") + # Add or modify specific fields if self.transcript_by: review_flag = kwargs.get("review_flag", "") @@ -312,6 +319,8 @@ def export(self, transcript: Transcript, **kwargs) -> str: Args: transcript: The transcript to export add_timestamp: Whether to add a timestamp to the filename (default: False) + content_key: The key in transcript.outputs to use for the content (default: "raw") + suffix: A suffix to add to the filename (e.g., "_raw") **kwargs: Additional parameters (unused) Returns: @@ -319,11 +328,17 @@ def export(self, transcript: Transcript, **kwargs) -> str: """ self.logger.debug("Exporting transcript to plain text...") - if transcript.outputs["raw"] is None: - raise Exception("No transcript content found") + content_key = kwargs.get("content_key", "raw") + content = transcript.outputs.get(content_key) + if content is None and content_key == "summary": + content = transcript.summary + + if content is None: + raise Exception(f"No content found for key: {content_key}") # Get parameters add_timestamp = kwargs.get("add_timestamp", False) + suffix = kwargs.get("suffix", "") # Get output directory output_dir = self.get_output_path(transcript) @@ -331,13 +346,13 @@ def export(self, transcript: Transcript, **kwargs) -> str: # Construct file path file_path = self.construct_file_path( directory=output_dir, - filename=transcript.title, + filename=f"{transcript.title}{suffix}", file_type="txt", include_timestamp=add_timestamp, ) # Write to file - result_path = self.write_to_file(transcript.outputs["raw"], file_path) + result_path = self.write_to_file(content, file_path) self.logger.info(f"(exporter) Text file written to: {result_path}") return result_path diff --git a/app/github_api_handler.py b/app/github_api_handler.py index e5afb5d..127fb53 100644 --- a/app/github_api_handler.py +++ b/app/github_api_handler.py @@ -93,13 +93,17 @@ def create_branch(self, repo_type, branch_name, sha): response = self._make_request('POST', url, json=data) return response.json() - def create_or_update_file(self, repo_type, file_path, content, commit_message, branch): + def create_or_update_file(self, repo_type, file_path, content, commit_message, branch, get_sha=False): url = f"https://api.github.com/repos/{self.repos[repo_type]['owner']}/{self.repos[repo_type]['name']}/contents/{quote(file_path)}" data = { "message": commit_message, "content": base64.b64encode(content.encode()).decode(), "branch": branch } + if get_sha: + response = self._make_request('GET', url + f'?ref={branch}') + data['sha'] = response.json()['sha'] + response = self._make_request('PUT', url, json=data) return response.json() @@ -114,23 +118,34 @@ def create_pull_request(self, repo_type, title, head, base, body): response = self._make_request('POST', url, json=data) return response.json() - def push_transcripts(self, transcripts: list[Transcript]) -> str | None: + def push_transcripts(self, transcripts: list[Transcript], markdown_exporter) -> str | None: try: default_branch = self.get_default_branch('transcripts') branch_sha = self.get_branch_sha('transcripts', default_branch) - branch_name = f"transcripts-{''.join(random.choices('0123456789', k=6))}" + branch_name = f"transcripts-{'' .join(random.choices('0123456789', k=6))}" self.create_branch('transcripts', branch_name, branch_sha) for transcript in transcripts: - if transcript.outputs and transcript.outputs['markdown']: - with open(transcript.outputs['markdown'], 'r') as file: - content = file.read() + # First commit: Raw transcript + raw_content = markdown_exporter._create_with_metadata(transcript, content_key='raw') + self.create_or_update_file( + 'transcripts', + transcript.output_path_with_title + ".md", + raw_content, + f'ai(transcript): "{transcript.title}" (raw)', + branch_name + ) + + # Second commit: Corrected transcript + if transcript.outputs.get('corrected_text'): + corrected_content = markdown_exporter._create_with_metadata(transcript, content_key='corrected_text') self.create_or_update_file( 'transcripts', - transcript.output_path_with_title, - content, - f'ai(transcript): "{transcript.title}" ({transcript.source.loc})', - branch_name + transcript.output_path_with_title + ".md", + corrected_content, + f'ai(transcript): "{transcript.title}" (corrected)', + branch_name, + get_sha=True # We need the SHA of the file to update it ) pr = self.create_pull_request( diff --git a/app/services/correction.py b/app/services/correction.py new file mode 100644 index 0000000..7501f91 --- /dev/null +++ b/app/services/correction.py @@ -0,0 +1,129 @@ +from app.transcript import Transcript +from app.logging import get_logger +from app.services.global_tag_manager import GlobalTagManager +import openai +from app.config import settings + +logger = get_logger() + +class CorrectionService: + def __init__(self, provider='openai', model='gpt-4o'): + self.provider = provider + self.model = model + self.tag_manager = GlobalTagManager() + if self.provider == 'openai': + self.client = openai + self.client.api_key = settings.OPENAI_API_KEY + else: + raise ValueError(f"Unsupported LLM provider: {provider}") + + def process(self, transcript: Transcript, **kwargs): + logger.info(f"Correcting transcript with {self.provider}...") + keywords = kwargs.get('keywords', []) + + metadata = transcript.source.to_json() + global_context = self.tag_manager.get_correction_context() + + prompt = self._build_enhanced_prompt(transcript.outputs['raw'], keywords, metadata, global_context) + + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}] + ) + corrected_text = response.choices[0].message.content + + transcript.outputs['corrected_text'] = corrected_text + logger.info("Correction complete.") + + def _build_enhanced_prompt(self, text, keywords, metadata, global_context): + prompt = ( + "You are a transcript correction specialist with expertise in Bitcoin and blockchain terminology.\n\n" + "The following transcript was generated by automatic speech recognition (ASR). Your task is to " + "correct ONLY the obvious mistakes while keeping the transcript as close to the original as possible.\n\n" + "DO NOT:\n" + "- Rephrase or rewrite sentences\n" + "- Change the speaker's style or tone\n" + "- Add or remove content\n" + "- Make major structural changes\n\n" + "DO:\n" + "- Fix spelling errors and typos\n" + "- Correct misheard words using context\n" + "- Fix technical terminology and proper names\n" + "- Maintain the exact same flow and structure\n\n" + "--- Current Video Metadata ---\n" + ) + + if metadata.get('title'): + prompt += f"Video Title: {metadata['title']}\n" + if metadata.get('speakers'): + prompt += f"Speakers: {', '.join(metadata['speakers'])}\n" + if metadata.get('tags'): + prompt += f"Video Tags: {', '.join(metadata['tags'])}\n" + if metadata.get('categories'): + prompt += f"Categories: {', '.join(metadata['categories'])}\n" + if metadata.get('youtube', {}).get('description'): + description = metadata['youtube']['description'][:200] + "..." if len(metadata['youtube']['description']) > 200 else metadata['youtube']['description'] + prompt += f"Description: {description}\n" + + video_count = global_context.get('video_count', 0) + prompt += f"\n--- Global Bitcoin Knowledge Base (From {video_count} Transcripts) ---\n" + + if global_context.get('frequent_tags'): + frequent_tags = global_context['frequent_tags'][:15] + prompt += f"Most Common Topics: {', '.join(frequent_tags)}\n" + + if global_context.get('technical_terms'): + tech_terms = global_context['technical_terms'][:20] + prompt += f"Technical Terms to Recognize: {', '.join(tech_terms)}\n" + + if global_context.get('project_names'): + projects = global_context['project_names'][:15] + prompt += f"Bitcoin Projects/Tools: {', '.join(projects)}\n" + + if global_context.get('common_speakers'): + speakers = global_context['common_speakers'][:10] + prompt += f"Frequent Speakers: {', '.join(speakers)}\n" + + if global_context.get('common_categories'): + categories = global_context['common_categories'][:8] + prompt += f"Common Content Categories: {', '.join(categories)}\n" + + if global_context.get('expertise_areas'): + areas = global_context['expertise_areas'][:8] + prompt += f"Domain Expertise Areas: {', '.join(areas)}\n" + + if global_context.get('domain_context'): + prompt += f"Primary Domain Focus: {global_context['domain_context']}\n" + + prompt += "\n--- Focus Areas for Correction ---\n" + prompt += "Using the metadata and global knowledge, focus on correcting:\n" + prompt += "1. Technical terms (ensure proper spelling and capitalization)\n" + prompt += "2. Speaker names and project names (match known variations)\n" + prompt += "3. Common ASR mishears (but, bit, big -> Bitcoin when context suggests it)\n" + prompt += "4. Homophones and similar-sounding words in Bitcoin context\n" + prompt += "5. Obvious typos and spelling mistakes\n\n" + prompt += "IMPORTANT: Make minimal changes - only fix clear errors, don't improve the text.\n" + + if global_context.get('tag_variations'): + variations = global_context['tag_variations'] + if variations: + prompt += "\n--- Common Term Variations ---\n" + for base_term, variants in list(variations.items())[:5]: + prompt += f"{base_term}: {', '.join(variants)}\n" + + if keywords: + prompt += ( + "\n--- Additional Priority Keywords ---\n" + "Pay special attention to these terms and ensure correct spelling/formatting:\n- " + ) + prompt += "\n- ".join(keywords) + + prompt += f"\n\n--- Transcript Start ---\n\n{text.strip()}\n\n--- Transcript End ---\n\n" + prompt += "Return ONLY the corrected transcript. Make minimal changes - fix only obvious errors while " + prompt += "preserving the original wording, sentence structure, and speaker's natural expression." + + return prompt + + def _build_prompt(self, text, keywords, metadata): + """Legacy method for backward compatibility""" + return self._build_enhanced_prompt(text, keywords, metadata, {}) diff --git a/app/services/global_tag_manager.py b/app/services/global_tag_manager.py new file mode 100644 index 0000000..3047ea9 --- /dev/null +++ b/app/services/global_tag_manager.py @@ -0,0 +1,314 @@ +import json +import os +from datetime import datetime, timezone +from collections import defaultdict +from typing import Dict, List, Any +from app.config import settings +from app.logging import get_logger +import re + +logger = get_logger() + +class GlobalTagManager: + """ + Manages a global dictionary of tags and terminology from all processed videos + to enhance correction accuracy across the entire corpus. + """ + + def __init__(self, metadata_dir=None): + self.metadata_dir = metadata_dir or settings.TSTBTC_METADATA_DIR or "metadata/" + self.dict_file = os.path.join(self.metadata_dir, "global_tag_dictionary.json") + self.tag_dict = self._load_dictionary() + + def _load_dictionary(self) -> Dict[str, Any]: + """Load existing global tag dictionary or create new one""" + if os.path.exists(self.dict_file): + try: + with open(self.dict_file, 'r', encoding='utf-8') as f: + return json.load(f) + except (json.JSONDecodeError, IOError) as e: + logger.warning(f"Failed to load global tag dictionary: {e}. Creating new one.") + + return self._create_new_dictionary() + + def _create_new_dictionary(self) -> Dict[str, Any]: + """Create a new global tag dictionary structure""" + return { + "last_updated": datetime.now(timezone.utc).isoformat(), + "version": "1.0", + "tags": {}, + "technical_terms": [], + "speaker_context": { + "common_speakers": [], + "expertise_areas": [] + }, + "project_names": [], + "categories": {}, + "video_count": 0, + "common_words": {} + } + + def _save_dictionary(self): + """Save the global tag dictionary to file""" + try: + os.makedirs(os.path.dirname(self.dict_file), exist_ok=True) + self.tag_dict["last_updated"] = datetime.now(timezone.utc).isoformat() + + with open(self.dict_file, 'w', encoding='utf-8') as f: + json.dump(self.tag_dict, f, indent=4, ensure_ascii=False) + + logger.debug(f"Global tag dictionary saved to {self.dict_file}") + except IOError as e: + logger.error(f"Failed to save global tag dictionary: {e}") + + def update_from_transcript(self, transcript): + """Update global dictionary with new transcript's metadata""" + try: + metadata = transcript.source.to_json() + + manual_tags = metadata.get('tags', []) + youtube_metadata = metadata.get('youtube', {}) + youtube_tags = youtube_metadata.get('tags', []) if youtube_metadata else [] + categories = metadata.get('categories', []) + speakers = metadata.get('speakers', []) + title = metadata.get('title', '') + description = youtube_metadata.get('description', '') if youtube_metadata else '' + + self.tag_dict["video_count"] = self.tag_dict.get("video_count", 0) + 1 + + all_tags = manual_tags + youtube_tags + categories + + for tag in all_tags: + if tag and isinstance(tag, str): + self._update_tag_entry(tag.strip()) + + for category in categories: + if category and isinstance(category, str): + self._update_category_frequency(category.strip()) + + text_content = f"{title} {description}".lower() + self._extract_technical_terms_dynamically(text_content, all_tags) + + for speaker in speakers: + if speaker and isinstance(speaker, str): + self._update_speaker_context(speaker.strip()) + + self._identify_project_names_dynamically(text_content, all_tags) + + self._update_expertise_areas(categories + all_tags) + + self._save_dictionary() + logger.info(f"Updated global tag dictionary with transcript: {title}") + + except Exception as e: + logger.error(f"Failed to update global tag dictionary: {e}") + + def _update_tag_entry(self, tag: str): + """Update or create entry for a tag in the dictionary""" + tag_lower = tag.lower() + tags_dict = self.tag_dict["tags"] + + if tag_lower in tags_dict: + tags_dict[tag_lower]["frequency"] += 1 + tags_dict[tag_lower]["last_seen"] = datetime.now(timezone.utc).isoformat() + + if tag not in tags_dict[tag_lower]["variations"]: + tags_dict[tag_lower]["variations"].append(tag) + else: + tags_dict[tag_lower] = { + "frequency": 1, + "variations": [tag], + "context": self._infer_context(tag_lower), + "first_seen": datetime.now(timezone.utc).isoformat(), + "last_seen": datetime.now(timezone.utc).isoformat() + } + + def _update_speaker_context(self, speaker: str): + """Update speaker information in the global context""" + speaker_context = self.tag_dict["speaker_context"] + if speaker not in speaker_context["common_speakers"]: + speaker_context["common_speakers"].append(speaker) + if len(speaker_context["common_speakers"]) > 50: + speaker_context["common_speakers"] = speaker_context["common_speakers"][-50:] + + def _update_category_frequency(self, category: str): + """Track category frequencies""" + categories_dict = self.tag_dict.get("categories", {}) + category_lower = category.lower() + categories_dict[category_lower] = categories_dict.get(category_lower, 0) + 1 + self.tag_dict["categories"] = categories_dict + + def _extract_technical_terms_dynamically(self, text_content: str, tags: List[str]): + """Dynamically extract technical terms from content and tags""" + technical_terms = self.tag_dict.get("technical_terms", []) + + bitcoin_indicators = ['bitcoin', 'btc', 'blockchain', 'crypto', 'lightning', 'ln'] + is_bitcoin_content = any(indicator in text_content or + any(indicator in tag.lower() for tag in tags) + for indicator in bitcoin_indicators) + + if is_bitcoin_content: + for tag in tags: + if tag and len(tag) > 3 and not tag.isdigit(): + tag_lower = tag.lower() + technical_indicators = ['network', 'protocol', 'script', 'sig', 'key', 'hash', + 'node', 'chain', 'block', 'tx', 'vault', 'channel'] + if (any(indicator in tag_lower for indicator in technical_indicators) or + tag_lower.startswith(('op_', 'bip', 'bolt')) or + tag_lower in ['taproot', 'segwit', 'multisig', 'htlc']): + if tag_lower not in technical_terms: + technical_terms.append(tag_lower) + + self.tag_dict["technical_terms"] = technical_terms + + def _identify_project_names_dynamically(self, text_content: str, tags: List[str]): + """Dynamically identify project names from content and tags""" + project_names = self.tag_dict.get("project_names", []) + + for tag in tags: + if tag and len(tag) > 2: + if (tag[0].isupper() or + 'core' in tag.lower() or 'lightning' in tag.lower() or + any(pattern in tag.lower() for pattern in ['btc', 'lightning', 'wallet', 'pay'])): + if tag not in project_names: + project_names.append(tag) + + self.tag_dict["project_names"] = project_names + + def _update_expertise_areas(self, tags_and_categories: List[str]): + """Update expertise areas based on categories and tags""" + expertise_areas = self.tag_dict.get("speaker_context", {}).get("expertise_areas", []) + + expertise_mapping = { + 'development': ['development', 'dev', 'programming', 'coding', 'technical'], + 'podcast': ['podcast', 'interview', 'discussion'], + 'conference': ['conference', 'talk', 'presentation', 'summit'], + 'education': ['education', 'tutorial', 'learning', 'teaching'], + 'mining': ['mining', 'miner', 'hashrate', 'pool'], + 'security': ['security', 'privacy', 'cryptography', 'audit'], + 'payments': ['payments', 'lightning', 'channel', 'transaction'], + 'trading': ['trading', 'exchange', 'market', 'price'] + } + + for item in tags_and_categories: + if item and isinstance(item, str): + item_lower = item.lower() + for area, keywords in expertise_mapping.items(): + if any(keyword in item_lower for keyword in keywords): + if area not in expertise_areas: + expertise_areas.append(area) + + speaker_context = self.tag_dict.get("speaker_context", {}) + speaker_context["expertise_areas"] = expertise_areas + self.tag_dict["speaker_context"] = speaker_context + + def _infer_context(self, tag: str) -> str: + """Infer context category for a tag""" + development_terms = ["script", "bdk", "core", "node", "api", "rpc", "development"] + payment_terms = ["lightning", "payment", "channel", "invoice", "bolt"] + security_terms = ["multisig", "signature", "key", "seed", "private", "security"] + mining_terms = ["mining", "hash", "difficulty", "block", "pow"] + + if any(term in tag for term in development_terms): + return "development" + elif any(term in tag for term in payment_terms): + return "payments" + elif any(term in tag for term in security_terms): + return "security" + elif any(term in tag for term in mining_terms): + return "mining" + else: + return "general" + + def get_correction_context(self) -> Dict[str, Any]: + """Get enriched context for correction prompts""" + tags_dict = self.tag_dict.get("tags", {}) + + frequent_tags = sorted( + tags_dict.items(), + key=lambda x: x[1]["frequency"], + reverse=True + )[:30] + + categories_dict = self.tag_dict.get("categories", {}) + common_categories = sorted( + categories_dict.items(), + key=lambda x: x[1], + reverse=True + )[:10] + + return { + 'frequent_tags': [tag for tag, _ in frequent_tags], + 'tag_variations': self._get_tag_variations(), + 'technical_terms': self.tag_dict.get('technical_terms', []), + 'project_names': self.tag_dict.get('project_names', []), + 'common_speakers': self.tag_dict.get('speaker_context', {}).get('common_speakers', [])[:20], + 'common_categories': [cat for cat, _ in common_categories], + 'expertise_areas': self.tag_dict.get('speaker_context', {}).get('expertise_areas', []), + 'domain_context': self._build_domain_context(), + 'video_count': self.tag_dict.get('video_count', 0) + } + + def _get_tag_variations(self) -> Dict[str, List[str]]: + """Get mapping of tags to their variations""" + variations = {} + for tag, data in self.tag_dict.get("tags", {}).items(): + if len(data["variations"]) > 1: + variations[tag] = data["variations"] + return variations + + def _build_domain_context(self) -> str: + """Build a domain context string for correction prompts based on actual data""" + context_parts = [] + + expertise_areas = self.tag_dict.get("speaker_context", {}).get("expertise_areas", []) + categories = self.tag_dict.get("categories", {}) + + if categories: + top_categories = sorted(categories.items(), key=lambda x: x[1], reverse=True)[:3] + for category, _ in top_categories: + if category in ["development", "technical"]: + context_parts.append("Bitcoin development and technical implementation") + elif category in ["education", "tutorial"]: + context_parts.append("Bitcoin education and learning") + elif category in ["podcast", "interview"]: + context_parts.append("Bitcoin discussion and interviews") + elif category in ["conference", "presentation"]: + context_parts.append("Bitcoin conferences and presentations") + + for area in expertise_areas[:3]: + if area == "payments" and "payment" not in " ".join(context_parts).lower(): + context_parts.append("Bitcoin payments and Lightning Network") + elif area == "security" and "security" not in " ".join(context_parts).lower(): + context_parts.append("Bitcoin security and cryptography") + elif area == "mining" and "mining" not in " ".join(context_parts).lower(): + context_parts.append("Bitcoin mining and network") + + return ", ".join(context_parts) or "Bitcoin and blockchain technology" + + def get_statistics(self) -> Dict[str, Any]: + """Get statistics about the global tag dictionary""" + tags_dict = self.tag_dict.get("tags", {}) + categories_dict = self.tag_dict.get("categories", {}) + + return { + "videos_processed": self.tag_dict.get("video_count", 0), + "total_unique_tags": len(tags_dict), + "total_tag_occurrences": sum(data["frequency"] for data in tags_dict.values()), + "technical_terms_count": len(self.tag_dict.get("technical_terms", [])), + "project_names_count": len(self.tag_dict.get("project_names", [])), + "speakers_count": len(self.tag_dict.get("speaker_context", {}).get("common_speakers", [])), + "categories_count": len(categories_dict), + "expertise_areas_count": len(self.tag_dict.get("speaker_context", {}).get("expertise_areas", [])), + "last_updated": self.tag_dict.get("last_updated"), + "most_frequent_tags": sorted( + tags_dict.items(), + key=lambda x: x[1]["frequency"], + reverse=True + )[:10], + "most_common_categories": sorted( + categories_dict.items(), + key=lambda x: x[1], + reverse=True + )[:5] if categories_dict else [] + } \ No newline at end of file diff --git a/app/services/summarizer.py b/app/services/summarizer.py new file mode 100644 index 0000000..0f6eca2 --- /dev/null +++ b/app/services/summarizer.py @@ -0,0 +1,30 @@ +from app.transcript import Transcript +from app.logging import get_logger +import openai +from app.config import settings + +logger = get_logger() + +class SummarizerService: + def __init__(self, provider='openai', model='gpt-4o'): + self.provider = provider + self.model = model + if self.provider == 'openai': + self.client = openai + self.client.api_key = settings.OPENAI_API_KEY + else: + raise ValueError(f"Unsupported LLM provider: {provider}") + + def process(self, transcript: Transcript, **kwargs): + logger.info(f"Summarizing transcript with {self.provider}...") + text_to_summarize = transcript.outputs.get('corrected_text', transcript.outputs['raw']) + + prompt = f"""Please summarize the following text.\n---\n{text_to_summarize}""" + + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}] + ) + summary = response.choices[0].message.content + transcript.summary = summary + logger.info("Summarization complete.") diff --git a/app/transcript.py b/app/transcript.py index f29f08e..46e1efe 100644 --- a/app/transcript.py +++ b/app/transcript.py @@ -72,6 +72,10 @@ def title(self): def summary(self): return self.source.summary + @summary.setter + def summary(self, value): + self.source.summary = value + def __str__(self): excluded_fields = ['test_mode', 'logger'] fields = {key: value for key, value in self.__dict__.items() @@ -339,6 +343,33 @@ def download_video_metadata(self): def process(self, working_dir): """Process video""" + + def download_video(): + """Helper method to download a YouTube video and return its absolute path""" + # sanity checks + if self.local: + raise Exception(f"{self.source_file} is a local file") + try: + self.logger.debug(f"Downloading video: {self.source_file}") + ydl_opts = { + "format": 'worstvideo+worstaudio/worst', + "outtmpl": os.path.join(working_dir, "videoFile.%(ext)s"), + "nopart": True, + } + with yt_dlp.YoutubeDL(ydl_opts) as ytdl: + ytdl.download([self.source_file]) + + for ext in ["mp4", "mkv", "webm"]: + output_file = os.path.join(working_dir, f"videoFile.{ext}") + if os.path.exists(output_file): + return os.path.abspath(output_file) + raise Exception("Downloaded file not found in expected formats.") + + return os.path.abspath(output_file) + except Exception as e: + self.logger.error(e) + raise Exception(f"Error downloading video: {e}") + try: self.logger.debug(f"Video processing: '{self.source_file}'") media_processor = MediaProcessor() @@ -426,4 +457,4 @@ def __config_source(self): self.entries.append(source) else: self.logger.warning( - f"Invalid source for '{entry.title}'. '{enclosure.type}' not supported for RSS feeds, source skipped.") + f"Invalid source for '{entry.title}'. '{enclosure.type}' not supported for RSS feeds, source skipped.") \ No newline at end of file diff --git a/app/transcription.py b/app/transcription.py index 2c7004f..0c8cf66 100644 --- a/app/transcription.py +++ b/app/transcription.py @@ -14,6 +14,9 @@ from app.data_fetcher import DataFetcher from app.github_api_handler import GitHubAPIHandler from app.exporters import ExporterFactory, TranscriptExporter +from app.services.correction import CorrectionService +from app.services.summarizer import SummarizerService +from app.services.global_tag_manager import GlobalTagManager class Transcription: @@ -36,9 +39,13 @@ def __init__( batch_preprocessing_output=False, needs_review=False, include_metadata=True, + correct=False, + llm_provider="openai", + llm_correction_model="gpt-4o", + llm_summary_model="gpt-4o", ): self.nocleanup = nocleanup - self.status = "idle" # Can be "idle", "in_progress", or "completed" + self.status = "idle" self.test_mode = test_mode self.logger = get_logger() self.tmp_dir = ( @@ -46,25 +53,26 @@ def __init__( ) self.transcript_by = self.__configure_username(username) - # during testing we need to create the markdown for validation purposes self.markdown = markdown or test_mode self.include_metadata = include_metadata self.metadata_writer = DataWriter( self.__configure_tstbtc_metadata_dir() ) + + # Initialize global tag manager + self.tag_manager = GlobalTagManager(self.__configure_tstbtc_metadata_dir()) - # Create exporters for transcript output formats - export_config = { - "markdown": self.markdown, - "text_output": text_output, - "json": json, - "model_output_dir": model_output_dir, - } self.exporters: dict[ str, TranscriptExporter ] = ExporterFactory.create_exporters( - config=export_config, transcript_by=self.transcript_by + config={ + "markdown": self.markdown, + "text_output": text_output, + "json": json, + "model_output_dir": model_output_dir, + }, + transcript_by=self.transcript_by, ) self.model_output_dir = model_output_dir @@ -74,8 +82,12 @@ def __init__( self.github_handler = GitHubAPIHandler() self.review_flag = self.__configure_review_flag(needs_review) - # @TODO: use ExporterFactory instead of `metadata_writer` for - # services metadata output + self.processing_services = [] + if correct: + self.processing_services.append(CorrectionService(provider=llm_provider, model=llm_correction_model)) + if summarize: + self.processing_services.append(SummarizerService(provider=llm_provider, model=llm_summary_model)) + if deepgram: self.service = services.Deepgram( summarize, diarize, upload, self.metadata_writer @@ -86,7 +98,7 @@ def __init__( self.transcripts: list[Transcript] = [] self.existing_media = None self.preprocessing_output = [] if batch_preprocessing_output else None - self.data_fetcher = DataFetcher(settings.BTC_TRANSCRIPTS_URL) + self.data_fetcher = DataFetcher(base_url="http://btctranscripts.com") self.logger.debug(f"Temp directory: {self.tmp_dir}") @@ -204,13 +216,20 @@ def _new_transcript_from_source(self, source: Source): # Keep preprocessing outputs for later use self.preprocessing_output.append(source.to_json()) # Initialize new transcript from source - self.transcripts.append( - Transcript( - source=source, - test_mode=self.test_mode, - metadata_file=metadata_file, - ) + transcript = Transcript( + source=source, + test_mode=self.test_mode, + metadata_file=metadata_file, ) + + # Update global tag dictionary with new transcript metadata + try: + self.tag_manager.update_from_transcript(transcript) + self.logger.debug(f"Updated global tag dictionary with transcript: {source.title}") + except Exception as e: + self.logger.warning(f"Failed to update global tag dictionary: {e}") + + self.transcripts.append(transcript) def add_transcription_source( self, @@ -418,6 +437,7 @@ def start(self, test_transcript=None): self.service.transcribe(transcript) transcript.status = "completed" self.postprocess(transcript) + self.export(transcript) self.status = "completed" if self.github: @@ -431,7 +451,12 @@ def push_to_github(self, transcripts: list[Transcript]): if not self.github_handler: return - pr_url_transcripts = self.github_handler.push_transcripts(transcripts) + markdown_exporter = self.exporters.get("markdown") + if not markdown_exporter: + self.logger.error("Markdown exporter not configured, cannot push to GitHub.") + return + + pr_url_transcripts = self.github_handler.push_transcripts(transcripts, markdown_exporter) if pr_url_transcripts: self.logger.info( f"transcripts: Pull request created: {pr_url_transcripts}" @@ -479,33 +504,30 @@ def write_to_markdown_file(self, transcript: Transcript): raise Exception(f"Error writing to markdown file: {e}") def postprocess(self, transcript: Transcript) -> None: - """ - Process the transcript to produce output files in the configured formats. - This updated method uses exporters when available, but maintains compatibility - with the existing code. - """ - try: - # Handle markdown output - if self.markdown or self.github_handler: - transcript.outputs["markdown"] = self.write_to_markdown_file( - transcript, - ) - - if "text" in self.exporters: - try: - transcript.outputs["text"] = self.exporters["text"].export( - transcript, add_timestamp=False - ) - except Exception as e: - self.logger.warning(f"Text exporter failed: {e}") - - if "json" in self.exporters: - transcript.outputs["json"] = self.exporters["json"].export( - transcript - ) + for service in self.processing_services: + service.process(transcript) + + def export(self, transcript: Transcript): + """Exports the transcript to the configured formats.""" + text_exporter = self.exporters.get("text") + if text_exporter: + # Save raw, corrected, and summary files + if transcript.outputs.get("raw"): + text_exporter.export(transcript, add_timestamp=False, content_key="raw", suffix="_raw") + if transcript.outputs.get("corrected_text"): + text_exporter.export(transcript, add_timestamp=False, content_key="corrected_text", suffix="_corrected") + if transcript.summary: + text_exporter.export(transcript, add_timestamp=False, content_key="summary", suffix="_summary") + + if self.markdown or self.github_handler: + transcript.outputs["markdown"] = self.write_to_markdown_file( + transcript, + ) - except Exception as e: - raise Exception(f"Error with postprocessing: {e}") from e + if "json" in self.exporters: + transcript.outputs["json"] = self.exporters["json"].export( + transcript + ) def clean_up(self): self.logger.debug("Cleaning up...") diff --git a/config.ini.example b/config.ini.example index a3244bb..bc3eafd 100644 --- a/config.ini.example +++ b/config.ini.example @@ -6,6 +6,9 @@ github = False save_to_markdown = True needs_review = False one_sentence_per_line = True +llm_provider = openai +llm_correction_model = gpt-4o +llm_summary_model = gpt-4o [development] verbose_logging = True diff --git a/requirements.txt b/requirements.txt index 97bd365..7822410 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ requests==2.32.3 setuptools==45.2.0 static_ffmpeg==2.3 ffmpeg-python==0.2.0 -yt-dlp==2025.2.19 +yt-dlp>=2025.2.19 deepgram-sdk==2.4.0 boto3==1.26.143 black==23.3.0 @@ -20,4 +20,6 @@ librosa==0.10.2.post1 fastapi==0.111.0 PyJWT==2.9.0 cryptography==43.0.1 -psutil==7.0.0 \ No newline at end of file +psutil==7.0.0 +openai==1.98.0 +pytest-asyncio==0.21.0 \ No newline at end of file diff --git a/routes/transcription.py b/routes/transcription.py index 2d867f6..30d554a 100644 --- a/routes/transcription.py +++ b/routes/transcription.py @@ -112,6 +112,8 @@ async def add_to_queue( cutoff_date: Optional[str] = Form(None), source: Optional[str] = Form(None), source_file: Optional[UploadFile] = File(None), + correct: bool = Form(False), + llm_provider: str = Form("openai"), ): temp_file_path = None try: @@ -130,6 +132,8 @@ async def add_to_queue( include_metadata=not no_metadata, text_output=text, needs_review=needs_review, + correct=correct, + llm_provider=llm_provider, ) if source_file: with tempfile.NamedTemporaryFile(delete=False) as tmp: diff --git a/test/exporters/test_text.py b/test/exporters/test_text.py index ce22a67..94cdfba 100644 --- a/test/exporters/test_text.py +++ b/test/exporters/test_text.py @@ -60,7 +60,7 @@ def test_error_handling_no_content(self, text_exporter, mock_transcript): text_exporter.export(mock_transcript) # Check the error message - assert "No transcript content found" in str(excinfo.value) + assert "No content found for key: raw" in str(excinfo.value) def test_output_directory_creation(self, temp_dir, mock_transcript): """Test that the exporter creates output directories as needed""" diff --git a/test/integration/test_transcription_exporters.py b/test/integration/test_transcription_exporters.py index 9c2afeb..fb068d3 100644 --- a/test/integration/test_transcription_exporters.py +++ b/test/integration/test_transcription_exporters.py @@ -82,18 +82,18 @@ def test_write_to_markdown_file( # Check the result assert result == "/path/to/exported/markdown.md" - def test_postprocess_with_markdown( + def test_export_with_markdown( self, transcription_with_exporters, mock_transcript ): - """Test postprocess with markdown output""" + """Test export with markdown output""" # Mock the write_to_markdown_file method to avoid calling the exporter directly transcription_with_exporters.write_to_markdown_file = mock.MagicMock() transcription_with_exporters.write_to_markdown_file.return_value = ( "/path/to/exported/markdown.md" ) - # Call postprocess - transcription_with_exporters.postprocess(mock_transcript) + # Call export + transcription_with_exporters.export(mock_transcript) # Check that write_to_markdown_file was called transcription_with_exporters.write_to_markdown_file.assert_called_once() @@ -104,13 +104,13 @@ def test_postprocess_with_markdown( == "/path/to/exported/markdown.md" ) - def test_postprocess_with_text( + def test_export_with_text( self, transcription_with_exporters, mock_transcript, mock_transcription_deps, ): - """Test postprocess with text output""" + """Test export with text output""" # Get the mock exporters exporters = mock_transcription_deps[ "ExporterFactory" @@ -126,23 +126,17 @@ def test_postprocess_with_text( "/path/to/exported/markdown.md" ) - # Call postprocess - transcription_with_exporters.postprocess(mock_transcript) + # Call export + transcription_with_exporters.export(mock_transcript) # Check that the text exporter was called - text_exporter.export.assert_called_once() + text_exporter.export.assert_called() assert text_exporter.export.call_args[1]["add_timestamp"] is False - # Check that the output was stored in the transcript - assert ( - mock_transcript.outputs["text"] - == "/path/to/exported/transcript.txt" - ) - - def test_postprocess_with_json( + def test_export_with_json( self, transcription_with_exporters, mock_transcript, mock_transcription_deps ): - """Test postprocess with JSON output""" + """Test export with JSON output""" # Get the mock exporters exporters = mock_transcription_deps[ "ExporterFactory" @@ -158,8 +152,8 @@ def test_postprocess_with_json( "/path/to/exported/markdown.md" ) - # Call postprocess - transcription_with_exporters.postprocess(mock_transcript) + # Call export + transcription_with_exporters.export(mock_transcript) # Check that the json exporter was called json_exporter.export.assert_called_once() @@ -168,4 +162,68 @@ def test_postprocess_with_json( assert ( mock_transcript.outputs["json"] == "/path/to/exported/transcript.json" - ) \ No newline at end of file + ) + + def test_export_with_all_outputs( + self, + transcription_with_exporters, + mock_transcript, + mock_transcription_deps, + ): + """Test export with all outputs enabled""" + # Get mock exporters + exporters = mock_transcription_deps[ + "ExporterFactory" + ].create_exporters.return_value + text_exporter = exporters["text"] + json_exporter = exporters["json"] + + # Set up return values + text_exporter.export.return_value = "/path/to/text.txt" + json_exporter.export.return_value = "/path/to/json.json" + + # Mock write_to_markdown_file + transcription_with_exporters.write_to_markdown_file = mock.MagicMock() + transcription_with_exporters.write_to_markdown_file.return_value = ( + "/path/to/markdown.md" + ) + + # Call export + transcription_with_exporters.export(mock_transcript) + + # Check that all exporters were called + text_exporter.export.assert_called() + json_exporter.export.assert_called_once() + transcription_with_exporters.write_to_markdown_file.assert_called_once() + + # Check that all outputs are stored + assert mock_transcript.outputs["json"] == "/path/to/json.json" + assert mock_transcript.outputs["markdown"] == "/path/to/markdown.md" + + def test_export_no_outputs( + self, patched_transcription, mock_transcript, mock_transcription_deps + ): + """Test export with no outputs enabled""" + # Create a Transcription instance with all export options disabled + transcription = Transcription( + markdown=False, text_output=False, json=False, username="Test User" + ) + transcription.exporters.clear() + + # Clear the mock transcript's outputs and add back the raw output + mock_transcript.outputs.clear() + mock_transcript.outputs['raw'] = 'test transcript' + + # Mock write_to_markdown_file + transcription.write_to_markdown_file = mock.MagicMock() + + # Call export + transcription.export(mock_transcript) + + # Check that no exporters were called + transcription.write_to_markdown_file.assert_not_called() + + # Check that no outputs were stored + assert "text" not in mock_transcript.outputs + assert "json" not in mock_transcript.outputs + assert "markdown" not in mock_transcript.outputs diff --git a/transcriber.py b/transcriber.py index a0b23c1..c6335d0 100644 --- a/transcriber.py +++ b/transcriber.py @@ -117,7 +117,7 @@ def print_help(ctx, param, value): is_flag=True, default=settings.config.getboolean("summarize", False), show_default=True, - help="Summarize the transcript [only available with deepgram]", + help="Summarize the transcript using the configured LLM provider.", ) cutoff_date = click.option( "--cutoff-date", @@ -210,6 +210,19 @@ def print_help(ctx, param, value): help="Supply this flag to enable verbose logging", ) +correct_transcript = click.option( + "--correct", + is_flag=True, + default=settings.config.getboolean("correct", False), + help="Correct the transcript using the configured LLM provider.", +) +llm_provider = click.option( + "--llm-provider", + type=click.Choice(["openai", "google", "claude"]), + default=settings.config.get("llm_provider", "openai"), + help="LLM provider for correction and summarization.", +) + add_loc = click.option( "--loc", default="misc", @@ -279,6 +292,8 @@ def print_help(ctx, param, value): @nocleanup @verbose_logging @auto_start_server +@correct_transcript +@llm_provider def transcribe( source: str, loc: str, @@ -303,7 +318,8 @@ def transcribe( no_metadata: bool, needs_review: bool, cutoff_date: str, - nocheck: bool, + correct: bool, + llm_provider: str, ) -> None: """Transcribe the provided sources. Suported sources include: \n - YouTube videos and playlists\n @@ -342,7 +358,8 @@ def transcribe( "include_metadata": not no_metadata, "needs_review": needs_review, "cutoff_date": cutoff_date, - "nocheck": nocheck, + "correct": correct, + "llm_provider": llm_provider, } try: queue_response = api_client.add_to_queue(data, source)