262 lines
9.4 KiB
Python

"""Upload functionality for paper metadata."""
import codecs
import csv
import datetime
from io import StringIO
import json
import pandas as pd
from flask import (
Blueprint,
flash,
jsonify,
redirect,
render_template,
request,
send_file,
session,
url_for,
current_app
)
from ..db import db
from ..models import PaperMetadata, ActivityLog
from ..celery import celery # Import the celery instance directly
from ..defaults import DUPLICATE_STRATEGIES
bp = Blueprint("upload", __name__)
REQUIRED_COLUMNS = {"alternative_id", "journal", "doi", "issn", "title"}
CHUNK_SIZE = 100 # Number of rows to process per batch
def parse_date(date_str):
"""Parse date string into datetime object."""
if not date_str or pd.isna(date_str):
return None
try:
return datetime.datetime.strptime(date_str, "%Y-%m-%d")
except ValueError:
return None
@bp.route("/", methods=["GET", "POST"])
def upload():
if request.method == "POST":
file = request.files.get("file")
delimiter = request.form.get("delimiter", ",")
duplicate_strategy = request.form.get("duplicate_strategy", "skip")
if not file:
return jsonify({"error": "No file selected."})
stream = codecs.iterdecode(file.stream, "utf-8")
content = "".join(stream)
# Trigger the Celery task
task = process_csv.delay(content, delimiter, duplicate_strategy)
return jsonify({"task_id": task.id})
return render_template("upload.html.jinja", duplicate_strategies=DUPLICATE_STRATEGIES)
@celery.task(bind=True)
def process_csv(self, file_content, delimiter, duplicate_strategy="skip"):
"""Process CSV file and import paper metadata."""
# With the ContextTask in place, we're already inside an app context
added_count = skipped_count = updated_count = error_count = 0
errors = []
skipped_records = [] # Add this to track skipped records
try:
# Log the start of import using ActivityLog model
ActivityLog.log_import_activity(
action="start_csv_import",
status="processing",
description=f"Starting CSV import with strategy: {duplicate_strategy}",
file_size=len(file_content),
delimiter=delimiter
)
# Set initial progress percentage
self.update_state(state='PROGRESS', meta={'progress': 10})
# Read CSV into chunks
csv_buffer = StringIO(file_content)
# Count total chunks
csv_buffer.seek(0)
total_chunks = len(list(pd.read_csv(csv_buffer, delimiter=delimiter, chunksize=CHUNK_SIZE)))
csv_buffer.seek(0)
# Process each chunk of rows
for chunk_idx, chunk in enumerate(pd.read_csv(csv_buffer, delimiter=delimiter, chunksize=CHUNK_SIZE)):
for index, row in chunk.iterrows():
try:
doi = str(row.get("doi", "N/A"))
# Validate required fields
if pd.isna(row.get("title")) or pd.isna(row.get("doi")) or pd.isna(row.get("issn")):
raise ValueError("Missing required fields")
# Try finding an existing record based on DOI
existing = db.session.query(PaperMetadata).filter_by(doi=doi).first()
if existing:
if duplicate_strategy == "update":
existing.title = row["title"]
existing.alt_id = row.get("alternative_id")
existing.issn = row["issn"]
existing.journal = row.get("journal")
existing.published_online = parse_date(row.get("published_online"))
updated_count += 1
else:
# Track why this record was skipped
skipped_records.append({
"row": index + 2,
"doi": doi,
"reason": f"Duplicate DOI found and strategy is '{duplicate_strategy}'"
})
skipped_count += 1
continue
else:
metadata = PaperMetadata(
title=row["title"],
doi=doi,
alt_id=row.get("alternative_id"),
issn=row["issn"],
journal=row.get("journal"),
published_online=parse_date(row.get("published_online")),
status="New",
)
db.session.add(metadata)
added_count += 1
except Exception as e:
error_count += 1
errors.append({"row": index + 2, "doi": row.get("doi", "N/A"), "error": str(e)})
# Commit the chunk and roll session fresh
db.session.commit()
# Log periodic progress every 5 chunks
if (chunk_idx + 1) % 5 == 0:
ActivityLog.log_import_activity(
action="import_progress",
status="processing",
description=f"Processed {chunk_idx+1}/{total_chunks} chunks",
current_stats={
"added": added_count,
"updated": updated_count,
"skipped": skipped_count,
"errors": error_count
}
)
progress = min(90, 10 + int((chunk_idx + 1) * 80 / total_chunks))
self.update_state(state='PROGRESS', meta={'progress': progress})
# Final progress update and completion log
self.update_state(state='PROGRESS', meta={'progress': 100})
ActivityLog.log_import_activity(
action="complete_csv_import",
status="success",
description="CSV import completed",
stats={
"added": added_count,
"updated": updated_count,
"skipped": skipped_count,
"errors": error_count
}
)
except Exception as e:
db.session.rollback()
ActivityLog.log_error(
error_message="CSV import failed",
exception=e,
severity="error",
source="upload.process_csv"
)
return {'error': str(e), 'progress': 0}
finally:
db.session.remove()
# If there were errors, store an error CSV for potential download
if errors:
try:
error_csv = StringIO()
writer = csv.DictWriter(error_csv, fieldnames=["row", "doi", "error"])
writer.writeheader()
writer.writerows(errors)
ActivityLog.log_import_activity(
action="import_errors",
status="error",
description=f"Import completed with {error_count} errors",
error_csv=error_csv.getvalue(),
task_id=self.request.id,
error_count=error_count
)
except Exception:
# Do not fail the task if error logging fails
pass
# Update the return value to include skipped records information
return {
"added": added_count,
"updated": updated_count,
"skipped": skipped_count,
"skipped_records": skipped_records[:5], # Include up to 5 examples
"skipped_reason_summary": "Records were skipped because they already exist in the database. Use 'update' strategy to update them.",
"errors": errors[:5],
"error_count": error_count,
"task_id": self.request.id
}
@bp.route("/task_status/<task_id>")
def task_status(task_id):
"""Get status of background task."""
task = celery.AsyncResult(task_id)
if task.state == "PENDING":
response = {"state": task.state, "progress": 0}
elif task.state == "PROGRESS":
response = {
"state": task.state,
"progress": task.info.get("progress", 0)
}
elif task.state == "SUCCESS":
response = {
"state": task.state,
"result": task.result
}
else: # FAILURE, REVOKED, etc.
response = {
"state": task.state,
"error": str(task.info) if task.info else "Unknown error"
}
return jsonify(response)
@bp.route("/download_error_log/<task_id>")
def download_error_log(task_id):
# Find the most recent error log for this task
error_log = ActivityLog.query.filter(
ActivityLog.action == "import_errors",
ActivityLog.extra_data.like(f'%"{task_id}"%') # Search in JSON
).order_by(ActivityLog.timestamp.desc()).first()
if not error_log:
flash("No error data available.")
return redirect(url_for("upload.upload"))
# Get the CSV data from extra_data
extra_data = error_log.get_extra_data()
error_csv = extra_data.get("error_csv")
if not error_csv:
flash("Error data format is invalid.")
return redirect(url_for("upload.upload"))
buffer = StringIO(error_csv)
return send_file(
buffer,
mimetype="text/csv",
as_attachment=True,
download_name=f"upload_errors_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
)