-- Copyright (c) 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in
-- the LICENSE file in the root directory of this source tree. An
-- additional grant of patent rights can be found in the PATENTS file
-- in the same directory.

{-# LANGUAGE MultiWayIf #-}

-- |
-- Module      : Codec.Compression.Zstd.Streaming
-- Copyright   : (c) 2016-present, Facebook, Inc. All rights reserved.
--
-- License     : BSD3
-- Maintainer  : bryano@fb.com
-- Stability   : experimental
-- Portability : GHC
--
-- Streaming compression and decompression support for zstd.

module Codec.Compression.Zstd.Streaming
    (
      Result(..)
    , compress
    , decompress
    , maxCLevel
    ) where

import Codec.Compression.Zstd.FFI hiding (compress, decompress)
import Codec.Compression.Zstd.FFI.Types (peekPos)
import qualified Data.ByteString as B
import Data.ByteString.Internal (ByteString(..), mallocByteString)
import Foreign.Marshal.Alloc (finalizerFree, malloc)
import Foreign.C.Types (CSize)
import Foreign.ForeignPtr
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr)
import Foreign.Storable (poke)
import Foreign.Ptr (Ptr, plusPtr)
import Data.Word (Word8)

-- | The result of a streaming compression or decompression step.
data Result
  = Produce ByteString (IO Result)
    -- ^ A single frame of transformed data, and an action that when
    -- executed will yield the next step in the streaming operation.
    -- The action is ephemeral; you should discard it as soon as you
    -- use it.
  | Consume (ByteString -> IO Result)
    -- ^ Provide the function with more input for the streaming
    -- operation to continue.  This function is ephemeral. You should
    -- call it exactly once, and discard it immediately after you call
    -- it.
    --
    -- To signal the end of a stream of data, supply an 'B.empty'
    -- input.
  | Error String String
    -- ^ An error has occurred. If an error occurs, the streaming
    -- operation cannot continue.
  | Done ByteString
    -- ^ The streaming operation has ended.  This payload may be
    -- empty. If it is not, it must be written out.
    --
    -- A non-empty payload consists of a frame epilogue, possibly
    -- preceded by any data left over from the final streaming step.

instance Show Result where
    show :: Result -> String
show (Produce ByteString
bs IO Result
_) = String
"Produce " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
bs String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" _"
    show (Consume ByteString -> IO Result
_)    = String
"Consume _"
    show (Error String
n String
d)    = String
"Error " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. Show a => a -> String
show String
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. Show a => a -> String
show String
d
    show (Done ByteString
bs)      = String
"Done " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
bs

-- | Begin a streaming compression operation.
--
-- The initial result will be either an 'Error' or a 'Consume'.
compress :: Int
         -- ^ Compression level. Must be >= 1 and <= 'maxCLevel'.
         -> IO Result
compress :: Int -> IO Result
compress Int
level
  | Int
level Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
level Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxCLevel =
    Result -> IO Result
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> String -> Result
Error String
"compress" String
"unsupported compression level")
  | Bool
otherwise =
  IO (Ptr CStream)
-> FinalizerPtr CStream
-> Int
-> (Ptr CStream -> IO CSize)
-> ConsumeBlock CStream (ZonkAny 2)
-> Finish CStream (ZonkAny 2)
-> IO Result
forall ctx io.
IO (Ptr ctx)
-> FinalizerPtr ctx
-> Int
-> (Ptr ctx -> IO CSize)
-> ConsumeBlock ctx io
-> Finish ctx io
-> IO Result
streaming
  IO (Ptr CStream)
createCStream
  FinalizerPtr CStream
p_freeCStream
  Int
outSize
  (\Ptr CStream
cs -> Ptr CStream -> CInt -> IO CSize
initCStream Ptr CStream
cs (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
level))
  ConsumeBlock CStream (ZonkAny 2)
compressStream
  Finish CStream (ZonkAny 2)
forall {t}.
Num t =>
ForeignPtr CStream
-> ForeignPtr (Buffer Out) -> t -> ForeignPtr Word8 -> IO Result
finish
 where
  outSize :: Int
outSize = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
cstreamOutSize
  finish :: ForeignPtr CStream
-> ForeignPtr (Buffer Out) -> t -> ForeignPtr Word8 -> IO Result
finish ForeignPtr CStream
cfp ForeignPtr (Buffer Out)
obfp t
opos ForeignPtr Word8
dfp = do
    let cptr :: Ptr CStream
cptr = ForeignPtr CStream -> Ptr CStream
forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CStream
cfp
        obuf :: Ptr (Buffer Out)
obuf = ForeignPtr (Buffer Out) -> Ptr (Buffer Out)
forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr (Buffer Out)
obfp
    String -> IO CSize -> (CSize -> IO Result) -> IO Result
check String
"endStream" (Ptr CStream -> Ptr (Buffer Out) -> IO CSize
endStream Ptr CStream
cptr Ptr (Buffer Out)
obuf) ((CSize -> IO Result) -> IO Result)
-> (CSize -> IO Result) -> IO Result
forall a b. (a -> b) -> a -> b
$ \CSize
leftover -> do
      ForeignPtr CStream -> IO ()
forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr CStream
cfp
      ForeignPtr (Buffer Out) -> IO ()
forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr (Buffer Out)
obfp
      if | CSize
leftover CSize -> CSize -> Bool
forall a. Ord a => a -> a -> Bool
<= CSize
0 -> do -- leftover will never be <0, but compiler does not know that
             opos1 <- CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Int) -> IO CSize -> IO Int
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr (Buffer Out) -> IO CSize
forall io. Ptr (Buffer io) -> IO CSize
peekPos Ptr (Buffer Out)
obuf
             Done `fmap` shrink outSize dfp opos1
         | CSize
leftover CSize -> CSize -> Bool
forall a. Ord a => a -> a -> Bool
> CSize
0 -> do
             opos1 <- CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Int) -> IO CSize -> IO Int
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr (Buffer Out) -> IO CSize
forall io. Ptr (Buffer io) -> IO CSize
peekPos Ptr (Buffer Out)
obuf
             dfp1 <- mallocByteString (fromIntegral leftover)
             poke obuf (buffer (unsafeForeignPtrToPtr dfp1) leftover)
             touchForeignPtr obfp
             bs <- shrink outSize dfp opos1
             return (Produce bs (finish cfp obfp 0 dfp1))

type ConsumeBlock ctx io = Ptr ctx -> Ptr (Buffer Out)
                         -> Ptr (Buffer In) -> IO CSize

type Finish ctx io = ForeignPtr ctx -> ForeignPtr (Buffer Out)
                   -> Int -> ForeignPtr Word8 -> IO Result

streaming :: IO (Ptr ctx)
          -> FinalizerPtr ctx
          -> Int
          -> (Ptr ctx -> IO CSize)
          -> ConsumeBlock ctx io
          -> Finish ctx io
          -> IO Result
streaming :: forall ctx io.
IO (Ptr ctx)
-> FinalizerPtr ctx
-> Int
-> (Ptr ctx -> IO CSize)
-> ConsumeBlock ctx io
-> Finish ctx io
-> IO Result
streaming IO (Ptr ctx)
createStream FinalizerPtr ctx
freeStream Int
outSize Ptr ctx -> IO CSize
initStream ConsumeBlock ctx io
consumeBlock Finish ctx io
finish = do
  cx <- String -> IO (Ptr ctx) -> IO (Ptr ctx)
forall a. String -> IO (Ptr a) -> IO (Ptr a)
checkAlloc String
"createStream" IO (Ptr ctx)
createStream
  cxfp <- newForeignPtr freeStream cx
  check "initStream" (initStream cx) $ \CSize
_ -> do
    ibfp <- FinalizerPtr (Buffer In)
-> Ptr (Buffer In) -> IO (ForeignPtr (Buffer In))
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr (Buffer In)
forall a. FinalizerPtr a
finalizerFree (Ptr (Buffer In) -> IO (ForeignPtr (Buffer In)))
-> IO (Ptr (Buffer In)) -> IO (ForeignPtr (Buffer In))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Ptr (Buffer In))
forall a. Storable a => IO (Ptr a)
malloc
    obfp <- newForeignPtr finalizerFree =<< malloc
    dfp <- newOutput obfp
    advanceInput cxfp ibfp obfp 0 dfp
 where
  advanceInput :: ForeignPtr ctx
-> ForeignPtr (Buffer In)
-> ForeignPtr (Buffer Out)
-> Int
-> ForeignPtr Word8
-> IO Result
advanceInput ForeignPtr ctx
cxfp ForeignPtr (Buffer In)
ibfp ForeignPtr (Buffer Out)
obfp Int
opos ForeignPtr Word8
dfp = do
    let prompt :: ByteString -> IO Result
prompt (PS ForeignPtr Word8
fp Int
off Int
len)
          | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Finish ctx io
finish ForeignPtr ctx
cxfp ForeignPtr (Buffer Out)
obfp Int
opos ForeignPtr Word8
dfp
          | Bool
otherwise = do
              ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
sp0 ->
                ForeignPtr (Buffer In) -> (Ptr (Buffer In) -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (Buffer In)
ibfp ((Ptr (Buffer In) -> IO ()) -> IO ())
-> (Ptr (Buffer In) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Buffer In)
ibuf ->
                  Ptr (Buffer In) -> Buffer In -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr (Buffer In)
ibuf (Ptr (ZonkAny 0) -> CSize -> Buffer In
forall a io. Ptr a -> CSize -> Buffer io
buffer (Ptr Word8
sp0 Ptr Word8 -> Int -> Ptr (ZonkAny 0)
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len))
              ForeignPtr ctx
-> ForeignPtr (Buffer In)
-> CSize
-> Int
-> ForeignPtr (Buffer Out)
-> Int
-> ForeignPtr Word8
-> ForeignPtr Word8
-> IO Result
consume ForeignPtr ctx
cxfp ForeignPtr (Buffer In)
ibfp CSize
0 Int
len ForeignPtr (Buffer Out)
obfp Int
0 ForeignPtr Word8
dfp ForeignPtr Word8
fp
    Result -> IO Result
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString -> IO Result) -> Result
Consume ByteString -> IO Result
prompt)
  newOutput :: ForeignPtr (Buffer io) -> IO (ForeignPtr a)
newOutput ForeignPtr (Buffer io)
obfp = do
    dfp <- Int -> IO (ForeignPtr a)
forall a. Int -> IO (ForeignPtr a)
mallocByteString Int
outSize
    withForeignPtr dfp $ \Ptr a
dp ->
      ForeignPtr (Buffer io) -> (Ptr (Buffer io) -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (Buffer io)
obfp ((Ptr (Buffer io) -> IO ()) -> IO ())
-> (Ptr (Buffer io) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Buffer io)
obuf ->
        Ptr (Buffer io) -> Buffer io -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr (Buffer io)
obuf (Ptr a -> CSize -> Buffer io
forall a io. Ptr a -> CSize -> Buffer io
buffer Ptr a
dp (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
outSize))
    return dfp
  consume :: ForeignPtr ctx
-> ForeignPtr (Buffer In)
-> CSize
-> Int
-> ForeignPtr (Buffer Out)
-> Int
-> ForeignPtr Word8
-> ForeignPtr Word8
-> IO Result
consume ForeignPtr ctx
cxfp ForeignPtr (Buffer In)
ibfp CSize
ipos Int
ilen ForeignPtr (Buffer Out)
obfp Int
opos ForeignPtr Word8
dfp ForeignPtr Word8
fp = do
    if | CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
ipos Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
ilen -> ForeignPtr ctx
-> ForeignPtr (Buffer In)
-> ForeignPtr (Buffer Out)
-> Int
-> ForeignPtr Word8
-> IO Result
advanceInput ForeignPtr ctx
cxfp ForeignPtr (Buffer In)
ibfp ForeignPtr (Buffer Out)
obfp Int
opos ForeignPtr Word8
dfp
       | Int
opos Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
outSize -> do
           let go :: IO Result
go = do
                 ndfp <- ForeignPtr (Buffer Out) -> IO (ForeignPtr Word8)
forall {io} {a}. ForeignPtr (Buffer io) -> IO (ForeignPtr a)
newOutput ForeignPtr (Buffer Out)
obfp
                 consume cxfp ibfp ipos ilen obfp 0 ndfp fp
           Result -> IO Result
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO Result -> Result
Produce (ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
dfp Int
0 Int
opos) IO Result
go)
       | Bool
otherwise -> do
           let obuf :: Ptr (Buffer Out)
obuf = ForeignPtr (Buffer Out) -> Ptr (Buffer Out)
forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr (Buffer Out)
obfp
               ibuf :: Ptr (Buffer In)
ibuf = ForeignPtr (Buffer In) -> Ptr (Buffer In)
forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr (Buffer In)
ibfp
           String -> IO CSize -> (CSize -> IO Result) -> IO Result
check String
"consumeBlock"
             (ForeignPtr ctx -> (Ptr ctx -> IO CSize) -> IO CSize
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr ctx
cxfp ((Ptr ctx -> IO CSize) -> IO CSize)
-> (Ptr ctx -> IO CSize) -> IO CSize
forall a b. (a -> b) -> a -> b
$ \Ptr ctx
cptr ->
               ConsumeBlock ctx io
consumeBlock Ptr ctx
cptr Ptr (Buffer Out)
obuf Ptr (Buffer In)
ibuf IO CSize -> IO () -> IO CSize
forall a b. IO a -> IO b -> IO a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ForeignPtr Word8 -> IO ()
forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr Word8
fp) ((CSize -> IO Result) -> IO Result)
-> (CSize -> IO Result) -> IO Result
forall a b. (a -> b) -> a -> b
$ \CSize
_ -> do
             opos1 <- CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Int) -> IO CSize -> IO Int
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr (Buffer Out) -> IO CSize
forall io. Ptr (Buffer io) -> IO CSize
peekPos Ptr (Buffer Out)
obuf
             ipos1 <- peekPos ibuf
             touchForeignPtr obfp
             touchForeignPtr ibfp
             consume cxfp ibfp ipos1 ilen obfp opos1 dfp fp

-- | Begin a streaming decompression operation.
--
-- The initial result will be either an 'Error' or a 'Consume'.
decompress :: IO Result
decompress :: IO Result
decompress =
  IO (Ptr DStream)
-> FinalizerPtr DStream
-> Int
-> (Ptr DStream -> IO CSize)
-> ConsumeBlock DStream (ZonkAny 1)
-> Finish DStream (ZonkAny 1)
-> IO Result
forall ctx io.
IO (Ptr ctx)
-> FinalizerPtr ctx
-> Int
-> (Ptr ctx -> IO CSize)
-> ConsumeBlock ctx io
-> Finish ctx io
-> IO Result
streaming
  IO (Ptr DStream)
createDStream
  FinalizerPtr DStream
p_freeDStream
  Int
outSize
  Ptr DStream -> IO CSize
initDStream
  ConsumeBlock DStream (ZonkAny 1)
decompressStream
  Finish DStream (ZonkAny 1)
forall {p} {p}. p -> p -> Int -> ForeignPtr Word8 -> IO Result
finish
 where
  outSize :: Int
outSize = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
dstreamOutSize
  finish :: p -> p -> Int -> ForeignPtr Word8 -> IO Result
finish p
_cxfp p
_obfp Int
opos ForeignPtr Word8
dfp = ByteString -> Result
Done (ByteString -> Result) -> IO ByteString -> IO Result
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Int -> ForeignPtr Word8 -> Int -> IO ByteString
shrink Int
outSize ForeignPtr Word8
dfp Int
opos

shrink :: Int -> ForeignPtr Word8 -> Int -> IO B.ByteString
shrink :: Int -> ForeignPtr Word8 -> Int -> IO ByteString
shrink Int
capacity ForeignPtr Word8
dfp Int
opos
  | Int
opos Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0  = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
B.empty
  | let unused :: Int
unused = Int
capacity Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
opos
    in Int
unused Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
1024 Bool -> Bool -> Bool
|| Int
unused Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
capacity Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
8
               = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString
B.copy (ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
dfp Int
0 Int
opos))
  | Bool
otherwise  = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
dfp Int
0 Int
opos)

buffer :: Ptr a -> CSize -> Buffer io
buffer :: forall a io. Ptr a -> CSize -> Buffer io
buffer Ptr a
ptr CSize
size = Ptr a -> CSize -> CSize -> Buffer io
forall io a. Ptr a -> CSize -> CSize -> Buffer io
Buffer Ptr a
ptr CSize
size CSize
0

check :: String -> IO CSize -> (CSize -> IO Result) -> IO Result
check :: String -> IO CSize -> (CSize -> IO Result) -> IO Result
check String
name IO CSize
act CSize -> IO Result
onSuccess = do
  ret <- IO CSize
act
  if isError ret
    then return (Error name (getErrorName ret))
    else onSuccess ret