Files
blog-embedding/blog_search.py

116 lines
3.3 KiB
Python
Raw Normal View History

2026-03-30 04:44:51 +00:00
"""
Blog Semantic Search
====================
讀取 blog_embeddings.json讓使用者輸入或貼上任意文字
用同一個 embedding 模型算出向量列出最相關的 20 篇文章
需求
- ollama 已啟動且已 pull qwen3-embedding:8b
- pip install numpy requests
- 已經跑過 blog_embeddings.py 產生 blog_embeddings.json
"""
import json
import numpy as np
import requests
# ============================================================
# 設定區
# ============================================================
EMBEDDINGS_FILE = "./blog_embeddings.json"
OLLAMA_URL = "http://localhost:11434/api/embed"
OLLAMA_MODEL = "qwen3-embedding:8b"
# ============================================================
def get_embedding(text: str) -> list[float]:
"""透過 Ollama API 取得 embedding 向量。"""
resp = requests.post(OLLAMA_URL, json={
"model": OLLAMA_MODEL,
"input": text,
})
resp.raise_for_status()
return resp.json()["embeddings"][0]
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""計算兩個向量的 cosine similarity。"""
dot = np.dot(a, b)
norm = np.linalg.norm(a) * np.linalg.norm(b)
if norm == 0:
return 0.0
return float(dot / norm)
def format_row(rank: int, sim: float, title: str, slug: str) -> str:
"""格式化一行結果。"""
bar = "" * int(sim * 30)
return f" {rank:3d}. {sim:.4f} {bar} {title} ({slug})"
def read_query() -> str:
"""讀取使用者輸入,支援多行貼上(連按兩次 Enter 結束)。"""
print("請輸入搜尋文字(連按兩次 Enter 結束):")
lines = []
empty_count = 0
while True:
line = input()
if line == "":
empty_count += 1
if empty_count >= 2:
break
lines.append(line)
else:
empty_count = 0
lines.append(line)
return "\n".join(lines).strip()
def main():
# 讀取 embeddings
with open(EMBEDDINGS_FILE, encoding="utf-8") as f:
data = json.load(f)
if not data:
print("找不到任何文章資料。")
return
print(f"已載入 {len(data)} 篇文章的 embeddings\n")
# 預先轉成 numpy array之後比較比較快
vectors = [np.array(item["embedding"]) for item in data]
while True:
query = read_query()
if not query:
print("沒有輸入,掰掰!")
break
print(f"\n正在產生 embedding模型{OLLAMA_MODEL}...")
query_vec = np.array(get_embedding(query))
# 計算與所有文章的相似度
similarities = []
for i, item in enumerate(data):
sim = cosine_similarity(query_vec, vectors[i])
similarities.append({
"slug": item["slug"],
"title": item.get("title", "(無標題)"),
"similarity": sim,
})
similarities.sort(key=lambda x: x["similarity"], reverse=True)
top = similarities[:20]
print(f"\n🔍 與輸入文字最相關的 20 篇:\n")
for rank, s in enumerate(top, 1):
print(format_row(rank, s["similarity"], s["title"], s["slug"]))
print("\n" + "" * 60 + "\n")
if __name__ == "__main__":
main()