"""Basic RAG database"""
from dataclasses import dataclass
from typing import Sequence, Optional
from hashlib import sha1
import numpy as np
from pathlib import Path
import warnings
import sqlite3
from mistral_tools.utils import RateLimiter
from mistral_tools.embeddings import EmbeddingModel
warnings.filterwarnings("ignore", category=DeprecationWarning, module="faiss")
import faiss # noqa: E402
[docs]
@dataclass
class TextChunk():
    """A text chunk with metadata"""
    text_chunk: str
    file_path: str
    start_line: int
    end_line: int
    file_hash: Optional[bytes] = None # the hash of the file. 
        #Mainly used for checking if the file has changed for updates
[docs]
    def to_str(self):
        """Pretty print the text chunk"""
        return (f"----{self.file_path}: l.{self.start_line} "
                f"to l.{self.end_line}-----\n{self.text_chunk}") 
 
[docs]
class RAGDatabase():
    """Simple RAG database
    implemented as
    1. a sqlite database with a single table with columns
        id, text_chunk, embedding, file_path, start_line, end_line, file_sha
    2. a faiss index
    I ended up re-coding this because I was not able to find a RAG database
    that was both simple enough (no server needed, no huge framework) 
    and flexible enough.
    """
    db_path: Path
    index_path: Path
    index: faiss.Index
    db: sqlite3.Connection
    rate_limit: RateLimiter
    max_id: int
    model: str
    max_n_tokens: int
    def __init__(self, db_path: Path, index_path: Path, 
                 rate_limit: RateLimiter|float = 1.1, model="mistral-embed", 
                 max_n_tokens=16384):
        self.db_path = db_path
        self.index_path = index_path
        self.db = sqlite3.connect(db_path)
        self.db.execute("CREATE TABLE IF NOT EXISTS rag "
            "(id INTEGER PRIMARY KEY, text_chunk TEXT, embedding NULLABLE BLOB, "
            "file_path TEXT, start_line INTEGER, end_line INTEGER, file_sha BLOB)")
        self.db.commit()
        self.model = model
        self.max_n_tokens = max_n_tokens
        max_id, = self.db.execute("SELECT MAX(id) FROM rag").fetchone()
        max_id = max_id if max_id is not None else 0
        self.max_id = max_id
        self.rate_limit = (
                rate_limit if isinstance(rate_limit, RateLimiter) 
                else RateLimiter(rate_limit))
        if index_path.exists():
            self.index = faiss.read_index(str(index_path))
        else:
            inner_index = faiss.IndexFlatL2(1024)
            # ivf_index = faiss.IndexIVFFlat(inner_index, 1024, n_cells)
            # ^^^ TODO: implement switching to ivf when the index gets large
            self.index = faiss.IndexIDMap(inner_index)
[docs]
    def insert_db(self, chunk: TextChunk, *,  id=None,  embedding,
                  do_commit=True, add_to_index=False):
        """Insert a text chunk into the sqlite database and the index"""
        cursor = self.db.cursor()
        if id is None: 
            id = self.max_id + 1
            self.max_id = id
        cursor.execute("INSERT INTO rag VALUES(?, ?, ?, ?, ?, ?, ?)", 
                (id, chunk.text_chunk, 
                 embedding, str(chunk.file_path), chunk.start_line, 
                                chunk.end_line, chunk.file_hash))
        if embedding is not None and add_to_index:
            self.index.add_with_ids(embedding, id) #type: ignore
        if do_commit: self.commit()
        return id 
[docs]
    @staticmethod
    def get_chunks(file, *, chunk_size=25, overlap=5, filename, hash=None):
        """Cut a file into chunks
        Args:
            file: a Path or bytes object
            chunk_size: the size of the chunks
            overlap: the overlap between the chunks
            filename: the filename
            hash: the hash of the file (optional)
        """
        chunk_limits = chunk_size - overlap
        try: 
            if isinstance(file, Path): lines = file.read_text().splitlines()
            else: lines = file.decode().splitlines()
        except UnicodeDecodeError:
            return [] # skip non-text files
        n_lines = len(lines)
        
        chunk_starts = range(0, n_lines, chunk_limits)
        for start in chunk_starts:
            end = min(n_lines, start + chunk_size)
            yield TextChunk("\n".join(lines[start:end]), filename, start, end, 
                            file_hash=hash) 
[docs]
    @classmethod
    def get_all_chunks(cls, files: Sequence[Path|bytes], *, chunk_size=25, 
                       overlap=5, file_paths: Sequence[str]|None = None, 
                       file_shas_to_skip = None):
        """Cut a list of files into chunks
        Args:
            files: the files
            chunk_size: the size of the chunks
            overlap: the overlap between the chunks
            file_paths: the filenames (Optional: if not provided, 
                                       and the files are Path objects
                                       the filenames will be the paths)
            file_shas_to_skip: the file hashes to skip
        """
        file_paths = file_paths or [str(file) for file in files]
        if file_shas_to_skip is None: file_shas_to_skip = set()
        else: file_shas_to_skip = set(file_shas_to_skip)
        for file, file_path in zip(files, file_paths):
            if isinstance(file, Path): hash = sha1(file.read_bytes()).digest()
            else: hash = sha1(file).digest()
            if hash in file_shas_to_skip: continue
            new_chunks = list(cls.get_chunks(file, chunk_size=chunk_size, 
                                overlap=overlap, filename=file_path, hash=hash))
            yield from new_chunks  
[docs]
    def generate_index(self, files: Sequence[Path|bytes], *, api_key, chunk_size=25,
                       overlap=5, file_paths: Sequence[str]|None = None, 
                       file_shas_to_skip = None):
        """Generate the index from a list of files"""
        all_chunks = list(self.get_all_chunks(files, chunk_size=chunk_size, 
                                              overlap=overlap, file_paths=file_paths, 
                                              file_shas_to_skip=file_shas_to_skip))
        embedding_model = EmbeddingModel(api_key=api_key, model=self.model,
                                         max_n_tokens=self.max_n_tokens)
        embeddings, embeddings_too_long_filter = \
            
embedding_model.get_embeddings_batched([c.text_chunk for c in all_chunks])
        if embeddings is None:
            return # no embedddings to add
        ids = []
        for chunk, embedding, embedding_filter in zip(
                all_chunks, embeddings, embeddings_too_long_filter):
            id = self.insert_db(chunk=chunk, 
                                embedding=embedding if not embedding_filter else None,
                                do_commit=False)
            ids.append(id)
        ids = np.array(ids)
        if not self.index.is_trained:
            self.index.train(embeddings[~embeddings_too_long_filter]) # type: ignore
        self.index.add_with_ids(embeddings[~embeddings_too_long_filter], 
                                ids[~embeddings_too_long_filter]) # type: ignore
        self.commit() 
[docs]
    def commit(self):
        """Write changes to disk
        do a database commit, and write the index
        """
        self.db.commit()
        faiss.write_index(self.index, str(self.index_path)) 
[docs]
    def get_chunk_by_id(self, id):
        """Get a chunk from database by its id"""
        if not isinstance(id, int): id = int(id)
        res = self.db.execute("SELECT text_chunk, file_path, start_line, "
                              "end_line, file_sha FROM rag WHERE id = ?", 
                              (id,)).fetchone()
        if res is None: return None
        return TextChunk(*res) 
[docs]
    def query(self, query, n_results=5, *, api_key):
        """Do a Knn search on the index"""
        embedding_model = EmbeddingModel(
                api_key=api_key, model=self.model, max_n_tokens=self.max_n_tokens, 
                rate_limit=self.rate_limit)
        query_embedding, _ = embedding_model.get_embeddings_batched([query])
        if query_embedding is None:
            raise RuntimeError("Query too long")
        scores, ids = self.index.search(query_embedding, k=n_results) # type: ignore
        chunks = [self.get_chunk_by_id(id) for id in ids[0]]
        return chunks, scores[0] 
        
[docs]
    def update_index(self, files: Sequence[Path|bytes], *, api_key, chunk_size=25, 
                     overlap=5, file_paths: Sequence[str]|None = None):
        """Update the index from a list of files.
        Like generate_index, but preemptively checks which files have changed, 
        and only updates those.
        """
        files_shas_in_db = self.db\
                
.execute("SELECT DISTINCT file_sha FROM rag")\
                
.fetchall()
        files_shas_in_db  = set([sha for sha, in files_shas_in_db])
        new_shas = []
        for file in files:
            if isinstance(file, Path): hash = sha1(file.read_bytes()).digest()
            else: hash = sha1(file).digest()
            new_shas.append(hash)
        new_shas = set(new_shas)
        deleted_shas = tuple(new_shas - files_shas_in_db)
        self.db.execute("CREATE TEMP TABLE deleted_shas (sha BLOB)")
        self.db.executemany("INSERT INTO deleted_shas VALUES (?)", 
                            [(sha,) for sha in deleted_shas] )
        self.db.commit()
        deleted_shas_ids = self.db.execute(
                "SELECT id FROM rag INNER JOIN deleted_shas "
                "ON rag.file_sha = deleted_shas.sha").fetchall()
        self.db.execute("DELETE FROM rag WHERE file_sha IN deleted_shas")
        self.db.execute("DROP TABLE deleted_shas")
        self.index.remove_ids(np.array([id for id, in deleted_shas_ids]))
        # self.commit() # only commit after the full update is done
        # to avoid failing with a transient state on disk
        self.generate_index(files, api_key=api_key, chunk_size=chunk_size, 
                            overlap=overlap, file_paths=file_paths, 
                            file_shas_to_skip=files_shas_in_db)