116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
|
|
"""
|
|||
|
|
Blog Semantic Search
|
|||
|
|
====================
|
|||
|
|
讀取 blog_embeddings.json,讓使用者輸入(或貼上)任意文字,
|
|||
|
|
用同一個 embedding 模型算出向量,列出最相關的 20 篇文章。
|
|||
|
|
|
|||
|
|
需求:
|
|||
|
|
- ollama 已啟動且已 pull qwen3-embedding:8b
|
|||
|
|
- pip install numpy requests
|
|||
|
|
- 已經跑過 blog_embeddings.py 產生 blog_embeddings.json
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import numpy as np
|
|||
|
|
import requests
|
|||
|
|
|
|||
|
|
# ============================================================
|
|||
|
|
# 設定區
|
|||
|
|
# ============================================================
|
|||
|
|
|
|||
|
|
EMBEDDINGS_FILE = "./blog_embeddings.json"
|
|||
|
|
OLLAMA_URL = "http://localhost:11434/api/embed"
|
|||
|
|
OLLAMA_MODEL = "qwen3-embedding:8b"
|
|||
|
|
|
|||
|
|
# ============================================================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_embedding(text: str) -> list[float]:
|
|||
|
|
"""透過 Ollama API 取得 embedding 向量。"""
|
|||
|
|
resp = requests.post(OLLAMA_URL, json={
|
|||
|
|
"model": OLLAMA_MODEL,
|
|||
|
|
"input": text,
|
|||
|
|
})
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
return resp.json()["embeddings"][0]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
|||
|
|
"""計算兩個向量的 cosine similarity。"""
|
|||
|
|
dot = np.dot(a, b)
|
|||
|
|
norm = np.linalg.norm(a) * np.linalg.norm(b)
|
|||
|
|
if norm == 0:
|
|||
|
|
return 0.0
|
|||
|
|
return float(dot / norm)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def format_row(rank: int, sim: float, title: str, slug: str) -> str:
|
|||
|
|
"""格式化一行結果。"""
|
|||
|
|
bar = "█" * int(sim * 30)
|
|||
|
|
return f" {rank:3d}. {sim:.4f} {bar} {title} ({slug})"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def read_query() -> str:
|
|||
|
|
"""讀取使用者輸入,支援多行貼上(連按兩次 Enter 結束)。"""
|
|||
|
|
print("請輸入搜尋文字(連按兩次 Enter 結束):")
|
|||
|
|
lines = []
|
|||
|
|
empty_count = 0
|
|||
|
|
while True:
|
|||
|
|
line = input()
|
|||
|
|
if line == "":
|
|||
|
|
empty_count += 1
|
|||
|
|
if empty_count >= 2:
|
|||
|
|
break
|
|||
|
|
lines.append(line)
|
|||
|
|
else:
|
|||
|
|
empty_count = 0
|
|||
|
|
lines.append(line)
|
|||
|
|
return "\n".join(lines).strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# 讀取 embeddings
|
|||
|
|
with open(EMBEDDINGS_FILE, encoding="utf-8") as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
|
|||
|
|
if not data:
|
|||
|
|
print("找不到任何文章資料。")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
print(f"已載入 {len(data)} 篇文章的 embeddings\n")
|
|||
|
|
|
|||
|
|
# 預先轉成 numpy array,之後比較比較快
|
|||
|
|
vectors = [np.array(item["embedding"]) for item in data]
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
query = read_query()
|
|||
|
|
if not query:
|
|||
|
|
print("沒有輸入,掰掰!")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
print(f"\n正在產生 embedding(模型:{OLLAMA_MODEL})...")
|
|||
|
|
query_vec = np.array(get_embedding(query))
|
|||
|
|
|
|||
|
|
# 計算與所有文章的相似度
|
|||
|
|
similarities = []
|
|||
|
|
for i, item in enumerate(data):
|
|||
|
|
sim = cosine_similarity(query_vec, vectors[i])
|
|||
|
|
similarities.append({
|
|||
|
|
"slug": item["slug"],
|
|||
|
|
"title": item.get("title", "(無標題)"),
|
|||
|
|
"similarity": sim,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
|||
|
|
|
|||
|
|
top = similarities[:20]
|
|||
|
|
print(f"\n🔍 與輸入文字最相關的 20 篇:\n")
|
|||
|
|
for rank, s in enumerate(top, 1):
|
|||
|
|
print(format_row(rank, s["similarity"], s["title"], s["slug"]))
|
|||
|
|
|
|||
|
|
print("\n" + "─" * 60 + "\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|