Skip to content

vector_memory_collection

VectorMemoryCollection

Source code in cat/memory/vector_memory_collection.py
class VectorMemoryCollection:
    def __init__(
        self,
        client: Any,
        collection_name: str,
        embedder_name: str,
        embedder_size: int,
    ):
        # Set attributes (metadata on the embedder are useful because it may change at runtime)
        self.client = client
        self.collection_name = collection_name
        self.embedder_name = embedder_name
        self.embedder_size = embedder_size

        # Check if memory collection exists also in vectorDB, otherwise create it
        self.create_db_collection_if_not_exists()

        # Check db collection vector size is same as embedder size
        self.check_embedding_size()

        # log collection info
        log.debug(f"Collection {self.collection_name}:")
        log.debug(self.client.get_collection(self.collection_name))

    def check_embedding_size(self):
        # having the same size does not necessarily imply being the same embedder
        # having vectors with the same size but from diffent embedder in the same vector space is wrong
        same_size = (
            self.client.get_collection(self.collection_name).config.params.vectors.size
            == self.embedder_size
        )
        alias = self.embedder_name + "_" + self.collection_name
        if (
            alias
            == self.client.get_collection_aliases(self.collection_name)
            .aliases[0]
            .alias_name
            and same_size
        ):
            log.debug(f'Collection "{self.collection_name}" has the same embedder')
        else:
            log.warning(f'Collection "{self.collection_name}" has different embedder')
            # Memory snapshot saving can be turned off in the .env file with:
            # SAVE_MEMORY_SNAPSHOTS=false
            if get_env("CCAT_SAVE_MEMORY_SNAPSHOTS") == "true":
                # dump collection on disk before deleting
                self.save_dump()
                log.info(f"Dump '{self.collection_name}' completed")

            self.client.delete_collection(self.collection_name)
            log.warning(f"Collection '{self.collection_name}' deleted")
            self.create_collection()

    def create_db_collection_if_not_exists(self):
        # is collection present in DB?
        collections_response = self.client.get_collections()
        for c in collections_response.collections:
            if c.name == self.collection_name:
                # collection exists. Do nothing
                log.info(
                    f"Collection '{self.collection_name}' already present in vector store"
                )
                return

        self.create_collection()

    # create collection
    def create_collection(self):
        log.warning(f"Creating collection '{self.collection_name}' ...")
        self.client.recreate_collection(
            collection_name=self.collection_name,
            vectors_config=VectorParams(
                size=self.embedder_size, distance=Distance.COSINE
            ),
            # hybrid mode: original vector on Disk, quantized vector in RAM
            optimizers_config=OptimizersConfigDiff(memmap_threshold=20000),
            quantization_config=ScalarQuantization(
                scalar=ScalarQuantizationConfig(
                    type=ScalarType.INT8, quantile=0.95, always_ram=True
                )
            ),
            # shard_number=3,
        )

        self.client.update_collection_aliases(
            change_aliases_operations=[
                CreateAliasOperation(
                    create_alias=CreateAlias(
                        collection_name=self.collection_name,
                        alias_name=self.embedder_name + "_" + self.collection_name,
                    )
                )
            ]
        )

    # adapted from https://github.com/langchain-ai/langchain/blob/bfc12a4a7644cfc4d832cc4023086a7a5374f46a/libs/langchain/langchain/vectorstores/qdrant.py#L1965
    def _qdrant_filter_from_dict(self, filter: dict) -> Filter:
        if not filter:
            return None

        return Filter(
            must=[
                condition
                for key, value in filter.items()
                for condition in self._build_condition(key, value)
            ]
        )

    # adapted from https://github.com/langchain-ai/langchain/blob/bfc12a4a7644cfc4d832cc4023086a7a5374f46a/libs/langchain/langchain/vectorstores/qdrant.py#L1941
    def _build_condition(self, key: str, value: Any) -> List[FieldCondition]:
        out = []

        if isinstance(value, dict):
            for _key, value in value.items():
                out.extend(self._build_condition(f"{key}.{_key}", value))
        elif isinstance(value, list):
            for _value in value:
                if isinstance(_value, dict):
                    out.extend(self._build_condition(f"{key}[]", _value))
                else:
                    out.extend(self._build_condition(f"{key}", _value))
        else:
            out.append(
                FieldCondition(
                    key=f"metadata.{key}",
                    match=MatchValue(value=value),
                )
            )

        return out

    def add_point(
        self,
        content: str,
        vector: Iterable,
        metadata: dict = None,
        id: Optional[str] = None,
        **kwargs: Any,
    ) -> List[str]:
        """Add a point (and its metadata) to the vectorstore.

        Args:
            content: original text.
            vector: Embedding vector.
            metadata: Optional metadata dict associated with the text.
            id:
                Optional id to associate with the point. Id has to be a uuid-like string.

        Returns:
            Point id as saved into the vectorstore.
        """

        # TODO: may be adapted to upload batches of points as langchain does.
        # Not necessary now as the bottleneck is the embedder
        point = PointStruct(
            id=id or uuid.uuid4().hex,
            payload={
                "page_content": content,
                "metadata": metadata,
            },
            vector=vector,
        )

        update_status = self.client.upsert(
            collection_name=self.collection_name, points=[point], **kwargs
        )

        if update_status.status == "completed":
            # returnign stored point
            return point # TODOV2 return internal MemoryPoint
        else:
            return None

    def delete_points_by_metadata_filter(self, metadata=None):
        res = self.client.delete(
            collection_name=self.collection_name,
            points_selector=self._qdrant_filter_from_dict(metadata),
        )
        return res

    # delete point in collection
    def delete_points(self, points_ids):
        res = self.client.delete(
            collection_name=self.collection_name,
            points_selector=points_ids,
        )
        return res

    # retrieve similar memories from embedding
    def recall_memories_from_embedding(
        self, embedding, metadata=None, k=5, threshold=None
    ):
        # retrieve memories
        memories = self.client.search(
            collection_name=self.collection_name,
            query_vector=embedding,
            query_filter=self._qdrant_filter_from_dict(metadata),
            with_payload=True,
            with_vectors=True,
            limit=k,
            score_threshold=threshold,
            search_params=SearchParams(
                quantization=QuantizationSearchParams(
                    ignore=False,
                    rescore=True,
                    oversampling=2.0,  # Available as of v1.3.0
                )
            ),
        )

        # convert Qdrant points to langchain.Document
        langchain_documents_from_points = []
        for m in memories:
            langchain_documents_from_points.append(
                (
                    Document(
                        page_content=m.payload.get("page_content"),
                        metadata=m.payload.get("metadata") or {},
                    ),
                    m.score,
                    m.vector,
                    m.id,
                )
            )

        # we'll move out of langchain conventions soon and have our own cat Document
        # for doc, score, vector in langchain_documents_from_points:
        #    doc.lc_kwargs = None

        return langchain_documents_from_points

    # retrieve all the points in the collection
    def get_all_points(self):
        # retrieving the points
        all_points, _ = self.client.scroll(
            collection_name=self.collection_name,
            with_vectors=True,
            limit=10000,  # yeah, good for now dear :*
        )

        return all_points

    def db_is_remote(self):
        return isinstance(self.client._client, QdrantRemote)

    # dump collection on disk before deleting
    def save_dump(self, folder="dormouse/"):
        # only do snapshotting if using remote Qdrant
        if not self.db_is_remote():
            return

        host = self.client._client._host
        port = self.client._client._port

        if os.path.isdir(folder):
            log.info("Directory dormouse exists")
        else:
            log.warning("Directory dormouse does NOT exists, creating it.")
            os.mkdir(folder)

        self.snapshot_info = self.client.create_snapshot(
            collection_name=self.collection_name
        )
        snapshot_url_in = (
            "http://"
            + str(host)
            + ":"
            + str(port)
            + "/collections/"
            + self.collection_name
            + "/snapshots/"
            + self.snapshot_info.name
        )
        snapshot_url_out = folder + self.snapshot_info.name
        # rename snapshots for a easyer restore in the future
        alias = (
            self.client.get_collection_aliases(self.collection_name)
            .aliases[0]
            .alias_name
        )
        response = requests.get(snapshot_url_in)
        open(snapshot_url_out, "wb").write(response.content)
        new_name = folder + alias.replace("/", "-") + ".snapshot"
        os.rename(snapshot_url_out, new_name)
        for s in self.client.list_snapshots(self.collection_name):
            self.client.delete_snapshot(
                collection_name=self.collection_name, snapshot_name=s.name
            )
        log.warning(f'Dump "{new_name}" completed')

add_point(content, vector, metadata=None, id=None, **kwargs)

Add a point (and its metadata) to the vectorstore.

Args: content: original text. vector: Embedding vector. metadata: Optional metadata dict associated with the text. id: Optional id to associate with the point. Id has to be a uuid-like string.

Returns: Point id as saved into the vectorstore.

Source code in cat/memory/vector_memory_collection.py
def add_point(
    self,
    content: str,
    vector: Iterable,
    metadata: dict = None,
    id: Optional[str] = None,
    **kwargs: Any,
) -> List[str]:
    """Add a point (and its metadata) to the vectorstore.

    Args:
        content: original text.
        vector: Embedding vector.
        metadata: Optional metadata dict associated with the text.
        id:
            Optional id to associate with the point. Id has to be a uuid-like string.

    Returns:
        Point id as saved into the vectorstore.
    """

    # TODO: may be adapted to upload batches of points as langchain does.
    # Not necessary now as the bottleneck is the embedder
    point = PointStruct(
        id=id or uuid.uuid4().hex,
        payload={
            "page_content": content,
            "metadata": metadata,
        },
        vector=vector,
    )

    update_status = self.client.upsert(
        collection_name=self.collection_name, points=[point], **kwargs
    )

    if update_status.status == "completed":
        # returnign stored point
        return point # TODOV2 return internal MemoryPoint
    else:
        return None