Category: AI

  • A Guide on Estimating Long-Term Effects in A/B Tests

    A Guide on Estimating Long-Term Effects in A/B Tests

    Kseniia Baidina

    Addressing the complexity of identifying and measuring long-term effects in online experiments

    Photo by Isaac Smith on Unsplash

    Imagine you’re an analyst at an online store. You and your team aim to understand how offering free delivery will affect the number of orders on the platform, so you decide to run an A/B test. The test group enjoys free delivery, while the control group sticks to the regular delivery fare. In the initial days of the experiment, you’ll observe more people completing orders after adding items to their carts. But the real impact is long-term — users in the test group are more likely to return for future shopping on your platform because they know you offer free delivery.

    In essence, what’s the key takeaway from this example? The impact of free delivery on orders tends to increase gradually. Testing it for only a short period might mean you miss the whole story, and this is a challenge we aim to address in this article.

    Understanding Why Long-Term and Short-Term Effects May Differ

    Overall, there could be multiple reasons why short-term effects of the experiment differ from long-term effects [1]:

    Heterogeneous Treatment Effect

    • The impact of the experiment may vary for frequent and occasional users of the product. In the short run, frequent users might disproportionately influence the experiment’s outcome, introducing bias to the average treatment effect.

    User Learning

    • Novelty Effect — picture this: you introduce a new gamification mechanic to your product. Initially, users are curious, but this effect tends to decrease over time.
    • Primacy Effect — think about when Facebook changed its ranking algorithm from chronological to recommendations. Initially, there might be a drop in time spent in the feed as users can’t find what they expect, leading to frustration. However, over time, engagement is likely to recover as users get used to the new algorithm, and discover interesting posts. Users may initially react negatively but eventually adapt, leading to increased engagement.

    In this article, our focus will be on addressing two questions:

    How to identify and test whether the long-term impact of the experiment differs from the short-term?

    How to estimate the long-term effect when running the experiment for a sufficiently long period isn’t possible?

    Methods for Identifying Trends in Long-Term Effects

    Visualization

    The initial step is to observe how the difference between the test and control groups changes over time. If you notice a pattern like this, you will have to dive into the details to grasp the long-term effect.

    Illustration from Sadeghi et al. (2021) [2]

    It might be also tempting to plot the experiment’s effect based not only on the experiment day but also on the number of days from the first exposure.

    Illustration from Sadeghi et al. (2021) [2]

    However, there are several pitfalls when you look at the number of days from the first exposure:

    • Engaged Users Bias: The right side of the chart might show more engaged users. The observed pattern might not be due to user learning but because of diverse treatment effects. The impact on highly engaged users could be different from the effect on occasional users.
    • Selective Sampling Issue: We could decide to focus solely on highly engaged users and observe how their effect evolves over time. However, this subset may not accurately represent the entire user base.
    • Decreasing User Numbers: There may be only a few users who have a substantial number of days since the first exposure (the right part of the graph). This widens the confidence intervals, making it tricky to draw dependable conclusions.

    The visual method for identifying long-term effects in an experiment is quite straightforward, and it’s always a good starting point to observe the difference in effects over time. However, this approach lacks rigor; you might also consider formally testing the presence of long-term effects. We’ll explore that in the next part.

    Ladder Experiment Assignment [2]

    The concept behind this approach is as follows: before initiating the experiment, we categorize users into k cohorts and incrementally introduce them to the experiment. For instance, if we divide users into 4 cohorts, k_1 is the control group, k_2 receives the treatment from week 1, k_3 from week 2, and k_4 from week 3.

    Illustration from Sadeghiet al. (2021)²

    The user-learning rate can be estimated by comparing the treatment effects from various time periods.

    Illustration from Sadeghi et al. (2021) [2]

    For instance, if you aim to estimate user learning in week 4, you would compare values T4_5 and T4_2.

    The challenges with this approach are quite evident. Firstly, it introduces extra operational complexities to the experiment design. Secondly, a substantial number of users are needed to effectively divide them into different cohorts and attain reasonable statistical significance levels. Thirdly, one should anticipate having different long-term effects beforehand, and prepare to run an experiment in this complicated setting.

    Difference-in-Difference [2]

    This approach is a simplified version of the previous one. We split the experiment into two (or more generally, into k) time periods and compare the treatment effect in the first period with the treatment effect in the k-th period.

    Illustration from Sadeghi et al. (2021) [2]

    In this approach, a vital question is how to estimate the variance of the estimate to make conclusions about statistical significance. The authors suggest the following formula (for details, refer to the article):

    Illustration from Sadeghi et al. (2021) [2]

    σ2 — the variance of each experimental unit within each time window

    ρ — the correlation of the metric for each experimental unit in two time windows

    Random VS Constant Treatment Assignment³

    This is another extension of the ladder experiment assignment. In this approach, the pool of users is divided into three groups: C — control group, E — the group that receives treatment throughout the experiment, and E1 — the group in which users are assigned to treatment every day with probability p. As a result, each user in the E1 group will receive treatment only a few days, preventing user learning. Now, how do we estimate user learning? Let’s introduce E1_d — a fraction of users from E1 exposed to treatment on day d. The user learning rate is then determined by the difference between E and E1_d.

    User “Unlearning” [3]

    This approach enables us to assess both the existence of user learning and the duration of this learning. The concept is quite elegant: it posits that users learn at the same rate as they “unlearn.” The idea is as follows: turn off the experiment and observe how the test and control groups converge over time. As both groups will receive the same treatment post-experiment, any changes in their behavior will occur because of the different treatments during the experiment period.

    This approach helps us measure the period required for users to “forget” about the experiment, and we assume that this forgetting period will be equivalent to the time users take to learn during the feature roll-out.

    This method has two significant drawbacks: firstly, it requires a considerable amount of time to analyze user learning. Initially, you run an experiment for an extended period to allow users to “learn,” and then you must deactivate the experiment and wait for them to “unlearn.” This process can be time-consuming. Secondly, you need to deactivate the experimental feature, which businesses may be hesitant to do.

    Methods for Assessing the Long-Term Effects [4]

    You’ve successfully established the existence of user learning in your experiment, and it’s clear that the long-term results are likely to differ from what you observe in the short term. Now, the question is how to predict these long-term results without running the experiment for weeks or even months.

    One approach is to attempt predicting long-run outcomes of Y using short-term data. The simplest method is to use lags of Y, and it is referred to as “auto-surrogate” models. Suppose you want to predict the experiment’s result after two months but currently have only two weeks of data. In this scenario, you can train a linear regression (or any other) model:

    Illustration from Zhang et al. (2023) [5]

    m is the average daily outcome for user i over two months

    Yi_t are value of the metric for user i at day t (T ranges from 1 to 14 in our case)

    In that case, the long-term treatment effect is determined by the difference in predicted values of the metric for the test and control groups using surrogate models.

    Illustration from Zhang et al. (2023) [5]

    Where N_a represents the number of users in the experiment group, and N_0 represents the number of users in the control group.

    There appears to be an inconsistency here: we aim to predict μ (the long-term effect of the experiment), but to train the model, we require this μ. So, how do we obtain the model? There are two approaches:

    • Using pre-experiment data: We can train a model using two months of pre-experiment data for the same users.
    • Similar experiments: We can select a “gold standard” experiment from the same product domain that ran for two months and use it to train the model.

    In their article, Netflix validated this approach using 200 experiments and concluded that surrogate index models are consistent with long-term measurements in 95% of experiments [5].

    Conclusion

    We’ve learned a lot, so let’s summarize it. Short-term experiment results often differ from the long-term due to factors like heterogeneous treatment effects or user learning. There are several approaches to detect this difference, with the most straightforward being:

    • Visual Approach: Simply observing the difference between the test and control over time. However, this method lacks rigor.
    • Difference-in-Difference: Comparing the difference in the test and control at the beginning and after some time of the experiment.

    If you suspect user learning in your experiment, the ideal approach is to extend the experiment until the treatment effect stabilizes. However, this may not always be feasible due to technical (e.g., short-lived cookies) or business restrictions. In such cases, you can predict the long-term effect using auto-surrogate models, forecasting the long-term outcome of the experiment on Y using lags of Y.

    Thank you for taking the time to read this article. I would love to hear your thoughts, so please feel free to share any comments or questions you may have.

    References

    1. N. Larsen, J. Stallrich, S. Sengupta, A. Deng, R. Kohavi, N. T. Stevens, Statistical Challenges in Online Controlled Experiments: A Review of A/B Testing Methodology (2023), https://arxiv.org/pdf/2212.11366.pdf
    2. S. Sadeghi, S. Gupta, S. Gramatovici, J. Lu, H. Ai, R. Zhang, Novelty and Primacy: A Long-Term Estimator for Online Experiments (2021), https://arxiv.org/pdf/2102.12893.pdf
    3. H. Hohnhold, D. O’Brien, D. Tang, Focusing on the Long-term: It’s Good for Users and Business (2015), https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43887.pdf
    4. S. Athey, R. Chetty, G. W. Imbens, H. Kang, The Surrogate Index: Combining Short-Term Proxies to Estimate Long-Term Treatment Effects More Rapidly and Precisely (2019), https://www.nber.org/system/files/working_papers/w26463/w26463.pdf
    5. V. Zhang, M. Zhao, A. Le, M. Dimakopoulou, N. Kallus, Evaluating the Surrogate Index as a Decision-Making Tool Using 200 A/B Tests at Netflix (2023), https://arxiv.org/pdf/2311.11922.pdf


    A Guide on Estimating Long-Term Effects in A/B Tests was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

    Originally appeared here:
    A Guide on Estimating Long-Term Effects in A/B Tests

    Go Here to Read this Fast! A Guide on Estimating Long-Term Effects in A/B Tests

  • Designing and Deploying a Machine Learning Python Application (Part 2)

    Designing and Deploying a Machine Learning Python Application (Part 2)

    Noah Haglund

    You don’t have to be Atlas to get your model into the cloud

    Image by Midjourney

    Now that we have our trained Detectron2 model (see Part 1), let’s deploy it as a part of an application to provide its inferencing abilities to others.

    Even though Part 1 and 2 of this series use Detectron2 for Object Detection, no matter the machine learning library you are using (Detectron, Yolo, PyTorch, Tensorflow, etc) and no matter your use case (Computer Vision, Natural Language Processing, Deep Learning, etc), various topics discussed here concerning model deployment will be useful for all those developing ML processes.

    Although the fields of Data Science and Computer Science overlap in many ways, training and deploying an ML model combines the two, as those concerned with developing an efficient and accurate model are not typically the ones trying to deploy it and vice versa. On the other hand, someone more CS oriented may not have the understanding of ML or its associated libraries to determine whether application bottlenecks could be fixed with configurations to the ML process or rather the backend and hosting service/s.

    In order to aid you in your quest to deploy an application that utilizes ML, this article will begin by discussing: (1) high level CS design concepts that can help DS folks makes decisions in order to balance load and mitigate bottlenecks and (2) low level design by walking through deploying a Detectron2 inferencing process using the Python web framework Django, an API using Django Rest Framework, the distributed task queue Celery, Docker, Heroku, and AWS S3.

    For following along with this article, it will be helpful to have in advance:

    • Strong Python Knowledge
    • Understanding of Django, Django Rest Framework, Docker, Celery, and AWS
    • Familiarity with Heroku

    High Level Design

    In order to dig into the high level design, let’s discuss a couple key problems and potential solutions.

    Problem 1: Memory

    The saved ML model from Part 1, titled model_final.pth, will start off at ~325MB. With more training data, the model will increase in size, with models trained on large datasets (100,000+ annotated images) increasing to ~800MB. Additionally, an application based on (1) a Python runtime, (2) Detectron2, (3) large dependencies such as Torch, and (4) a Django web framework will utilize ~150MB of memory on deployment.

    So at minimum, we are looking at ~475MB of memory utilized right off the bat.

    We could load the Detectron2 model only when the ML process needs to run, but this would still mean that our application would eat up ~475MB eventually. If you have a tight budget and are unable to vertically scale your application, memory now becomes a substantial limitation on many hosting platforms. For example, Heroku offers containers to run applications, termed “dynos”, that started with 512MB RAM for base payment plans, will begin writing to disk beyond the 512MB threshold, and will crash and restart the dyno at 250% utilization (1280MB).

    On the topic of memory, Detectron2 inferencing will cause spikes in memory usage depending on the amount of objects detected in an image, so it is important to ensure memory is available during this process.

    For those of you trying to speed up inferencing, but are cautious of memory constraints, batch inferencing will be of no help here either. As noted by one of the contributors to the Detectron2 repo, with batch inferencing:

    N images use N times more memory than 1 image…You can predict on N images one by one in a loop instead.

    Overall, this summarizes problem #1:

    running a long ML processes as a part of an application will most likely be memory intensive, due to the size of the model, ML dependencies, and inferencing process.

    Problem 2: Time

    A deployed application that incorporates ML will likely need to be designed to manage a long-running process.

    Using the example of an application that uses Detectron2, the model would be sent an image as input and output inference coordinates. With one image, inference may only take a few seconds, but say for instance we are processing a long PDF document with one image per page (as per the training data in Part 1), this could take a while.

    During this process, Detectron2 inferencing would be either CPU or GPU bound, depending on your configurations. See the below Python code block to change this (CPU is entirely fine for inferencing, however, GPU/Cuda is necessary for training as mentioned in Part 1):

    from detectron2.config import get_cfg
    cfg = get_cfg()
    cfg.MODEL.DEVICE = "cpu" #or "cuda"

    Additionally, saving images after inferencing, say to AWS S3 for example, would introduce I/O bound processes. Altogether, this could serve to clog up the backend, which introduces problem #2:

    single-threaded Python applications will not process additional HTTP requests, concurrently or otherwise, while running a process.

    Problem 3: Scale

    When considering the horizontal scalability of a Python application, it is important to note that Python (assuming it is compiled/interpreted by CPython) suffers from the limitations of the Global Interpreter Lock (GIL), which allows only one thread to hold the control of the Python interpreter. Thus, the paradigm of multithreading doesn’t correctly apply to Python, as applications can still implement multithreading, using web servers such as Gunicorn, but will do so concurrently, meaning that the threads aren’t running in parallel.

    I know all of this sounds fairly abstract, perhaps especially for the Data Science folks, so let me provide an example to illustrate this problem.

    You are your application and right now your hardware, brain, is processing two requests, cleaning the counter and texting on your phone. With two arms to do this, you are now a multithreaded Python application, doing both simultaneously. But you’re not actually thinking about both at the same exact time, you start your hand in a cleaning motion, then switch your attention to your phone to look at what you are typing, then look back at the counter to make sure you didn’t miss a spot.

    In actuality, you are processing these tasks concurrently.

    The GIL functions in the same way, processing one thread at a time but switching between them for concurrency. This means that multithreading a Python application is still useful for running background or I/O bound-oriented tasks, such as downloading a file, while the main execution’s thread is still running. To take the analogy this far, your background task of cleaning the counter (i.e. downloading a file) continues to happen while you are thinking about texting, but you still need to change your focus back to your cleaning hand in order to process the next step.

    This “change in focus” may not seem like a big deal when concurrently processing multiple requests, but when you need to handle hundreds of requests simultaneously, suddenly this becomes a limiting factor for large scale applications that need to be adequately responsive to end users.

    Thus, we have problem #3:

    the GIL prevents multithreading from being a good scalability solution for Python applications.

    Solutions

    Now that we have identified key problems, let’s discuss a few potential solutions.

    The aforementioned problems are ordered in terms of importance, as we need to manage memory first and foremost (problem #1) to ensure the application doesn’t crash, then leave room for the app to process more than one request at a time (problem #2) while still ensuring our means of simultaneous request handling is effective at scale (problem #3).

    So, let’s jump right into addressing problem #1.

    Depending on the hosting platform, we will need to be fully aware of the configurations available in order to scale. As we will be using Heroku, feel free to check out the guidance on dyno scaling. Without having to vertically scale up your dyno, we can scale out by adding another process. For instance, with the Basic dyno type, a developer is able to deploy both a web process and a worker process on the same dyno. A few reasons this is useful:

    • This enables a means of multiprocessing.
    • The dyno resources are now duplicated, meaning each process has a 512MB RAM threshold.
    • Cost wise, we are looking at $7 per month per process (so $14 a month with both a web and worker process). Much cheaper than vertically scaling the dyno to get more RAM, with $50 a month per dyno if you want to increase the 512MB allocation to 1024MB.

    Hopping back to the previous analogy of cleaning the counter and texting on your phone, instead of threading yourself further by adding additional arms to your body, we can now have two people (multiprocessing in parallel) to perform the separate tasks. We are scaling out by increasing workload diversity as opposed to scaling up, in turn saving us money.

    Okay, but with two separate processes, what’s the difference?

    Using Django, our web process will be initialized with:

    python manage.py runserver

    And using a distributed task queue, such as Celery, the worker will be initialized with:

    celery -A <DJANGO_APP_NAME_HERE> worker

    As intended by Heroku, the web process is the server for our core web framework and the worker process is intended for queuing libraries, cron jobs, or other work performed in the background. Both represent an instance of the deployed application, so will be running at ~150MB given the core dependencies and runtime. However, we can ensure that the worker is the only process that runs the ML tasks, saving the web process from using ~325MB+ in RAM. This has multiple benefits:

    • Memory usage, although still high for the worker, will be distributed to a node outside of the system, ensuring any problems encountered during the execution of an ML task can be handled and monitored separately from the web process. This helps to mitigate problem #1.
    • The newly found means of parallelism ensures that the web process can still respond to requests during a long-running ML task, helping to address problem #2.
    • We are preparing for scale by implementing a means of multiprocessing, helping to address problem #3.

    As we haven’t quite solved the key problems, let’s dig in just a bit further before getting into the low-level nitty-gritty. As stated by Heroku:

    Web applications that process incoming HTTP requests concurrently make much more efficient use of dyno resources than web applications that only process one request at a time. Because of this, we recommend using web servers that support concurrent request processing whenever developing and running production services.

    The Django and Flask web frameworks feature convenient built-in web servers, but these blocking servers only process a single request at a time. If you deploy with one of these servers on Heroku, your dyno resources will be underutilized and your application will feel unresponsive.

    We are already ahead of the game by utilizing worker multiprocessing for the ML task, but can take this a step further by using Gunicorn:

    Gunicorn is a pure-Python HTTP server for WSGI applications. It allows you to run any Python application concurrently by running multiple Python processes within a single dyno. It provides a perfect balance of performance, flexibility, and configuration simplicity.

    Okay, awesome, now we can utilize even more processes, but there’s a catch: each new worker Gunicorn worker process will represent a copy of the application, meaning that they too will utilize the base ~150MB RAM in addition to the Heroku process. So, say we pip install gunicorn and now initialize the Heroku web process with the following command:

    gunicorn <DJANGO_APP_NAME_HERE>.wsgi:application --workers=2 --bind=0.0.0.0:$PORT

    The base ~150MB RAM in the web process turns into ~300MB RAM (base memory usage multipled by # gunicorn workers).

    While being cautious of the limitations to multithreading a Python application, we can add threads to workers as well using:

    gunicorn <DJANGO_APP_NAME_HERE>.wsgi:application --threads=2 --worker-class=gthread --bind=0.0.0.0:$PORT

    Even with problem #3, we can still find a use for threads, as we want to ensure our web process is capable of processing more than one request at a time while being careful of the application’s memory footprint. Here, our threads could process miniscule requests while ensuring the ML task is distributed elsewhere.

    Either way, by utilizing gunicorn workers, threads, or both, we are setting our Python application up to process more than one request at a time. We’ve more or less solved problem #2 by incorporating various ways to implement concurrency and/or parallel task handling while ensuring our application’s critical ML task doesn’t rely on potential pitfalls, such as multithreading, setting us up for scale and getting to the root of problem #3.

    Okay so what about that tricky problem #1. At the end of the day, ML processes will typically end up taxing the hardware in one way or another, whether that would be memory, CPU, and/or GPU. However, by using a distributed system, our ML task is integrally linked to the main web process yet handled in parallel via a Celery worker. We can track the start and end of the ML task via the chosen Celery broker, as well as review metrics in a more isolated manner. Here, curtailing Celery and Heroku worker process configurations are up to you, but it is an excellent starting point for integrating a long-running, memory-intensive ML process into your application.

    Low Level Design and Setup

    Now that we’ve had a chance to really dig in and get a high level picture of the system we are building, let’s put it together and focus on the specifics.

    For your convenience, here is the repo I will be mentioning in this section.

    First we will begin by setting up Django and Django Rest Framework, with installation guides here and here respectively. All requirements for this app can be found in the repo’s requirements.txt file (and Detectron2 and Torch will be built from Python wheels specified in the Dockerfile, in order to keep the Docker image size small).

    The next part will be setting up the Django app, configuring the backend to save to AWS S3, and exposing an endpoint using DRF, so if you are already comfortable doing this, feel free to skip ahead and go straight to the ML Task Setup and Deployment section.

    Django Setup

    Go ahead and create a folder for the Django project and cd into it. Activate the virtual/conda env you are using, ensure Detectron2 is installed as per the installation instructions in Part 1, and install the requirements as well.

    Issue the following command in a terminal:

    django-admin startproject mltutorial

    This will create a Django project root directory titled “mltutorial”. Go ahead and cd into it to find a manage.py file and a mltutorial sub directory (which is the actual Python package for your project).

    mltutorial/
    manage.py
    mltutorial/
    __init__.py
    settings.py
    urls.py
    asgi.py
    wsgi.py

    Open settings.py and add ‘rest_framework’, ‘celery’, and ‘storages’ (needed for boto3/AWS) in the INSTALLED_APPS list to register those packages with the Django project.

    In the root dir, let’s create an app which will house the core functionality of our backend. Issue another terminal command:

    python manage.py startapp docreader

    This will create an app in the root dir called docreader.

    Let’s also create a file in docreader titled mltask.py. In it, define a simple function for testing our setup that takes in a variable, file_path, and prints it out:

    def mltask(file_path):
    return print(file_path)

    Now getting to structure, Django apps use the Model View Controller (MVC) design pattern, defining the Model in models.py, View in views.py, and Controller in Django Templates and urls.py. Using Django Rest Framework, we will include serialization in this pipeline, which provide a way of serializing and deserializing native Python dative structures into representations such as json. Thus, the application logic for exposing an endpoint is as follows:

    Database ← → models.py ← → serializers.py ← → views.py ← → urls.py

    In docreader/models.py, write the following:

    from django.db import models
    from django.dispatch import receiver
    from .mltask import mltask
    from django.db.models.signals import(
    post_save
    )

    class Document(models.Model):
    title = models.CharField(max_length=200)
    file = models.FileField(blank=False, null=False)

    @receiver(post_save, sender=Document)
    def user_created_handler(sender, instance, *args, **kwargs):
    mltask(str(instance.file.file))

    This sets up a model Document that will require a title and file for each entry saved in the database. Once saved, the @receiver decorator listens for a post save signal, meaning that the specified model, Document, was saved in the database. Once saved, user_created_handler() takes the saved instance’s file field and passes it to, what will become, our Machine Learning function.

    Anytime changes are made to models.py, you will need to run the following two commands:

    python manage.py makemigrations
    python manage.py migrate

    Moving forward, create a serializers.py file in docreader, allowing for the serialization and deserialization of the Document’s title and file fields. Write in it:

    from rest_framework import serializers
    from .models import Document

    class DocumentSerializer(serializers.ModelSerializer):
    class Meta:
    model = Document
    fields = [
    'title',
    'file'
    ]

    Next in views.py, where we can define our CRUD operations, let’s define the ability to create, as well as list, Document entries using generic views (which essentially allows you to quickly write views using an abstraction of common view patterns):

    from django.shortcuts import render
    from rest_framework import generics
    from .models import Document
    from .serializers import DocumentSerializer

    class DocumentListCreateAPIView(
    generics.ListCreateAPIView):

    queryset = Document.objects.all()
    serializer_class = DocumentSerializer

    Finally, update urls.py in mltutorial:

    from django.contrib import admin
    from django.urls import path, include

    urlpatterns = [
    path("admin/", admin.site.urls),
    path('api/', include('docreader.urls')),
    ]

    And create urls.py in docreader app dir and write:

    from django.urls import path

    from . import views

    urlpatterns = [
    path('create/', views.DocumentListCreateAPIView.as_view(), name='document-list'),
    ]

    Now we are all setup to save a Document entry, with title and field fields, at the /api/create/ endpoint, which will call mltask() post save! So, let’s test this out.

    To help visualize testing, let’s register our Document model with the Django admin interface, so we can see when a new entry has been created.

    In docreader/admin.py write:

    from django.contrib import admin
    from .models import Document

    admin.site.register(Document)

    Create a user that can login to the Django admin interface using:

    python manage.py createsuperuser

    Now, let’s test the endpoint we exposed.

    To do this without a frontend, run the Django server and go to Postman. Send the following POST request with a PDF file attached:

    If we check our Django logs, we should see the file path printed out, as specified in the post save mltask() function call.

    AWS Setup

    You will notice that the PDF was saved to the project’s root dir. Let’s ensure any media is instead saved to AWS S3, getting our app ready for deployment.

    Go to the S3 console (and create an account and get our your account’s Access and Secret keys if you haven’t already). Create a new bucket, here we will be titling it ‘djangomltest’. Update the permissions to ensure the bucket is public for testing (and revert back, as needed, for production).

    Now, let’s configure Django to work with AWS.

    Add your model_final.pth, trained in Part 1, into the docreader dir. Create a .env file in the root dir and write the following:

    AWS_ACCESS_KEY_ID = <Add your Access Key Here>
    AWS_SECRET_ACCESS_KEY = <Add your Secret Key Here>
    AWS_STORAGE_BUCKET_NAME = 'djangomltest'

    MODEL_PATH = './docreader/model_final.pth'

    Update settings.py to include AWS configurations:

    import os
    from dotenv import load_dotenv, find_dotenv
    load_dotenv(find_dotenv())

    # AWS
    AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
    AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
    AWS_STORAGE_BUCKET_NAME = os.environ['AWS_STORAGE_BUCKET_NAME']

    #AWS Config
    AWS_DEFAULT_ACL = 'public-read'
    AWS_S3_CUSTOM_DOMAIN = f'{AWS_STORAGE_BUCKET_NAME}.s3.amazonaws.com'
    AWS_S3_OBJECT_PARAMETERS = {'CacheControl': 'max-age=86400'}

    #Boto3
    STATICFILES_STORAGE = 'mltutorial.storage_backends.StaticStorage'
    DEFAULT_FILE_STORAGE = 'mltutorial.storage_backends.PublicMediaStorage'

    #AWS URLs
    STATIC_URL = f'https://{AWS_S3_CUSTOM_DOMAIN}/static/'
    MEDIA_URL = f'https://{AWS_S3_CUSTOM_DOMAIN}/media/'

    Optionally, with AWS serving our static and media files, you will want to run the following command in order to serve static assets to the admin interface using S3:

    python manage.py collectstatic

    If we run the server again, our admin should appear the same as how it would with our static files served locally.

    Once again, let’s run the Django server and test the endpoint to make sure the file is now saved to S3.

    ML Task Setup and Deployment

    With Django and AWS properly configured, let’s set up our ML process in mltask.py. As the file is long, see the repo here for reference (with comments added in to help with understanding the various code blocks).

    What’s important to see is that Detectron2 is imported and the model is loaded only when the function is called. Here, we will call the function only through a Celery task, ensuring the memory used during inferencing will be isolated to the Heroku worker process.

    So finally, let’s setup Celery and then deploy to Heroku.

    In mltutorial/_init__.py write:

    from .celery import app as celery_app
    __all__ = ('celery_app',)

    Create celery.py in the mltutorial dir and write:

    import os

    from celery import Celery

    # Set the default Django settings module for the 'celery' program.
    os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'mltutorial.settings')

    # We will specify Broker_URL on Heroku
    app = Celery('mltutorial', broker=os.environ['CLOUDAMQP_URL'])

    # Using a string here means the worker doesn't have to serialize
    # the configuration object to child processes.
    # - namespace='CELERY' means all celery-related configuration keys
    # should have a `CELERY_` prefix.
    app.config_from_object('django.conf:settings', namespace='CELERY')

    # Load task modules from all registered Django apps.
    app.autodiscover_tasks()

    @app.task(bind=True, ignore_result=True)
    def debug_task(self):
    print(f'Request: {self.request!r}')

    Lastly, make a tasks.py in docreader and write:

    from celery import shared_task
    from .mltask import mltask

    @shared_task
    def ml_celery_task(file_path):
    mltask(file_path)
    return "DONE"

    This Celery task, ml_celery_task(), should now be imported into models.py and used with the post save signal instead of the mltask function pulled directly from mltask.py. Update the post_save signal block to the following:

    @receiver(post_save, sender=Document)
    def user_created_handler(sender, instance, *args, **kwargs):
    ml_celery_task.delay(str(instance.file.file))

    And to test Celery, let’s deploy!

    In the root project dir, include a Dockerfile and heroku.yml file, both specified in the repo. Most importantly, editing the heroku.yml commands will allow you to configure the gunicorn web process and the Celery worker process, which can aid in further mitigating potential problems.

    Make a Heroku account and create a new app called “mlapp” and gitignore the .env file. Then initialize git in the projects root dir and change the Heroku app’s stack to container (in order to deploy using Docker):

    $ heroku login
    $ git init
    $ heroku git:remote -a mlapp
    $ git add .
    $ git commit -m "initial heroku commit"
    $ heroku stack:set container
    $ git push heroku master

    Once pushed, we just need to add our env variables into the Heroku app.

    Go to settings in the online interface, scroll down to Config Vars, click Reveal Config Vars, and add each line listed in the .env file.

    You may have noticed there was a CLOUDAMQP_URL variable specified in celery.py. We need to provision a Celery Broker on Heroku, for which there are a variety of options. I will be using CloudAMQP which has a free tier. Go ahead and add this to your app. Once added, the CLOUDAMQP_URL environment variable will be included automatically in the Config Vars.

    Finally, let’s test the final product.

    To monitor requests, run:

    $ heroku logs --tail

    Issue another Postman POST request to the Heroku app’s url at the /api/create/ endpoint. You will see the POST request come through, Celery receive the task, load the model, and start running pages:

    We will continue to see the “Running for page…” until the end of the process and you can check the AWS S3 bucket as it runs.

    Congrats! You’ve now deployed and ran a Python backend using Machine Learning as a part of a distributed task queue running in parallel to the main web process!

    As mentioned, you will want to adjust the heroku.yml commands to incorporate gunicorn threads and/or worker processes and fine tune celery. For further learning, here’s a great article on configuring gunicorn to meet your app’s needs, one for digging into Celery for production, and another for exploring Celery worker pools, in order to help with properly managing your resources.

    Happy coding!

    Unless otherwise noted, all images used in this article are by the author


    Designing and Deploying a Machine Learning Python Application (Part 2) was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

    Originally appeared here:
    Designing and Deploying a Machine Learning Python Application (Part 2)

    Go Here to Read this Fast! Designing and Deploying a Machine Learning Python Application (Part 2)

  • Understanding Direct Preference Optimization

    Understanding Direct Preference Optimization

    Matthew Gunton

    A look at the “Direct Preference Optimization:
    Your Language Model is Secretly a Reward Model” paper and its findings

    Image by the Author via DALL-E

    This blog post was inspired by a discussion I recently had with some friends about the Direct Preference Optimization (DPO) paper. The discussion was lively and went over many important topics in LLMs and Machine Learning in general. Below is an expansion on some of those ideas and the concepts discussed in the paper.

    Direct Preference Optimization (DPO) has become the way that new foundation models are fine-tuned. Famously Mixtral 8x7B, the Sparse Mixture of Experts model created by Mistral, was able to reach LLaMa 70B levels of performance with significantly fewer parameters by using DPO. Naturally, this success has led many in the community to begin fine-tuning their own models with DPO.

    Let’s dive into what exactly DPO is and how we got here.

    High Level Discussion

    Let’s begin with setting out what fine-tuning should do from a high level. Once you have a pre-trained a model to have strong generative capacities, you typically want to control its output somehow. Whether that be optimizing it to respond in dialogue as a chat-bot or to respond in code rather than English, the goal here is to take an LLM that is already functional and find a way to be more selective with its output. As this is machine learning, the way we show it the right behavior is with data.

    There are some key terms here I’ll define before we start diving into the technicals:

    Loss Function — a function we use as a guide to optimize performance of our model. This is chosen based on what has been found to be effective

    KL Divergence— stands for Kullback–Leibler divergence, which is a way to measure the difference between two continuous probability distributions. To learn more about this, there is a wonderful post by Aparna Dhinakaran on the topic.

    Policy — an abstraction that describes how a neural network will make decisions. Put a different way, if a neural network is trained 3 times, each time it will have a different policy, whose performances you can compare.

    The Status Quo before DPO (PPO)

    Before DPO, we used to have to train an entirely separate model to help us fine-tune, typically called the reward model or RLHF model. We would sample completions from our LLM and then have the reward model give us a score for each completion. The idea here was simple. Humans are expensive to have evaluate your LLMs outputs but the quality of your LLM will ultimately be determined by humans. To keep costs down and quality high, you would train the reward model to approximate the human’s feedback. This is why the method was called Proximal Policy Optimization (or PPO), and it lives or dies based on the strength of your reward model.

    Figure 1 from the paper showing how PPO works

    The Math behind PPO

    To find the ideal reward model, we assume human preferences are more probabilistic than deterministic, so we can represent this symbolically in the Bradley-Terry model like below.

    Equation 1 from the paper

    Going variable by variable, p* means that this is the optimal probability distribution, or the one the model should treat as the source of truth. y₁ and y₂ are 2 completions from the model that we are going to compare, and x is the prompt given to LLM. r* means that the reward function is optimal, or put another way, to train the model to approximate the optimal probability distribution, you give it the rewards from the optimal reward function.

    Nevertheless, the perfect probability distribution of human preference is difficult, if not impossible, to know. For this reason, we focus on the reward model , so we need to find a way to figure out r*. In machine learning, we often use loss minimization to estimate complex issues. If we have access to training data that shows us what human preferences truly are, and thus would give scores that are part of the p* distribution, then we can use those samples to train the reward model like below:

    Equation 2 from the paper

    Here rϕ is the rewards model we are training, D is a set of the samples we are training on, yw is the preferred completion and yl is the dispreferred completion. The authors have chosen to frame the problem as a binary-classification problem, which we will see why later on, but for now just remember this is why we have yw and yl.

    Once we have optimized our reward model, we use it to fine-tune the LLM using a difference between the old policy (π ref) and the new policy (π θ). Importantly, we are doing a KL divergence to prevent the model from shifting too much.

    Why don’t we want it shifting too much? Remember the model is already mostly functional, and it has taken quite a lot of compute resources to reach this level. Consequently, we want to make sure the model retains many of the good traits it currently has while we focus on having it follow instructions better.

    Equation 3 from the paper

    While the above methodology is effective — LLaMa2 for instance was fine-tuned this way — it has a one major weakness: it requires training an entirely separate model, which is costly and requires huge amounts of additional data.

    How does DPO improve on this?

    DPO removes the need for the rewards model all together! This allows us to avoid training a costly separate reward model and incidentally, we have found that DPO requires a lot less data to work as well as PPO.

    Figure 1 from the paper showing a high level of how DPO works

    The Math behind DPO

    The major leap stems from the KL constraint we placed on ourselves in equation 3. By adding this constraint, we can actually derive the ideal policy that will maximize a KL-constrained rewards model. The algebra is shown below:

    Appendix A.1 from the paper showing how we can maximize a KL Divergence Bound Rewards Model

    For our purposes, the most important point to take away is that we now have the below equation for a policy π r, such that the reward function r is easily solved for.

    Equation 4 from the paper

    Naturally, we immediately solve for r

    Equation 5 from the paper

    Returning to our ideal probability distribution equation (equation 1), we can rewrite that so that each instance of r is replaced by equation 5.

    Equation 6 from the paper

    What this has shown is that you don’t need the reward model to optimize the policy to follow the ideal probability distribution of human preferences. Instead, you can directly work on the policy to improve it (hence where Direct Preference optimization gets its name from). We are using the probabilities that your LLM generates for each token to help it fine-tune itself.

    To finish the derivation, we do the same math as we did in equation 3 to come up with our loss optimizing function to optimize for the policy.

    Equation 7 from the paper

    That was a lot of algebra, but equation 7 is the most important one to understand, so I’ll break down the most important pieces. We now have an equation which will compare the policy probabilities of the old policy (π ref) and the new policy (π θ) for a winning completion (yw) and a losing completion (yl). When we compare these, we are optimizing so that that yw is bigger, as this would mean that the policies are getting better at giving winning responses than losing responses.

    Consequences

    First, DPO does not require a reward model! You simply need high quality data so that the model has a clear direction of what is good and bad, and it will improve.

    Second, DPO is dynamic. Every time you use new data, it is going to adapt immediately thanks to the way it figures out the right direction to go. Compared to PPO, where you have to retrain your reward model each time you have new data, this is a big win.

    Third, DPO allows you to train a model to avoid certain topics just as much as it will learn to give good answers for others. One way to conceptualize the new loss equation is as a signal that points our training in the right direction. By using both a good and bad example, we are teaching the model to avoid certain responses as much as we tell them to go towards others. As a large part of fine-tuning involves the model ignoring certain subjects, this feature is very valuable.

    Closing Thoughts

    Figure 2 from the paper showing comparative performance between DPO, PPO, and other methodologies

    Understanding the consequences of DPO’s math make me more optimistic about the future of LLMs.

    DPO requires less data and compute than PPO, both of which are major contributors to the cost of making your own model. With this cost reduction, more people will be able to fine-tune their own models, potentially giving society access to more specialized LLMs.

    Moreover, as DPO explicitly requires good and bad examples, while PPO only asks for good ones, it is much better at restricting behavior. This means that LLMs can be made far safer, another piece that will allow them to help out society.

    With forces like DPO giving us access to better quality LLMs that can be more easily trained, it is an incredibly exciting time for this field.

    [1] R. Rafailov, et al., Direct Preference Optimization: Your Language Model is Secretly a Reward Mode (2023), arXiv

    [2] A. Jiang, et al., Mixtral of Experts (2024), ArXiv


    Understanding Direct Preference Optimization was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

    Originally appeared here:
    Understanding Direct Preference Optimization

    Go Here to Read this Fast! Understanding Direct Preference Optimization

  • Spurious Correlations: The Comedy and Drama of Statistics

    Spurious Correlations: The Comedy and Drama of Statistics

    Celia Banks, Ph.D.

    What not to do with statistics

    By Celia Banks, PhD and Paul Boothroyd III

    By

    Introduction

    Since Tyler Vigen coined the term ‘spurious correlations’ for “any random correlations dredged up from silly data” (Vigen, 2014) see: Tyler Vigen’s personal website, there have been many articles that pay tribute to the perils and pitfalls of this whimsical tendency to manipulate statistics to make correlation equal causation. See: HBR (2015), Medium (2016), FiveThirtyEight (2016). As data scientists, we are tasked with providing statistical analyses that either accept or reject null hypotheses. We are taught to be ethical in how we source data, extract it, preprocess it, and make statistical assumptions about it. And this is no small matter — global companies rely on the validity and accuracy of our analyses. It is just as important that our work be reproducible. Yet, in spite of all of the ‘good’ that we are taught to practice, there may be that ​one occasion (or more) where a boss or client will insist that you work the data until it supports the hypothesis and, above all, show how variable y causes variable x when correlated. This is the basis of p-hacking where you enter into a territory that is far from supported by ‘good’ practice. In this report, we learn how to conduct fallacious research using spurious correlations. We get to delve into ‘bad’ with the objective of learning what not to do when you are faced with that inevitable moment to deliver what the boss or client whispers in your ear.

    The objective of this project is to teach you

    what not to do with statistics

    ​We’ll demonstrate the spurious correlation of two unrelated variables. Datasets from two different sources were preprocessed and merged together in order to produce visuals of relationships. Spurious correlations occur when two variables are misleadingly correlated, and it is further assumed that one variable directly affects the other variable so as to cause a certain outcome. ​The reason we chose this project idea is because we were interested in ways that manage a client’s expectations of what a data analysis project should produce. For team member Banks, sometimes she has had clients demonstrate displeasure with analysis results and actually on one occasion she was asked to go back and look at other data sources and opportunities to “help” arrive at the answers they were seeking. Yes, this is p-hacking — in this case, where the client insisted that causal relationships existed because they believe the correlations existed to cause an outcome.

    Examples of Spurious Correlations

    Excerpts of Tyler Vigen’s Spurious Correlations. Retrieved February 1, 2024, from Spurious Correlations (tylervigen.com) Reprinted with permission from the author.

    Research Questions Pertinent to this Study

    What are the research questions?

    Why the heck do we need them?

    We’re doing a “bad” analysis, right?

    Research questions are the foundation of the research study. They guide the research process by focusing on specific topics that the researcher will investigate. Reasons why they are essential include but are not limited to: for focus and clarity; as guidance for methodology; establish the relevance of the study; help to structure the report; help the researcher evaluate results and interpret findings. ​In learning how a ‘bad’ analysis is conducted, we addressed the following questions:

    (1) Are the data sources valid (not made up)?

    (2) How were missing values handled?

    (3) How were you able to merge dissimilar datasets?

    (4) What are the response and predictor variables?

    (5) Is the relationship between the response and predictor variables linear?

    (6) Is there a correlation between the response and predictor variables?

    (7) Can we say that there is a causal relationship between the variables?

    (8) What explanation would you provide a client interested in the relationship between these two variables?

    (9) Did you find spurious correlations in the chosen datasets?

    (10) What learning was your takeaway in conducting this project?

    Methodology

    How did we conduct a study about

    Spurious Correlations?​

    To investigate the presence of spurious correlations between variables, a comprehensive analysis was conducted. The datasets spanned different domains of economic and environmental factors that were collected and affirmed as being from public sources. The datasets contained variables with no apparent causal relationship but exhibited statistical correlation. The chosen datasets were of the Apple stock data, the primary, and daily high temperatures in New York City, the secondary. The datasets spanned the time period of January, 2017 through December, 2022.

    ​Rigorous statistical techniques were used to analyze the data. A Pearson correlation coefficients was calculated to quantify the strength and direction of linear relationships between pairs of the variables. To complete this analysis, scatter plots of the 5-year daily high temperatures in New York City, candlestick charting of the 5-year Apple stock trend, and a dual-axis charting of the daily high temperatures versus sock trend were utilized to visualize the relationship between variables and to identify patterns or trends. Areas this methodology followed were:

    The Data: Source/Extract/Process

    Primary dataset: Apple Stock Price History | Historical AAPL Company Stock Prices | FinancialContent Business Page

    Secondary dataset: New York City daily high temperatures from Jan 2017 to Dec 2022: https://www.extremeweatherwatch.com/cities/new-york/year-{year}

    The data was affirmed as publicly sourced and available for reproducibility. Capturing the data over a time period of five years gave a meaningful view of patterns, trends, and linearity. Temperature readings saw seasonal trends. For temperature and stock, there were troughs and peaks in data points. Note temperature was in Fahrenheit, a meteorological setting. We used astronomical setting to further manipulate our data to pose stronger spuriousness. While the data could be downloaded as csv or xls files, for this assignment, Python’s Beautiful soup web scraping API was used.

    Next, the data was checked for missing values and how many records each contained. Weather data contained date, daily high, daily low temperature, and Apple stock data contained date, opening price, closing price, volume, stock price, stock name. To merge the datasets, the date columns needed to be in datetime format. An inner join matched records and discarded non-matching. For Apple stock, date and daily closing price represented the columns of interest. For the weather, date and daily high temperature represented the columns of interest.

    The Data: Manipulation

    From Duarte® Slide Deck

    To do ‘bad’ the right way, you have to

    massage the data until you find the

    relationship that you’re looking for…​

    Our earlier approach did not quite yield the intended results. So, instead of using the summer season of 2018 temperatures in five U.S. cities, we pulled five years of daily high temperatures for New York City and Apple Stock performance from January, 2017 through December, 2022. In conducting exploratory analysis, we saw weak correlations across the seasons and years. So, our next step was to convert the temperature. Instead of meteorological, we chose astronomical. This gave us ​‘meaningful’ correlations across seasons.

    ​With the new approach in place, we noticed that merging the datasets was problematic. The date fields were different where for weather, the date was month and day. For stock, the date was in year-month-day format. We addressed this by converting each dataset’s date column to datetime. Also, each date column was sorted either in chronological or reverse chronological order. This was resolved by sorting both date columns in ascending order.

    Analysis I: Do We Have Spurious Correlation? Can We Prove It?

    The spurious nature of the correlations

    here is shown by shifting from

    meteorological seasons (Spring: Mar-May,

    Summer: Jun-Aug, Fall: Sep-Nov, Winter:

    Dec-Feb) which are based on weather

    patterns in the northern hemisphere, to

    astronomical seasons (Spring: Apr-Jun,

    Summer: Jul-Sep, Fall: Oct-Dec, Winter:

    Jan-Mar) which are based on Earth’s tilt.

    ​Once we accomplished the exploration, a key point in our analysis of spurious correlation was to determine if the variables of interest correlate. We eyeballed that Spring 2020 had a correlation of 0.81. We then determined if there was statistical significance — yes, and at p-value ≈ 0.000000000000001066818316115281, I’d say we have significance!

    Spring 2020 temperatures correlate with Apple stock

    Analysis II: Additional Statistics to Test the Nature of Spuriousness

    If there is truly spurious correlation, we may want to

    consider if the correlation equates to causation — that

    is, does a change in astronomical temperature cause

    Apple stock to fluctuate? We employed further

    statistical testing to prove or reject the hypothesis

    that one variable causes the other variable.

    There are numerous statistical tools that test for causality. Tools such as Instrumental Variable (IV) Analysis, Panel Data Analysis, Structural Equation Modelling (SEM), Vector Autoregression Models, Cointegration Analysis, and Granger Causality. IV analysis considers omitted variables in regression analysis; Panel Data studies fixed-effects and random effects models; SEM analyzes structural relationships; Vector Autoregression considers dynamic multivariate time series interactions; and Cointegration Analysis determines whether variables move together in a stochastic trend. We wanted a tool that could finely distinguish between genuine causality and coincidental association. To achieve this, our choice was Granger Causality.

    Granger Causality

    A Granger test checks whether past values can predict future ones. In our case, we tested whether past daily high temperatures in New York City could predict future values of Apple stock prices.

    Ho: Daily high temperatures in New York City do not Granger cause Apple stock price fluctuation.

    ​To conduct the test, we ran through 100 lags to see if there was a standout p-value. We encountered near 1.0 p-values, and this suggested that we could not reject the null hypothesis, and we concluded that there was no evidence of a causal relationship between the variables of interest.

    Granger Causality Test at lags=100

    Analysis III: Statistics to Validate Not Rejecting the Null Ho

    Granger causality proved the p-value

    insignificant in rejecting the null

    hypothesis. But, is that enough?

    Let’s validate our analysis.

    To help in mitigating the risk of misinterpreting spuriousness as genuine causal effects, performing a Cross-Correlation analysis in conjunction with a Granger causality test will confirm its finding. Using this approach, if spurious correlation exists, we will observe significance in cross-correlation at some lags without consistent causal direction or without Granger causality being present.

    Cross-Correlation Analysis

    This method is accomplished by the following steps:

    • Examine temporal patterns of correlations between variables;
    • •If variable A Granger causes variable B, significant cross-correlation will occur between variable A and variable B at positive lags;
    • Significant peaks in cross-correlation at specific lags infers the time delay between changes in the causal variable.

    Interpretation:

    The ccf and lag values show significance in positive correlation at certain lags. This confirms that spurious correlation exists. However, like the Granger causality, the cross-correlation analysis cannot support the claim that causality exists in the relationship between the two variables.

    Wrapup: Key Learnings

    • Spurious correlations are a form of p-hacking. Correlation does not imply causation.
    • Even with ‘bad’ data tactics, statistical testing will root out the lack of significance. While there was statistical evidence of spuriousness in the variables, causality testing could not support the claim that causality existed in the relationship of the variables.
    • A study cannot rest on the sole premise that variables displaying linearity can be correlated to exhibit causality. Instead, other factors that contribute to each variable must be considered.
    • A non-statistical test of whether daily high temperatures in New York City cause Apple stock to fluctuate can be to just consider: If you owned an Apple stock certificate and you placed it in the freezer, would the value of the certificate be impacted by the cold? Similarly, if you placed the certificate outside on a sunny, hot day, would the sun impact the value of the certificate?

    Ethical Considerations: P-Hacking is Not a Valid Analysis

    https://www.freepik.com/free-vector/business-people-saying-no-concept-illustration_38687005.htm#query=refuse%20work&position=20&from_view=keyword&track=ais&uuid=e5cd742b-f902-40f7-b7c4-812b147fe1df Image by storyset on Freepik

    Spurious correlations are not causality.

    P-hacking may impact your credibility as a

    data scientist. Be the adult in the room and

    refuse to participate in bad statistics.

    This study portrayed analysis that involved ‘bad’ statistics. It demonstrated how a data scientist could source, extract and manipulate data in such a way as to statistically show correlation. In the end, statistical testing withstood the challenge and demonstrated that correlation does not equal causality.

    ​Conducting a spurious correlation brings ethical questions of using statistics to derive causation in two unrelated variables. It is an example of p-hacking, which exploits statistics in order to achieve a desired outcome. This study was done as academic research to show the absurdity in misusing statistics.

    ​Another area of ethical consideration is the practice of web scraping. Many website owners warn against pulling data from their sites to use in nefarious ways or ways unintended by them. For this reason, sites like Yahoo Finance make stock data downloadable to csv files. This is also true for most weather sites where you can request time datasets of temperature readings. Again, this study is for academic research and to demonstrate one’s ability to extract data in a nonconventional way.

    ​When faced with a boss or client that compels you to p-hack and offer something like a spurious correlation as proof of causality, explain the implications of their ask and respectfully refuse the project. Whatever your decision, it will have a lasting impact on your credibility as a data scientist.

    Dr. Banks is CEO of I-Meta, maker of the patented Spice Chip Technology that provides Big Data analytics for various industries. Mr. Boothroyd, III is a retired Military Analyst. Both are veterans having honorably served in the United States military and both enjoy discussing spurious correlations. They are cohorts of the University of Michigan, School of Information MADS program…Go Blue!

    References

    Aschwanden, Christie. January 2016. You Can’t Trust What You Read About Nutrition. FiveThirtyEight. Retrieved January 24, 2024 from https://fivethirtyeight.com/features/you-cant-trust-what-you-read-about-nutrition/

    Business Management: From the Magazine. June 2015. Beware Spurious Correlations. Harvard Business Review. Retrieved January 24, 2024 from https://hbr.org/2015/06/beware-spurious-correlations

    Extreme Weather Watch. 2017–2023. Retrieved January 24, 2024 from https://www.extremeweatherwatch.com/cities/new-york/year-2017

    Financial Content Services, Inc. Apple Stock Price History | Historical AAPL Company Stock Prices | Financial Content Business Page. Retrieved January 24, 2024 from

    https://markets.financialcontent.com/stocks/quote/historical?Symbol=537%3A908440&Year=2019&Month=1&Range=12

    Plotlygraphs.July 2016. Spurious-Correlations. Medium. Retrieved January 24, 2024 from https://plotlygraphs.medium.com/spurious-correlations-56752fcffb69

    Vigen, Tyler. Spurious Correlations. Retrieved February 1, 2024 from https://www.tylervigen.com/spurious-correlations

    Mr. Vigen’s graphs were reprinted with permission from the author received on January 31, 2024.

    Images were licensed from their respective owners.

    Code Section

    ##########################
    # IMPORT LIBRARIES SECTION
    ##########################
    # Import web scraping tool
    import requests
    from bs4 import BeautifulSoup
    import pandas as pd
    import numpy as np

    # Import visualization appropriate libraries
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    import seaborn as sns # New York temperature plotting
    import plotly.graph_objects as go # Apple stock charting
    from pandas.plotting import scatter_matrix # scatterplot matrix

    # Import appropriate libraries for New York temperature plotting
    import seaborn as sns
    import matplotlib.pyplot as plt
    from datetime import datetime, timedelta
    import re

    # Convert day to datetime library
    import calendar

    # Cross-correlation analysis library
    from statsmodels.tsa.stattools import ccf

    # Stats library
    import scipy.stats as stats

    # Granger causality library
    from statsmodels.tsa.stattools import grangercausalitytests
    ##################################################################################
    # EXAMINE THE NEW YORK CITY WEATHER AND APPLE STOCK DATA IN READYING FOR MERGE ...
    ##################################################################################

    # Extract New York City weather data for the years 2017 to 2022 for all 12 months
    # 5-YEAR NEW YORK CITY TEMPERATURE DATA

    # Function to convert 'Day' column to a consistent date format for merging
    def convert_nyc_date(day, month_name, year):
    month_num = datetime.strptime(month_name, '%B').month

    # Extract numeric day using regular expression
    day_match = re.search(r'd+', day)
    day_value = int(day_match.group()) if day_match else 1

    date_str = f"{month_num:02d}-{day_value:02d}-{year}"

    try:
    return pd.to_datetime(date_str, format='%m-%d-%Y')
    except ValueError:
    return pd.to_datetime(date_str, errors='coerce')

    # Set variables
    years = range(2017, 2023)
    all_data = [] # Initialize an empty list to store data for all years

    # Enter for loop
    for year in years:
    url = f'https://www.extremeweatherwatch.com/cities/new-york/year-{year}'
    response = requests.get(url)
    soup = BeautifulSoup(response.text, 'html.parser')

    div_container = soup.find('div', {'class': 'page city-year-page'})

    if div_container:
    select_month = div_container.find('select', {'class': 'form-control url-selector'})

    if select_month:
    monthly_data = []
    for option in select_month.find_all('option'):
    month_name = option.text.strip().lower()

    h5_tag = soup.find('a', {'name': option['value'][1:]}).find_next('h5', {'class': 'mt-4'})

    if h5_tag:
    responsive_div = h5_tag.find_next('div', {'class': 'responsive'})
    table = responsive_div.find('table', {'class': 'bordered-table daily-table'})

    if table:
    data = []
    for row in table.find_all('tr')[1:]:
    cols = row.find_all('td')
    day = cols[0].text.strip()
    high_temp = float(cols[1].text.strip())
    data.append([convert_nyc_date(day, month_name, year), high_temp])

    monthly_df = pd.DataFrame(data, columns=['Date', 'High (°F)'])
    monthly_data.append(monthly_df)
    else:
    print(f"Table not found for {month_name.capitalize()} {year}")
    else:
    print(f"h5 tag not found for {month_name.capitalize()} {year}")

    # Concatenate monthly data to form the complete dataframe for the year
    yearly_nyc_df = pd.concat(monthly_data, ignore_index=True)

    # Extract month name from the 'Date' column
    yearly_nyc_df['Month'] = yearly_nyc_df['Date'].dt.strftime('%B')

    # Capitalize the month names
    yearly_nyc_df['Month'] = yearly_nyc_df['Month'].str.capitalize()

    all_data.append(yearly_nyc_df)


    ######################################################################################################
    # Generate a time series plot of the 5-year New York City daily high temperatures
    ######################################################################################################

    # Concatenate the data for all years
    if all_data:
    combined_df = pd.concat(all_data, ignore_index=True)

    # Create a line plot for each year
    plt.figure(figsize=(12, 6))
    sns.lineplot(data=combined_df, x='Date', y='High (°F)', hue=combined_df['Date'].dt.year)
    plt.title('New York City Daily High Temperature Time Series (2017-2022) - 5-Year Trend', fontsize=18)
    plt.xlabel('Date', fontsize=16) # Set x-axis label
    plt.ylabel('High Temperature (°F)', fontsize=16) # Set y-axis label
    plt.legend(title='Year', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=14) # Display legend outside the plot
    plt.tick_params(axis='both', which='major', labelsize=14) # Set font size for both axes' ticks
    plt.show()
    # APPLE STOCK CODE

    # Set variables
    years = range(2017, 2023)
    data = [] # Initialize an empty list to store data for all years

    # Extract Apple's historical data for the years 2017 to 2022
    for year in years:
    url = f'https://markets.financialcontent.com/stocks/quote/historical?Symbol=537%3A908440&Year={year}&Month=12&Range=12'
    response = requests.get(url)
    soup = BeautifulSoup(response.text, 'html.parser')
    table = soup.find('table', {'class': 'quote_detailed_price_table'})

    if table:
    for row in table.find_all('tr')[1:]:
    cols = row.find_all('td')
    date = cols[0].text

    # Check if the year is within the desired range
    if str(year) in date:
    open_price = cols[1].text
    high = cols[2].text
    low = cols[3].text
    close = cols[4].text
    volume = cols[5].text
    change_percent = cols[6].text
    data.append([date, open_price, high, low, close, volume, change_percent])

    # Create a DataFrame from the extracted data
    apple_df = pd.DataFrame(data, columns=['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Change(%)'])

    # Verify that DataFrame contains 5-years
    # apple_df.head(50)

    #################################################################
    # Generate a Candlestick charting of the 5-year stock performance
    #################################################################

    new_apple_df = apple_df.copy()

    # Convert Apple 'Date' column to a consistent date format
    new_apple_df['Date'] = pd.to_datetime(new_apple_df['Date'], format='%b %d, %Y')

    # Sort the datasets by 'Date' in ascending order
    new_apple_df = new_apple_df.sort_values('Date')

    # Convert numerical columns to float, handling empty strings
    numeric_cols = ['Open', 'High', 'Low', 'Close', 'Volume', 'Change(%)']
    for col in numeric_cols:
    new_apple_df[col] = pd.to_numeric(new_apple_df[col], errors='coerce')

    # Create a candlestick chart
    fig = go.Figure(data=[go.Candlestick(x=new_apple_df['Date'],
    open=new_apple_df['Open'],
    high=new_apple_df['High'],
    low=new_apple_df['Low'],
    close=new_apple_df['Close'])])

    # Set the layout
    fig.update_layout(title='Apple Stock Candlestick Chart',
    xaxis_title='Date',
    yaxis_title='Stock Price',
    xaxis_rangeslider_visible=False,
    font=dict(
    family="Arial",
    size=16,
    color="Black"
    ),
    title_font=dict(
    family="Arial",
    size=20,
    color="Black"
    ),
    xaxis=dict(
    title=dict(
    text="Date",
    font=dict(
    family="Arial",
    size=18,
    color="Black"
    )
    ),
    tickfont=dict(
    family="Arial",
    size=16,
    color="Black"
    )
    ),
    yaxis=dict(
    title=dict(
    text="Stock Price",
    font=dict(
    family="Arial",
    size=18,
    color="Black"
    )
    ),
    tickfont=dict(
    family="Arial",
    size=16,
    color="Black"
    )
    )
    )

    # Show the chart
    fig.show()
    ##########################################
    # MERGE THE NEW_NYC_DF WITH NEW_APPLE_DF
    ##########################################
    # Convert the 'Day' column in New York City combined_df to a consistent date format ...

    new_nyc_df = combined_df.copy()

    # Add missing weekends to NYC temperature data
    start_date = new_nyc_df['Date'].min()
    end_date = new_nyc_df['Date'].max()
    weekend_dates = pd.date_range(start_date, end_date, freq='B') # B: business day frequency (excludes weekends)
    missing_weekends = weekend_dates[~weekend_dates.isin(new_nyc_df['Date'])]
    missing_data = pd.DataFrame({'Date': missing_weekends, 'High (°F)': None})
    new_nyc_df = pd.concat([new_nyc_df, missing_data]).sort_values('Date').reset_index(drop=True) # Resetting index
    new_apple_df = apple_df.copy()

    # Convert Apple 'Date' column to a consistent date format
    new_apple_df['Date'] = pd.to_datetime(new_apple_df['Date'], format='%b %d, %Y')

    # Sort the datasets by 'Date' in ascending order
    new_nyc_df = combined_df.sort_values('Date')
    new_apple_df = new_apple_df.sort_values('Date')

    # Merge the datasets on the 'Date' column
    merged_df = pd.merge(new_apple_df, new_nyc_df, on='Date', how='inner')

    # Verify the correct merge -- should merge only NYC temp records that match with Apple stock records by Date
    merged_df
    # Ensure the columns of interest are numeric 
    merged_df['High (°F)'] = pd.to_numeric(merged_df['High (°F)'], errors='coerce')
    merged_df['Close'] = pd.to_numeric(merged_df['Close'], errors='coerce')

    # UPDATED CODE BY PAUL USES ASTRONOMICAL TEMPERATURES

    # CORRELATION HEATMAP OF YEAR-OVER-YEAR
    # DAILY HIGH NYC TEMPERATURES VS.
    # APPLE STOCK 2017-2023

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns

    # Convert 'Date' to datetime
    merged_df['Date'] = pd.to_datetime(merged_df['Date'])

    # Define a function to map months to seasons
    def map_season(month):
    if month in [4, 5, 6]:
    return 'Spring'
    elif month in [7, 8, 9]:
    return 'Summer'
    elif month in [10, 11, 12]:
    return 'Fall'
    else:
    return 'Winter'

    # Extract month from the Date column and map it to seasons
    merged_df['Season'] = merged_df['Date'].dt.month.map(map_season)

    # Extract the years present in the data
    years = merged_df['Date'].dt.year.unique()

    # Create subplots for each combination of year and season
    seasons = ['Spring', 'Summer', 'Fall', 'Winter']

    # Convert 'Close' column to numeric
    merged_df['Close'] = pd.to_numeric(merged_df['Close'], errors='coerce')

    # Create an empty DataFrame to store correlation matrix
    corr_matrix = pd.DataFrame(index=years, columns=seasons)

    # Calculate correlation matrix for each combination of year and season
    for year in years:
    year_data = merged_df[merged_df['Date'].dt.year == year]
    for season in seasons:
    data = year_data[year_data['Season'] == season]
    corr = data['High (°F)'].corr(data['Close'])
    corr_matrix.loc[year, season] = corr

    # Plot correlation matrix
    plt.figure(figsize=(10, 6))
    sns.heatmap(corr_matrix.astype(float), annot=True, cmap='coolwarm', fmt=".2f")
    plt.title('Temperature-Stock Correlation', fontsize=18) # Set main title font size
    plt.xlabel('Season', fontsize=16) # Set x-axis label font size
    plt.ylabel('Year', fontsize=16) # Set y-axis label font size
    plt.tick_params(axis='both', which='major', labelsize=14) # Set annotation font size
    plt.tight_layout()
    plt.show()
    #######################
    # STAT ANALYSIS SECTION
    #######################
    #############################################################
    # GRANGER CAUSALITY TEST
    # test whether past values of temperature (or stock prices)
    # can predict future values of stock prices (or temperature).
    # perform the Granger causality test between 'High (°F)' and
    # 'Close' columns in merged_df up to a maximum lag of 255
    #############################################################

    # Perform Granger causality test
    max_lag = 1 # Choose the maximum lag of 100 - Jupyter times out at higher lags
    test_results = grangercausalitytests(merged_df[['High (°F)', 'Close']], max_lag)

    # Interpretation:

    # looks like none of the lag give a significant p-value
    # at alpha .05, we cannot reject the null hypothesis, that is,
    # we cannot conclude that Granger causality exists between daily high
    # temperatures in NYC and Apple stock

    #################################################################
    # CROSS-CORRELATION ANALYSIS
    # calculate the cross-correlation between 'High (°F)' and 'Close'
    # columns in merged_df, and ccf_values will contain the
    # cross-correlation coefficients, while lag_values will
    # contain the corresponding lag values
    #################################################################

    # Calculate cross-correlation
    ccf_values = ccf(merged_df['High (°F)'], merged_df['Close'])
    lag_values = np.arange(-len(merged_df)+1, len(merged_df))

    ccf_values, lag_values

    # Interpretation:
    # Looks like there is strong positive correlation in the variables
    # in latter years and positive correlation in their respective
    # lags. This confirms what our plotting shows us

    ########################################################
    # LOOK AT THE BEST CORRELATION COEFFICIENT - 2020? LET'S
    # EXPLORE FURTHER AND CALCULATE THE p-VALUE AND
    # CONFIDENCE INTERVAL
    ########################################################

    # Get dataframes for specific periods of spurious correlation

    merged_df['year'] = merged_df['Date'].dt.year
    best_season_data = merged_df.loc[(merged_df['year'] == 2020) & (merged_df['Season'] == 'Spring')]

    # Calculate correlation coefficient and p-value
    corr_coeff, p_value = stats.pearsonr(best_season_data['High (°F)'], best_season_data['Close'])
    corr_coeff, p_value

    # Perform bootstrapping to obtain confidence interval
    def bootstrap_corr(data, n_bootstrap=1000):
    corr_values = []
    for _ in range(n_bootstrap):
    sample = data.sample(n=len(data), replace=True)
    corr_coeff, _ = stats.pearsonr(sample['High (°F)'], sample['Close'])
    corr_values.append(corr_coeff)
    return np.percentile(corr_values, [2.5, 97.5]) # 95% confidence interval

    confidence_interval = bootstrap_corr(best_season_data)
    confidence_interval
    #####################################################################
    # VISUALIZE RELATIONSHIP BETWEEN APPLE STOCK AND NYC DAILY HIGH TEMPS
    #####################################################################

    # Dual y-axis plotting using twinx() function from matplotlib
    date = merged_df['Date']
    temperature = merged_df['High (°F)']
    stock_close = merged_df['Close']

    # Create a figure and axis
    fig, ax1 = plt.subplots(figsize=(10, 6))

    # Plotting temperature on the left y-axis (ax1)
    color = 'tab:red'
    ax1.set_xlabel('Date', fontsize=16)
    ax1.set_ylabel('Temperature (°F)', color=color, fontsize=16)
    ax1.plot(date, temperature, color=color)
    ax1.tick_params(axis='y', labelcolor=color)

    # Create a secondary y-axis for the stock close prices
    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Stock Close Price', color=color, fontsize=16)
    ax2.plot(date, stock_close, color=color)
    ax2.tick_params(axis='y', labelcolor=color)

    # Title and show the plot
    plt.title('Apple Stock correlates with New York City Temperature', fontsize=18)
    plt.show()


    Spurious Correlations: The Comedy and Drama of Statistics was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

    Originally appeared here:
    Spurious Correlations: The Comedy and Drama of Statistics

    Go Here to Read this Fast! Spurious Correlations: The Comedy and Drama of Statistics

  • Editing Text in Images with AI

    Editing Text in Images with AI

    Julia Turc

    Research Review for Scene Text Editing: STEFANN, SRNet, TextDiffuser, AnyText and more.

    If you ever tried to change the text in an image, you know it’s not trivial. Preserving the background, textures, and shadows takes a Photoshop license and hard-earned designer skills. In the video below, a Photoshop expert takes 13 minutes to fix a few misspelled characters in a poster that is not even stylistically complex. The good news is — in our relentless pursuit of AGI, humanity is also building AI models that are actually useful in real life. Like the ones that allow us to edit text in images with minimal effort.

    The task of automatically updating the text in an image is formally known as Scene Text Editing (STE). This article describes how STE model architectures have evolved over time and the capabilities they have unlocked. We will also talk about their limitations and the work that remains to be done. Prior familiarity with GANs and Diffusion models will be helpful, but not strictly necessary.

    Disclaimer: I am the cofounder of Storia AI, building an AI copilot for visual editing. This literature review was done as part of developing Textify, a feature that allows users to seamlessly change text in images. While Textify is closed-source, we open-sourced a related library, Detextify, which automatically removes text from a corpus of images.

    Example of Scene Text Editing (STE). The original image (left) was generated via Midjourney. We used Textify to annotate the image (center) and automatically fix the misspelling (right).

    The Task of Scene Text Editing (STE)

    Definition

    Scene Text Editing (STE) is the task of automatically modifying text in images that capture a visual scene (as opposed to images that mainly contain text, such as scanned documents). The goal is to change the text while preserving the original aesthetics (typography, calligraphy, background etc.) without the inevitably expensive human labor.

    Use Cases

    Scene Text Editing might seem like a contrived task, but it actually has multiple practical uses cases:

    (1) Synthetic data generation for Scene Text Recognition (STR)

    Synthetic image (right) obtained by editing text in the original image (left, from Unsplash). This technique can be used to augment the training set of STR (Scene Text Recognition) models.

    When I started researching this task, I was surprised to discover that Alibaba (an e-commerce platform) and Baidu (a search engine) are consistently publishing research on STE.

    At least in Alibaba’s case, it is likely their research is in support of AMAP, their alternative to Google Maps [source]. In order to map the world, you need a robust text recognition system that can read traffic and street signs in a variety of fonts, under various real-world conditions like occlusions or geometric distortions, potentially in multiple languages.

    In order to build a training set for Scene Text Recognition, one could collect real-world data and have it annotated by humans. But this approach is bottlenecked by human labor, and might not guarantee enough data variety. Instead, synthetic data generation provides a virtually unlimited source of diverse data, with automatic labels.

    (2) Control over AI-generated images

    AI-generated image via Midjourney (left) and corrected via Scene Text Editing.

    AI image generators like Midjourney, Stability and Leonardo have democratized visual asset creation. Small business owners and social media marketers can now create images without the help of an artist or a designer by simply typing a text prompt. However, the text-to-image paradigm lacks the controllability needed for practical assets that go beyond concept art — event posters, advertisements, or social media posts.

    Such assets often need to include textual information (a date and time, contact details, or the name of the company). Spelling correctly has been historically difficult for text-to-image models, though there has been recent process — DeepFloyd IF, Midjourney v6. But even when these models do eventually learn to spell perfectly, the UX constraints of the text-to-image interface remain. It is tedious to describe in words where and how to place a piece of text.

    (3) Automatic localization of visual media

    Movies and games are often localized for various geographies. Sometimes this might entail switching a broccoli for a green pepper, but most times it requires translating the text that is visible on screen. With other aspects of the film and gaming industries getting automated (like dubbing and lip sync), there is no reason for visual text editing to remain manual.

    Timeline of Architectures: from GANs to Diffusion

    The training techniques and model architectures used for Scene Text Editing largely follow the trends of the larger task of image generation.

    The GAN Era (2019–2021)

    GANs (Generative Adversarial Networks) dominated the mid-2010s for image generation tasks. GAN refers to a particular training framework (rather than prescribing a model architecture) that is adversarial in nature. A generator model is trained to capture the data distribution (and thus has the capability to generate new data), while a discriminator is trained to distinguish the output of the generator from real data. The training process is finalized when the discriminator’s guess is as good as a random coin toss. During inference, the discriminator is discarded.

    GANs are particularly suited for image generation because they can perform unsupervised learning — that is, learn the data distribution without requiring labeled data. Following the general trend of image generation, the initial Scene Text Editing models also leveraged GANs.

    GAN Epoch #1: Character-Level Editing — STEFANN

    STEFANN, recognized as the first work to modify text in scene images, operates at a character level. The character editing problem is broken into two: font adaptation and color adaptation.

    The STEFANN model architecture (source). The character editing task is broken into two: FANnet (Font Adaptation Network) generates a black-and-white target character in the desired shape, and Colornet fills in the appropriate color.

    STEFANN is recognized as the first work to modify text in scene images. It builds on prior work in the space of font synthesis (the task of creating new fonts or text styles that closely resemble the ones observed in input data), and adds the constraint that the output needs to blend seamlessly back into the original image. Compared to previous work, STEFANN takes a pure machine learning approach (as opposed to e.g. explicit geometrical modeling) and does not depend on character recognition to label the source character.

    The STEFANN model architecture is based on CNNs (Convolutional Neural Networks) and decomposes the problem into (1) font adaptation via FANnet — turning a binarized version of the source character into a binarized target character, (2) color adaptation via Colornet — colorizing the output of FANnet to match the rest of the text in the image, and (3) character placement — blending the target character back into the original image using previously-established techniques like inpainting and seam carving. The first two modules are trained with a GAN objective.

    While STEFANN paved the way for Scene Text Editing, it has multiple limitations that restrict its use in practice. It can only operate on one character at a time; changing an entire word requires multiple calls (one per letter) and constrains the target word to have the same length as the source word. Also, the character placement algorithm in step (3) assumes that the characters are non-overlapping.

    GAN Epoch #2: Word-Level Editing — SRNet and 3-Module Networks

    SRNet was the first model to perform scene text editing at the word level. SRNet decomposed the STE task into three (jointly-trained) modules: text conversion, background inpainting and fusion.

    The SRNet model architecture. The three modules decompose the STE problem into smaller building blocks (text conversion, background inpainting and fusion), while being jointly trained. This architecture was largely adopted by follow-up work in the field.

    SRNet was the first model to perform scene text editing at the word level. SRNet decomposed the STE task into three (jointly-trained) modules:

    1. The text conversion module (in blue) takes a programatic rendering of the target text (“barbarous” in the figure above) and aims to render it in the same typeface as the input word (“introduce”) on a plain background.
    2. The background inpainting module (in green) erases the text from the input image and fills in the gaps to reconstruct the original background.
    3. The fusion module (in orange) pastes the rendered target text onto the background.

    SRNet architecture. All three modules are flavors of Fully Convolutional Networks (FCNs), with the background inpainting module in particular resembling U-Net (an FCN with the specific property that encoder layers are skip-connected to decoder layers of the same size).

    SRNet training. Each module has its own loss, and the network is jointly trained on the sum of losses (LT + LB + LF), where the latter two are trained via GAN. While this modularization is conceptually elegant, it comes with the drawback of requiring paired training data, with supervision for each intermediate step. Realistically, this can only be achieved with artificial data. For each data point, one chooses a random image (from a dataset like COCO), selects two arbitrary words from a dictionary, and renders them with an arbitrary typeface to simulate the “before” and “after” images. As a consequence, the training set doesn’t include any photorealistic examples (though it can somewhat generalize beyond rendered fonts).

    Honorable mentions. SwapText followed the same GAN-based 3-module network approach to Scene Text Editing and proposed improvements to the text conversion module.

    GAN Epoch #3: Self-supervised and Hybrid Networks

    Leap to unsupervised learning. The next leap in STE research was to adopt a self-supervised training approach, where models are trained on unpaired data (i.e., a mere repository of images containing text). To achieve this, one had to remove the label-dependent intermediate losses LT and LB. And due to the design of GANs, the remaining final loss does not require a label either; the model is simply trained on the discriminator’s ability to distinguish between real images and the ones produced by the generator. TextStyleBrush pioneered self-supervised training for STE, while RewriteNet and MOSTEL made the best of both worlds by training in two stages: one supervised (advantage: abundance of synthetic labeled data) and one self-supervised (advantage: realism of natural unlabeled data).

    Disentangling text content & style. To remove the intermediate losses, TextStyleBrush and RewriteNet reframe the problem into disentangling text content from text style. To reiterate, the inputs to an STE system are (a) an image with original text, and (b) the desired text — more specifically, a programatic rendering of the desired text on a white or gray background, with a fixed font like Arial. The goal is to combine the style from (a) with the content from (b). In other words, we complementarily aim to discard the content from (a) and the style of (b). This is why it’s necessary to disentangle the text content from the style in a given image.

    Inference architecture of RewriteNet. The encoder E disentangles text style (circle) from text content (triangle). The style embedding from the original image and content embedding from the text rendering are fed into a generator, which fuses the two into an output image.

    TextStyleBrush and why GANs went out of fashion. While the idea of disentangling text content from style is straightforward, achieving it in practice required complicated architectures. TextStyleBrush, the most prominent paper in this category, used no less than seven jointly-trained subnetworks, a pre-trained typeface classifier, a pre-trained OCR model and multiple losses. Designing such a system must have been expensive, since all of these components require ablation studies to determine their effect. This, coupled with the fact that GANs are notoriously difficult to train (in theory, the generator and discriminator need to reach Nash equilibrium), made STE researchers eager to switch to diffusion models once they proved so apt for image generation.

    The Diffusion Era (2022 — present)

    At the beginning of 2022, the image generation world shifted away from GANs towards Latent Diffusion Models (LDM). A comprehensive explanation of LDMs is out of scope here, but you can refer to The Illustrated Stable Diffusion for an excellent tutorial. Here I will focus on the parts of the LDM architecture that are most relevant to the Scene Text Editing task.

    Diffusion-based Scene Text Editing. In addition to the text embedding passed to the actual diffusion module in a standard text-to-image-model, STE architectures also create embeddings that reflect desired properties of the target text (position, shape, style etc.). Illustration by the author.

    As illustrated above, an LDM-based text-to-image model has three main components: (1) a text encoder — typically CLIP, (2) the actual diffusion module — which converts the text embedding into an image embedding in latent space, and (3) an image decoder — which upscales the latent image into a fully-sized image.

    Scene Text Editing as a Diffusion Inpainting Task

    Text-to-image is not the only paradigm supported by diffusion models. After all, CLIP is equally a text and image encoder, so the embedding passed to the image information creator module can also encode an image. In fact, it can encode any modality, or a concatenation of multiple inputs.

    This is the principle behind inpainting, the task of modifying only a subregion of an input image based on given instructions, in a way that looks coherent with the rest of the image. The image information creator ingests an encoding that captures the input image, the mask of the region to be inpainted, and a textual instruction.

    Scene Text Editing can be regarded as a specialized form of inpainting. Most of the STE research reduces to the following question: How can we augment the text embedding with additional information about the task (i.e., the original image, the desired text and its positioning, etc.)? Formally, this is known as conditional guidance.

    The research papers that fall into this bucket (TextDiffuser, TextDiffuser 2, GlyphDraw, AnyText, etc.) propose various forms of conditional guidance.

    Positional guidance

    Evidently, there needs to be a way of specifying where to make changes to the original image. This can be a text instruction (e.g. “Change the title at the bottom”), a granular indication of the text line, or more fine-grained positional information for each target character.

    Positional guidance via image masks. One way of indicating the desired text position is via grayscale mask images, which can then be encoded into latent space via CLIP or an alternative image encoder. For instance, the DiffUTE model simply uses a black image with a white strip indicating the desired text location.

    Input to the DiffUTE model. Positional guidance is achieved via the mask m and the masked input xm. These are deterministically rendered based on user input.

    TextDiffuser produces character-level segmentation masks: first, it roughly renders the desired text in the right position (black text in Arial font on a white image), then passes this rendering through a segmenter to obtain a grayscale image with individual bounding boxes for each character. The segmenter is a U-Net model trained separately from the main network on 4M of synthetic instances.

    Character-level segmentation mask used by TextDiffuser. The target word (“WORK”) is rendered with a standard font on a white background, then passed through a segmenter (U-Net) to obtain the grayscale mask.

    Positional guidance via language modeling. In A Unified Sequence Inference for Vision Tasks, the authors show that large language models (LLMs) can act as effective descriptors of object positions within an image by simply generating numerical tokens. Arguably, this was an unintuitive discovery. Since LLMs learn language based on statistical frequency (i.e., by observing how often tokens occur in the same context), it feels unrealistic to expect them to generate the right numerical tokens. But the massive scale of current LLMs often defies our expectations nowadays.

    TextDiffuser 2 leverage this discovery in an interesting way. They fine-tune an LLM on a synthetic corpus of <text, OCR detection> pairs, teaching it to generate the top-left and bottom-right coordinates of text bounding boxes, as show in the figure below. Notably, they decide to generate bounding boxes for text lines (as opposed to characters), giving the image generator more flexibility. They also run an interesting ablation study that uses a single point to encode text position (either top-left or center of the box), but observe poorer spelling performance — the model often hallucinates additional characters when not explicitly told where the text should end.

    Architecture of TextDiffuser 2. The language model M1 takes the target text from the user, then splits it into lines and predicts their positions as [x1] [y1] [x2] [y2] tokens. The language model M2 is a fine-tuned version of CLIP that encodes the modified prompt (which includes text lines and their positions) into latent space.

    Glyph guidance

    In addition to position, another piece of information that can be fed into the image generator is the shape of the characters. One could argue that shape information is redundant. After all, when we prompt a text-to-image model to generate a flamingo, we generally don’t need to pass any additional information about its long legs or the color of its feathers — the model has presumably learnt these details from the training data. However, in practice, the trainings sets (such as Stable Diffusion’s LAION-5B) are dominated by natural pictures, in which text is underrepresented (and non-Latin scripts even more so).

    Multiple studies (DiffUTE, GlyphControl, GlyphDraw, GlyphDiffusion, AnyText etc.) attempt to make up for this imbalance via explicit glyph guidance — effectively rendering the glyphs programmatically with a standard font, and then passing an encoding of the rendering to the image generator. Some simply place the glyphs in the center of the additional image, some close to the target positions (reminiscent of ControlNet).

    STE via Diffusion is (Still) Complicated

    While the training process for diffusion models is more stable than GANs, the diffusion architectures for STE in particular are still quite complicated. The figure below shows the AnyText architecture, which includes (1) an auxiliary latent module (including the positional and glyph guidance discussed above), (2) a text embedding module that, among other components, requires a pre-trained OCR module, and (3) the standard diffusion pipeline for image generation. It is hard to argue this is conceptually much simpler than the GAN-based TextStyleBrush.

    The (complex) architecture of AnyText.

    The Future of Scene Text Editing

    When the status quo is too complicated, we have a natural tendency to keep working on it until it converges to a clean solution. In a way, this is what happened to the natural language processing field: computational linguistics theories, grammars, dependency parsing — all collapsed under Transformers, which make a very simple statement: the meaning of a token depends on all others around it. Evidently, Scene Text Editing is miles away from this clarity. Architectures contain many jointly-trained subnetworks, pre-trained components, and require specific training data.

    Text-to-image models will inevitably become better at certain aspects of text generation (spelling, typeface diversity, and how crisp the characters look), with the right amount and quality of training data. But controllability will remain a problem for a much longer time. And even when models do eventually learn to follow your instructions to the t, the text-to-image paradigm might still be a subpar user experience — would you rather describe the position, look and feel of a piece of text in excruciating detail, or would you rather just draw an approximate box and choose an inspiration color from a color picker?

    Epilogue: Preventing Abuse

    Generative AI has brought to light many ethical questions, from authorship / copyright / licensing to authenticity and misinformation. While all these loom large in our common psyche and manifest in various abstract ways, the misuses of Scene Text Editing are down-to-earth and obvious — people faking documents.

    While building Textify, we’ve seen it all. Some people bump up their follower count in Instagram screenshots. Others increase their running speed in Strava screenshots. And yes, some attempt to fake IDs, credit cards and diplomas. The temporary solution is to build classifiers for certain types of documents and simply refuse to edit them, but, long-term the generative AI community needs to invest in automated ways of determining document authenticity, be it a text snippet, an image or a video.


    Editing Text in Images with AI was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

    Originally appeared here:
    Editing Text in Images with AI

    Go Here to Read this Fast! Editing Text in Images with AI