/****************************************************************************
 * mm/kasan/kasan.c
 *
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.  The
 * ASF licenses this file to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance with the
 * License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 ****************************************************************************/

/****************************************************************************
 * Included Files
 ****************************************************************************/

#include <nuttx/spinlock.h>

#include <assert.h>
#include <debug.h>
#include <stdbool.h>
#include <stdint.h>

#include "kasan.h"

/****************************************************************************
 * Pre-processor Definitions
 ****************************************************************************/

#define KASAN_BYTES_PER_WORD (sizeof(uintptr_t))
#define KASAN_BITS_PER_WORD  (KASAN_BYTES_PER_WORD * 8)

#define KASAN_FIRST_WORD_MASK(start) \
  (UINTPTR_MAX << ((start) & (KASAN_BITS_PER_WORD - 1)))
#define KASAN_LAST_WORD_MASK(end) \
  (UINTPTR_MAX >> (-(end) & (KASAN_BITS_PER_WORD - 1)))

#define KASAN_SHADOW_SCALE (sizeof(uintptr_t))

#define KASAN_SHADOW_SIZE(size) \
  (KASAN_BYTES_PER_WORD * ((size) / KASAN_SHADOW_SCALE / KASAN_BITS_PER_WORD))
#define KASAN_REGION_SIZE(size) \
  (sizeof(struct kasan_region_s) + KASAN_SHADOW_SIZE(size))

#ifdef CONFIG_MM_KASAN_GLOBAL

#  define KASAN_GLOBAL_SHADOW_SCALE (32)

#  define KASAN_GLOBAL_NEXT_REGION(region) \
  (FAR struct kasan_region_s *) \
  ((FAR char *)region->shadow + (size_t)region->next)

#endif

#define KASAN_INIT_VALUE            0xDEADCAFE

/****************************************************************************
 * Private Types
 ****************************************************************************/

struct kasan_region_s
{
  FAR struct kasan_region_s *next;
  uintptr_t begin;
  uintptr_t end;
  uintptr_t shadow[1];
};

/****************************************************************************
 * Private Data
 ****************************************************************************/

static spinlock_t g_lock;
static FAR struct kasan_region_s *g_region;
static uint32_t g_region_init;

/****************************************************************************
 * Public Data
 ****************************************************************************/

#ifdef CONFIG_MM_KASAN_GLOBAL
extern const unsigned char g_globals_region[];
#endif

/****************************************************************************
 * Private Functions
 ****************************************************************************/

static FAR uintptr_t *kasan_mem_to_shadow(FAR const void *ptr, size_t size,
                                          unsigned int *bit)
{
  FAR struct kasan_region_s *region;
  uintptr_t addr = (uintptr_t)ptr;

  if (g_region_init != KASAN_INIT_VALUE)
    {
      return NULL;
    }

  for (region = g_region; region != NULL; region = region->next)
    {
      if (addr >= region->begin && addr < region->end)
        {
          DEBUGASSERT(addr + size <= region->end);
          addr -= region->begin;
          addr /= KASAN_SHADOW_SCALE;
          *bit  = addr % KASAN_BITS_PER_WORD;
          return &region->shadow[addr / KASAN_BITS_PER_WORD];
        }
    }

#ifdef CONFIG_MM_KASAN_GLOBAL
  for (region = (FAR struct kasan_region_s *)g_globals_region;
       region->next;
       region = KASAN_GLOBAL_NEXT_REGION(region))
    {
      if (addr >= region->begin && addr < region->end)
        {
          DEBUGASSERT(addr + size <= region->end);
          addr -= region->begin;
          addr /= KASAN_GLOBAL_SHADOW_SCALE;
          *bit  = addr % KASAN_BITS_PER_WORD;
          return &region->shadow[addr / KASAN_BITS_PER_WORD];
        }
    }
#endif

  return NULL;
}

static void kasan_report(FAR const void *addr, size_t size,
                         bool is_write,
                         FAR void *return_address)
{
  static int recursion;

  if (++recursion == 1)
    {
      _alert("kasan detected a %s access error, address at %p,"
             "size is %zu, return address: %p\n",
             is_write ? "write" : "read",
             addr, size, return_address);
      PANIC();
    }

  --recursion;
}

static bool kasan_is_poisoned(FAR const void *addr, size_t size)
{
  FAR uintptr_t *p;
  unsigned int bit;

  p = kasan_mem_to_shadow(addr + size - 1, 1, &bit);
  return p && ((*p >> bit) & 1);
}

static void kasan_set_poison(FAR const void *addr, size_t size,
                             bool poisoned)
{
  FAR uintptr_t *p;
  unsigned int bit;
  unsigned int nbit;
  uintptr_t mask;
  int flags;

  flags = spin_lock_irqsave(&g_lock);

  p = kasan_mem_to_shadow(addr, size, &bit);
  DEBUGASSERT(p != NULL);

  nbit = KASAN_BITS_PER_WORD - bit % KASAN_BITS_PER_WORD;
  mask = KASAN_FIRST_WORD_MASK(bit);

  size /= KASAN_SHADOW_SCALE;
  while (size >= nbit)
    {
      if (poisoned)
        {
          *p++ |= mask;
        }
      else
        {
          *p++ &= ~mask;
        }

      bit  += nbit;
      size -= nbit;

      nbit = KASAN_BITS_PER_WORD;
      mask = UINTPTR_MAX;
    }

  if (size)
    {
      mask &= KASAN_LAST_WORD_MASK(bit + size);
      if (poisoned)
        {
          *p |= mask;
        }
      else
        {
          *p &= ~mask;
        }
    }

  spin_unlock_irqrestore(&g_lock, flags);
}

static inline void kasan_check_report(FAR const void *addr, size_t size,
                                      bool is_write,
                                      FAR void *return_address)
{
  if (kasan_is_poisoned(addr, size))
    {
      kasan_report(addr, size, false, return_address);
    }
}

/****************************************************************************
 * Public Functions
 ****************************************************************************/

/* Exported functions called from other mm module */

void kasan_poison(FAR const void *addr, size_t size)
{
  kasan_set_poison(addr, size, true);
}

void kasan_unpoison(FAR const void *addr, size_t size)
{
  kasan_set_poison(addr, size, false);
}

void kasan_register(FAR void *addr, FAR size_t *size)
{
  FAR struct kasan_region_s *region;
  int flags;

  region = (FAR struct kasan_region_s *)
    ((FAR char *)addr + *size - KASAN_REGION_SIZE(*size));

  region->begin = (uintptr_t)addr;
  region->end   = region->begin + *size;

  flags = spin_lock_irqsave(&g_lock);
  region->next  = g_region;
  g_region      = region;
  g_region_init = KASAN_INIT_VALUE;
  spin_unlock_irqrestore(&g_lock, flags);

  kasan_poison(addr, *size);
  *size -= KASAN_REGION_SIZE(*size);
}

/* Exported functions called from the compiler generated code */

void __sanitizer_annotate_contiguous_container(FAR const void *beg,
                                               FAR const void *end,
                                               FAR const void *old_mid,
                                               FAR const void *new_mid)
{
  /* Shut up compiler complaints */
}

void __asan_before_dynamic_init(FAR const void *module_name)
{
  /* Shut up compiler complaints */
}

void __asan_after_dynamic_init(void)
{
  /* Shut up compiler complaints */
}

void __asan_handle_no_return(void)
{
  /* Shut up compiler complaints */
}

void __asan_report_load_n_noabort(FAR void *addr, size_t size)
{
  kasan_report(addr, size, false, return_address(0));
}

void __asan_report_store_n_noabort(FAR void *addr, size_t size)
{
  kasan_report(addr, size, true, return_address(0));
}

void __asan_loadN_noabort(FAR void *addr, size_t size)
{
  kasan_check_report(addr, size, false, return_address(0));
}

void __asan_storeN_noabort(FAR void * addr, size_t size)
{
  kasan_check_report(addr, size, true, return_address(0));
}

void __asan_loadN(FAR void *addr, size_t size)
{
  kasan_check_report(addr, size, false, return_address(0));
}

void __asan_storeN(FAR void *addr, size_t size)
{
  kasan_check_report(addr, size, true, return_address(0));
}

#define DEFINE_ASAN_LOAD_STORE(size) \
  void __asan_report_load##size##_noabort(FAR void *addr) \
  { \
    kasan_report(addr, size, false, return_address(0)); \
  } \
  void __asan_report_store##size##_noabort(FAR void *addr) \
  { \
    kasan_report(addr, size, true, return_address(0)); \
  } \
  void __asan_load##size##_noabort(FAR void *addr) \
  { \
    kasan_check_report(addr, size, false, return_address(0)); \
  } \
  void __asan_store##size##_noabort(FAR void *addr) \
  { \
    kasan_check_report(addr, size, true, return_address(0)); \
  } \
  void __asan_load##size(FAR void *addr) \
  { \
    kasan_check_report(addr, size, false, return_address(0)); \
  } \
  void __asan_store##size(FAR void *addr) \
  { \
    kasan_check_report(addr, size, true, return_address(0)); \
  }

DEFINE_ASAN_LOAD_STORE(1)
DEFINE_ASAN_LOAD_STORE(2)
DEFINE_ASAN_LOAD_STORE(4)
DEFINE_ASAN_LOAD_STORE(8)
DEFINE_ASAN_LOAD_STORE(16)

#ifdef CONFIG_MM_KASAN_GLOBAL
void __asan_register_globals(void *ptr, ssize_t size)
{
  /* Shut up compiler complaints */
}

void __asan_unregister_globals(void *ptr, ssize_t size)
{
  /* Shut up compiler complaints */
}
#endif