Spaces:
Runtime error
Runtime error
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from smolagents import Tool | |
| import modal | |
| from .app import app | |
| from .image import image | |
| from .volume import volume | |
| class RemoteObjectDetectionModelRetrieverModalApp: | |
| def setup(self): | |
| self.vector_store = FAISS.load_local( | |
| folder_path="/volume/vector_store", | |
| embeddings=HuggingFaceEmbeddings( | |
| model_name="all-MiniLM-L6-v2", | |
| model_kwargs={"device": "cuda"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| show_progress=True, | |
| ), | |
| index_name="object_detection_models_faiss_index", | |
| allow_dangerous_deserialization=True, | |
| ) | |
| def forward(self, query: str) -> str: | |
| docs = self.vector_store.similarity_search(query, k=7) | |
| model_ids = [doc.metadata["model_id"] for doc in docs] | |
| model_labels = [doc.metadata["model_labels"] for doc in docs] | |
| models_dict = { | |
| model_id: model_labels | |
| for model_id, model_labels in zip(model_ids, model_labels) | |
| } | |
| return models_dict | |
| class RemoteObjectDetectionModelRetrieverTool(Tool): | |
| name = "object_detection_model_retriever" | |
| description = """ | |
| For a given class of objects, retrieve the models that can detect that class. | |
| The query is a string that describes the class of objects the model needs to detect. | |
| The output is a dictionary with the model id as the key and the labels that the model can detect as the value. | |
| """ | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The class of objects the model needs to detect.", | |
| } | |
| } | |
| output_type = "object" | |
| def __init__(self): | |
| super().__init__() | |
| self.tool_class = modal.Cls.from_name( | |
| app.name, RemoteObjectDetectionModelRetrieverModalApp.__name__ | |
| ) | |
| def forward(self, query: str) -> str: | |
| assert isinstance(query, str), "Your search query must be a string" | |
| tool = self.tool_class() | |
| result = tool.forward.remote(query) | |
| return result | |