Fast GeLU approximation

I needed a faster version of GeLU for my application. The following approximation with p=0.544790 works quite well:

0.5*x*(1.0+x/sqrt(p+x*x))

We can rewrite it in terms of y = 0.5 * x as:

y+y*y/sqrt(0.25*p+y*y)

which can be implemented efficiently with AVX2 using _mm256_rsqrt_ps (optionally refined with one Newton–Raphson step for better accuracy).

Regards,
GW

1 Like

Hi @gwiesenekker, You’re right. The rephrased approximation you mentioned, GeLU(x)≈y+y*y/sqrt(0.25*p+y*y) , is well-suited as it allows fast implementation with AVX2 using _mm256_rsqrt_ps. Thanks!

Hi,

The constant p has been determined by minimizing the MSE over the interval [-5, 5]. Please find below an implementation in C using FMA with a unit test:

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <smmintrin.h>
#include <immintrin.h>

#define USE_AVX2_INTRINSICS

static void activation_geluv2(int n,
                              float *restrict a,
                              float *restrict b)
{
#ifdef USE_AVX2_INTRINSICS
  if ((n < 8) || ((n % 8) != 0))
#endif
  {
    const float p = 0.544790f;

    for (int i = 0; i < n; i++)
      b[i] = 0.5f * a[i] * (1.0f + a[i] / sqrtf(p + a[i] * a[i]));
  }
#ifdef USE_AVX2_INTRINSICS
  else
  {
    const __m256 qp = _mm256_set1_ps(0.25f * 0.544790f);
    const __m256 half = _mm256_set1_ps(0.5f);
    const __m256 three_halfs = _mm256_set1_ps(1.5f);

    for (const float *z = a + n; a < z; a += 8, b += 8)
    {
      __m256 va = _mm256_loadu_ps(a);

      __m256 vy = _mm256_mul_ps(half, va);

      __m256 vy2 = _mm256_mul_ps(vy, vy);

      __m256 vy2qp = _mm256_fmadd_ps(vy, vy, qp);

      __m256 rsqrt = _mm256_rsqrt_ps(vy2qp);

      // needed for Newton-Raphson

      __m256 rsqrt2 = _mm256_mul_ps(rsqrt, rsqrt);

      rsqrt = _mm256_mul_ps(
        rsqrt,
        _mm256_fnmadd_ps(_mm256_mul_ps(half, vy2qp), rsqrt2, three_halfs));

      __m256 vg = _mm256_fmadd_ps(vy2, rsqrt, vy);

      _mm256_storeu_ps(b, vg);
    }
  }
#endif
}

if a and b are properly aligned you can use _mm256_load_ps and _mm256_store_ps.

static double gelu(double x)
{
  return (0.5 * x * (1.0 + erf(x / sqrt(2.0))));
}

#define SCALE 1000
#define XMAX  5
#define N     8

void main(void)
{
  float a[N];
  float b[N];
  double e[N];

  for (int i = 0; i < N; i++) a[i] = 0.0f;

  double ex = 0.0;

  for (int i = - XMAX * SCALE; i <= XMAX * SCALE; ++i)
  {
    float x = (float) i / (float) SCALE;

    a[0] = x;

    activation_geluv2(N, a, b);

    double g = gelu((double) x);
     
    e[0] = fabs((double) b[0] - g);

    if (e[0] > ex)
    {
      ex = e[0];
  
      for (int i = N - 1; i > 0; --i)
      {
        a[i] = a[i - 1];
        b[i] = b[i - 1];
        e[i] = e[i - 1];
      }
    }
  }
  for (int i = 1; i < N; i++)
  {
    printf("i=%d a[i]=%.6f gelu(a[i])=%.6f b[i]=%.6f e[i]=%.6f\n",
           i, a[i], gelu((double) a[i]), b[i], e[i]);
  }
  printf("ex=%.6f\n", ex);
}

As we have to call activation_geluv2 with (a multiple of 8) elements I keep track of the worst N - 1 absolute errors :

i=1 a[i]=-0.854000 gelu(a[i])=-0.167856 b[i]=-0.103940 e[i]=0.063916
i=2 a[i]=-0.855000 gelu(a[i])=-0.167816 b[i]=-0.103900 e[i]=0.063916
i=3 a[i]=-0.856000 gelu(a[i])=-0.167775 b[i]=-0.103860 e[i]=0.063915
i=4 a[i]=-0.857000 gelu(a[i])=-0.167734 b[i]=-0.103820 e[i]=0.063914
i=5 a[i]=-0.858000 gelu(a[i])=-0.167693 b[i]=-0.103780 e[i]=0.063913
i=6 a[i]=-0.859000 gelu(a[i])=-0.167651 b[i]=-0.103739 e[i]=0.063912
i=7 a[i]=-0.860000 gelu(a[i])=-0.167609 b[i]=-0.103699 e[i]=0.063910
ex=0.063916

Regards,
GW