上傳檔案到「/」
This commit is contained in:
115
blog_search.py
Normal file
115
blog_search.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Blog Semantic Search
|
||||
====================
|
||||
讀取 blog_embeddings.json,讓使用者輸入(或貼上)任意文字,
|
||||
用同一個 embedding 模型算出向量,列出最相關的 20 篇文章。
|
||||
|
||||
需求:
|
||||
- ollama 已啟動且已 pull qwen3-embedding:8b
|
||||
- pip install numpy requests
|
||||
- 已經跑過 blog_embeddings.py 產生 blog_embeddings.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
# ============================================================
|
||||
# 設定區
|
||||
# ============================================================
|
||||
|
||||
EMBEDDINGS_FILE = "./blog_embeddings.json"
|
||||
OLLAMA_URL = "http://localhost:11434/api/embed"
|
||||
OLLAMA_MODEL = "qwen3-embedding:8b"
|
||||
|
||||
# ============================================================
|
||||
|
||||
|
||||
def get_embedding(text: str) -> list[float]:
|
||||
"""透過 Ollama API 取得 embedding 向量。"""
|
||||
resp = requests.post(OLLAMA_URL, json={
|
||||
"model": OLLAMA_MODEL,
|
||||
"input": text,
|
||||
})
|
||||
resp.raise_for_status()
|
||||
return resp.json()["embeddings"][0]
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""計算兩個向量的 cosine similarity。"""
|
||||
dot = np.dot(a, b)
|
||||
norm = np.linalg.norm(a) * np.linalg.norm(b)
|
||||
if norm == 0:
|
||||
return 0.0
|
||||
return float(dot / norm)
|
||||
|
||||
|
||||
def format_row(rank: int, sim: float, title: str, slug: str) -> str:
|
||||
"""格式化一行結果。"""
|
||||
bar = "█" * int(sim * 30)
|
||||
return f" {rank:3d}. {sim:.4f} {bar} {title} ({slug})"
|
||||
|
||||
|
||||
def read_query() -> str:
|
||||
"""讀取使用者輸入,支援多行貼上(連按兩次 Enter 結束)。"""
|
||||
print("請輸入搜尋文字(連按兩次 Enter 結束):")
|
||||
lines = []
|
||||
empty_count = 0
|
||||
while True:
|
||||
line = input()
|
||||
if line == "":
|
||||
empty_count += 1
|
||||
if empty_count >= 2:
|
||||
break
|
||||
lines.append(line)
|
||||
else:
|
||||
empty_count = 0
|
||||
lines.append(line)
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def main():
|
||||
# 讀取 embeddings
|
||||
with open(EMBEDDINGS_FILE, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not data:
|
||||
print("找不到任何文章資料。")
|
||||
return
|
||||
|
||||
print(f"已載入 {len(data)} 篇文章的 embeddings\n")
|
||||
|
||||
# 預先轉成 numpy array,之後比較比較快
|
||||
vectors = [np.array(item["embedding"]) for item in data]
|
||||
|
||||
while True:
|
||||
query = read_query()
|
||||
if not query:
|
||||
print("沒有輸入,掰掰!")
|
||||
break
|
||||
|
||||
print(f"\n正在產生 embedding(模型:{OLLAMA_MODEL})...")
|
||||
query_vec = np.array(get_embedding(query))
|
||||
|
||||
# 計算與所有文章的相似度
|
||||
similarities = []
|
||||
for i, item in enumerate(data):
|
||||
sim = cosine_similarity(query_vec, vectors[i])
|
||||
similarities.append({
|
||||
"slug": item["slug"],
|
||||
"title": item.get("title", "(無標題)"),
|
||||
"similarity": sim,
|
||||
})
|
||||
|
||||
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
top = similarities[:20]
|
||||
print(f"\n🔍 與輸入文字最相關的 20 篇:\n")
|
||||
for rank, s in enumerate(top, 1):
|
||||
print(format_row(rank, s["similarity"], s["title"], s["slug"]))
|
||||
|
||||
print("\n" + "─" * 60 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user