390 lines
12 KiB
Python

import codecs
import csv
import datetime
import io
from io import StringIO
import pandas as pd
from flask import (
Blueprint,
current_app,
flash,
redirect,
render_template,
request,
send_file,
session, # Add this line
url_for,
)
from sqlalchemy import asc, desc
from .db import db
from .models import PaperMetadata, ScheduleConfig, VolumeConfig
bp = Blueprint("main", __name__)
@bp.route("/")
def index():
return render_template("index.html")
REQUIRED_COLUMNS = {"alternative_id", "journal", "doi", "issn", "title"}
@bp.route("/upload", methods=["GET", "POST"])
def upload():
if request.method == "POST":
file = request.files.get("file")
delimiter = request.form.get("delimiter", ",")
duplicate_strategy = request.form.get("duplicate_strategy", "skip")
if not file:
return render_template("upload.html", error="No file selected.")
try:
stream = codecs.iterdecode(file.stream, "utf-8")
content = "".join(stream)
df = pd.read_csv(StringIO(content), delimiter=delimiter)
except Exception as e:
return render_template("upload.html", error=f"Failed to read CSV file: {e}")
missing = REQUIRED_COLUMNS - set(df.columns)
if missing:
return render_template(
"upload.html", error=f"Missing required columns: {', '.join(missing)}"
)
# Optional: parse 'published_online' to date
def parse_date(val):
if pd.isna(val):
return None
try:
return pd.to_datetime(val).date()
except Exception:
return None
# Count statistics
added_count = 0
skipped_count = 0
updated_count = 0
error_count = 0
# Collect error information
errors = []
# Process each row
for index, row in df.iterrows():
try:
# Get DOI from row for error reporting
doi = str(row.get("doi", "N/A"))
# Validate required fields
for field in ["title", "doi", "issn"]:
if pd.isna(row.get(field)) or not str(row.get(field)).strip():
raise ValueError(f"Missing required field: {field}")
# Check if paper with this DOI already exists
existing = PaperMetadata.query.filter_by(doi=doi).first()
if existing:
if duplicate_strategy == 'update':
# Update existing record
existing.title = row["title"]
existing.alt_id = row.get("alternative_id")
existing.issn = row["issn"]
existing.journal = row.get("journal")
existing.type = row.get("type")
existing.language = row.get("language")
existing.published_online = parse_date(row.get("published_online"))
updated_count += 1
else:
# Skip this record
skipped_count += 1
continue
else:
# Create new record
metadata = PaperMetadata(
title=row["title"],
doi=doi,
alt_id=row.get("alternative_id"),
issn=row["issn"],
journal=row.get("journal"),
type=row.get("type"),
language=row.get("language"),
published_online=parse_date(row.get("published_online")),
status="New",
file_path=None,
error_msg=None,
)
db.session.add(metadata)
added_count += 1
except Exception as e:
error_count += 1
errors.append({
"row": index + 2, # +2 because index is 0-based and we have a header row
"doi": row.get("doi", "N/A"),
"error": str(e)
})
continue # Skip this row and continue with the next
try:
db.session.commit()
except Exception as e:
db.session.rollback()
return render_template(
"upload.html", error=f"Failed to save data to database: {e}"
)
# Prepare error samples for display
error_samples = errors[:5] if errors else []
error_message = None
if errors:
error_message = f"Encountered {len(errors)} errors. First 5 shown below."
# Store the full errors list in the session for potential download
if errors:
error_csv = StringIO()
writer = csv.DictWriter(error_csv, fieldnames=["row", "doi", "error"])
writer.writeheader()
writer.writerows(errors)
session["error_data"] = error_csv.getvalue()
return render_template(
"upload.html",
success=f"File processed! Added: {added_count}, Updated: {updated_count}, Skipped: {skipped_count}, Errors: {error_count}",
error_message=error_message,
error_samples=error_samples
)
return render_template("upload.html")
# Add a route to download the error log
@bp.route("/download_error_log")
def download_error_log():
error_data = session.get("error_data")
if not error_data:
flash("No error data available.")
return redirect(url_for("main.upload"))
buffer = StringIO(error_data)
return send_file(
buffer,
mimetype="text/csv",
as_attachment=True,
download_name=f"upload_errors_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
)
@bp.route("/papers")
def list_papers():
page = request.args.get("page", 1, type=int)
per_page = 50
# Filters
status = request.args.get("status")
created_from = request.args.get("created_from")
created_to = request.args.get("created_to")
updated_from = request.args.get("updated_from")
updated_to = request.args.get("updated_to")
sort_by = request.args.get("sort_by", "created_at")
sort_dir = request.args.get("sort_dir", "desc")
query = PaperMetadata.query
# Apply filters
if status:
query = query.filter(PaperMetadata.status == status)
def parse_date(val):
from datetime import datetime
try:
return datetime.strptime(val, "%Y-%m-%d")
except (ValueError, TypeError):
return None
if created_from := parse_date(created_from):
query = query.filter(PaperMetadata.created_at >= created_from)
if created_to := parse_date(created_to):
query = query.filter(PaperMetadata.created_at <= created_to)
if updated_from := parse_date(updated_from):
query = query.filter(PaperMetadata.updated_at >= updated_from)
if updated_to := parse_date(updated_to):
query = query.filter(PaperMetadata.updated_at <= updated_to)
# Sorting
sort_col = getattr(PaperMetadata, sort_by, PaperMetadata.created_at)
sort_func = desc if sort_dir == "desc" else asc
query = query.order_by(sort_func(sort_col))
# Pagination
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
# Statistics
total_papers = PaperMetadata.query.count()
status_counts = (
db.session.query(PaperMetadata.status, db.func.count(PaperMetadata.status))
.group_by(PaperMetadata.status)
.all()
)
status_counts = {status: count for status, count in status_counts}
return render_template(
"papers.html",
papers=pagination.items,
pagination=pagination,
total_papers=total_papers,
status_counts=status_counts,
sort_by=sort_by,
sort_dir=sort_dir,
)
@bp.route("/papers/export")
def export_papers():
query = PaperMetadata.query
# Filters
status = request.args.get("status")
created_from = request.args.get("created_from")
created_to = request.args.get("created_to")
updated_from = request.args.get("updated_from")
updated_to = request.args.get("updated_to")
sort_by = request.args.get("sort_by", "created_at")
sort_dir = request.args.get("sort_dir", "desc")
query = PaperMetadata.query
# Apply filters
if status:
query = query.filter(PaperMetadata.status == status)
def parse_date(val):
try:
return datetime.datetime.strptime(val, "%Y-%m-%d")
except Exception:
return None
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(
["ID", "Title", "Journal", "DOI", "ISSN", "Status", "Created At", "Updated At"]
)
for paper in query:
writer.writerow(
[
paper.id,
paper.title,
getattr(paper, "journal", ""),
paper.doi,
paper.issn,
paper.status,
paper.created_at,
paper.updated_at,
]
)
output.seek(0)
return send_file(
io.BytesIO(output.read().encode("utf-8")),
mimetype="text/csv",
as_attachment=True,
download_name="papers.csv",
)
@bp.route("/papers/<int:paper_id>/detail")
def paper_detail(paper_id):
paper = PaperMetadata.query.get_or_404(paper_id)
return render_template("partials/paper_detail_modal.html", paper=paper)
@bp.route("/schedule", methods=["GET", "POST"])
def schedule():
if request.method == "POST":
try:
# Check if we're updating volume or schedule
if "total_volume" in request.form:
# Volume update
try:
new_volume = float(request.form.get("total_volume", 0))
if new_volume <= 0 or new_volume > 1000:
raise ValueError("Volume must be between 1 and 1000")
volume_config = VolumeConfig.query.first()
if not volume_config:
volume_config = VolumeConfig(volume=new_volume)
db.session.add(volume_config)
else:
volume_config.volume = new_volume
db.session.commit()
flash("Volume updated successfully!", "success")
except ValueError as e:
db.session.rollback()
flash(f"Error updating volume: {str(e)}", "error")
else:
# Schedule update logic
# Validate form data
for hour in range(24):
key = f"hour_{hour}"
if key not in request.form:
raise ValueError(f"Missing data for hour {hour}")
try:
weight = float(request.form.get(key, 0))
if weight < 0 or weight > 5:
raise ValueError(
f"Weight for hour {hour} must be between 0 and 5"
)
except ValueError:
raise ValueError(f"Invalid weight value for hour {hour}")
# Update database if validation passes
for hour in range(24):
key = f"hour_{hour}"
weight = float(request.form.get(key, 0))
config = ScheduleConfig.query.get(hour)
if config:
config.weight = weight
else:
db.session.add(ScheduleConfig(hour=hour, weight=weight))
db.session.commit()
flash("Schedule updated successfully!", "success")
except ValueError as e:
db.session.rollback()
flash(f"Error updating schedule: {str(e)}", "error")
schedule = {
sc.hour: sc.weight
for sc in ScheduleConfig.query.order_by(ScheduleConfig.hour).all()
}
volume = VolumeConfig.query.first()
return render_template(
"schedule.html",
schedule=schedule,
volume=volume.volume,
app_title="PaperScraper",
)
@bp.route("/logs")
def logs():
return render_template("logs.html", app_title="PaperScraper")
@bp.route("/about")
def about():
return render_template("about.html", app_title="PaperScraper")