Source code for src.editor.editor

from src.llm.client import RoccoClient
from src.llm.schemas import EditorOutput, EvaluatorOutput, Citation, EditingSession
from src.common.utils import _build_rubric, _build_evaluation_text
from src.retriever.retriever import VectorStoreManager
from src.prompts.loader import load_prompt, render
from datetime import datetime
from typing import List, Dict, Any, Optional, Union
from langchain_core.documents import Document
import json
from pathlib import Path


[docs] class DescriptionEditor: """Improves dataset descriptions""" def __init__( self, model: RoccoClient, rubric: Dict, vector_store_manager: Optional[VectorStoreManager] = None, use_rag: bool = True, top_k_context: int = 5, ): self.model = model self.rubric = rubric self.vector_store_manager = vector_store_manager self.use_rag = use_rag self.top_k_context = top_k_context self.conversation_history: List[Dict[str, str]] = [] self.session_creation_time = datetime.now().isoformat() self.original_description = None self.current_description = None # TODO: Add logging
[docs] def save_session(self, filepath: Path) -> None: """Save the current session to a file""" session = EditingSession( created_at=self.session_creation_time, original_description=self.original_description, current_description=self.current_description, conversation_history=self.conversation_history, rubric=self.rubric, config={ "use_rag": self.use_rag, "top_k_context": self.top_k_context, }, ) # Ensure directory exists filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) # Save using Pydantic's model_dump_json with open(filepath, "w", encoding="utf-8") as f: f.write(session.model_dump_json(indent=2)) print(f"Session saved to {filepath}") print(session.get_summary())
[docs] def load_session(self, filepath: Path) -> None: """Load a session from a file""" filepath = Path(filepath) if not filepath.exists(): raise FileNotFoundError(f"Session file not found: {filepath}") # Load and validate using Pydantic schema with open(filepath, "r", encoding="utf-8") as f: session_data = json.load(f) # Validate with schema session = EditingSession(**session_data) # Restore session state self.conversation_history = session.conversation_history self.session_creation_time = session.created_at self.original_description = session.original_description self.current_description = session.current_description self.conversation_history = session.conversation_history # Restore config if session.config: self.use_rag = session.config.get("use_rag", self.use_rag) self.top_k_context = session.config.get("top_k_context", self.top_k_context) print(f"Session loaded from {filepath}") print(session.get_summary())
[docs] def get_session_summary(self) -> str: """Get a summary of the current session""" session = EditingSession( created_at=self.session_metadata["created_at"], original_description=self.session_metadata.get("original_description"), current_description=self.session_metadata.get("current_description"), conversation_history=self.conversation_history, rubric=self.rubric, config={ "use_rag": self.use_rag, "top_k_context": self.top_k_context, }, ) return session.get_summary()
[docs] def retrieve_context(self, query: str = None) -> List[Document]: """Retrieve relevant context from related papers""" if not self.use_rag or self.vector_store_manager is None: return [] try: results = self.vector_store_manager.similarity_search( query, k=self.top_k_context ) return results except Exception as e: print(f"Error retrieving context: {str(e)}") return []
[docs] def generate_search_query( self, draft_evaluation: EvaluatorOutput, query_all: bool = True ) -> List[str]: """Generate search queries based on evaluation feedback""" queries = [] criterion_queries = { 1: "data description data summary data overview", 2: "research goals objectives study purpose motivation", 3: "sample material porous media type lithology", 4: "research problem research question hypothesis", 5: "applications reuse reproducibility validation machine learning simulation", 6: "methodology experimental setup x-ray imaging technique data collection image acquisition scanning image processing", 7: "dataset structure organization file data format contents", 8: "quality control validation verification calibration inspection", 9: None, 10: "keywords relevant concepts domain-specific terminology nomenclature", } if query_all: for criterion_id, query in criterion_queries.items(): if query: queries.append(query) else: for criterion_id, rubric_item in enumerate( draft_evaluation.rubric_breakdown, 1 ): score = rubric_item.score # TODO Handle different score weights if score <= 0.5: query = criterion_queries.get(criterion_id) if query: queries.append(query) if not queries: print("No queries generated - skipping context retrieval.") return queries
[docs] def build_prompt( self, draft_text: str, draft_evaluation: EvaluatorOutput, context: Optional[Union[List[Document], List[str]]] = None, user_feedback: Optional[str] = None, history_override: Optional[List[Dict[str, str]]] = None, ) -> str: """Prepare prompt for improving the draft""" rubric_str = _build_rubric(self.rubric) eval_feedback = _build_evaluation_text(draft_evaluation) # Build conversation history history = "" history_source = ( history_override if history_override is not None else self.conversation_history ) if history_source: history = "\n## CONVERSATION HISTORY:\n" for message in history_source: if message["role"] == "user": history += f"\n**User Feedback:**\n{message['content']}\n" elif message["role"] == "assistant": history += f"\n**Previous Version:**\n{message['content']}\n" if "rationale" in message: history += f"**Rationale:** {message['rationale']}\n" # Determine mode (initial vs refinement) mode = "refinement" if user_feedback else "initial" # Format context with separators and citation metadata context_str = "" if context: formatted_chunks = [] for chunk_num, item in enumerate(context, 1): if isinstance(item, Document): # Format Document with citation-ready metadata chunk_header = f"[CONTEXT_CHUNK_{chunk_num}]" metadata_lines = [chunk_header] # Add source metadata if available if item.metadata: doc_title = item.metadata.get("doc_title", "unknown") page = item.metadata.get("page") chunk_idx = item.metadata.get("chunk_index") source_info = f"Source: {doc_title}" if page is not None: source_info += f", Page {page}" if chunk_idx is not None: source_info += f", Chunk {chunk_idx}" metadata_lines.append(source_info) metadata_lines.append("---") formatted_chunks.append( "\n".join(metadata_lines) + "\n" + item.page_content ) else: # String fallback (for context_override) formatted_chunks.append(f"[CONTEXT_CHUNK_{chunk_num}]\n---\n{item}") context_str = "\n\n---\n\n".join(formatted_chunks) # Load prompt template and render prompt_data = load_prompt("editor") prompt = render( prompt_data["user"], mode=mode, context_str=context_str, rubric_str=rubric_str, original_description=draft_text, evaluation_feedback=eval_feedback, history=history, user_feedback=user_feedback or "", ) return prompt
[docs] def enhance( self, draft_text: str, draft_evaluation: EvaluatorOutput, retrieve_context: bool = True, context_override: Optional[List[str]] = None, query_all_criterion: bool = True, user_feedback: Optional[str] = None, history_override: Optional[List[Dict[str, str]]] = None, ) -> EditorOutput: """Improve the description draft using evaluation feedback and optional context from papers""" if self.original_description is None: self.original_description = draft_text if user_feedback: self.conversation_history.append({"role": "user", "content": user_feedback}) context_metadata = [] if context_override: # context_override is still a list of formatted strings context = context_override elif retrieve_context and self.use_rag: queries = self.generate_search_query( draft_evaluation, query_all=query_all_criterion ) context = [] if queries: seen_content = set() for i, query in enumerate(queries): query_context = self.retrieve_context(query) for doc in query_context: if doc.page_content not in seen_content: context.append(doc) seen_content.add(doc.page_content) context_metadata.append( { "doc_title": doc.metadata.get( "doc_title", "unknown" ), "page": doc.metadata.get("page"), "chunk_index": doc.metadata.get("chunk_index"), "snippet": doc.page_content[:120], } ) else: context = [] prompt = self.build_prompt( draft_text, draft_evaluation, context, user_feedback=user_feedback, history_override=history_override, ) raw_resp = self.model.send_prompt(prompt) try: data = json.loads(raw_resp.strip()) citations = [] if "citations" in data["updated_description"][0]: for cit in data["updated_description"][0]["citations"]: citations.append( Citation( statement=cit["statement"], source=cit["source"], quote=cit["quote"], doc_title=cit.get("doc_title"), page=cit.get("page"), chunk_index=cit.get("chunk_index"), ) ) updated_desc = EditorOutput( original_text=draft_text, suggested_text=data["updated_description"][0]["updated_description"], rationale=data["updated_description"][0]["rationale"], citation=citations, context_used=context_metadata, ) self.conversation_history.append( { "role": "assistant", "content": updated_desc.suggested_text, "rationale": updated_desc.rationale, } ) self.current_description = updated_desc.suggested_text except Exception as e: print(e) updated_desc = EditorOutput( original_text=draft_text, suggested_text=f"Could not parse the response as JSON. Please check the prompt format.", rationale=f"Parsing error: {str(e)}", citation=[], ) return updated_desc
[docs] def print_enhancement_result(self, editor_output: EditorOutput) -> None: """Utility to print enhancement results""" print(f"Original Description:\n{editor_output.original_text}\n") print(f"Enhanced Description:\n{editor_output.suggested_text}\n") print(f"Justifications:\n {editor_output.rationale}") print(f"Citations:\n") for item in editor_output.citation: print(f"Statement: {item.statement}") print(f"Source: {item.source}") print(f"Source Quote: {item.quote}\n")
[docs] def reset_conversation_history(self): """Clear all stored conversation turns, starting a fresh refinement session.""" self.conversation_history = []