42

Academic Metrics

Active Development

Academic Metrics is an AI-powered toolkit for analyzing and classifying academic research publications using LLMs and automated data collection. It has full documentation and is open source. Academic Metrics has over 10K downloads/month on PyPi and can be installed via pip.

Academic Metrics

Table of Contents

Overview

Academic Metrics is an AI-powered toolkit for analyzing and classifying academic research publications using LLMs and automated data collection. It has full documentation and is open source. Academic Metrics has over 10K downloads/month on PyPi and can be installed via pip install academic-metrics.

The system provides comprehensive functionality for collecting, classifying, and analyzing academic publications. It features Crossref API integration for institutional affiliation-based data collection, smart web scraping for enhanced data completeness, and automated Digital Object Identifier (DOI) processing. Using LLM-powered analysis, it can classify research into NSF PhD focus areas, extract themes and methodologies from abstracts, and generate detailed analytics at article, author, and category levels.

Key Features

  • Automated data collection from multiple sources with smart fusion
  • AI-powered classification using fine-tuned LLM prompts
  • Comprehensive analytics for citations, author statistics, and research trends
  • Flexible data storage options including MongoDB, JSON, and Excel exports
  • Robust processing pipeline with async handling and rate limiting
  • Developer-friendly tools with modular design and extensive documentation

Technical Architecture

The system is built with scalability and maintainability in mind, utilizing sophisticated software design patterns and modern AI technologies:

Pipeline Orchestration

The core pipeline implements a robust orchestration system that:

  • Manages end-to-end processing of academic publication data
  • Handles data collection, classification, and storage in MongoDB
  • Implements intelligent retry mechanisms such as rate limiting and logit bias with configurable thresholds
  • Features comprehensive logging and error handling
  • Includes data validation and deduplication via a custom Minhash algorithm

Classification System

The classification system employs a three-stage process:

  1. Pre-classification:

    • Method extraction and analysis
    • Sentence structure analysis
    • Abstract summarization
  2. Classification:

    • Recursive taxonomy traversal
    • Multi-level category classification
    • Validation with retry mechanisms
  3. Theme Recognition:

    • Key theme identification
    • Concept extraction
    • Higher temperature creative analysis

Strategy Pattern Implementation

The system uses several advanced design patterns:

  • Strategy pattern for flexible attribute extraction
  • Factory pattern with registry for strategy management
  • Decorator pattern for extensible behavior
  • Registry pattern for strategy registration

Command Line Interface

Academic Metrics provides a powerful CLI that supports:

  • Flexible date range processing with year and month granularity
  • Multiple LLM model selection for different processing stages
  • Excel report generation
  • MongoDB integration with configurable connection options
  • Test mode for development and validation
  • Environment variable management for secure credential handling

I built an early version of ChainComposer to aid in the LLM development workflow, which has now been published as its own package.

This tool is useful for analyzing research trends, tracking departmental output, or conducting institutional research assessments.

Code Snippets

Core Pipeline Orchestration

View full source code →
class PipelineRunner:
    """Orchestrates the academic metrics data processing pipeline.
 
    This class manages the end-to-end process of collecting, processing, and storing
    academic publication data. It handles data collection from Crossref, classification
    of publications, generation of statistics, and storage in MongoDB.
 
    Attributes:
        SAVE_OFFLINE_KWARGS (SaveOfflineKwargs): Default configuration for offline processing.
        logger (logging.Logger): Pipeline-wide logger instance.
        ai_api_key (str): API key for AI services.
        db_name (str): Name of the MongoDB database.
        mongodb_uri (str): URI for MongoDB connection.
        db (DatabaseWrapper): Database interface instance.
        scraper (Scraper): Web scraping utility instance.
        crossref_wrapper (CrossrefWrapper): Crossref API interface instance.
        taxonomy (Taxonomy): Publication taxonomy utility.
        warning_manager (WarningManager): Warning logging utility.
        strategy_factory (StrategyFactory): Strategy pattern factory.
        utilities (Utilities): General utility functions.
        classification_orchestrator (ClassificationOrchestrator): Publication classifier.
        dataclass_factory (DataClassFactory): Data class creation utility.
        category_processor (CategoryProcessor): Category statistics processor.
        faculty_postprocessor (FacultyPostprocessor): Faculty data processor.
        department_postprocessor (DepartmentPostprocessor): Department data processor.
        debug (bool): Debug mode flag.
 
    Methods:
        run_pipeline: Executes the main data processing pipeline.
        _create_taxonomy: Creates a new Taxonomy instance.
        _create_classifier_factory: Creates a new ClassifierFactory instance.
        _create_warning_manager: Creates a new WarningManager instance.
        _create_strategy_factory: Creates a new StrategyFactory instance.
        _create_utilities_instance: Creates a new Utilities instance.
        _create_classification_orchestrator: Creates a new ClassificationOrchestrator.
        _create_orchestrator: Creates a new CategoryDataOrchestrator.
        _get_acf_func: Returns the abstract classifier factory function.
        _validate_api_key: Validates the provided API key.
        _make_files: Creates split files from input files.
        _load_files: Loads and returns data from split files.
        _create_dataclass_factory: Creates a new DataClassFactory instance.
        _create_crossref_wrapper: Creates a new CrossrefWrapper instance.
        _create_category_processor: Creates a new CategoryProcessor instance.
        _create_faculty_postprocessor: Creates a new FacultyPostprocessor instance.
        _create_scraper: Creates a new Scraper instance.
        _create_db: Creates a new DatabaseWrapper instance.
        _encode_affiliation: URL encodes an affiliation string.
    """
 
    def run_pipeline(
        self,
        save_offline_kwargs: SaveOfflineKwargs = SAVE_OFFLINE_KWARGS,
        test_filtering: bool | None = False,
        save_to_db: bool | None = True,
    ):
        """Execute the main data processing pipeline.
 
        This method orchestrates the entire pipeline process:
        1. Retrieves existing DOIs from database
        2. Collects new publication data from Crossref
        3. Filters out duplicate articles
        4. Runs AI classification on publications
        5. Processes and generates category statistics
        6. Saves processed data to MongoDB
 
        Args:
            save_offline_kwargs (SaveOfflineKwargs, optional): Configuration for offline processing.
                Defaults to SAVE_OFFLINE_KWARGS.
                - offline: Whether to run in offline mode
                - run_crossref_before_file_load: Run Crossref before loading files
                - make_files: Generate new split files
                - extend: Extend existing data
 
        Raises:
            Exception: If there are errors in data processing or database operations.
        """
        self.logger.info("Running pipeline...")
 
        # Get the existing DOIs from the database
        # so that we don't process duplicates
        self.logger.info("Getting existing DOIs from database...")
        existing_dois: List[str] = []
        if save_to_db:
            existing_dois: List[str] = self.db.get_dois()
        self.logger.info(f"Found {len(existing_dois)} existing DOIs in database")
 
        # Get data from crossref for the school and date range
        self.logger.info("Getting data from Crossref...")
        data: List[Dict[str, Any]] = []
        if save_offline_kwargs["offline"]:
            if save_offline_kwargs["run_crossref_before_file_load"]:
                data: List[Dict[str, Any]] = self.crossref_wrapper.run_all_process()
            if save_offline_kwargs["make_files"]:
                self._make_files()
            data: List[Dict[str, Any]] = self._load_files()
        else:
            # Fetch raw data from Crossref api for the year range
            # and get out the result list containing the raw data.
            data: List[Dict[str, Any]] = (
                self.crossref_wrapper.run_afetch_yrange().get_result_list()
            )
        self.logger.info(
            "Filtering out articles whose DOIs are already in the db or those that are not found..."
        )
        # Then filter out articles whose DOIs are already
        # in the db or those that are not found.
        already_existing_count: int = 0
        filtered_data: List[Dict[str, Any]] = []
        for article in data:
            # Get the DOI out of the article item
            attribute_results: List[str] = self.utilities.get_attributes(
                article, [AttributeTypes.CROSSREF_DOI]
            )
            # Unpack the DOI from the dict returned by get_attributes
            doi = (
                attribute_results[AttributeTypes.CROSSREF_DOI][1]
                if attribute_results[AttributeTypes.CROSSREF_DOI][0]
                else None
            )
            # Only keep articles that have a DOI and aren't already in the database
            if doi is not None:
                if doi not in existing_dois:
                    filtered_data.append(article)
                else:
                    already_existing_count += 1
            else:
                self.logger.warning(f"Article with no DOI: {article}")
                continue
 
        self.logger.info(f"Filtered out {already_existing_count}/{len(data)} articles")
        self.logger.info(f"Articles to process: {len(filtered_data)}")
        self.logger.info("Initial filtering complete")
 
        if len(filtered_data) == 0:
            self.logger.info("No articles to process")
            return
 
        # Then set data to filtered data so we don't
        # keep the raw data floating in memory.
        data: List[Dict[str, Any]] = filtered_data
 
        # Now run final processing to have `Scraper` fetch missing abstracts.
        # Reset the result list in `CrossrefWrapper` so it doesn't
        # run on the original raw data, and instead runs on the filtered data.
        self.logger.info("Resetting CrossrefWrapper result list...")
        self.crossref_wrapper.result = data
        self.logger.info("CrossrefWrapper result list reset successfully")
 
        # Run the final processing to fetch missing abstracts
        # and get out the final data.
        # Again, we don't want to keep the raw data floating in memory,
        # so we reassign `data` to the the result list returned by `.get_result_list()`.
        self.logger.info("Running final processing to fetch missing abstracts...")
        data = self.crossref_wrapper.final_data_process().get_result_list()
        self.logger.info("Final processing complete")
 
        if len(data) == 0:
            self.logger.info(
                "None of the remaining articles have abstracts or none could be retrieved"
            )
            return
 
        if test_filtering:
            print(f"\n\nFiltered out {already_existing_count} articles\n\n")
            print(
                f"\n\nFILTERED DATA VAR CONTENTS:\n{json.dumps(filtered_data, indent=4)}\n\n"
            )
            print(f"\n\nDATA VAR CONTENTS:\n{data}\n\n")
            return
 
        self.logger.info(f"\n\nDATA: {data}\n\n")
 
        if self.debug:
            print(f"There are {len(data)} articles to process.")
            response: str = input("Would you like to slice the data? (y/n)")
            if response == "y":
                res: str = input("How many articles would you like to process?")
                data = data[: int(res)]
                self.logger.info(f"\n\nSLICED DATA:\n{data}\n\n")
 
        # Run classification on all data
        # comment out to run without AI for testing
        self.logger.info("Running classification...")
        data = self.classification_orchestrator.run_classification(
            data,
            pre_classification_model=self.pre_classification_model,
            classification_model=self.classification_model,
            theme_model=self.theme_model,
        )
        self.logger.info("Classification complete")
 
        with open("classified_data.json", "w") as file:
            json.dump(data, file, indent=4)
 
        # Process classified data and generate category statistics
        self.logger.info(
            "Processing classified data and generating category statistics..."
        )
        category_orchestrator: CategoryDataOrchestrator = self._create_orchestrator(
            data=data,
            extend=save_offline_kwargs["extend"],
        )
        category_orchestrator.run_orchestrator()
        self.logger.info("Category statistics processing complete")
 
        # Get all the processed data from CategoryDataOrchestrator
        self.logger.info("Getting final data...")
 
        self.logger.info("Getting final category data...")
        category_data: List[Dict[str, Any]] = (
            category_orchestrator.get_final_category_data()
        )
        self.logger.info("Final category data retrieved successfully")
 
        self.logger.info("Getting final faculty data...")
        # faculty_data = self.category_orchestrator.get_final_faculty_data()
        article_data: List[Dict[str, Any]] = (
            category_orchestrator.get_final_article_data()
        )
        self.logger.info("Final article data retrieved successfully")
 
        self.logger.info("Getting final global faculty data...")
        global_faculty_data: List[Dict[str, Any]] = (
            category_orchestrator.get_final_global_faculty_data()
        )
        self.logger.info("Final global faculty data retrieved successfully")
 
        if save_to_db:
            self.logger.info("Attempting to save data to database...")
            try:
                self.db.insert_categories(category_data)
                self.logger.info(
                    f"""Successfully inserted {len(category_data)} categories into database"""
                )
            except Exception as e:
                self.logger.error(f"Error saving to database: {e}")
 
            try:
                self.db.insert_articles(article_data)
                self.logger.info(
                    f"""Successfully inserted {len(article_data)} articles into database"""
                )
            except Exception as e:
                self.logger.error(f"Error saving to database: {e}")
 
            try:
                self.db.insert_faculty(global_faculty_data)
                self.logger.info(
                    f"""Successfully inserted {len(global_faculty_data)} faculty into database"""
                )
            except Exception as e:
                self.logger.error(f"Error saving to database: {e}")

Sophisticated Strategy Retrieval

This retrieval system uses several patterns:

  • Strategy pattern for input-dependent behavior

  • Factory pattern for strategy retrieval

  • Decorator pattern for extensibility

  • Registry pattern for strategy registration

Interface for Attribute Extraction

View full source code →
class Utilities:
    """A class containing various utility methods for processing and analyzing academic data.
 
    Attributes:
        strategy_factory (StrategyFactory): An instance of the StrategyFactory class.
        warning_manager (WarningManager): An instance of the WarningManager class.
 
    Methods:
        get_attributes(self, data, attributes):
            Extracts specified attributes from the data and returns them in a dictionary.
        crossref_file_splitter(self, *, path_to_file, split_files_dir_path):
            Splits a crossref file into individual entries and creates a separate file for each entry in the specified output directory.
        make_files(self, *, path_to_file: str, split_files_dir_path: str):
            Splits a document into individual entries and creates a separate file for each entry in the specified output directory.
    """
 
    def get_attributes(
        self, 
        data: Dict[str, Any], 
        attributes: List[AttributeTypes]
    ) -> Dict[AttributeTypes, Tuple[bool, Any]]:
        """Extracts specified attributes from the article entry and returns them in a dictionary.
        It also warns about missing or invalid attributes.
 
        Parameters:
            entry_text (str): The text of the article entry.
            attributes (list of str): A list of attribute names to extract from the entry, e.g., ["title", "author"].
 
        Returns:
            dict: A dictionary where keys are attribute names and values are tuples.
                  Each tuple contains a boolean indicating success or failure of extraction,
                  and the extracted attribute value or None.
 
        Raises:
            ValueError: If an attribute not defined in `self.attribute_patterns` is requested.
        """
        attribute_results: Dict[AttributeTypes, Tuple[bool, Any]] = {}
        for attribute in attributes:
            extraction_strategy: AttributeExtractionStrategy = (
                self.strategy_factory.get_strategy(attribute, self.warning_manager)
            )
            attribute_results[attribute] = extraction_strategy.extract_attribute(data)
        return attribute_results

Factory with Registry Pattern

View source code →
from __future__ import annotations
 
import logging
import os
from typing import TYPE_CHECKING
 
from academic_metrics.configs import (
    configure_logging,
    DEBUG,
)
from academic_metrics.enums import AttributeTypes
from academic_metrics.utils import WarningManager
 
if TYPE_CHECKING:
    from academic_metrics.strategies import AttributeExtractionStrategy
 
 
class StrategyFactory:
    """
    A factory class for managing and retrieving attribute extraction strategies.
 
    This class provides a mechanism to register and retrieve different strategies for extracting attributes from data entries. It uses a dictionary to map attribute types to their corresponding strategy classes, allowing for flexible and dynamic strategy management.
 
    Attributes:
        _strategies (dict): A class-level dictionary that maps attribute types to their corresponding strategy classes.
 
    Methods:
        register_strategy(*attribute_types): Registers a strategy class for one or more attribute types.
        get_strategy(attribute_type, warning_manager): Retrieves the strategy class for a given attribute type and initializes it with a warning manager.
 
    Usage:
    - Add a strategy to the factory:
    - StrategyFactory.register_strategy(AttributeTypes.TITLE)(TitleExtractionStrategy)
    - Add the enum to enums.py
    - get a strategy from the factory:
    - get_attributes() in utilities.py will then use this factory to get the strategy for a given attribute type.
    """
 
    _strategies = {}
 
    def __init__(self):
        """Initializes the StrategyFactory."""
        self.logger = configure_logging(
            module_name=__name__,
            log_file_name="strategy_factory",
            log_level=DEBUG,
        )
 
    @classmethod
    def register_strategy(
        cls,
        *attribute_types: AttributeTypes,
    ):
        """
        Registers a strategy class for one or more attribute types.
 
        This method is used to associate a strategy class with specific attribute types. The strategy class
        is stored in the _strategies dictionary, allowing it to be retrieved later based on the attribute type.
 
        Args:
            *attribute_types (AttributeTypes): One or more attribute types to associate with the strategy class.
 
        Returns:
            function: A decorator function that registers the strategy class.
        """
 
        def decorator(strategy_class):
            for attribute_type in attribute_types:
                cls._strategies[attribute_type] = strategy_class
            return strategy_class
 
        return decorator
 
    @classmethod
    def get_strategy(
        cls, attribute_type: AttributeTypes, warning_manager: WarningManager
    ):
        """
        Retrieves the strategy class for a given attribute type and initializes it with a warning manager.
 
        This method looks up the strategy class associated with the specified attribute type in the _strategies
        dictionary. If a strategy class is found, it is instantiated with the provided warning manager and returned.
 
        Args:
            attribute_type (AttributeTypes):
            - The attribute type for which to retrieve the strategy class.
 
            warning_manager (WarningManager):
            - An instance of WarningManager to be passed to the strategy class.
 
        Returns:
            strategy (AttributeExtractionStrategy):
            - An instance of the strategy class associated with the specified attribute type.
 
        Raises:
            ValueError:
            - If no strategy is found for the specified attribute type.
        """
        strategy_class: AttributeExtractionStrategy = cls._strategies.get(
            attribute_type
        )
        if not strategy_class:
            raise ValueError(f"No strategy found for attribute type: {attribute_type}")
        return strategy_class(warning_manager)
Strategy for Digital Object Identifier (DOI) Extraction
View full source code →
@StrategyFactory.register_strategy(AttributeTypes.CROSSREF_DOI)
class CrossrefDOIExtractionStrategy(AttributeExtractionStrategy):
    """A strategy for extracting the DOI from a Crossref entry.
 
    This class implements the AttributeExtractionStrategy for DOI extraction specifically from Crossref JSON data. It focuses on retrieving the DOI associated with a publication.
    """
 
    def __init__(self, warning_manager: WarningManager):
        """Initializes the CrossrefDOIExtractionStrategy.
 
        This constructor sets up the strategy with a warning manager.
 
        Args:
            warning_manager (WarningManager): An instance of WarningManager for handling extraction warnings.
        """
        super().__init__(warning_manager=warning_manager)
        self.logger = configure_logging(
            module_name=__name__,
            log_file_name="crossref_doi_extraction_strategy",
            log_level=DEBUG,
        )
 
    def extract_attribute(
        self, 
        entry_text: dict
    ) -> tuple[bool, str]:
        """Extracts the DOI from the Crossref entry.
 
        Args:
            entry_text (dict): The Crossref JSON data containing the publication information.
 
        Returns:
            tuple[bool, str]: A tuple containing:
            
                - A boolean indicating success (True) if DOI is found, False otherwise.
                
                - A string representing the DOI, or None if not found.
        """
        doi = entry_text.get("DOI")
        if doi:
            return (True, doi)
        self.log_extraction_warning(
            attribute_class_name=self.__class__.__name__,
            warning_message="Attribute: 'Crossref_DOI' was not found in the entry",
            entry_id=entry_text,
        )
        return (False, None)

Classifier System

The snippets below show two methods which orchestrate most of the classification process. To see the full class, you can follow the links provided.

Classify

This method orchestrates the system. It goes:

  • Pre-classification

  • Classification

  • Theme Extraction

View full source code →
    def classify(self) -> Self:
        """Orchestrates the complete classification pipeline for all abstracts.
 
        This method manages the end-to-end processing of all abstracts present in the
        doi_to_abstract_dict dictionary through three stages: pre-classification,
        classification, and theme recognition.
 
        Args:
            None
 
        Returns:
            Self: Returns self for method chaining.
                Type: :class:`academic_metrics.AI.AbstractClassifier.AbstractClassifier`
 
        Notes:
            Pipeline Stages:
            - Pre-classification:
                - Method extraction: Identifies research methods and techniques
                - Sentence analysis: Analyzes abstract structure and components
                - Summarization: Generates structured abstract summary
 
            - Classification:
                - Uses enriched data from pre-classification
                - Recursively classifies through taxonomy levels
                - Validates and retries invalid classifications
 
            - Theme Recognition:
                - Processes classified abstracts
                - Identifies key themes and concepts
                - Uses higher temperature for creative analysis
 
            State Updates:
            - classification_results: Nested defaultdict structure:
            {
                "doi1": {
                    "top_category1": {
                        "mid_category1": ["low1", "low2"],
                        "mid_category2": ["low3", "low4"]
                    },
                    "themes": ["theme1", "theme2"]
                }
            }
            - raw_classification_outputs: List of raw outputs from classification
            - raw_theme_outputs: Dictionary mapping DOIs to theme analysis results
 
            Processing Details:
            - Processes abstracts sequentially
            - Requires initialized chain managers
            - Updates multiple result stores
            - Maintains logging throughout process
            - Chains data between processing stages
        """
        # Track total abstracts for progress logging
        n_abstracts: int = len(self.doi_to_abstract_dict.keys())
 
        # Process each abstract through the complete pipeline
        for i, (doi, abstract) in enumerate(self.doi_to_abstract_dict.items()):
            # Log progress and abstract details for monitoring
            self.logger.info(f"Processing abstract {i+1} of {n_abstracts}")
            self.logger.info(f"Current DOI: {doi}")
            self.logger.info(
                f"Current abstract:\n{abstract[:10]}...{abstract[-10:]}\n\n"
            )
 
            #######################
            # 1. Pre-classification
            #######################
 
            # Initialize initial prompt variables used in the system and human prompts for the pre-classification chain layers
            initial_prompt_variables: Dict[str, Any] = {
                "abstract": abstract,
                "METHOD_JSON_FORMAT": METHOD_JSON_FORMAT,
                "METHOD_EXTRACTION_CORRECT_EXAMPLE_JSON": METHOD_EXTRACTION_CORRECT_EXAMPLE_JSON,
                "METHOD_EXTRACTION_INCORRECT_EXAMPLE_JSON": METHOD_EXTRACTION_INCORRECT_EXAMPLE_JSON,
                "SENTENCE_ANALYSIS_JSON_EXAMPLE": SENTENCE_ANALYSIS_JSON_EXAMPLE,
                "SUMMARY_JSON_STRUCTURE": SUMMARY_JSON_STRUCTURE,
                "extra_context": self.extra_context,
            }
 
            # Execute pre-classification chain (method extraction -> sentence analysis -> summarization)
            self.pre_classification_chain_manager.run(
                prompt_variables_dict=initial_prompt_variables
            )
 
            # Call this (pre_classification_chain_manager) ChainManager instance's get_chain_variables() method to get the current
            # chain variables which includes all initial_prompt_variables and the outputs of the
            # The new items inserted have a key which matches the layers output_passthrough_key_name value.
            prompt_variables: Dict[str, Any] = (
                self.pre_classification_chain_manager.get_chain_variables()
            )
            method_extraction_output: Dict[str, Any] = prompt_variables.get(
                "method_json_output", {}
            )
            self.logger.debug(f"Method extraction output: {method_extraction_output}")
            sentence_analysis_output: Dict[str, Any] = prompt_variables.get(
                "sentence_analysis_output", {}
            )
            self.logger.debug(f"Sentence analysis output: {sentence_analysis_output}")
            summary_output: Dict[str, Any] = prompt_variables.get(
                "abstract_summary_output", {}
            )
            self.logger.debug(f"Summary output: {summary_output}")
 
            ######################
            # 2. Classification
            ######################
 
            # Update the prompt variables by adding classification-specific variables.
            # Start with top-level categories - recursive classification will handle lower levels.
            prompt_variables.update(
                {
                    "categories": self.taxonomy.get_top_categories(),
                    "CLASSIFICATION_JSON_FORMAT": CLASSIFICATION_JSON_FORMAT,
                    "TAXONOMY_EXAMPLE": TAXONOMY_EXAMPLE,
                }
            )
 
            # Execute recursive classification through taxonomy levels
            self.classify_abstract(
                abstract=abstract,
                doi=doi,
                prompt_variables=prompt_variables,
            )
 
            ######################
            # 3. Theme Recognition
            ######################
 
            # Get updated variables after classification.
            # Details:
            #   Once classify_abstract returns it will have classified the abstract into top categories
            #   then recursively classified mid and low level categories within each classified top category
            #   so now this abstract has been classified into all relevant categories and subcategories within the taxonomy.
            #   Given this, we can now process the themes for this abstract.
            #   Like before fetch this ChainManager (classification_chain_manager this time) instance's chain variables and update them:
            prompt_variables: Dict[str, Any] = (
                self.classification_chain_manager.get_chain_variables()
            )
 
            # Add in the theme recognition specific variables
            # The only one not already present in prompt_variables which is present as a placeholder
            # in the theme_recognition_system_prompt is THEME_RECOGNITION_JSON_FORMAT, so we add that in.
            # Then update the categories key with the categories from the classification results.
            prompt_variables.update(
                {
                    "THEME_RECOGNITION_JSON_FORMAT": THEME_RECOGNITION_JSON_FORMAT,
                    "categories": self._get_classification_results_by_doi(doi),
                }
            )
 
            # Execute theme recognition on the current abstract.
            # Details:
            #   Here we actually store the result as we want to want to store this raw output into the raw theme outputs dictionary
            #   We don't need to pull out prompt_variables again as we can just extract the themes directly out the theme_results
            #   Remember, before we had to pull out the prompt_variables as we needed all variables to propagate through to the
            #   future chains which weren't the same ChainManager instance.
            theme_results: Dict[str, Any] = self.theme_chain_manager.run(
                prompt_variables_dict=prompt_variables
            ).get("theme_output", {})
 
            theme_results = ThemeAnalysis(**theme_results)
 
            # Store raw theme results
            self.raw_theme_outputs[doi] = theme_results.model_dump()
 
            # Update final classification results with themes
            # Details:
            #   Done in an if statement to avoid killing the live service if this happens, though it shouldn't,
            #   or at least a more explicit and detailed error should be thrown much earlier.
            #   Due to the context here not much detail is known, so throwing an error isn't particularly helpful.
            if doi in self.classification_results:
                self.classification_results[doi]["themes"] = theme_results.themes
            else:
                # Log error if DOI missing from results (as mentioned before, this shouldn't happen in normal operation, but just in case)
                self.logger.error(
                    f"DOI not found in classification results: {doi}, class results: {self.classification_results}"
                )
 
        return self

Classifying an Abstract through all levels of the Taxonomy

View full source code →
    def classify_abstract(
        self,
        abstract: str,
        doi: str,
        prompt_variables: Dict[str, Any],
        level: str | None = "top",
        parent_category: str | None = None,
        current_dict: Dict[str, Any] | None = None,
    ) -> None:
        """Recursively classifies an abstract through the taxonomy hierarchy.
 
        This method implements a depth-first traversal of the taxonomy tree, classifying
        the abstract at each level and recursively processing subcategories. It maintains
        state using a nested defaultdict structure that mirrors the taxonomy hierarchy.
 
        Args:
            abstract (str): The text of the abstract to classify.
                Type: str
            doi (str): The DOI identifier for the abstract.
                Type: str
            prompt_variables (Dict[str, Any]): Variables required for classification.
                Type: Dict[str, Any]
                Pre-classification requirements:
                - method_json_output: Method extraction results
                - sentence_analysis_output: Sentence analysis results
                - abstract_summary_output: Abstract summary
                Classification requirements:
                - abstract: The abstract text
                - categories: Available categories for current level
                - CLASSIFICATION_JSON_FORMAT: Format specification
                - TAXONOMY_EXAMPLE: Example classifications
            level (str | None): Current taxonomy level ("top", "mid", or "low").
                Type: str | None
                Defaults to "top".
            parent_category (str | None): The parent category from previous level.
                Type: str | None
                Defaults to None.
            current_dict (Dict[str, Any] | None): Current position in classification results.
                Type: Dict[str, Any] | None
                Defaults to None.
 
        Returns:
            None
 
        Raises:
            ValueError: If classification fails validation after max retries.
            Exception: If any other error occurs during classification.
 
        Notes:
            - Pre-classification must run method extraction, sentence analysis, and summarization
            - Top level classification processes into top categories then recursively into subcategories
            - Mid level classification processes into mid categories under parent then into low categories
            - Low level classification appends results to parent mid category's list
            - Validates all classified categories against taxonomy
            - Retries classification up to max_classification_retries times
            - On final retry, bans invalid categories to force valid results
        """
        self.logger.info(f"Classifying abstract at {level} level")
 
        # Start at the top level of our defaultdict if not passed in
        if current_dict is None:
            current_dict = self.classification_results[doi]
 
        try:
            classification_output: Dict[str, Any] = (
                self.classification_chain_manager.run(
                    prompt_variables_dict=prompt_variables
                ).get("classification_output", {})
            )
            self.logger.debug(f"Raw classification output: {classification_output}")
 
            # Use **kwargs to unpack the dictionary into keyword arguments for the Pydantic model.
            # '**classification_output' will fill in the values for the keys in the Pydantic model
            # even if there are more keys present in the output which are not part of the pydantic model.
            # This is critical as the outputs here will have all prompt variables from the ones passed to run()
            # as well as the output of the chain layer.
            classification_output: ClassificationOutput = ClassificationOutput(
                **classification_output
            )
            self.raw_classification_outputs.append(classification_output.model_dump())
 
            # Extract out just the classified categories from the classification output.
            # When the level is top and mid these extracted categories will be used to recursively classify child categories
            # When the level is low these extracted categories will be used to update the current mid category's list of low categories
            classified_categories: List[str] = self.extract_classified_categories(
                classification_output
            )
            self.logger.info(
                f"Classified categories at {level} level: {classified_categories}"
            )
 
            # Validate categories before proceeding
            retry_count: int = 0
            while not all(
                self.is_valid_category(category, level)
                for category in classified_categories
            ):
                # Find the invalid categories
                invalid_categories: List[str] = [
                    category
                    for category in classified_categories
                    if not self.is_valid_category(category, level)
                ]
                if retry_count >= self.max_classification_retries:
                    raise ValueError(
                        f"Failed to get valid category after {self.max_classification_retries} retries. Invalid categories at {level} level. "
                        f"Invalid categories: {invalid_categories}"
                    )
                self.logger.warning(
                    f"Invalid categories at {level} level, retry {retry_count + 1} "
                    f"Invalid categories: {invalid_categories}"
                )
 
                # Only set banned words on the final retry.
                # This is done as words may be split into multiple tokens
                # leading to pieces of words being banned rather than the entire word.
                # This could lead to conflict with actual valid categories, and lead
                # the LLM to not classify into categories that it would otherwise.
                # This is done as a last resort to try and elicit valid categories.
                if retry_count == self.max_classification_retries - 1:
                    self.logger.warning("Final retry - attempting with token banning")
                    self.banned_categories.extend(invalid_categories)
                    self.classification_chain_manager.set_words_to_ban(
                        self.banned_categories
                    )
 
                # Increment retry count
                retry_count += 1
 
                # Retry classification at this level
                classification_output = self.classification_chain_manager.run(
                    prompt_variables_dict=prompt_variables
                ).get("classification_output", {})
 
                # Update the classification output with the new output
                classification_output = ClassificationOutput(**classification_output)
 
                # Update the classified categories with the new output
                classified_categories = self.extract_classified_categories(
                    classification_output
                )
 
            self.logger.info(
                f"Classified categories at {level} level: {classified_categories}"
            )
 
            result: Dict[str, Any] = {}
 
            for category in classified_categories:
                if level == "top":
                    # Get the mid categories for the current top category
                    subcategories: List[str] = self.taxonomy.get_mid_categories(
                        category
                    )
 
                    # Set the next level to mid so the recursive call will classify the mid categories extracted above
                    next_level: str = "mid"
 
                    # Move to this category's dictionary in the defaultdict
                    next_dict: Dict[str, Any] = current_dict[category]
 
                elif level == "mid":
                    # Get the low categories for the current mid category
                    subcategories: List[str] = self.taxonomy.get_low_categories(
                        parent_category, category
                    )
 
                    # Set the next level to low so the recursive call will classify the low categories extracted above
                    next_level: str = "low"
 
                    # Move to this category's dictionary in the defaultdict
                    next_dict: Dict[str, Any] = current_dict[category]
 
                elif level == "low":
                    # Append the low category to the parent (mid) category's list
                    current_dict.append(category)
                    continue
 
                if subcategories:
                    # Update prompt variables with new subcategories
                    prompt_variables.update(
                        {
                            "categories": subcategories,
                        }
                    )
 
                    # Recursively classify the subcategories
                    result[category] = self.classify_abstract(
                        abstract=abstract,
                        doi=doi,
                        prompt_variables=prompt_variables,
                        level=next_level,
                        parent_category=category,
                        current_dict=next_dict,
                    )
 
        except Exception as e:
            self.logger.error(
                f"Error during classification at {level} level:\n"
                f"DOI: {doi}\n"
                f"Current category: {category if 'category' in locals() else 'N/A'}\n"
                f"Parent category: {parent_category}\n"
                f"Exception: {str(e)}\n"
                f"Traceback: {traceback.format_exc()}"
            )
            raise e

Command Line Interface (CLI)

View full source code →
def main(
    openai_api_key_env_var_name: str | None = "OPENAI_API_KEY",
    mongodb_uri_env_var_name: str | None = "MONGODB_URI",
):
    if openai_api_key_env_var_name is None:
        raise ValueError("openai_api_key_env_var_name cannot be None")
 
    if mongodb_uri_env_var_name is None:
        raise ValueError("mongodb_uri_env_var_name cannot be None")
 
    import argparse
    from dotenv import load_dotenv
 
    load_dotenv()
    ai_api_key = os.getenv(openai_api_key_env_var_name)
    mongodb_uri = os.getenv(mongodb_uri_env_var_name)
 
    if ai_api_key is None:
        raise ValueError(
            f"\n\nError: {openai_api_key_env_var_name} environment variable not found."
            "\n\nPlease set the environment variable and try again."
            "\nIf you are unsure how to set an environment variable, or you do not have an OpenAI API key,"
            "\nplease refer to the README.md file for more information:"
            "\nhttps://github.com/spencerpresley/COSC425-DATA"
        )
 
    if mongodb_uri is None:
        raise ValueError(
            f"\n\nError: {mongodb_uri_env_var_name} environment variable not found."
            "\n\nPlease set the environment variable and try again."
            "\nIf you are unsure how to set an environment variable, or you do not have a MongoDB URI,"
            "\nplease refer to the README.md file for more information:"
            "\nhttps://github.com/spencerpresley/COSC425-DATA"
        )
 
    # Create argument parser
    parser = argparse.ArgumentParser(description="Run the academic metrics pipeline")
 
    parser.add_argument(
        "--test-run",
        action="store_true",
        help="Run in test mode using local MongoDB",
    )
    parser.add_argument(
        "--pre-classification-model",
        type=str,
        default="gpt-4o-mini",
        choices=["gpt-4o-mini", "gpt-4o"],
        help="Valid pre-classification-model's are 'gpt-4o-mini' or 'gpt-4o'",
    )
    parser.add_argument(
        "--classification-model",
        type=str,
        default="gpt-4o-mini",
        choices=["gpt-4o-mini", "gpt-4o"],
        help="Valid classification-model's are 'gpt-4o-mini' or 'gpt-4o'",
    )
    parser.add_argument(
        "--theme-model",
        type=str,
        default="gpt-4o-mini",
        choices=["gpt-4o-mini", "gpt-4o"],
        help="Valid theme-model's are 'gpt-4o-mini' or 'gpt-4o'",
    )
 
    parser.add_argument(
        "--from-year",
        type=int,
        default=2024,
        required=True,
        help="Starting year for data collection (e.g., 2019)",
    )
 
    parser.add_argument(
        "--to-year",
        type=int,
        default=2024,
        required=True,
        help="Ending year for data collection (e.g., 2024)",
    )
 
    parser.add_argument(
        "--from-month",
        type=int,
        default=1,
        choices=range(1, 13),
        help="Starting month (1-12, default: 1)",
    )
 
    parser.add_argument(
        "--to-month",
        type=int,
        default=12,
        choices=range(1, 13),
        help="Ending month (1-12, default: 12)",
    )
 
    parser.add_argument(
        "--as-excel",
        action="store_true",
        help="Save data to excel files. This is an additional action, it doesn't remove other data output types.",
    )
 
    parser.add_argument(
        "--db-name",
        type=str,
        default="Site_Data",
        help="Name of the database to use",
    )
 
    parser.add_argument(
        "--crossref-affiliation",
        type=str,
        required=True,
        help="The affiliation to use for the Crossref API",
    )
 
    args = parser.parse_args()
 
    # Configure logging
    logger = configure_logging(__name__, "main", log_level=logging.DEBUG)
 
    pre_classification_model = args.pre_classification_model
    classification_model = args.classification_model
    theme_model = args.theme_model
 
    if args.test_run:
        # Load local mongodb url
        logger.info("Running in test mode using local MongoDB...")
        mongodb_uri = os.getenv("LOCAL_MONGODB_URL")
        pipeline = PipelineRunner(
            ai_api_key=ai_api_key,
            crossref_affiliation="Salisbury University",
            data_from_year=2024,
            data_to_year=2024,
            mongodb_uri=mongodb_uri,
        )
 
        # Execute test run
        pipeline.test_run()
    else:
        # Normal pipeline execution
        logger.info(f"Running in production mode...")
        mongodb_uri = os.getenv(mongodb_uri_env_var_name)
        years = [str(year) for year in range(args.from_year, args.to_year + 1)]
        months = [str(month) for month in range(args.from_month, args.to_month + 1)]
 
        processed_dict = defaultdict(list)
 
        for year in years:
            for month in months:
                pipeline_runner = PipelineRunner(
                    ai_api_key=ai_api_key,
                    crossref_affiliation=args.crossref_affiliation,
                    data_from_month=int(month),
                    data_to_month=int(month),
                    data_from_year=int(year),
                    data_to_year=int(year),
                    mongodb_uri=mongodb_uri,
                    db_name=args.db_name,
                    pre_classification_model=pre_classification_model,
                    classification_model=classification_model,
                    theme_model=theme_model,
                )
                pipeline_runner.run_pipeline()
                processed_dict[year].append(month)
                logger.info(f"Processed year: {year}, month: {month}")
 
    logger.info(f"Processed data: {json.dumps(processed_dict, indent=4)}")
 
    if args.as_excel:
        logger.info("Creating Excel files...")
        db = DatabaseWrapper(db_name=args.db_name, mongo_uri=mongodb_uri)
 
        get_excel_report(db)
 
        logger.info("Excel files created successfully")