技术总结|基于S3存储实现Serverless数据库

1、背景

最近在研究ChatGPT的相关的知识,其中一个应用场景是将一本书作为训练数据来回问题,使用到向量数据库。

如果一本书转化为向量库,其中是典型的一次写多次读的架构,就是书只需要训练一次就可以用向量数据库存储起来,后续可以多次复用。

如果要用一个单独的存储服务器,那后面维护很麻烦,于是想到了能否使用对象存储来存储这些数据?

考虑到架构扩容,后续除了一些向量数据库场景,还有一些关系型数据库场景,就分析了sqlite实现该场景的可能性。

为什么会选择sqlite呢?

首先sqlite本身是成熟的数据库,使用场景很多,兼容各个平台;

其次sqlite支持extension,可以结合向量算法,扩展为关系型向量数据库;

最后sqlite可以扩展为serverless的数据库场景。

为了整体介绍通过sqlite实现向量数据库存储,最后将ChatGPT训练的数据全部存到S3上,因此文章分为多个系列实现,其中本文主要介绍基于S3实现Sqlite的关系型数据库。

2、VFS

sqlite为了平台移植性,提供一种虚拟文件系统。

针对Unix和Windows上的文件访问彼此不同,VFS提供通用API来访问文件,而无需考虑其运行的操作系统类型。

该API包括打开、读取、写入和关闭文件的函数。

标准的SQLite源代码包含用于UNIX和Windows的内置VFS,也可以使用sqlite3_vfs_register()接口在启动时或运行时添加备用VFS。


技术总结|基于S3存储实现Serverless数据库


于是基于sqlite3_vfs_register,我们可以实现用s3+HTTP+sqliteVFS实现一次写多次读的serverless数据库。

3、架构设计

技术总结|基于S3存储实现Serverless数据库


(1)通过Client端集成SQLite的引擎和我们实现的VFS包装层;

(2)写入数据的场景,就是直接通过包装的文件功能直接上传到S3的对象存储中;

(3)读数据场景,根据VFS的read接口,计算出对象存储文件的偏移量;

# 通过自定义函数实现sqlite VFS的接口
self.io_methods = SqliteVFS.make_struct((
    ('i_version', c_int, 1),
    ('x_close', SqliteVFS.x_close_type,
        SqliteVFS.x_close_type(SqliteVFS.x_close)),
    ('x_read', SqliteVFS.x_read_type,
        SqliteVFS.x_read_type(self._x_read())),
    ('x_write', c_void_p, None),
    ('x_truncate', c_void_p, None),
    ('x_sync', c_void_p, None),
    ('x_file_size', SqliteVFS.x_file_size_type,
        SqliteVFS.x_file_size_type(self._x_file_size())),
    ('x_lock', SqliteVFS.x_lock_type,
        SqliteVFS.x_lock_type(SqliteVFS.x_lock)),
    ('x_unlock', SqliteVFS.x_unlock_type,
        SqliteVFS.x_unlock_type(SqliteVFS.x_unlock)),
    ('x_check_reserved_lock', c_void_p, None),
    ('x_file_control', SqliteVFS.x_file_control_type,
        SqliteVFS.x_file_control_type(SqliteVFS.x_file_control)),
    ('x_sector_size', c_void_p, None),))
self.file = SqliteVFS.make_struct(
    (('p_methods', POINTER(type(self.io_methods)), pointer(self.io_methods)),))
self.vfs = SqliteVFS.make_struct((
    ('i_version', c_int, 1),
    ('sz_os_file', c_int, sizeof(self.file)),
    ('mx_pathname', c_int, 1024),
    ('p_next', c_void_p, None),
    ('z_name', c_char_p, vfs_name),
    ('p_app_data', c_char_p, None),
    ('x_open', SqliteVFS.x_open_type,
        SqliteVFS.x_open_type(self._x_open())),
    ('x_delete', c_void_p, None),
    ('x_access', SqliteVFS.x_access_type,
        SqliteVFS.x_access_type(SqliteVFS.x_access)),
    ('x_full_pathname', SqliteVFS.x_full_pathname_type,
        SqliteVFS.x_full_pathname_type(SqliteVFS.x_full_pathname)),
    ('x_dl_open', c_void_p, None),
    ('x_dl_error', c_void_p, None),
    ('x_dl_sym', c_void_p, None),
    ('x_dl_close', c_void_p, None),
    ('x_randomness', c_void_p, None),
    ('x_sleep', c_void_p, None),
    ('x_current_time', SqliteVFS.x_current_time_type,
        SqliteVFS.x_current_time_type(SqliteVFS.x_current_time)),
    ('x_get_last_error', c_void_p, None),
))

(4)通过偏移量和HTTP接口封装,设置HTTP的range头,读到对应的块数据;

# make_auth_request调用的是x_read接口
with SqliteS3.make_auth_request(url, 'GET', (('versionId', version_id),), ((
        'range', f'bytes={i_ofst}-{i_ofst + i_amt - 1}'),),
        self.body_hash, s3_region, access_key_id, secret_access_key, session_token) as response:
    yield response
# x_read接口实现:
def _x_read(self):
    ''' 读取数据 '''
    def x_read(p_file, p_out, i_amt, i_ofst):
        offset = 0
        try:
            with self.make_auth_request(i_ofst, i_amt) as response:
                for chunk in response.iter_bytes():
                    memmove(p_out + offset, chunk,
                            min(i_amt - offset, len(chunk)))
                    offset += len(chunk)
                    if offset > i_amt:
                        break
        except Exception as e:
            _debug_log("e: " + str(e))
            return SqliteVFS.SQLITE_IOERR
        if offset != i_amt:
            return SqliteVFS.SQLITE_IOERR
        return SqliteVFS.SQLITE_OK
    return x_read

(5)最后就可以通过标准的sqlite接口打印行数据;

(6)这里使用的对象存储可以是S3,也可以是腾讯云COS,阿里云的OSS或者自己搭建的Minio存储等;

4、性能测试

性能测试步骤:

(1)写入10000条数据,存储在腾讯云的COS中,表结构:

CREATE TABLE my_table (
    id integer PRIMARY KEY autoincrement, 
    my_col_a varchar (50), 
    my_col_b varchar (50), 
    createdate datetime default (datetime('now''localtime'))
);

(2)随机生成100个整数,用于查询my_table的id字段:

[randint(0, 10000) for i in range(0, 100)]

(3)统计查询100条数据耗时:90s

当然这个性能数据不高,主要依赖网络IO和Python的性能,但是对于向量数据库的查询应该是够了,下一篇文章就介绍将sqlite扩展为向量数据库。

技术总结|基于S3存储实现Serverless数据库


5、代码实现

from contextlib import contextmanager
from ctypes import CFUNCTYPE, POINTER, Structure, create_string_buffer, pointer, cast, memmove, memset
from ctypes import sizeof, addreSSOf, cdll, byref, string_at, c_char_p, c_int, c_double, c_int64, c_void_p, c_char
from ctypes.util import find_library
from functools import partial
from hashlib import sha256
from random import *
import hmac
from datetime import datetime
from re import sub
from time import time
from urllib.parse import urlencode, urlsplit, quote
from uuid import uuid4
from contextlib import contextmanager
from ctypes import cdll
from ctypes.util import find_library
import httpx
import hashlib
import urllib
import time

def _debug_log(fmt, args=None):
    pass


def _get_http_client():
    return httpx.Client(transport=httpx.HTTPTransport(retries=3))


class SqliteVFS:
    SQLITE_OK = 0
    SQLITE_IOERR = 10
    SQLITE_NOTFOUND = 12
    SQLITE_ROW = 100
    SQLITE_DONE = 101
    SQLITE_TRANSIENT = -1
    SQLITE_OPEN_READONLY = 0x00000001
    SQLITE_OPEN_NOMUTEX = 0x00008000
    SQLITE_IOCAP_IMMUTABLE = 0x00002000

    def __init__(self, libsqlite3, size, vfs_name, make_auth_request):
        self.size = size
        self.vfs_name = vfs_name
        self.make_auth_request = make_auth_request
        self.io_methods = SqliteVFS.make_struct((
            ('i_version', c_int, 1),
            ('x_close', SqliteVFS.x_close_type,
             SqliteVFS.x_close_type(SqliteVFS.x_close)),
            ('x_read', SqliteVFS.x_read_type,
             SqliteVFS.x_read_type(self._x_read())),
            ('x_write', c_void_p, None),
            ('x_truncate', c_void_p, None),
            ('x_sync', c_void_p, None),
            ('x_file_size', SqliteVFS.x_file_size_type,
             SqliteVFS.x_file_size_type(self._x_file_size())),
            ('x_lock', SqliteVFS.x_lock_type,
             SqliteVFS.x_lock_type(SqliteVFS.x_lock)),
            ('x_unlock', SqliteVFS.x_unlock_type,
             SqliteVFS.x_unlock_type(SqliteVFS.x_unlock)),
            ('x_check_reserved_lock', c_void_p, None),
            ('x_file_control', SqliteVFS.x_file_control_type,
             SqliteVFS.x_file_control_type(SqliteVFS.x_file_control)),
            ('x_sector_size', c_void_p, None),
            ('x_device_characteristics', SqliteVFS.x_device_characteristics_type,
             SqliteVFS.x_device_characteristics_type(SqliteVFS.x_device_characteristics)),
        ))
        self.file = SqliteVFS.make_struct(
            (('p_methods', POINTER(type(self.io_methods)), pointer(self.io_methods)),))
        self.vfs = SqliteVFS.make_struct((
            ('i_version', c_int, 1),
            ('sz_os_file', c_int, sizeof(self.file)),
            ('mx_pathname', c_int, 1024),
            ('p_next', c_void_p, None),
            ('z_name', c_char_p, vfs_name),
            ('p_app_data', c_char_p, None),
            ('x_open', SqliteVFS.x_open_type,
             SqliteVFS.x_open_type(self._x_open())),
            ('x_delete', c_void_p, None),
            ('x_access', SqliteVFS.x_access_type,
             SqliteVFS.x_access_type(SqliteVFS.x_access)),
            ('x_full_pathname', SqliteVFS.x_full_pathname_type,
             SqliteVFS.x_full_pathname_type(SqliteVFS.x_full_pathname)),
            ('x_dl_open', c_void_p, None),
            ('x_dl_error', c_void_p, None),
            ('x_dl_sym', c_void_p, None),
            ('x_dl_close', c_void_p, None),
            ('x_randomness', c_void_p, None),
            ('x_sleep', c_void_p, None),
            ('x_current_time', SqliteVFS.x_current_time_type,
             SqliteVFS.x_current_time_type(SqliteVFS.x_current_time)),
            ('x_get_last_error', c_void_p, None),
        ))

    @staticmethod
    def get_libsqlite3():
        libsqlite3 = cdll.LoadLibrary(find_library('sqlite3'))
        libsqlite3.sqlite3_errstr.restype = c_char_p
        libsqlite3.sqlite3_errmsg.restype = c_char_p
        libsqlite3.sqlite3_column_name.restype = c_char_p
        libsqlite3.sqlite3_column_double.restype = c_double
        libsqlite3.sqlite3_column_int64.restype = c_int64
        libsqlite3.sqlite3_column_blob.restype = c_void_p
        libsqlite3.sqlite3_column_bytes.restype = c_int64
        return libsqlite3

    @staticmethod
    def make_struct(fields):
        class Struct(Structure):
            _fields_ = [(field_name, field_type)
                        for (field_name, field_type, _) in fields]
        return Struct(*tuple(value for (_, _, value) in fields))

    x_open_type = CFUNCTYPE(c_int, c_void_p, c_char_p,
                            c_void_p, c_int, POINTER(c_int))

    def _x_open(self):
        ''' 打开文件 '''
        def x_open(p_vfs, z_name, p_file, flags, p_out_flags):
            memmove(p_file, addressof(self.file), sizeof(self.file))
            p_out_flags[0] = flags
            return SqliteVFS.SQLITE_OK
        return x_open

    x_close_type = CFUNCTYPE(c_int, c_void_p)

    @staticmethod
    def x_close(p_file):
        return SqliteVFS.SQLITE_OK

    x_read_type = CFUNCTYPE(c_int, c_void_p, c_void_p, c_int, c_int64)

    def _x_read(self):
        ''' 读取数据 '''
        def x_read(p_file, p_out, i_amt, i_ofst):
            offset = 0
            try:
                with self.make_auth_request(i_ofst, i_amt) as response:
                    for chunk in response.iter_bytes():
                        memmove(p_out + offset, chunk,
                                min(i_amt - offset, len(chunk)))
                        offset += len(chunk)
                        if offset > i_amt:
                            break
            except Exception as e:
                _debug_log("e: " + str(e))
                return SqliteVFS.SQLITE_IOERR
            if offset != i_amt:
                return SqliteVFS.SQLITE_IOERR
            return SqliteVFS.SQLITE_OK
        return x_read

    x_file_size_type = CFUNCTYPE(c_int, c_void_p, POINTER(c_int64))

    def _x_file_size(self):
        ''' 计算文件大小 '''
        def x_file_size(p_file, p_size):
            p_size[0] = self.size
            return SqliteVFS.SQLITE_OK
        return x_file_size

    x_lock_type = CFUNCTYPE(c_int, c_void_p, c_int)

    @staticmethod
    def x_lock(p_file, e_lock):
        return SqliteVFS.SQLITE_OK

    x_unlock_type = CFUNCTYPE(c_int, c_void_p, c_int)

    @staticmethod
    def x_unlock(p_file, e_lock):
        return SqliteVFS.SQLITE_OK

    x_file_control_type = CFUNCTYPE(c_int, c_void_p, c_int, c_void_p)

    @staticmethod
    def x_file_control(p_file, op, p_arg):
        return SqliteVFS.SQLITE_NOTFOUND

    x_device_characteristics_type = CFUNCTYPE(c_int, c_void_p)

    @staticmethod
    def x_device_characteristics(p_file):
        return SqliteVFS.SQLITE_IOCAP_IMMUTABLE

    x_access_type = CFUNCTYPE(c_int, c_void_p, c_char_p, c_int, POINTER(c_int))

    @staticmethod
    def x_access(p_vfs, z_name, flags, z_out):
        z_out[0] = 0
        return SqliteVFS.SQLITE_OK

    x_full_pathname_type = CFUNCTYPE(
        c_int, c_void_p, c_char_p, c_int, POINTER(c_char))

    @staticmethod
    def x_full_pathname(p_vfs, z_name, n_out, z_out):
        memmove(z_out, z_name, len(z_name) + 1)
        return SqliteVFS.SQLITE_OK

    x_current_time_type = CFUNCTYPE(c_int, c_void_p, POINTER(c_double))

    @staticmethod
    def x_current_time(p_vfs, c_double_p):
        c_double_p[0] = time()/86400.0 + 2440587.5
        return SqliteVFS.SQLITE_OK


class SqliteS3:
    def __init__(self, s3_region, s3_access_key_id, s3_secret_access_key, s3_session_token):
        self.get_credentials = lambda now: (
            s3_region, s3_access_key_id, s3_secret_access_key, s3_session_token)
        self.vfs_name = b's3-' + str(uuid4()).encode()
        self.file_name = b's3-' + str(uuid4()).encode()
        self.body_hash = sha256(b'').hexdigest()
        self.libsqlite3 = SqliteVFS.get_libsqlite3()
        self.bind = {
            type(0): self.libsqlite3.sqlite3_bind_int64,
            type(0.0): self.libsqlite3.sqlite3_bind_double,
            type(''): lambda pp_stmt, i, value: self.libsqlite3.sqlite3_bind_text(pp_stmt, i, value.encode('utf-8'), len(value.encode('utf-8')), SqliteVFS.SQLITE_TRANSIENT),
            type(b''): lambda pp_stmt, i, value: self.libsqlite3.sqlite3_bind_blob(pp_stmt, i, value, len(value), SqliteVFS.SQLITE_TRANSIENT),
            type(None): lambda pp_stmt, i, _: self.libsqlite3.sqlite3_bind_null(pp_stmt, i),
        }
        self.extract = {
            1: self.libsqlite3.sqlite3_column_int64,
            2: self.libsqlite3.sqlite3_column_double,
            3: lambda pp_stmt, i: string_at(
                self.libsqlite3.sqlite3_column_blob(pp_stmt, i),
                self.libsqlite3.sqlite3_column_bytes(pp_stmt, i),
            ).decode(),
            4: lambda pp_stmt, i: string_at(
                self.libsqlite3.sqlite3_column_blob(pp_stmt, i),
                self.libsqlite3.sqlite3_column_bytes(pp_stmt, i),
            ),
            5: lambda pp_stmt, i: None,
        }

    @staticmethod
    def _s3_sigv4_headers(
        now, access_key_id, secret_access_key, region, method, headers_to_sign, params, netloc, body_hash, path
    ):
        def sign(key, msg):
            return hmac.new(key, msg.encode('ascii'), sha256).digest()
        algorithm = 'AWS4-HMAC-SHA256'
        amzdate = now.strftime('%Y%m%dT%H%M%SZ')
        datestamp = amzdate[:8]
        credential_scope = f'{datestamp}/{region}/s3/aws4_request'
        headers = tuple(sorted(headers_to_sign + (
            ('host', netloc),
            ('x-amz-content-sha256', body_hash),
            ('x-amz-date', amzdate),
        )))
        signed_headers = ';'.join(key for key, _ in headers)
        canonical_uri = quote(path, safe='/~')
        if params:
            quoted_params = sorted(
                (quote(key, safe='~'), quote(value, safe='~'))
                for key, value in params
            )
            canonical_querystring = '&'.join(
                f'{key}={value}' for key, value in quoted_params)
        else:
            canonical_querystring = ''
        canonical_headers = ''.join(
            f'{key}:{value}n' for key, value in headers)
        canonical_request = f'{method}n{canonical_uri}n{canonical_querystring}n' + 
                            f'{canonical_headers}n{signed_headers}n{body_hash}'
        string_to_sign = f'{algorithm}n{amzdate}n{credential_scope}n' + 
            sha256(canonical_request.encode('ascii')).hexdigest()
        date_key = sign(
            ('AWS4' + secret_access_key).encode('ascii'), datestamp)
        region_key = sign(date_key, region)
        service_key = sign(region_key, 's3')
        request_key = sign(service_key, 'aws4_request')
        signature = sign(request_key, string_to_sign).hex()
        return (
            ('authorization', (
                f'{algorithm} Credential={access_key_id}/{credential_scope}, '
                f'SignedHeaders={signed_headers}, Signature={signature}')
             ),
        ) + headers

    @contextmanager
    @staticmethod
    def make_auth_request(url, method, params, headers, body_hash, region, access_key_id, secret_access_key, session_token):
        http_client = _get_http_client()
        scheme, netloc, path, _, _ = urlsplit(url)
        now = datetime.utcnow()
        to_auth_headers = headers + (
            (('x-amz-security-token', session_token),
             ) if session_token is not None else ()
        )
        request_headers = SqliteS3._s3_sigv4_headers(
            now, access_key_id, secret_access_key, region, method, to_auth_headers, params, netloc, body_hash, path
        )
        _url = f'{scheme}://{netloc}{path}'
        _debug_log("request_headers: " + str(request_headers))
        _debug_log("_url: " + str(_url))
        with http_client.stream(method, _url, params=params, headers=request_headers) as response:
            response.raise_for_status()
            yield response

    @contextmanager
    def query_multi(self, url):
        with self.get_vfs(url) as vfs:
            yield partial(self.query, vfs)

    @contextmanager
    def get_vfs(self, url):
        s3_region, access_key_id, secret_access_key, session_token = self.get_credentials(
            None)
        with SqliteS3.make_auth_request(url, 'HEAD', (), (), self.body_hash, s3_region, access_key_id, secret_access_key, session_token) as response:
            head_headers = response.headers
            next(response.iter_bytes(), b'')
        try:
            version_id = head_headers['x-amz-version-id']
        except KeyError:
            raise Exception('The bucket must have versioning enabled')
        size = int(head_headers['content-length'])

        @contextmanager
        def __make_auth_request(i_ofst, i_amt):
            with SqliteS3.make_auth_request(url, 'GET', (('versionId', version_id),), ((
                    'range', f'bytes={i_ofst}-{i_ofst + i_amt - 1}'),),
                    self.body_hash, s3_region, access_key_id, secret_access_key, session_token) as response:
                yield response

        sqlite_vfs = SqliteVFS(
            self.libsqlite3, size, self.vfs_name, __make_auth_request)
        res = self.libsqlite3.sqlite3_vfs_register(byref(sqlite_vfs.vfs), 0)
        if res != 0:
            raise Exception(self.libsqlite3.sqlite3_errstr(res).decode())
        try:
            yield sqlite_vfs.vfs
        finally:
            res = self.libsqlite3.sqlite3_vfs_unregister(byref(sqlite_vfs.vfs))
            if res != 0:
                raise Exception(self.libsqlite3.sqlite3_errstr(res).decode())

    @contextmanager
    def get_db(self):
        _debug_log("self.file_name: " + str(self.file_name) +
                   ", self.vfs_name: " + str(self.vfs_name))
        self.db = c_void_p()
        res = self.libsqlite3.sqlite3_open_v2(self.file_name, byref(
            self.db), SqliteVFS.SQLITE_OPEN_READONLY | SqliteVFS.SQLITE_OPEN_NOMUTEX, self.vfs_name)
        if res != 0:
            raise Exception(self.libsqlite3.sqlite3_errstr(res).decode())
        try:
            yield self.db
        finally:
            if self.libsqlite3.sqlite3_close(self.db) != 0:
                raise Exception(
                    self.libsqlite3.sqlite3_errmsg(self.db).decode())

    def _get_pp_stmt(self, statement):
        try:
            return self.statements[statement]
        except KeyError:
            raise Exception('Attempting to use finalized statement') from None

    def _get_pp_stmts(self, sql):
        p_encoded = POINTER(c_char)(create_string_buffer(sql.encode()))
        while True:
            pp_stmt = c_void_p()
            if self.libsqlite3.sqlite3_prepare_v2(self.db, p_encoded, -1, byref(pp_stmt), byref(p_encoded)) != 0:
                raise Exception(
                    self.libsqlite3.sqlite3_errmsg(self.db).decode())
            if not pp_stmt:
                break
            statement = object()
            self.statements[statement] = pp_stmt
            yield partial(self._get_pp_stmt, statement), partial(self._finalize, statement)

    def _finalize(self, statement):
        try:
            pp_stmt = self.statements.pop(statement)
            if self.libsqlite3.sqlite3_finalize(pp_stmt) != 0:
                raise Exception(
                    self.libsqlite3.sqlite3_errmsg(self.db).decode())
        except KeyError:
            return

    @contextmanager
    def get_pp_stmt_getter(self, db):
        self.statements = {}
        try:
            yield self._get_pp_stmts
        finally:
            for statement in self.statements.copy().keys():
                self._finalize(statement)

    def rows(self, get_pp_stmt, columns):
        while True:
            pp_stmt = get_pp_stmt()
            res = self.libsqlite3.sqlite3_step(pp_stmt)
            if res == SqliteVFS.SQLITE_DONE:
                break
            if res != SqliteVFS.SQLITE_ROW:
                raise Exception(
                    self.libsqlite3.sqlite3_errstr(res).decode())
            yield tuple(
                self.extract[self.libsqlite3.sqlite3_column_type(
                    pp_stmt, i)](pp_stmt, i)
                for i in range(0, len(columns))
            )

    @staticmethod
    def zip_first(first_iterable, *iterables, default=()):
        iters = tuple(iter(iterable) for iterable in iterables)
        for value in first_iterable:
            yield (value,) + tuple(next(it, default) for it in iters)

    @contextmanager
    def query(self, vfs, sql, params=(), named_params=()):
        with self.get_db() as db, self.get_pp_stmt_getter(db) as get_pp_stmts:
            _debug_log("vfs: " + str(vfs))
            _debug_log("db: " + str(db))
            _debug_log("get_pp_stmts: " + str(get_pp_stmts))
            for (get_pp_stmt, finalize_stmt), statment_params, statement_named_params in SqliteS3.zip_first(get_pp_stmts(sql), params, named_params):
                try:
                    pp_stmt = get_pp_stmt()
                    for i, param in enumerate(statment_params):
                        if self.bind[type(param)](pp_stmt, i + 1, param) != 0:
                            raise Exception(
                                self.libsqlite3.sqlite3_errmsg(self.db).decode())
                    for param_name, param_value in statement_named_params:
                        index = self.libsqlite3.sqlite3_bind_parameter_index(
                            pp_stmt, param_name.encode('utf-8'))
                        if self.bind[type(param_value)](pp_stmt, index, param_value) != 0:
                            raise Exception(
                                self.libsqlite3.sqlite3_errmsg(self.db).decode())
                    columns = tuple(
                        self.libsqlite3.sqlite3_column_name(
                            pp_stmt, i).decode()
                        for i in range(0, self.libsqlite3.sqlite3_column_count(pp_stmt))
                    )
                    yield columns, self.rows(get_pp_stmt, columns)
                finally:
                    finalize_stmt()

    def create_bucket(self, uri, bucket, region, access_key_id, secret_access_key, session_token=None):
        url = uri + f'/{bucket}/'
        content = b''
        body_hash = hashlib.sha256(content).hexdigest()
        parsed_url = urllib.parse.urlsplit(url)
        now = datetime.utcnow()
        to_auth_headers = (
            (('x-amz-security-token', session_token),
             ) if session_token is not None else ()
        )
        headers = SqliteS3._s3_sigv4_headers(
            now, access_key_id, secret_access_key, region, 'PUT', to_auth_headers, None, parsed_url.netloc, body_hash, parsed_url.path
        )
        response = httpx.put(url, content=content, headers=headers)
        response.raise_for_status()

    def put_object(self, uri, bucket, key, content, region, access_key_id, secret_access_key, session_token=None):
        try:
            self.create_bucket(uri, bucket, region, access_key_id,
                               secret_access_key, session_token)
        except:
            pass
        url = uri + f'/{bucket}/{key}'
        sha = hashlib.sha256()
        length = 0
        for chunk in content():
            length += len(chunk)
            sha.update(chunk)
        body_hash = sha.hexdigest()
        parsed_url = urllib.parse.urlsplit(url)
        now = datetime.utcnow()
        to_auth_headers = (
            (('x-amz-security-token', session_token),
             ) if session_token is not None else ()
        )
        headers = SqliteS3._s3_sigv4_headers(
            now, access_key_id, secret_access_key, region, 'PUT', to_auth_headers, None, parsed_url.netloc, body_hash, parsed_url.path
        ) + ((b'content-length', str(length).encode()),)
        response = httpx.put(url, content=content(), headers=headers)
        response.raise_for_status()

    @contextmanager
    @staticmethod
    def exec_db(self, sql_list, chunk_size=65536):
        import sqlite3
        import tempfile
        with tempfile.NamedTemporaryFile() as fp:
            with sqlite3.connect(fp.name, isolation_level=None) as con:
                cur = con.cursor()
                cur.execute('BEGIN')
                for sql, params in sql_list:
                    cur.execute(sql, params)
                cur.execute('COMMIT')

            def db():
                with open(fp.name, 'rb') as f:
                    while True:
                        chunk = f.read(chunk_size)
                        if not chunk:
                            break
                        yield chunk
            yield db


if __name__ == "__main__":
    access_key_id = "XXXoRpxpRnyta1C8oDMF0UigOq2xd4FApFP"
    secret_access_key = "XXXB6lBKOL3WlmX3JJ2WjgMo0xe5aqrc"
    uri = "https://cos.ap-guangzhou.myqcloud.com"
    uri_s3name = 'my-1251014631'
    db_name = 'my.db'
    s3 = SqliteS3('us-east-1', access_key_id, secret_access_key, None)
    with s3.exec_db([
        ("CREATE TABLE my_table (id integer PRIMARY KEY autoincrement, my_col_a varchar (50), my_col_b varchar (50), createdate datetime default (datetime('now', 'localtime')));", ()),
    ] + [
        ("INSERT INTO my_table(my_col_a, my_col_b) VALUES " +
            ','.join(["('some-text-a', 'some-text-b')"] * 10000), ())
    ]) as db:
        s3.put_object(uri, uri_s3name, db_name, db,
                      'us-east-1', access_key_id, secret_access_key)

    start_time = time.time()
    random_list = [randint(0, 10000) for i in range(0, 100)]
    with s3.query_multi("/".join([uri, uri_s3name, db_name])) as query:
        for idx in random_list:
            with query("SELECT id, createdate FROM my_table where id = " + str(idx)) as (columns, rows):
                _debug_log("columns: " + str(columns), ", rows: " + str(rows))
                for row in rows:
                    _debug_log("c: ", row[0])
                    print("row: ", row[0], ", ", row[1])
    end_time = time.time()
    print("costtime: {:.2f}".format(end_time - start_time))

原文始发于微信公众号(周末程序猿):技术总结|基于S3存储实现Serverless数据库

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/169427.html

(0)
小半的头像小半

相关推荐

发表回复

登录后才能评论
极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!